1use anyhow::Result;
2use arrayvec::ArrayVec;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_api_types::SubmitEditPredictionFeedbackBody;
5use cloud_llm_client::predict_edits_v3::{
6 PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
7};
8use cloud_llm_client::{
9 EditPredictionRejectReason, EditPredictionRejection,
10 MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
11 PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME,
12};
13use collections::{HashMap, HashSet};
14use copilot::{Copilot, Reinstall, SignIn, SignOut};
15use db::kvp::{Dismissable, KEY_VALUE_STORE};
16use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
17use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
18use futures::{
19 AsyncReadExt as _, FutureExt as _, StreamExt as _,
20 channel::mpsc::{self, UnboundedReceiver},
21 select_biased,
22};
23use gpui::BackgroundExecutor;
24use gpui::http_client::Url;
25use gpui::{
26 App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
27 http_client::{self, AsyncBody, Method},
28 prelude::*,
29};
30use language::language_settings::all_language_settings;
31use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
32use language::{BufferSnapshot, OffsetRangeExt};
33use language_model::{LlmApiToken, NeedsLlmTokenRefresh, RefreshLlmTokenListener};
34use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
35use release_channel::AppVersion;
36use semver::Version;
37use serde::de::DeserializeOwned;
38use settings::{
39 EditPredictionPromptFormat, EditPredictionProvider, Settings as _, update_settings_file,
40};
41use std::collections::{VecDeque, hash_map};
42use std::env;
43use text::Edit;
44use workspace::Workspace;
45use zeta_prompt::{ZetaFormat, ZetaPromptInput};
46
47use std::mem;
48use std::ops::Range;
49use std::path::Path;
50use std::rc::Rc;
51use std::str::FromStr as _;
52use std::sync::Arc;
53use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
54use thiserror::Error;
55use util::{RangeExt as _, ResultExt as _};
56use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
57
58pub mod cursor_excerpt;
59pub mod example_spec;
60pub mod fim;
61mod license_detection;
62pub mod mercury;
63pub mod ollama;
64mod onboarding_modal;
65pub mod open_ai_response;
66mod prediction;
67pub mod sweep_ai;
68
69pub mod udiff;
70
71mod capture_example;
72mod zed_edit_prediction_delegate;
73pub mod zeta;
74
75#[cfg(test)]
76mod edit_prediction_tests;
77
78use crate::license_detection::LicenseDetectionWatcher;
79use crate::mercury::Mercury;
80use crate::onboarding_modal::ZedPredictModal;
81pub use crate::prediction::EditPrediction;
82pub use crate::prediction::EditPredictionId;
83use crate::prediction::EditPredictionResult;
84pub use crate::sweep_ai::SweepAi;
85pub use capture_example::capture_example;
86pub use language_model::ApiKeyState;
87pub use telemetry_events::EditPredictionRating;
88pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
89
90actions!(
91 edit_prediction,
92 [
93 /// Resets the edit prediction onboarding state.
94 ResetOnboarding,
95 /// Clears the edit prediction history.
96 ClearHistory,
97 ]
98);
99
100/// Maximum number of events to track.
101const EVENT_COUNT_MAX: usize = 6;
102const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
103const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
104const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
105const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
106
107pub struct Zeta2FeatureFlag;
108pub struct EditPredictionJumpsFeatureFlag;
109
110impl FeatureFlag for Zeta2FeatureFlag {
111 const NAME: &'static str = "zeta2";
112}
113
114impl FeatureFlag for EditPredictionJumpsFeatureFlag {
115 const NAME: &'static str = "edit_prediction_jumps";
116}
117
118#[derive(Clone)]
119struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
120
121impl Global for EditPredictionStoreGlobal {}
122
123/// Configuration for using the raw Zeta2 endpoint.
124/// When set, the client uses the raw endpoint and constructs the prompt itself.
125/// The version is also used as the Baseten environment name (lowercased).
126#[derive(Clone)]
127pub struct Zeta2RawConfig {
128 pub model_id: Option<String>,
129 pub format: ZetaFormat,
130}
131
132pub struct EditPredictionStore {
133 client: Arc<Client>,
134 user_store: Entity<UserStore>,
135 llm_token: LlmApiToken,
136 _llm_token_subscription: Subscription,
137 projects: HashMap<EntityId, ProjectState>,
138 update_required: bool,
139 edit_prediction_model: EditPredictionModel,
140 zeta2_raw_config: Option<Zeta2RawConfig>,
141 pub sweep_ai: SweepAi,
142 pub mercury: Mercury,
143 data_collection_choice: DataCollectionChoice,
144 reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
145 shown_predictions: VecDeque<EditPrediction>,
146 rated_predictions: HashSet<EditPredictionId>,
147}
148
149#[derive(Copy, Clone, PartialEq, Eq)]
150pub enum EditPredictionModel {
151 Zeta1,
152 Zeta2,
153 Fim { format: EditPredictionPromptFormat },
154 Sweep,
155 Mercury,
156}
157
158#[derive(Clone)]
159pub struct EditPredictionModelInput {
160 project: Entity<Project>,
161 buffer: Entity<Buffer>,
162 snapshot: BufferSnapshot,
163 position: Anchor,
164 events: Vec<Arc<zeta_prompt::Event>>,
165 related_files: Vec<RelatedFile>,
166 recent_paths: VecDeque<ProjectPath>,
167 trigger: PredictEditsRequestTrigger,
168 diagnostic_search_range: Range<Point>,
169 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
170 pub user_actions: Vec<UserActionRecord>,
171}
172
173#[derive(Debug)]
174pub enum DebugEvent {
175 ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
176 ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
177 EditPredictionStarted(EditPredictionStartedDebugEvent),
178 EditPredictionFinished(EditPredictionFinishedDebugEvent),
179}
180
181#[derive(Debug)]
182pub struct ContextRetrievalStartedDebugEvent {
183 pub project_entity_id: EntityId,
184 pub timestamp: Instant,
185 pub search_prompt: String,
186}
187
188#[derive(Debug)]
189pub struct ContextRetrievalFinishedDebugEvent {
190 pub project_entity_id: EntityId,
191 pub timestamp: Instant,
192 pub metadata: Vec<(&'static str, SharedString)>,
193}
194
195#[derive(Debug)]
196pub struct EditPredictionStartedDebugEvent {
197 pub buffer: WeakEntity<Buffer>,
198 pub position: Anchor,
199 pub prompt: Option<String>,
200}
201
202#[derive(Debug)]
203pub struct EditPredictionFinishedDebugEvent {
204 pub buffer: WeakEntity<Buffer>,
205 pub position: Anchor,
206 pub model_output: Option<String>,
207}
208
209const USER_ACTION_HISTORY_SIZE: usize = 16;
210
211#[derive(Clone, Debug)]
212pub struct UserActionRecord {
213 pub action_type: UserActionType,
214 pub buffer_id: EntityId,
215 pub line_number: u32,
216 pub offset: usize,
217 pub timestamp_epoch_ms: u64,
218}
219
220#[derive(Clone, Copy, Debug, PartialEq, Eq)]
221pub enum UserActionType {
222 InsertChar,
223 InsertSelection,
224 DeleteChar,
225 DeleteSelection,
226 CursorMovement,
227}
228
229/// An event with associated metadata for reconstructing buffer state.
230#[derive(Clone)]
231pub struct StoredEvent {
232 pub event: Arc<zeta_prompt::Event>,
233 pub old_snapshot: TextBufferSnapshot,
234 pub edit_range: Range<Anchor>,
235}
236
237impl StoredEvent {
238 fn can_merge(
239 &self,
240 next_old_event: &&&StoredEvent,
241 new_snapshot: &TextBufferSnapshot,
242 last_edit_range: &Range<Anchor>,
243 ) -> bool {
244 // Events must be for the same buffer
245 if self.old_snapshot.remote_id() != next_old_event.old_snapshot.remote_id() {
246 return false;
247 }
248 if self.old_snapshot.remote_id() != new_snapshot.remote_id() {
249 return false;
250 }
251
252 let a_is_predicted = matches!(
253 self.event.as_ref(),
254 zeta_prompt::Event::BufferChange {
255 predicted: true,
256 ..
257 }
258 );
259 let b_is_predicted = matches!(
260 next_old_event.event.as_ref(),
261 zeta_prompt::Event::BufferChange {
262 predicted: true,
263 ..
264 }
265 );
266
267 // If events come from the same source (both predicted or both manual) then
268 // we would have coalesced them already.
269 if a_is_predicted == b_is_predicted {
270 return false;
271 }
272
273 let left_range = self.edit_range.to_point(new_snapshot);
274 let right_range = next_old_event.edit_range.to_point(new_snapshot);
275 let latest_range = last_edit_range.to_point(&new_snapshot);
276
277 // Events near to the latest edit are not merged if their sources differ.
278 if lines_between_ranges(&left_range, &latest_range)
279 .min(lines_between_ranges(&right_range, &latest_range))
280 <= CHANGE_GROUPING_LINE_SPAN
281 {
282 return false;
283 }
284
285 // Events that are distant from each other are not merged.
286 if lines_between_ranges(&left_range, &right_range) > CHANGE_GROUPING_LINE_SPAN {
287 return false;
288 }
289
290 true
291 }
292}
293
294fn lines_between_ranges(left: &Range<Point>, right: &Range<Point>) -> u32 {
295 if left.start > right.end {
296 return left.start.row - right.end.row;
297 }
298 if right.start > left.end {
299 return right.start.row - left.end.row;
300 }
301 0
302}
303
304struct ProjectState {
305 events: VecDeque<StoredEvent>,
306 last_event: Option<LastEvent>,
307 recent_paths: VecDeque<ProjectPath>,
308 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
309 current_prediction: Option<CurrentEditPrediction>,
310 next_pending_prediction_id: usize,
311 pending_predictions: ArrayVec<PendingPrediction, 2>,
312 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
313 last_edit_prediction_refresh: Option<(EntityId, Instant)>,
314 last_jump_prediction_refresh: Option<(EntityId, Instant)>,
315 cancelled_predictions: HashSet<usize>,
316 context: Entity<RelatedExcerptStore>,
317 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
318 user_actions: VecDeque<UserActionRecord>,
319 _subscriptions: [gpui::Subscription; 2],
320 copilot: Option<Entity<Copilot>>,
321}
322
323impl ProjectState {
324 fn record_user_action(&mut self, action: UserActionRecord) {
325 if self.user_actions.len() >= USER_ACTION_HISTORY_SIZE {
326 self.user_actions.pop_front();
327 }
328 self.user_actions.push_back(action);
329 }
330
331 pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
332 self.events
333 .iter()
334 .cloned()
335 .chain(self.last_event.as_ref().iter().flat_map(|event| {
336 let (one, two) = event.split_by_pause();
337 let one = one.finalize(&self.license_detection_watchers, cx);
338 let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
339 one.into_iter().chain(two)
340 }))
341 .collect()
342 }
343
344 fn cancel_pending_prediction(
345 &mut self,
346 pending_prediction: PendingPrediction,
347 cx: &mut Context<EditPredictionStore>,
348 ) {
349 self.cancelled_predictions.insert(pending_prediction.id);
350
351 if pending_prediction.drop_on_cancel {
352 drop(pending_prediction.task);
353 } else {
354 cx.spawn(async move |this, cx| {
355 let Some(prediction_id) = pending_prediction.task.await else {
356 return;
357 };
358
359 this.update(cx, |this, cx| {
360 this.reject_prediction(
361 prediction_id,
362 EditPredictionRejectReason::Canceled,
363 false,
364 cx,
365 );
366 })
367 .ok();
368 })
369 .detach()
370 }
371 }
372
373 fn active_buffer(
374 &self,
375 project: &Entity<Project>,
376 cx: &App,
377 ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
378 let project = project.read(cx);
379 let active_path = project.path_for_entry(project.active_entry()?, cx)?;
380 let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
381 let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
382 Some((active_buffer, registered_buffer.last_position))
383 }
384}
385
386#[derive(Debug, Clone)]
387struct CurrentEditPrediction {
388 pub requested_by: PredictionRequestedBy,
389 pub prediction: EditPrediction,
390 pub was_shown: bool,
391 pub shown_with: Option<edit_prediction_types::SuggestionDisplayType>,
392}
393
394impl CurrentEditPrediction {
395 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
396 let Some(new_edits) = self
397 .prediction
398 .interpolate(&self.prediction.buffer.read(cx))
399 else {
400 return false;
401 };
402
403 if self.prediction.buffer != old_prediction.prediction.buffer {
404 return true;
405 }
406
407 let Some(old_edits) = old_prediction
408 .prediction
409 .interpolate(&old_prediction.prediction.buffer.read(cx))
410 else {
411 return true;
412 };
413
414 let requested_by_buffer_id = self.requested_by.buffer_id();
415
416 // This reduces the occurrence of UI thrash from replacing edits
417 //
418 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
419 if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
420 && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
421 && old_edits.len() == 1
422 && new_edits.len() == 1
423 {
424 let (old_range, old_text) = &old_edits[0];
425 let (new_range, new_text) = &new_edits[0];
426 new_range == old_range && new_text.starts_with(old_text.as_ref())
427 } else {
428 true
429 }
430 }
431}
432
433#[derive(Debug, Clone)]
434enum PredictionRequestedBy {
435 DiagnosticsUpdate,
436 Buffer(EntityId),
437}
438
439impl PredictionRequestedBy {
440 pub fn buffer_id(&self) -> Option<EntityId> {
441 match self {
442 PredictionRequestedBy::DiagnosticsUpdate => None,
443 PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
444 }
445 }
446}
447
448const DIAGNOSTIC_LINES_RANGE: u32 = 20;
449
450#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
451pub enum DiagnosticSearchScope {
452 Local,
453 Global,
454}
455
456#[derive(Debug)]
457struct PendingPrediction {
458 id: usize,
459 task: Task<Option<EditPredictionId>>,
460 /// If true, the task is dropped immediately on cancel (cancelling the HTTP request).
461 /// If false, the task is awaited to completion so rejection can be reported.
462 drop_on_cancel: bool,
463}
464
465/// A prediction from the perspective of a buffer.
466#[derive(Debug)]
467enum BufferEditPrediction<'a> {
468 Local { prediction: &'a EditPrediction },
469 Jump { prediction: &'a EditPrediction },
470}
471
472#[cfg(test)]
473impl std::ops::Deref for BufferEditPrediction<'_> {
474 type Target = EditPrediction;
475
476 fn deref(&self) -> &Self::Target {
477 match self {
478 BufferEditPrediction::Local { prediction } => prediction,
479 BufferEditPrediction::Jump { prediction } => prediction,
480 }
481 }
482}
483
484struct RegisteredBuffer {
485 file: Option<Arc<dyn File>>,
486 snapshot: TextBufferSnapshot,
487 last_position: Option<Anchor>,
488 _subscriptions: [gpui::Subscription; 2],
489}
490
491#[derive(Clone)]
492struct LastEvent {
493 old_snapshot: TextBufferSnapshot,
494 new_snapshot: TextBufferSnapshot,
495 old_file: Option<Arc<dyn File>>,
496 new_file: Option<Arc<dyn File>>,
497 edit_range: Option<Range<Anchor>>,
498 predicted: bool,
499 snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
500 last_edit_time: Option<Instant>,
501}
502
503impl LastEvent {
504 pub fn finalize(
505 &self,
506 license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
507 cx: &App,
508 ) -> Option<StoredEvent> {
509 let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
510 let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
511
512 let in_open_source_repo =
513 [self.new_file.as_ref(), self.old_file.as_ref()]
514 .iter()
515 .all(|file| {
516 file.is_some_and(|file| {
517 license_detection_watchers
518 .get(&file.worktree_id(cx))
519 .is_some_and(|watcher| watcher.is_project_open_source())
520 })
521 });
522
523 let (diff, edit_range) =
524 compute_diff_between_snapshots(&self.old_snapshot, &self.new_snapshot)?;
525
526 if path == old_path && diff.is_empty() {
527 None
528 } else {
529 Some(StoredEvent {
530 event: Arc::new(zeta_prompt::Event::BufferChange {
531 old_path,
532 path,
533 diff,
534 in_open_source_repo,
535 predicted: self.predicted,
536 }),
537 edit_range: self.new_snapshot.anchor_before(edit_range.start)
538 ..self.new_snapshot.anchor_before(edit_range.end),
539 old_snapshot: self.old_snapshot.clone(),
540 })
541 }
542 }
543
544 pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
545 let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
546 return (self.clone(), None);
547 };
548
549 let before = LastEvent {
550 old_snapshot: self.old_snapshot.clone(),
551 new_snapshot: boundary_snapshot.clone(),
552 old_file: self.old_file.clone(),
553 new_file: self.new_file.clone(),
554 edit_range: None,
555 predicted: self.predicted,
556 snapshot_after_last_editing_pause: None,
557 last_edit_time: self.last_edit_time,
558 };
559
560 let after = LastEvent {
561 old_snapshot: boundary_snapshot.clone(),
562 new_snapshot: self.new_snapshot.clone(),
563 old_file: self.old_file.clone(),
564 new_file: self.new_file.clone(),
565 edit_range: None,
566 predicted: self.predicted,
567 snapshot_after_last_editing_pause: None,
568 last_edit_time: self.last_edit_time,
569 };
570
571 (before, Some(after))
572 }
573}
574
575pub(crate) fn compute_diff_between_snapshots(
576 old_snapshot: &TextBufferSnapshot,
577 new_snapshot: &TextBufferSnapshot,
578) -> Option<(String, Range<Point>)> {
579 let edits: Vec<Edit<usize>> = new_snapshot
580 .edits_since::<usize>(&old_snapshot.version)
581 .collect();
582
583 let (first_edit, last_edit) = edits.first().zip(edits.last())?;
584
585 let old_start_point = old_snapshot.offset_to_point(first_edit.old.start);
586 let old_end_point = old_snapshot.offset_to_point(last_edit.old.end);
587 let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
588 let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
589
590 const CONTEXT_LINES: u32 = 3;
591
592 let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
593 let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
594 let old_context_end_row =
595 (old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
596 let new_context_end_row =
597 (new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
598
599 let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
600 let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
601 let old_end_line_offset = old_snapshot
602 .point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
603 let new_end_line_offset = new_snapshot
604 .point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
605 let old_edit_range = old_start_line_offset..old_end_line_offset;
606 let new_edit_range = new_start_line_offset..new_end_line_offset;
607
608 let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
609 let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
610
611 let diff = language::unified_diff_with_offsets(
612 &old_region_text,
613 &new_region_text,
614 old_context_start_row,
615 new_context_start_row,
616 );
617
618 Some((diff, new_start_point..new_end_point))
619}
620
621fn buffer_path_with_id_fallback(
622 file: Option<&Arc<dyn File>>,
623 snapshot: &TextBufferSnapshot,
624 cx: &App,
625) -> Arc<Path> {
626 if let Some(file) = file {
627 file.full_path(cx).into()
628 } else {
629 Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
630 }
631}
632
633impl EditPredictionStore {
634 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
635 cx.try_global::<EditPredictionStoreGlobal>()
636 .map(|global| global.0.clone())
637 }
638
639 pub fn global(
640 client: &Arc<Client>,
641 user_store: &Entity<UserStore>,
642 cx: &mut App,
643 ) -> Entity<Self> {
644 cx.try_global::<EditPredictionStoreGlobal>()
645 .map(|global| global.0.clone())
646 .unwrap_or_else(|| {
647 let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
648 cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
649 ep_store
650 })
651 }
652
653 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
654 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
655 let data_collection_choice = Self::load_data_collection_choice();
656
657 let llm_token = LlmApiToken::default();
658
659 let (reject_tx, reject_rx) = mpsc::unbounded();
660 cx.background_spawn({
661 let client = client.clone();
662 let llm_token = llm_token.clone();
663 let app_version = AppVersion::global(cx);
664 let background_executor = cx.background_executor().clone();
665 async move {
666 Self::handle_rejected_predictions(
667 reject_rx,
668 client,
669 llm_token,
670 app_version,
671 background_executor,
672 )
673 .await
674 }
675 })
676 .detach();
677
678 let this = Self {
679 projects: HashMap::default(),
680 client,
681 user_store,
682 llm_token,
683 _llm_token_subscription: cx.subscribe(
684 &refresh_llm_token_listener,
685 |this, _listener, _event, cx| {
686 let client = this.client.clone();
687 let llm_token = this.llm_token.clone();
688 cx.spawn(async move |_this, _cx| {
689 llm_token.refresh(&client).await?;
690 anyhow::Ok(())
691 })
692 .detach_and_log_err(cx);
693 },
694 ),
695 update_required: false,
696 edit_prediction_model: EditPredictionModel::Zeta2,
697 zeta2_raw_config: Self::zeta2_raw_config_from_env(),
698 sweep_ai: SweepAi::new(cx),
699 mercury: Mercury::new(cx),
700
701 data_collection_choice,
702 reject_predictions_tx: reject_tx,
703 rated_predictions: Default::default(),
704 shown_predictions: Default::default(),
705 };
706
707 this
708 }
709
710 fn zeta2_raw_config_from_env() -> Option<Zeta2RawConfig> {
711 let version_str = env::var("ZED_ZETA_FORMAT").ok()?;
712 let format = ZetaFormat::parse(&version_str).ok()?;
713 let model_id = env::var("ZED_ZETA_MODEL").ok();
714 Some(Zeta2RawConfig { model_id, format })
715 }
716
717 pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
718 self.edit_prediction_model = model;
719 }
720
721 pub fn set_zeta2_raw_config(&mut self, config: Zeta2RawConfig) {
722 self.zeta2_raw_config = Some(config);
723 }
724
725 pub fn zeta2_raw_config(&self) -> Option<&Zeta2RawConfig> {
726 self.zeta2_raw_config.as_ref()
727 }
728
729 pub fn icons(&self, cx: &App) -> edit_prediction_types::EditPredictionIconSet {
730 use ui::IconName;
731 match self.edit_prediction_model {
732 EditPredictionModel::Sweep => {
733 edit_prediction_types::EditPredictionIconSet::new(IconName::SweepAi)
734 .with_disabled(IconName::SweepAiDisabled)
735 .with_up(IconName::SweepAiUp)
736 .with_down(IconName::SweepAiDown)
737 .with_error(IconName::SweepAiError)
738 }
739 EditPredictionModel::Mercury => {
740 edit_prediction_types::EditPredictionIconSet::new(IconName::Inception)
741 }
742 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
743 edit_prediction_types::EditPredictionIconSet::new(IconName::ZedPredict)
744 .with_disabled(IconName::ZedPredictDisabled)
745 .with_up(IconName::ZedPredictUp)
746 .with_down(IconName::ZedPredictDown)
747 .with_error(IconName::ZedPredictError)
748 }
749 EditPredictionModel::Fim { .. } => {
750 let settings = &all_language_settings(None, cx).edit_predictions;
751 match settings.provider {
752 EditPredictionProvider::Ollama => {
753 edit_prediction_types::EditPredictionIconSet::new(IconName::AiOllama)
754 }
755 _ => {
756 edit_prediction_types::EditPredictionIconSet::new(IconName::AiOpenAiCompat)
757 }
758 }
759 }
760 }
761 }
762
763 pub fn has_sweep_api_token(&self, cx: &App) -> bool {
764 self.sweep_ai.api_token.read(cx).has_key()
765 }
766
767 pub fn has_mercury_api_token(&self, cx: &App) -> bool {
768 self.mercury.api_token.read(cx).has_key()
769 }
770
771 pub fn clear_history(&mut self) {
772 for project_state in self.projects.values_mut() {
773 project_state.events.clear();
774 project_state.last_event.take();
775 }
776 }
777
778 pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
779 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
780 project_state.events.clear();
781 project_state.last_event.take();
782 }
783 }
784
785 pub fn edit_history_for_project(
786 &self,
787 project: &Entity<Project>,
788 cx: &App,
789 ) -> Vec<StoredEvent> {
790 self.projects
791 .get(&project.entity_id())
792 .map(|project_state| project_state.events(cx))
793 .unwrap_or_default()
794 }
795
796 pub fn context_for_project<'a>(
797 &'a self,
798 project: &Entity<Project>,
799 cx: &'a mut App,
800 ) -> Vec<RelatedFile> {
801 self.projects
802 .get(&project.entity_id())
803 .map(|project_state| {
804 project_state.context.update(cx, |context, cx| {
805 context
806 .related_files_with_buffers(cx)
807 .map(|(mut related_file, buffer)| {
808 related_file.in_open_source_repo = buffer
809 .read(cx)
810 .file()
811 .map_or(false, |file| self.is_file_open_source(&project, file, cx));
812 related_file
813 })
814 .collect()
815 })
816 })
817 .unwrap_or_default()
818 }
819
820 pub fn copilot_for_project(&self, project: &Entity<Project>) -> Option<Entity<Copilot>> {
821 self.projects
822 .get(&project.entity_id())
823 .and_then(|project| project.copilot.clone())
824 }
825
826 pub fn start_copilot_for_project(
827 &mut self,
828 project: &Entity<Project>,
829 cx: &mut Context<Self>,
830 ) -> Option<Entity<Copilot>> {
831 if DisableAiSettings::get(None, cx).disable_ai {
832 return None;
833 }
834 let state = self.get_or_init_project(project, cx);
835
836 if state.copilot.is_some() {
837 return state.copilot.clone();
838 }
839 let _project = project.clone();
840 let project = project.read(cx);
841
842 let node = project.node_runtime().cloned();
843 if let Some(node) = node {
844 let next_id = project.languages().next_language_server_id();
845 let fs = project.fs().clone();
846
847 let copilot = cx.new(|cx| Copilot::new(Some(_project), next_id, fs, node, cx));
848 state.copilot = Some(copilot.clone());
849 Some(copilot)
850 } else {
851 None
852 }
853 }
854
855 pub fn context_for_project_with_buffers<'a>(
856 &'a self,
857 project: &Entity<Project>,
858 cx: &'a mut App,
859 ) -> Vec<(RelatedFile, Entity<Buffer>)> {
860 self.projects
861 .get(&project.entity_id())
862 .map(|project| {
863 project.context.update(cx, |context, cx| {
864 context.related_files_with_buffers(cx).collect()
865 })
866 })
867 .unwrap_or_default()
868 }
869
870 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
871 if matches!(
872 self.edit_prediction_model,
873 EditPredictionModel::Zeta2 | EditPredictionModel::Zeta1
874 ) {
875 self.user_store.read(cx).edit_prediction_usage()
876 } else {
877 None
878 }
879 }
880
881 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
882 self.get_or_init_project(project, cx);
883 }
884
885 pub fn register_buffer(
886 &mut self,
887 buffer: &Entity<Buffer>,
888 project: &Entity<Project>,
889 cx: &mut Context<Self>,
890 ) {
891 let project_state = self.get_or_init_project(project, cx);
892 Self::register_buffer_impl(project_state, buffer, project, cx);
893 }
894
895 fn get_or_init_project(
896 &mut self,
897 project: &Entity<Project>,
898 cx: &mut Context<Self>,
899 ) -> &mut ProjectState {
900 let entity_id = project.entity_id();
901 self.projects
902 .entry(entity_id)
903 .or_insert_with(|| ProjectState {
904 context: {
905 let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
906 cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
907 this.handle_excerpt_store_event(entity_id, event);
908 })
909 .detach();
910 related_excerpt_store
911 },
912 events: VecDeque::new(),
913 last_event: None,
914 recent_paths: VecDeque::new(),
915 debug_tx: None,
916 registered_buffers: HashMap::default(),
917 current_prediction: None,
918 cancelled_predictions: HashSet::default(),
919 pending_predictions: ArrayVec::new(),
920 next_pending_prediction_id: 0,
921 last_edit_prediction_refresh: None,
922 last_jump_prediction_refresh: None,
923 license_detection_watchers: HashMap::default(),
924 user_actions: VecDeque::with_capacity(USER_ACTION_HISTORY_SIZE),
925 _subscriptions: [
926 cx.subscribe(&project, Self::handle_project_event),
927 cx.observe_release(&project, move |this, _, cx| {
928 this.projects.remove(&entity_id);
929 cx.notify();
930 }),
931 ],
932 copilot: None,
933 })
934 }
935
936 pub fn remove_project(&mut self, project: &Entity<Project>) {
937 self.projects.remove(&project.entity_id());
938 }
939
940 fn handle_excerpt_store_event(
941 &mut self,
942 project_entity_id: EntityId,
943 event: &RelatedExcerptStoreEvent,
944 ) {
945 if let Some(project_state) = self.projects.get(&project_entity_id) {
946 if let Some(debug_tx) = project_state.debug_tx.clone() {
947 match event {
948 RelatedExcerptStoreEvent::StartedRefresh => {
949 debug_tx
950 .unbounded_send(DebugEvent::ContextRetrievalStarted(
951 ContextRetrievalStartedDebugEvent {
952 project_entity_id: project_entity_id,
953 timestamp: Instant::now(),
954 search_prompt: String::new(),
955 },
956 ))
957 .ok();
958 }
959 RelatedExcerptStoreEvent::FinishedRefresh {
960 cache_hit_count,
961 cache_miss_count,
962 mean_definition_latency,
963 max_definition_latency,
964 } => {
965 debug_tx
966 .unbounded_send(DebugEvent::ContextRetrievalFinished(
967 ContextRetrievalFinishedDebugEvent {
968 project_entity_id: project_entity_id,
969 timestamp: Instant::now(),
970 metadata: vec![
971 (
972 "Cache Hits",
973 format!(
974 "{}/{}",
975 cache_hit_count,
976 cache_hit_count + cache_miss_count
977 )
978 .into(),
979 ),
980 (
981 "Max LSP Time",
982 format!("{} ms", max_definition_latency.as_millis())
983 .into(),
984 ),
985 (
986 "Mean LSP Time",
987 format!("{} ms", mean_definition_latency.as_millis())
988 .into(),
989 ),
990 ],
991 },
992 ))
993 .ok();
994 }
995 }
996 }
997 }
998 }
999
1000 pub fn debug_info(
1001 &mut self,
1002 project: &Entity<Project>,
1003 cx: &mut Context<Self>,
1004 ) -> mpsc::UnboundedReceiver<DebugEvent> {
1005 let project_state = self.get_or_init_project(project, cx);
1006 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
1007 project_state.debug_tx = Some(debug_watch_tx);
1008 debug_watch_rx
1009 }
1010
1011 fn handle_project_event(
1012 &mut self,
1013 project: Entity<Project>,
1014 event: &project::Event,
1015 cx: &mut Context<Self>,
1016 ) {
1017 if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1018 return;
1019 }
1020 // TODO [zeta2] init with recent paths
1021 match event {
1022 project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
1023 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1024 return;
1025 };
1026 let path = project.read(cx).path_for_entry(*active_entry_id, cx);
1027 if let Some(path) = path {
1028 if let Some(ix) = project_state
1029 .recent_paths
1030 .iter()
1031 .position(|probe| probe == &path)
1032 {
1033 project_state.recent_paths.remove(ix);
1034 }
1035 project_state.recent_paths.push_front(path);
1036 }
1037 }
1038 project::Event::DiagnosticsUpdated { .. } => {
1039 if cx.has_flag::<EditPredictionJumpsFeatureFlag>() {
1040 self.refresh_prediction_from_diagnostics(
1041 project,
1042 DiagnosticSearchScope::Global,
1043 cx,
1044 );
1045 }
1046 }
1047 _ => (),
1048 }
1049 }
1050
1051 fn register_buffer_impl<'a>(
1052 project_state: &'a mut ProjectState,
1053 buffer: &Entity<Buffer>,
1054 project: &Entity<Project>,
1055 cx: &mut Context<Self>,
1056 ) -> &'a mut RegisteredBuffer {
1057 let buffer_id = buffer.entity_id();
1058
1059 if let Some(file) = buffer.read(cx).file() {
1060 let worktree_id = file.worktree_id(cx);
1061 if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
1062 project_state
1063 .license_detection_watchers
1064 .entry(worktree_id)
1065 .or_insert_with(|| {
1066 let project_entity_id = project.entity_id();
1067 cx.observe_release(&worktree, move |this, _worktree, _cx| {
1068 let Some(project_state) = this.projects.get_mut(&project_entity_id)
1069 else {
1070 return;
1071 };
1072 project_state
1073 .license_detection_watchers
1074 .remove(&worktree_id);
1075 })
1076 .detach();
1077 Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
1078 });
1079 }
1080 }
1081
1082 match project_state.registered_buffers.entry(buffer_id) {
1083 hash_map::Entry::Occupied(entry) => entry.into_mut(),
1084 hash_map::Entry::Vacant(entry) => {
1085 let buf = buffer.read(cx);
1086 let snapshot = buf.text_snapshot();
1087 let file = buf.file().cloned();
1088 let project_entity_id = project.entity_id();
1089 entry.insert(RegisteredBuffer {
1090 snapshot,
1091 file,
1092 last_position: None,
1093 _subscriptions: [
1094 cx.subscribe(buffer, {
1095 let project = project.downgrade();
1096 move |this, buffer, event, cx| {
1097 if let language::BufferEvent::Edited = event
1098 && let Some(project) = project.upgrade()
1099 {
1100 this.report_changes_for_buffer(&buffer, &project, false, cx);
1101 }
1102 }
1103 }),
1104 cx.observe_release(buffer, move |this, _buffer, _cx| {
1105 let Some(project_state) = this.projects.get_mut(&project_entity_id)
1106 else {
1107 return;
1108 };
1109 project_state.registered_buffers.remove(&buffer_id);
1110 }),
1111 ],
1112 })
1113 }
1114 }
1115 }
1116
1117 fn report_changes_for_buffer(
1118 &mut self,
1119 buffer: &Entity<Buffer>,
1120 project: &Entity<Project>,
1121 is_predicted: bool,
1122 cx: &mut Context<Self>,
1123 ) {
1124 let project_state = self.get_or_init_project(project, cx);
1125 let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
1126
1127 let buf = buffer.read(cx);
1128 let new_file = buf.file().cloned();
1129 let new_snapshot = buf.text_snapshot();
1130 if new_snapshot.version == registered_buffer.snapshot.version {
1131 return;
1132 }
1133
1134 let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
1135 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1136 let mut num_edits = 0usize;
1137 let mut total_deleted = 0usize;
1138 let mut total_inserted = 0usize;
1139 let mut edit_range: Option<Range<Anchor>> = None;
1140 let mut last_offset: Option<usize> = None;
1141
1142 for (edit, anchor_range) in
1143 new_snapshot.anchored_edits_since::<usize>(&old_snapshot.version)
1144 {
1145 num_edits += 1;
1146 total_deleted += edit.old.len();
1147 total_inserted += edit.new.len();
1148 edit_range = Some(match edit_range {
1149 None => anchor_range,
1150 Some(acc) => acc.start..anchor_range.end,
1151 });
1152 last_offset = Some(edit.new.end);
1153 }
1154
1155 let Some(edit_range) = edit_range else {
1156 return;
1157 };
1158
1159 let action_type = match (total_deleted, total_inserted, num_edits) {
1160 (0, ins, n) if ins == n => UserActionType::InsertChar,
1161 (0, _, _) => UserActionType::InsertSelection,
1162 (del, 0, n) if del == n => UserActionType::DeleteChar,
1163 (_, 0, _) => UserActionType::DeleteSelection,
1164 (_, ins, n) if ins == n => UserActionType::InsertChar,
1165 (_, _, _) => UserActionType::InsertSelection,
1166 };
1167
1168 if let Some(offset) = last_offset {
1169 let point = new_snapshot.offset_to_point(offset);
1170 let timestamp_epoch_ms = SystemTime::now()
1171 .duration_since(UNIX_EPOCH)
1172 .map(|d| d.as_millis() as u64)
1173 .unwrap_or(0);
1174 project_state.record_user_action(UserActionRecord {
1175 action_type,
1176 buffer_id: buffer.entity_id(),
1177 line_number: point.row,
1178 offset,
1179 timestamp_epoch_ms,
1180 });
1181 }
1182
1183 let events = &mut project_state.events;
1184
1185 let now = cx.background_executor().now();
1186 if let Some(last_event) = project_state.last_event.as_mut() {
1187 let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
1188 == last_event.new_snapshot.remote_id()
1189 && old_snapshot.version == last_event.new_snapshot.version;
1190
1191 let prediction_source_changed = is_predicted != last_event.predicted;
1192
1193 let should_coalesce = is_next_snapshot_of_same_buffer
1194 && !prediction_source_changed
1195 && last_event
1196 .edit_range
1197 .as_ref()
1198 .is_some_and(|last_edit_range| {
1199 lines_between_ranges(
1200 &edit_range.to_point(&new_snapshot),
1201 &last_edit_range.to_point(&new_snapshot),
1202 ) <= CHANGE_GROUPING_LINE_SPAN
1203 });
1204
1205 if should_coalesce {
1206 let pause_elapsed = last_event
1207 .last_edit_time
1208 .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
1209 .unwrap_or(false);
1210 if pause_elapsed {
1211 last_event.snapshot_after_last_editing_pause =
1212 Some(last_event.new_snapshot.clone());
1213 }
1214
1215 last_event.edit_range = Some(edit_range);
1216 last_event.new_snapshot = new_snapshot;
1217 last_event.last_edit_time = Some(now);
1218 return;
1219 }
1220 }
1221
1222 if let Some(event) = project_state.last_event.take() {
1223 if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1224 if events.len() + 1 >= EVENT_COUNT_MAX {
1225 events.pop_front();
1226 }
1227 events.push_back(event);
1228 }
1229 }
1230
1231 merge_trailing_events_if_needed(events, &old_snapshot, &new_snapshot, &edit_range);
1232
1233 project_state.last_event = Some(LastEvent {
1234 old_file,
1235 new_file,
1236 old_snapshot,
1237 new_snapshot,
1238 edit_range: Some(edit_range),
1239 predicted: is_predicted,
1240 snapshot_after_last_editing_pause: None,
1241 last_edit_time: Some(now),
1242 });
1243 }
1244
1245 fn prediction_at(
1246 &mut self,
1247 buffer: &Entity<Buffer>,
1248 position: Option<language::Anchor>,
1249 project: &Entity<Project>,
1250 cx: &App,
1251 ) -> Option<BufferEditPrediction<'_>> {
1252 let project_state = self.projects.get_mut(&project.entity_id())?;
1253 if let Some(position) = position
1254 && let Some(buffer) = project_state
1255 .registered_buffers
1256 .get_mut(&buffer.entity_id())
1257 {
1258 buffer.last_position = Some(position);
1259 }
1260
1261 let CurrentEditPrediction {
1262 requested_by,
1263 prediction,
1264 ..
1265 } = project_state.current_prediction.as_ref()?;
1266
1267 if prediction.targets_buffer(buffer.read(cx)) {
1268 Some(BufferEditPrediction::Local { prediction })
1269 } else {
1270 let show_jump = match requested_by {
1271 PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1272 requested_by_buffer_id == &buffer.entity_id()
1273 }
1274 PredictionRequestedBy::DiagnosticsUpdate => true,
1275 };
1276
1277 if show_jump {
1278 Some(BufferEditPrediction::Jump { prediction })
1279 } else {
1280 None
1281 }
1282 }
1283 }
1284
1285 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1286 let Some(current_prediction) = self
1287 .projects
1288 .get_mut(&project.entity_id())
1289 .and_then(|project_state| project_state.current_prediction.take())
1290 else {
1291 return;
1292 };
1293
1294 self.report_changes_for_buffer(¤t_prediction.prediction.buffer, project, true, cx);
1295
1296 // can't hold &mut project_state ref across report_changes_for_buffer_call
1297 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1298 return;
1299 };
1300
1301 for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1302 project_state.cancel_pending_prediction(pending_prediction, cx);
1303 }
1304
1305 match self.edit_prediction_model {
1306 EditPredictionModel::Sweep => {
1307 sweep_ai::edit_prediction_accepted(self, current_prediction, cx)
1308 }
1309 EditPredictionModel::Mercury => {
1310 mercury::edit_prediction_accepted(
1311 current_prediction.prediction.id,
1312 self.client.http_client(),
1313 cx,
1314 );
1315 }
1316 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
1317 let is_cloud = !matches!(
1318 all_language_settings(None, cx).edit_predictions.provider,
1319 EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1320 );
1321 if is_cloud {
1322 zeta::edit_prediction_accepted(self, current_prediction, cx)
1323 }
1324 }
1325 EditPredictionModel::Fim { .. } => {}
1326 }
1327 }
1328
1329 async fn handle_rejected_predictions(
1330 rx: UnboundedReceiver<EditPredictionRejection>,
1331 client: Arc<Client>,
1332 llm_token: LlmApiToken,
1333 app_version: Version,
1334 background_executor: BackgroundExecutor,
1335 ) {
1336 let mut rx = std::pin::pin!(rx.peekable());
1337 let mut batched = Vec::new();
1338
1339 while let Some(rejection) = rx.next().await {
1340 batched.push(rejection);
1341
1342 if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1343 select_biased! {
1344 next = rx.as_mut().peek().fuse() => {
1345 if next.is_some() {
1346 continue;
1347 }
1348 }
1349 () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1350 }
1351 }
1352
1353 let url = client
1354 .http_client()
1355 .build_zed_llm_url("/predict_edits/reject", &[])
1356 .unwrap();
1357
1358 let flush_count = batched
1359 .len()
1360 // in case items have accumulated after failure
1361 .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1362 let start = batched.len() - flush_count;
1363
1364 let body = RejectEditPredictionsBodyRef {
1365 rejections: &batched[start..],
1366 };
1367
1368 let result = Self::send_api_request::<()>(
1369 |builder| {
1370 let req = builder
1371 .uri(url.as_ref())
1372 .body(serde_json::to_string(&body)?.into());
1373 anyhow::Ok(req?)
1374 },
1375 client.clone(),
1376 llm_token.clone(),
1377 app_version.clone(),
1378 true,
1379 )
1380 .await;
1381
1382 if result.log_err().is_some() {
1383 batched.drain(start..);
1384 }
1385 }
1386 }
1387
1388 fn reject_current_prediction(
1389 &mut self,
1390 reason: EditPredictionRejectReason,
1391 project: &Entity<Project>,
1392 cx: &App,
1393 ) {
1394 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1395 project_state.pending_predictions.clear();
1396 if let Some(prediction) = project_state.current_prediction.take() {
1397 self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown, cx);
1398 }
1399 };
1400 }
1401
1402 fn did_show_current_prediction(
1403 &mut self,
1404 project: &Entity<Project>,
1405 display_type: edit_prediction_types::SuggestionDisplayType,
1406 cx: &mut Context<Self>,
1407 ) {
1408 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1409 return;
1410 };
1411
1412 let Some(current_prediction) = project_state.current_prediction.as_mut() else {
1413 return;
1414 };
1415
1416 let is_jump = display_type == edit_prediction_types::SuggestionDisplayType::Jump;
1417 let previous_shown_with = current_prediction.shown_with;
1418
1419 if previous_shown_with.is_none() || !is_jump {
1420 current_prediction.shown_with = Some(display_type);
1421 }
1422
1423 let is_first_non_jump_show = !current_prediction.was_shown && !is_jump;
1424
1425 if is_first_non_jump_show {
1426 current_prediction.was_shown = true;
1427 }
1428
1429 let display_type_changed = previous_shown_with != Some(display_type);
1430
1431 if self.edit_prediction_model == EditPredictionModel::Sweep && display_type_changed {
1432 sweep_ai::edit_prediction_shown(
1433 &self.sweep_ai,
1434 self.client.clone(),
1435 ¤t_prediction.prediction,
1436 display_type,
1437 cx,
1438 );
1439 }
1440
1441 if is_first_non_jump_show {
1442 self.shown_predictions
1443 .push_front(current_prediction.prediction.clone());
1444 if self.shown_predictions.len() > 50 {
1445 let completion = self.shown_predictions.pop_back().unwrap();
1446 self.rated_predictions.remove(&completion.id);
1447 }
1448 }
1449 }
1450
1451 fn reject_prediction(
1452 &mut self,
1453 prediction_id: EditPredictionId,
1454 reason: EditPredictionRejectReason,
1455 was_shown: bool,
1456 cx: &App,
1457 ) {
1458 match self.edit_prediction_model {
1459 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
1460 let is_cloud = !matches!(
1461 all_language_settings(None, cx).edit_predictions.provider,
1462 EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1463 );
1464 if is_cloud {
1465 self.reject_predictions_tx
1466 .unbounded_send(EditPredictionRejection {
1467 request_id: prediction_id.to_string(),
1468 reason,
1469 was_shown,
1470 })
1471 .log_err();
1472 }
1473 }
1474 EditPredictionModel::Mercury => {
1475 mercury::edit_prediction_rejected(
1476 prediction_id,
1477 was_shown,
1478 reason,
1479 self.client.http_client(),
1480 cx,
1481 );
1482 }
1483 EditPredictionModel::Sweep | EditPredictionModel::Fim { .. } => {}
1484 }
1485 }
1486
1487 fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1488 self.projects
1489 .get(&project.entity_id())
1490 .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1491 }
1492
1493 pub fn refresh_prediction_from_buffer(
1494 &mut self,
1495 project: Entity<Project>,
1496 buffer: Entity<Buffer>,
1497 position: language::Anchor,
1498 cx: &mut Context<Self>,
1499 ) {
1500 self.queue_prediction_refresh(
1501 project.clone(),
1502 PredictEditsRequestTrigger::Other,
1503 buffer.entity_id(),
1504 cx,
1505 move |this, cx| {
1506 let Some(request_task) = this
1507 .update(cx, |this, cx| {
1508 this.request_prediction(
1509 &project,
1510 &buffer,
1511 position,
1512 PredictEditsRequestTrigger::Other,
1513 cx,
1514 )
1515 })
1516 .log_err()
1517 else {
1518 return Task::ready(anyhow::Ok(None));
1519 };
1520
1521 cx.spawn(async move |_cx| {
1522 request_task.await.map(|prediction_result| {
1523 prediction_result.map(|prediction_result| {
1524 (
1525 prediction_result,
1526 PredictionRequestedBy::Buffer(buffer.entity_id()),
1527 )
1528 })
1529 })
1530 })
1531 },
1532 )
1533 }
1534
1535 pub fn refresh_prediction_from_diagnostics(
1536 &mut self,
1537 project: Entity<Project>,
1538 scope: DiagnosticSearchScope,
1539 cx: &mut Context<Self>,
1540 ) {
1541 if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1542 return;
1543 }
1544
1545 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1546 return;
1547 };
1548
1549 // Prefer predictions from buffer
1550 if project_state.current_prediction.is_some() {
1551 return;
1552 }
1553
1554 self.queue_prediction_refresh(
1555 project.clone(),
1556 PredictEditsRequestTrigger::Diagnostics,
1557 project.entity_id(),
1558 cx,
1559 move |this, cx| {
1560 let Some((active_buffer, snapshot, cursor_point)) = this
1561 .read_with(cx, |this, cx| {
1562 let project_state = this.projects.get(&project.entity_id())?;
1563 let (buffer, position) = project_state.active_buffer(&project, cx)?;
1564 let snapshot = buffer.read(cx).snapshot();
1565
1566 if !Self::predictions_enabled_at(&snapshot, position, cx) {
1567 return None;
1568 }
1569
1570 let cursor_point = position
1571 .map(|pos| pos.to_point(&snapshot))
1572 .unwrap_or_default();
1573
1574 Some((buffer, snapshot, cursor_point))
1575 })
1576 .log_err()
1577 .flatten()
1578 else {
1579 return Task::ready(anyhow::Ok(None));
1580 };
1581
1582 cx.spawn(async move |cx| {
1583 let diagnostic_search_range = match scope {
1584 DiagnosticSearchScope::Local => {
1585 let diagnostic_search_start =
1586 cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1587 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1588 Point::new(diagnostic_search_start, 0)
1589 ..Point::new(diagnostic_search_end, 0)
1590 }
1591 DiagnosticSearchScope::Global => Default::default(),
1592 };
1593
1594 let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1595 active_buffer,
1596 &snapshot,
1597 diagnostic_search_range,
1598 cursor_point,
1599 &project,
1600 cx,
1601 )
1602 .await?
1603 else {
1604 return anyhow::Ok(None);
1605 };
1606
1607 let Some(prediction_result) = this
1608 .update(cx, |this, cx| {
1609 this.request_prediction(
1610 &project,
1611 &jump_buffer,
1612 jump_position,
1613 PredictEditsRequestTrigger::Diagnostics,
1614 cx,
1615 )
1616 })?
1617 .await?
1618 else {
1619 return anyhow::Ok(None);
1620 };
1621
1622 this.update(cx, |this, cx| {
1623 Some((
1624 if this
1625 .get_or_init_project(&project, cx)
1626 .current_prediction
1627 .is_none()
1628 {
1629 prediction_result
1630 } else {
1631 EditPredictionResult {
1632 id: prediction_result.id,
1633 prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1634 }
1635 },
1636 PredictionRequestedBy::DiagnosticsUpdate,
1637 ))
1638 })
1639 })
1640 },
1641 );
1642 }
1643
1644 fn predictions_enabled_at(
1645 snapshot: &BufferSnapshot,
1646 position: Option<language::Anchor>,
1647 cx: &App,
1648 ) -> bool {
1649 let file = snapshot.file();
1650 let all_settings = all_language_settings(file, cx);
1651 if !all_settings.show_edit_predictions(snapshot.language(), cx)
1652 || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
1653 {
1654 return false;
1655 }
1656
1657 if let Some(last_position) = position {
1658 let settings = snapshot.settings_at(last_position, cx);
1659
1660 if !settings.edit_predictions_disabled_in.is_empty()
1661 && let Some(scope) = snapshot.language_scope_at(last_position)
1662 && let Some(scope_name) = scope.override_name()
1663 && settings
1664 .edit_predictions_disabled_in
1665 .iter()
1666 .any(|s| s == scope_name)
1667 {
1668 return false;
1669 }
1670 }
1671
1672 true
1673 }
1674
1675 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1676}
1677
1678fn is_ep_store_provider(provider: EditPredictionProvider) -> bool {
1679 match provider {
1680 EditPredictionProvider::Zed
1681 | EditPredictionProvider::Sweep
1682 | EditPredictionProvider::Mercury
1683 | EditPredictionProvider::Ollama
1684 | EditPredictionProvider::OpenAiCompatibleApi
1685 | EditPredictionProvider::Experimental(_) => true,
1686 EditPredictionProvider::None
1687 | EditPredictionProvider::Copilot
1688 | EditPredictionProvider::Supermaven
1689 | EditPredictionProvider::Codestral => false,
1690 }
1691}
1692
1693impl EditPredictionStore {
1694 fn queue_prediction_refresh(
1695 &mut self,
1696 project: Entity<Project>,
1697 request_trigger: PredictEditsRequestTrigger,
1698 throttle_entity: EntityId,
1699 cx: &mut Context<Self>,
1700 do_refresh: impl FnOnce(
1701 WeakEntity<Self>,
1702 &mut AsyncApp,
1703 )
1704 -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1705 + 'static,
1706 ) {
1707 fn select_throttle(
1708 project_state: &mut ProjectState,
1709 request_trigger: PredictEditsRequestTrigger,
1710 ) -> &mut Option<(EntityId, Instant)> {
1711 match request_trigger {
1712 PredictEditsRequestTrigger::Diagnostics => {
1713 &mut project_state.last_jump_prediction_refresh
1714 }
1715 _ => &mut project_state.last_edit_prediction_refresh,
1716 }
1717 }
1718
1719 let (needs_acceptance_tracking, max_pending_predictions) =
1720 match all_language_settings(None, cx).edit_predictions.provider {
1721 EditPredictionProvider::Zed
1722 | EditPredictionProvider::Sweep
1723 | EditPredictionProvider::Mercury
1724 | EditPredictionProvider::Experimental(_) => (true, 2),
1725 EditPredictionProvider::Ollama => (false, 1),
1726 EditPredictionProvider::OpenAiCompatibleApi => (false, 2),
1727 EditPredictionProvider::None
1728 | EditPredictionProvider::Copilot
1729 | EditPredictionProvider::Supermaven
1730 | EditPredictionProvider::Codestral => {
1731 log::error!("queue_prediction_refresh called with non-store provider");
1732 return;
1733 }
1734 };
1735
1736 let drop_on_cancel = !needs_acceptance_tracking;
1737 let throttle_timeout = Self::THROTTLE_TIMEOUT;
1738 let project_state = self.get_or_init_project(&project, cx);
1739 let pending_prediction_id = project_state.next_pending_prediction_id;
1740 project_state.next_pending_prediction_id += 1;
1741 let last_request = *select_throttle(project_state, request_trigger);
1742
1743 let task = cx.spawn(async move |this, cx| {
1744 if let Some(timeout) = last_request.and_then(|(last_entity, last_timestamp)| {
1745 if throttle_entity != last_entity {
1746 return None;
1747 }
1748 (last_timestamp + throttle_timeout).checked_duration_since(Instant::now())
1749 }) {
1750 cx.background_executor().timer(timeout).await;
1751 }
1752
1753 // If this task was cancelled before the throttle timeout expired,
1754 // do not perform a request.
1755 let mut is_cancelled = true;
1756 this.update(cx, |this, cx| {
1757 let project_state = this.get_or_init_project(&project, cx);
1758 let was_cancelled = project_state
1759 .cancelled_predictions
1760 .remove(&pending_prediction_id);
1761 if !was_cancelled {
1762 let new_refresh = (throttle_entity, Instant::now());
1763 *select_throttle(project_state, request_trigger) = Some(new_refresh);
1764 is_cancelled = false;
1765 }
1766 })
1767 .ok();
1768 if is_cancelled {
1769 return None;
1770 }
1771
1772 let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
1773 let new_prediction_id = new_prediction_result
1774 .as_ref()
1775 .map(|(prediction, _)| prediction.id.clone());
1776
1777 // When a prediction completes, remove it from the pending list, and cancel
1778 // any pending predictions that were enqueued before it.
1779 this.update(cx, |this, cx| {
1780 let project_state = this.get_or_init_project(&project, cx);
1781
1782 let is_cancelled = project_state
1783 .cancelled_predictions
1784 .remove(&pending_prediction_id);
1785
1786 let new_current_prediction = if !is_cancelled
1787 && let Some((prediction_result, requested_by)) = new_prediction_result
1788 {
1789 match prediction_result.prediction {
1790 Ok(prediction) => {
1791 let new_prediction = CurrentEditPrediction {
1792 requested_by,
1793 prediction,
1794 was_shown: false,
1795 shown_with: None,
1796 };
1797
1798 if let Some(current_prediction) =
1799 project_state.current_prediction.as_ref()
1800 {
1801 if new_prediction.should_replace_prediction(¤t_prediction, cx)
1802 {
1803 this.reject_current_prediction(
1804 EditPredictionRejectReason::Replaced,
1805 &project,
1806 cx,
1807 );
1808
1809 Some(new_prediction)
1810 } else {
1811 this.reject_prediction(
1812 new_prediction.prediction.id,
1813 EditPredictionRejectReason::CurrentPreferred,
1814 false,
1815 cx,
1816 );
1817 None
1818 }
1819 } else {
1820 Some(new_prediction)
1821 }
1822 }
1823 Err(reject_reason) => {
1824 this.reject_prediction(prediction_result.id, reject_reason, false, cx);
1825 None
1826 }
1827 }
1828 } else {
1829 None
1830 };
1831
1832 let project_state = this.get_or_init_project(&project, cx);
1833
1834 if let Some(new_prediction) = new_current_prediction {
1835 project_state.current_prediction = Some(new_prediction);
1836 }
1837
1838 let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
1839 for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1840 if pending_prediction.id == pending_prediction_id {
1841 pending_predictions.remove(ix);
1842 for pending_prediction in pending_predictions.drain(0..ix) {
1843 project_state.cancel_pending_prediction(pending_prediction, cx)
1844 }
1845 break;
1846 }
1847 }
1848 this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
1849 cx.notify();
1850 })
1851 .ok();
1852
1853 new_prediction_id
1854 });
1855
1856 if project_state.pending_predictions.len() < max_pending_predictions {
1857 project_state.pending_predictions.push(PendingPrediction {
1858 id: pending_prediction_id,
1859 task,
1860 drop_on_cancel,
1861 });
1862 } else {
1863 let pending_prediction = project_state.pending_predictions.pop().unwrap();
1864 project_state.pending_predictions.push(PendingPrediction {
1865 id: pending_prediction_id,
1866 task,
1867 drop_on_cancel,
1868 });
1869 project_state.cancel_pending_prediction(pending_prediction, cx);
1870 }
1871 }
1872
1873 pub fn request_prediction(
1874 &mut self,
1875 project: &Entity<Project>,
1876 active_buffer: &Entity<Buffer>,
1877 position: language::Anchor,
1878 trigger: PredictEditsRequestTrigger,
1879 cx: &mut Context<Self>,
1880 ) -> Task<Result<Option<EditPredictionResult>>> {
1881 self.request_prediction_internal(
1882 project.clone(),
1883 active_buffer.clone(),
1884 position,
1885 trigger,
1886 cx.has_flag::<Zeta2FeatureFlag>(),
1887 cx,
1888 )
1889 }
1890
1891 fn request_prediction_internal(
1892 &mut self,
1893 project: Entity<Project>,
1894 active_buffer: Entity<Buffer>,
1895 position: language::Anchor,
1896 trigger: PredictEditsRequestTrigger,
1897 allow_jump: bool,
1898 cx: &mut Context<Self>,
1899 ) -> Task<Result<Option<EditPredictionResult>>> {
1900 self.get_or_init_project(&project, cx);
1901 let project_state = self.projects.get(&project.entity_id()).unwrap();
1902 let stored_events = project_state.events(cx);
1903 let has_events = !stored_events.is_empty();
1904 let events: Vec<Arc<zeta_prompt::Event>> =
1905 stored_events.into_iter().map(|e| e.event).collect();
1906 let debug_tx = project_state.debug_tx.clone();
1907
1908 let snapshot = active_buffer.read(cx).snapshot();
1909 let cursor_point = position.to_point(&snapshot);
1910 let current_offset = position.to_offset(&snapshot);
1911
1912 let mut user_actions: Vec<UserActionRecord> =
1913 project_state.user_actions.iter().cloned().collect();
1914
1915 if let Some(last_action) = user_actions.last() {
1916 if last_action.buffer_id == active_buffer.entity_id()
1917 && current_offset != last_action.offset
1918 {
1919 let timestamp_epoch_ms = SystemTime::now()
1920 .duration_since(UNIX_EPOCH)
1921 .map(|d| d.as_millis() as u64)
1922 .unwrap_or(0);
1923 user_actions.push(UserActionRecord {
1924 action_type: UserActionType::CursorMovement,
1925 buffer_id: active_buffer.entity_id(),
1926 line_number: cursor_point.row,
1927 offset: current_offset,
1928 timestamp_epoch_ms,
1929 });
1930 }
1931 }
1932 let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1933 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1934 let diagnostic_search_range =
1935 Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1936
1937 let related_files = self.context_for_project(&project, cx);
1938
1939 let inputs = EditPredictionModelInput {
1940 project: project.clone(),
1941 buffer: active_buffer,
1942 snapshot: snapshot,
1943 position,
1944 events,
1945 related_files,
1946 recent_paths: project_state.recent_paths.clone(),
1947 trigger,
1948 diagnostic_search_range: diagnostic_search_range,
1949 debug_tx,
1950 user_actions,
1951 };
1952
1953 let task = match self.edit_prediction_model {
1954 EditPredictionModel::Zeta1 => zeta::request_prediction_with_zeta(
1955 self,
1956 inputs,
1957 Some(zeta_prompt::EditPredictionModelKind::Zeta1),
1958 cx,
1959 ),
1960 EditPredictionModel::Zeta2 => zeta::request_prediction_with_zeta(
1961 self,
1962 inputs,
1963 Some(zeta_prompt::EditPredictionModelKind::Zeta2),
1964 cx,
1965 ),
1966 EditPredictionModel::Fim { format } => fim::request_prediction(inputs, format, cx),
1967 EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
1968 EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
1969 };
1970
1971 cx.spawn(async move |this, cx| {
1972 let prediction = task.await?;
1973
1974 if prediction.is_none() && allow_jump && has_events {
1975 this.update(cx, |this, cx| {
1976 this.refresh_prediction_from_diagnostics(
1977 project,
1978 DiagnosticSearchScope::Local,
1979 cx,
1980 );
1981 })?;
1982 return anyhow::Ok(None);
1983 }
1984
1985 Ok(prediction)
1986 })
1987 }
1988
1989 pub(crate) async fn next_diagnostic_location(
1990 active_buffer: Entity<Buffer>,
1991 active_buffer_snapshot: &BufferSnapshot,
1992 active_buffer_diagnostic_search_range: Range<Point>,
1993 active_buffer_cursor_point: Point,
1994 project: &Entity<Project>,
1995 cx: &mut AsyncApp,
1996 ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1997 let collaborator_cursor_rows: Vec<u32> = active_buffer_snapshot
1998 .selections_in_range(Anchor::MIN..Anchor::MAX, false)
1999 .flat_map(|(_, _, _, selections)| {
2000 selections.map(|s| s.head().to_point(active_buffer_snapshot).row)
2001 })
2002 .collect();
2003
2004 let mut jump_location = active_buffer_snapshot
2005 .diagnostic_groups(None)
2006 .into_iter()
2007 .filter_map(|(_, group)| {
2008 let range = &group.entries[group.primary_ix]
2009 .range
2010 .to_point(&active_buffer_snapshot);
2011 if range.overlaps(&active_buffer_diagnostic_search_range) {
2012 return None;
2013 }
2014 let near_collaborator = collaborator_cursor_rows.iter().any(|&collab_row| {
2015 range.start.row.abs_diff(collab_row) <= DIAGNOSTIC_LINES_RANGE
2016 });
2017 let near_local = active_buffer_cursor_point.row.abs_diff(range.start.row)
2018 <= DIAGNOSTIC_LINES_RANGE;
2019 if near_collaborator && !near_local {
2020 return None;
2021 }
2022 Some(range.start)
2023 })
2024 .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
2025 .map(|position| {
2026 (
2027 active_buffer.clone(),
2028 active_buffer_snapshot.anchor_before(position),
2029 )
2030 });
2031
2032 if jump_location.is_none() {
2033 let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
2034 let file = buffer.file()?;
2035
2036 Some(ProjectPath {
2037 worktree_id: file.worktree_id(cx),
2038 path: file.path().clone(),
2039 })
2040 });
2041
2042 let mut candidates: Vec<(ProjectPath, usize)> = project.read_with(cx, |project, cx| {
2043 project
2044 .diagnostic_summaries(false, cx)
2045 .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
2046 .map(|(path, _, _)| {
2047 let shared_prefix = path
2048 .path
2049 .components()
2050 .zip(
2051 active_buffer_path
2052 .as_ref()
2053 .map(|p| p.path.components())
2054 .unwrap_or_default(),
2055 )
2056 .take_while(|(a, b)| a == b)
2057 .count();
2058 (path, shared_prefix)
2059 })
2060 .collect()
2061 });
2062
2063 candidates.sort_by(|a, b| b.1.cmp(&a.1));
2064
2065 for (path, _) in candidates {
2066 let candidate_buffer = project
2067 .update(cx, |project, cx| project.open_buffer(path, cx))
2068 .await?;
2069
2070 let (has_collaborators, diagnostic_position) =
2071 candidate_buffer.read_with(cx, |buffer, _cx| {
2072 let snapshot = buffer.snapshot();
2073 let has_collaborators = snapshot
2074 .selections_in_range(Anchor::MIN..Anchor::MAX, false)
2075 .next()
2076 .is_some();
2077 let position = buffer
2078 .buffer_diagnostics(None)
2079 .into_iter()
2080 .min_by_key(|entry| entry.diagnostic.severity)
2081 .map(|entry| entry.range.start);
2082 (has_collaborators, position)
2083 });
2084
2085 if has_collaborators {
2086 continue;
2087 }
2088
2089 if let Some(position) = diagnostic_position {
2090 jump_location = Some((candidate_buffer, position));
2091 break;
2092 }
2093 }
2094 }
2095
2096 anyhow::Ok(jump_location)
2097 }
2098
2099 async fn send_raw_llm_request(
2100 request: RawCompletionRequest,
2101 client: Arc<Client>,
2102 custom_url: Option<Arc<Url>>,
2103 llm_token: LlmApiToken,
2104 app_version: Version,
2105 ) -> Result<(RawCompletionResponse, Option<EditPredictionUsage>)> {
2106 let url = if let Some(custom_url) = custom_url {
2107 custom_url.as_ref().clone()
2108 } else {
2109 client
2110 .http_client()
2111 .build_zed_llm_url("/predict_edits/raw", &[])?
2112 };
2113
2114 Self::send_api_request(
2115 |builder| {
2116 let req = builder
2117 .uri(url.as_ref())
2118 .body(serde_json::to_string(&request)?.into());
2119 Ok(req?)
2120 },
2121 client,
2122 llm_token,
2123 app_version,
2124 true,
2125 )
2126 .await
2127 }
2128
2129 pub(crate) async fn send_v3_request(
2130 input: ZetaPromptInput,
2131 client: Arc<Client>,
2132 llm_token: LlmApiToken,
2133 app_version: Version,
2134 trigger: PredictEditsRequestTrigger,
2135 ) -> Result<(PredictEditsV3Response, Option<EditPredictionUsage>)> {
2136 let url = client
2137 .http_client()
2138 .build_zed_llm_url("/predict_edits/v3", &[])?;
2139
2140 let request = PredictEditsV3Request { input, trigger };
2141
2142 let json_bytes = serde_json::to_vec(&request)?;
2143 let compressed = zstd::encode_all(&json_bytes[..], 3)?;
2144
2145 Self::send_api_request(
2146 |builder| {
2147 let req = builder
2148 .uri(url.as_ref())
2149 .header("Content-Encoding", "zstd")
2150 .body(compressed.clone().into());
2151 Ok(req?)
2152 },
2153 client,
2154 llm_token,
2155 app_version,
2156 true,
2157 )
2158 .await
2159 }
2160
2161 fn handle_api_response<T>(
2162 this: &WeakEntity<Self>,
2163 response: Result<(T, Option<EditPredictionUsage>)>,
2164 cx: &mut gpui::AsyncApp,
2165 ) -> Result<T> {
2166 match response {
2167 Ok((data, usage)) => {
2168 if let Some(usage) = usage {
2169 this.update(cx, |this, cx| {
2170 this.user_store.update(cx, |user_store, cx| {
2171 user_store.update_edit_prediction_usage(usage, cx);
2172 });
2173 })
2174 .ok();
2175 }
2176 Ok(data)
2177 }
2178 Err(err) => {
2179 if err.is::<ZedUpdateRequiredError>() {
2180 cx.update(|cx| {
2181 this.update(cx, |this, _cx| {
2182 this.update_required = true;
2183 })
2184 .ok();
2185
2186 let error_message: SharedString = err.to_string().into();
2187 show_app_notification(
2188 NotificationId::unique::<ZedUpdateRequiredError>(),
2189 cx,
2190 move |cx| {
2191 cx.new(|cx| {
2192 ErrorMessagePrompt::new(error_message.clone(), cx)
2193 .with_link_button("Update Zed", "https://zed.dev/releases")
2194 })
2195 },
2196 );
2197 });
2198 }
2199 Err(err)
2200 }
2201 }
2202 }
2203
2204 async fn send_api_request<Res>(
2205 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
2206 client: Arc<Client>,
2207 llm_token: LlmApiToken,
2208 app_version: Version,
2209 require_auth: bool,
2210 ) -> Result<(Res, Option<EditPredictionUsage>)>
2211 where
2212 Res: DeserializeOwned,
2213 {
2214 let http_client = client.http_client();
2215
2216 let mut token = if require_auth {
2217 Some(llm_token.acquire(&client).await?)
2218 } else {
2219 llm_token.acquire(&client).await.ok()
2220 };
2221 let mut did_retry = false;
2222
2223 loop {
2224 let request_builder = http_client::Request::builder().method(Method::POST);
2225
2226 let mut request_builder = request_builder
2227 .header("Content-Type", "application/json")
2228 .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
2229
2230 // Only add Authorization header if we have a token
2231 if let Some(ref token_value) = token {
2232 request_builder =
2233 request_builder.header("Authorization", format!("Bearer {}", token_value));
2234 }
2235
2236 let request = build(request_builder)?;
2237
2238 let mut response = http_client.send(request).await?;
2239
2240 if let Some(minimum_required_version) = response
2241 .headers()
2242 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2243 .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2244 {
2245 anyhow::ensure!(
2246 app_version >= minimum_required_version,
2247 ZedUpdateRequiredError {
2248 minimum_version: minimum_required_version
2249 }
2250 );
2251 }
2252
2253 if response.status().is_success() {
2254 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2255
2256 let mut body = Vec::new();
2257 response.body_mut().read_to_end(&mut body).await?;
2258 return Ok((serde_json::from_slice(&body)?, usage));
2259 } else if !did_retry && token.is_some() && response.needs_llm_token_refresh() {
2260 did_retry = true;
2261 token = Some(llm_token.refresh(&client).await?);
2262 } else {
2263 let mut body = String::new();
2264 response.body_mut().read_to_string(&mut body).await?;
2265 anyhow::bail!(
2266 "Request failed with status: {:?}\nBody: {}",
2267 response.status(),
2268 body
2269 );
2270 }
2271 }
2272 }
2273
2274 pub fn refresh_context(
2275 &mut self,
2276 project: &Entity<Project>,
2277 buffer: &Entity<language::Buffer>,
2278 cursor_position: language::Anchor,
2279 cx: &mut Context<Self>,
2280 ) {
2281 self.get_or_init_project(project, cx)
2282 .context
2283 .update(cx, |store, cx| {
2284 store.refresh(buffer.clone(), cursor_position, cx);
2285 });
2286 }
2287
2288 #[cfg(feature = "cli-support")]
2289 pub fn set_context_for_buffer(
2290 &mut self,
2291 project: &Entity<Project>,
2292 related_files: Vec<RelatedFile>,
2293 cx: &mut Context<Self>,
2294 ) {
2295 self.get_or_init_project(project, cx)
2296 .context
2297 .update(cx, |store, cx| {
2298 store.set_related_files(related_files, cx);
2299 });
2300 }
2301
2302 #[cfg(feature = "cli-support")]
2303 pub fn set_recent_paths_for_project(
2304 &mut self,
2305 project: &Entity<Project>,
2306 paths: impl IntoIterator<Item = project::ProjectPath>,
2307 cx: &mut Context<Self>,
2308 ) {
2309 let project_state = self.get_or_init_project(project, cx);
2310 project_state.recent_paths = paths.into_iter().collect();
2311 }
2312
2313 fn is_file_open_source(
2314 &self,
2315 project: &Entity<Project>,
2316 file: &Arc<dyn File>,
2317 cx: &App,
2318 ) -> bool {
2319 if !file.is_local() || file.is_private() {
2320 return false;
2321 }
2322 let Some(project_state) = self.projects.get(&project.entity_id()) else {
2323 return false;
2324 };
2325 project_state
2326 .license_detection_watchers
2327 .get(&file.worktree_id(cx))
2328 .as_ref()
2329 .is_some_and(|watcher| watcher.is_project_open_source())
2330 }
2331
2332 pub(crate) fn is_data_collection_enabled(&self, cx: &App) -> bool {
2333 self.data_collection_choice.is_enabled(cx)
2334 }
2335
2336 fn load_data_collection_choice() -> DataCollectionChoice {
2337 let choice = KEY_VALUE_STORE
2338 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2339 .log_err()
2340 .flatten();
2341
2342 match choice.as_deref() {
2343 Some("true") => DataCollectionChoice::Enabled,
2344 Some("false") => DataCollectionChoice::Disabled,
2345 Some(_) => {
2346 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2347 DataCollectionChoice::NotAnswered
2348 }
2349 None => DataCollectionChoice::NotAnswered,
2350 }
2351 }
2352
2353 fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
2354 self.data_collection_choice = self.data_collection_choice.toggle();
2355 let new_choice = self.data_collection_choice;
2356 let is_enabled = new_choice.is_enabled(cx);
2357 db::write_and_log(cx, move || {
2358 KEY_VALUE_STORE.write_kvp(
2359 ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
2360 is_enabled.to_string(),
2361 )
2362 });
2363 }
2364
2365 pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2366 self.shown_predictions.iter()
2367 }
2368
2369 pub fn shown_completions_len(&self) -> usize {
2370 self.shown_predictions.len()
2371 }
2372
2373 pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2374 self.rated_predictions.contains(id)
2375 }
2376
2377 pub fn rate_prediction(
2378 &mut self,
2379 prediction: &EditPrediction,
2380 rating: EditPredictionRating,
2381 feedback: String,
2382 cx: &mut Context<Self>,
2383 ) {
2384 let organization = self.user_store.read(cx).current_organization();
2385
2386 self.rated_predictions.insert(prediction.id.clone());
2387
2388 cx.background_spawn({
2389 let client = self.client.clone();
2390 let prediction_id = prediction.id.to_string();
2391 let inputs = serde_json::to_value(&prediction.inputs);
2392 let output = prediction
2393 .edit_preview
2394 .as_unified_diff(prediction.snapshot.file(), &prediction.edits);
2395 async move {
2396 client
2397 .cloud_client()
2398 .submit_edit_prediction_feedback(SubmitEditPredictionFeedbackBody {
2399 organization_id: organization.map(|organization| organization.id.clone()),
2400 request_id: prediction_id,
2401 rating: match rating {
2402 EditPredictionRating::Positive => "positive".to_string(),
2403 EditPredictionRating::Negative => "negative".to_string(),
2404 },
2405 inputs: inputs?,
2406 output,
2407 feedback,
2408 })
2409 .await?;
2410
2411 anyhow::Ok(())
2412 }
2413 })
2414 .detach_and_log_err(cx);
2415
2416 cx.notify();
2417 }
2418}
2419
2420fn merge_trailing_events_if_needed(
2421 events: &mut VecDeque<StoredEvent>,
2422 end_snapshot: &TextBufferSnapshot,
2423 latest_snapshot: &TextBufferSnapshot,
2424 latest_edit_range: &Range<Anchor>,
2425) {
2426 if let Some(last_event) = events.back() {
2427 if last_event.old_snapshot.remote_id() != latest_snapshot.remote_id() {
2428 return;
2429 }
2430 }
2431
2432 let mut next_old_event = None;
2433 let mut mergeable_count = 0;
2434 for old_event in events.iter().rev() {
2435 if let Some(next_old_event) = &next_old_event
2436 && !old_event.can_merge(&next_old_event, latest_snapshot, latest_edit_range)
2437 {
2438 break;
2439 }
2440 mergeable_count += 1;
2441 next_old_event = Some(old_event);
2442 }
2443
2444 if mergeable_count <= 1 {
2445 return;
2446 }
2447
2448 let mut events_to_merge = events.range(events.len() - mergeable_count..).peekable();
2449 let oldest_event = events_to_merge.peek().unwrap();
2450 let oldest_snapshot = oldest_event.old_snapshot.clone();
2451
2452 if let Some((diff, edited_range)) =
2453 compute_diff_between_snapshots(&oldest_snapshot, end_snapshot)
2454 {
2455 let merged_event = match oldest_event.event.as_ref() {
2456 zeta_prompt::Event::BufferChange {
2457 old_path,
2458 path,
2459 in_open_source_repo,
2460 ..
2461 } => StoredEvent {
2462 event: Arc::new(zeta_prompt::Event::BufferChange {
2463 old_path: old_path.clone(),
2464 path: path.clone(),
2465 diff,
2466 in_open_source_repo: *in_open_source_repo,
2467 predicted: events_to_merge.all(|e| {
2468 matches!(
2469 e.event.as_ref(),
2470 zeta_prompt::Event::BufferChange {
2471 predicted: true,
2472 ..
2473 }
2474 )
2475 }),
2476 }),
2477 old_snapshot: oldest_snapshot.clone(),
2478 edit_range: end_snapshot.anchor_before(edited_range.start)
2479 ..end_snapshot.anchor_before(edited_range.end),
2480 },
2481 };
2482 events.truncate(events.len() - mergeable_count);
2483 events.push_back(merged_event);
2484 }
2485}
2486
2487pub(crate) fn filter_redundant_excerpts(
2488 mut related_files: Vec<RelatedFile>,
2489 cursor_path: &Path,
2490 cursor_row_range: Range<u32>,
2491) -> Vec<RelatedFile> {
2492 for file in &mut related_files {
2493 if file.path.as_ref() == cursor_path {
2494 file.excerpts.retain(|excerpt| {
2495 excerpt.row_range.start < cursor_row_range.start
2496 || excerpt.row_range.end > cursor_row_range.end
2497 });
2498 }
2499 }
2500 related_files.retain(|file| !file.excerpts.is_empty());
2501 related_files
2502}
2503
2504#[derive(Error, Debug)]
2505#[error(
2506 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2507)]
2508pub struct ZedUpdateRequiredError {
2509 minimum_version: Version,
2510}
2511
2512#[derive(Debug, Clone, Copy)]
2513pub enum DataCollectionChoice {
2514 NotAnswered,
2515 Enabled,
2516 Disabled,
2517}
2518
2519impl DataCollectionChoice {
2520 pub fn is_enabled(self, cx: &App) -> bool {
2521 if cx.is_staff() {
2522 return true;
2523 }
2524 match self {
2525 Self::Enabled => true,
2526 Self::NotAnswered | Self::Disabled => false,
2527 }
2528 }
2529
2530 #[must_use]
2531 pub fn toggle(&self) -> DataCollectionChoice {
2532 match self {
2533 Self::Enabled => Self::Disabled,
2534 Self::Disabled => Self::Enabled,
2535 Self::NotAnswered => Self::Enabled,
2536 }
2537 }
2538}
2539
2540impl From<bool> for DataCollectionChoice {
2541 fn from(value: bool) -> Self {
2542 match value {
2543 true => DataCollectionChoice::Enabled,
2544 false => DataCollectionChoice::Disabled,
2545 }
2546 }
2547}
2548
2549struct ZedPredictUpsell;
2550
2551impl Dismissable for ZedPredictUpsell {
2552 const KEY: &'static str = "dismissed-edit-predict-upsell";
2553
2554 fn dismissed() -> bool {
2555 // To make this backwards compatible with older versions of Zed, we
2556 // check if the user has seen the previous Edit Prediction Onboarding
2557 // before, by checking the data collection choice which was written to
2558 // the database once the user clicked on "Accept and Enable"
2559 if KEY_VALUE_STORE
2560 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2561 .log_err()
2562 .is_some_and(|s| s.is_some())
2563 {
2564 return true;
2565 }
2566
2567 KEY_VALUE_STORE
2568 .read_kvp(Self::KEY)
2569 .log_err()
2570 .is_some_and(|s| s.is_some())
2571 }
2572}
2573
2574pub fn should_show_upsell_modal() -> bool {
2575 !ZedPredictUpsell::dismissed()
2576}
2577
2578pub fn init(cx: &mut App) {
2579 cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2580 workspace.register_action(
2581 move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2582 ZedPredictModal::toggle(
2583 workspace,
2584 workspace.user_store().clone(),
2585 workspace.client().clone(),
2586 window,
2587 cx,
2588 )
2589 },
2590 );
2591
2592 workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2593 update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2594 settings
2595 .project
2596 .all_languages
2597 .edit_predictions
2598 .get_or_insert_default()
2599 .provider = Some(EditPredictionProvider::None)
2600 });
2601 });
2602 fn copilot_for_project(project: &Entity<Project>, cx: &mut App) -> Option<Entity<Copilot>> {
2603 EditPredictionStore::try_global(cx).and_then(|store| {
2604 store.update(cx, |this, cx| this.start_copilot_for_project(project, cx))
2605 })
2606 }
2607
2608 workspace.register_action(|workspace, _: &SignIn, window, cx| {
2609 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2610 copilot_ui::initiate_sign_in(copilot, window, cx);
2611 }
2612 });
2613 workspace.register_action(|workspace, _: &Reinstall, window, cx| {
2614 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2615 copilot_ui::reinstall_and_sign_in(copilot, window, cx);
2616 }
2617 });
2618 workspace.register_action(|workspace, _: &SignOut, window, cx| {
2619 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2620 copilot_ui::initiate_sign_out(copilot, window, cx);
2621 }
2622 });
2623 })
2624 .detach();
2625}