1use anyhow::Result;
2use arrayvec::ArrayVec;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{self, PromptFormat};
5use cloud_llm_client::{
6 AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason,
7 EditPredictionRejection, MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST,
8 MINIMUM_REQUIRED_VERSION_HEADER_NAME, PredictEditsRequestTrigger, RejectEditPredictionsBodyRef,
9 ZED_VERSION_HEADER_NAME,
10};
11use collections::{HashMap, HashSet};
12use db::kvp::{Dismissable, KEY_VALUE_STORE};
13use edit_prediction_context::EditPredictionExcerptOptions;
14use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
15use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
16use futures::{
17 AsyncReadExt as _, FutureExt as _, StreamExt as _,
18 channel::mpsc::{self, UnboundedReceiver},
19 select_biased,
20};
21use gpui::BackgroundExecutor;
22use gpui::{
23 App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
24 http_client::{self, AsyncBody, Method},
25 prelude::*,
26};
27use language::language_settings::all_language_settings;
28use language::{Anchor, Buffer, File, Point, ToPoint};
29use language::{BufferSnapshot, OffsetRangeExt};
30use language_model::{LlmApiToken, RefreshLlmTokenListener};
31use project::{Project, ProjectPath, WorktreeId};
32use release_channel::AppVersion;
33use semver::Version;
34use serde::de::DeserializeOwned;
35use settings::{EditPredictionProvider, SettingsStore, update_settings_file};
36use std::collections::{VecDeque, hash_map};
37use workspace::Workspace;
38
39use std::ops::Range;
40use std::path::Path;
41use std::rc::Rc;
42use std::str::FromStr as _;
43use std::sync::{Arc, LazyLock};
44use std::time::{Duration, Instant};
45use std::{env, mem};
46use thiserror::Error;
47use util::{RangeExt as _, ResultExt as _};
48use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
49
50mod cursor_excerpt;
51mod license_detection;
52pub mod mercury;
53mod onboarding_modal;
54pub mod open_ai_response;
55mod prediction;
56pub mod sweep_ai;
57
58#[cfg(any(test, feature = "test-support", feature = "eval-support"))]
59pub mod udiff;
60
61mod zed_edit_prediction_delegate;
62pub mod zeta1;
63pub mod zeta2;
64
65#[cfg(test)]
66mod edit_prediction_tests;
67
68use crate::license_detection::LicenseDetectionWatcher;
69use crate::mercury::Mercury;
70use crate::onboarding_modal::ZedPredictModal;
71pub use crate::prediction::EditPrediction;
72pub use crate::prediction::EditPredictionId;
73use crate::prediction::EditPredictionResult;
74pub use crate::sweep_ai::SweepAi;
75pub use telemetry_events::EditPredictionRating;
76pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
77
78actions!(
79 edit_prediction,
80 [
81 /// Resets the edit prediction onboarding state.
82 ResetOnboarding,
83 /// Clears the edit prediction history.
84 ClearHistory,
85 ]
86);
87
88/// Maximum number of events to track.
89const EVENT_COUNT_MAX: usize = 6;
90const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
91const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
92const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
93
94pub struct SweepFeatureFlag;
95
96impl FeatureFlag for SweepFeatureFlag {
97 const NAME: &str = "sweep-ai";
98}
99
100pub struct MercuryFeatureFlag;
101
102impl FeatureFlag for MercuryFeatureFlag {
103 const NAME: &str = "mercury";
104}
105
106pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
107 context: EditPredictionExcerptOptions {
108 max_bytes: 512,
109 min_bytes: 128,
110 target_before_cursor_over_total_bytes: 0.5,
111 },
112 prompt_format: PromptFormat::DEFAULT,
113};
114
115static USE_OLLAMA: LazyLock<bool> =
116 LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
117
118static EDIT_PREDICTIONS_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
119 match env::var("ZED_ZETA2_MODEL").as_deref() {
120 Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten
121 Ok(model) => model,
122 Err(_) if *USE_OLLAMA => "qwen3-coder:30b",
123 Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten
124 }
125 .to_string()
126});
127static PREDICT_EDITS_URL: LazyLock<Option<String>> = LazyLock::new(|| {
128 env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| {
129 if *USE_OLLAMA {
130 Some("http://localhost:11434/v1/chat/completions".into())
131 } else {
132 None
133 }
134 })
135});
136
137pub struct Zeta2FeatureFlag;
138
139impl FeatureFlag for Zeta2FeatureFlag {
140 const NAME: &'static str = "zeta2";
141
142 fn enabled_for_staff() -> bool {
143 true
144 }
145}
146
147#[derive(Clone)]
148struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
149
150impl Global for EditPredictionStoreGlobal {}
151
152pub struct EditPredictionStore {
153 client: Arc<Client>,
154 user_store: Entity<UserStore>,
155 llm_token: LlmApiToken,
156 _llm_token_subscription: Subscription,
157 projects: HashMap<EntityId, ProjectState>,
158 use_context: bool,
159 options: ZetaOptions,
160 update_required: bool,
161 #[cfg(feature = "eval-support")]
162 eval_cache: Option<Arc<dyn EvalCache>>,
163 edit_prediction_model: EditPredictionModel,
164 pub sweep_ai: SweepAi,
165 pub mercury: Mercury,
166 data_collection_choice: DataCollectionChoice,
167 reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
168 shown_predictions: VecDeque<EditPrediction>,
169 rated_predictions: HashSet<EditPredictionId>,
170}
171
172#[derive(Copy, Clone, Default, PartialEq, Eq)]
173pub enum EditPredictionModel {
174 #[default]
175 Zeta1,
176 Zeta2,
177 Sweep,
178 Mercury,
179}
180
181pub struct EditPredictionModelInput {
182 project: Entity<Project>,
183 buffer: Entity<Buffer>,
184 snapshot: BufferSnapshot,
185 position: Anchor,
186 events: Vec<Arc<zeta_prompt::Event>>,
187 related_files: Arc<[RelatedFile]>,
188 recent_paths: VecDeque<ProjectPath>,
189 trigger: PredictEditsRequestTrigger,
190 diagnostic_search_range: Range<Point>,
191 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
192}
193
194#[derive(Debug, Clone, PartialEq)]
195pub struct ZetaOptions {
196 pub context: EditPredictionExcerptOptions,
197 pub prompt_format: predict_edits_v3::PromptFormat,
198}
199
200#[derive(Debug)]
201pub enum DebugEvent {
202 ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
203 ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
204 EditPredictionStarted(EditPredictionStartedDebugEvent),
205 EditPredictionFinished(EditPredictionFinishedDebugEvent),
206}
207
208#[derive(Debug)]
209pub struct ContextRetrievalStartedDebugEvent {
210 pub project_entity_id: EntityId,
211 pub timestamp: Instant,
212 pub search_prompt: String,
213}
214
215#[derive(Debug)]
216pub struct ContextRetrievalFinishedDebugEvent {
217 pub project_entity_id: EntityId,
218 pub timestamp: Instant,
219 pub metadata: Vec<(&'static str, SharedString)>,
220}
221
222#[derive(Debug)]
223pub struct EditPredictionStartedDebugEvent {
224 pub buffer: WeakEntity<Buffer>,
225 pub position: Anchor,
226 pub prompt: Option<String>,
227}
228
229#[derive(Debug)]
230pub struct EditPredictionFinishedDebugEvent {
231 pub buffer: WeakEntity<Buffer>,
232 pub position: Anchor,
233 pub model_output: Option<String>,
234}
235
236pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
237
238struct ProjectState {
239 events: VecDeque<Arc<zeta_prompt::Event>>,
240 last_event: Option<LastEvent>,
241 recent_paths: VecDeque<ProjectPath>,
242 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
243 current_prediction: Option<CurrentEditPrediction>,
244 next_pending_prediction_id: usize,
245 pending_predictions: ArrayVec<PendingPrediction, 2>,
246 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
247 last_prediction_refresh: Option<(EntityId, Instant)>,
248 cancelled_predictions: HashSet<usize>,
249 context: Entity<RelatedExcerptStore>,
250 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
251 _subscription: gpui::Subscription,
252}
253
254impl ProjectState {
255 pub fn events(&self, cx: &App) -> Vec<Arc<zeta_prompt::Event>> {
256 self.events
257 .iter()
258 .cloned()
259 .chain(
260 self.last_event
261 .as_ref()
262 .and_then(|event| event.finalize(&self.license_detection_watchers, cx)),
263 )
264 .collect()
265 }
266
267 fn cancel_pending_prediction(
268 &mut self,
269 pending_prediction: PendingPrediction,
270 cx: &mut Context<EditPredictionStore>,
271 ) {
272 self.cancelled_predictions.insert(pending_prediction.id);
273
274 cx.spawn(async move |this, cx| {
275 let Some(prediction_id) = pending_prediction.task.await else {
276 return;
277 };
278
279 this.update(cx, |this, _cx| {
280 this.reject_prediction(prediction_id, EditPredictionRejectReason::Canceled, false);
281 })
282 .ok();
283 })
284 .detach()
285 }
286}
287
288#[derive(Debug, Clone)]
289struct CurrentEditPrediction {
290 pub requested_by: PredictionRequestedBy,
291 pub prediction: EditPrediction,
292 pub was_shown: bool,
293}
294
295impl CurrentEditPrediction {
296 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
297 let Some(new_edits) = self
298 .prediction
299 .interpolate(&self.prediction.buffer.read(cx))
300 else {
301 return false;
302 };
303
304 if self.prediction.buffer != old_prediction.prediction.buffer {
305 return true;
306 }
307
308 let Some(old_edits) = old_prediction
309 .prediction
310 .interpolate(&old_prediction.prediction.buffer.read(cx))
311 else {
312 return true;
313 };
314
315 let requested_by_buffer_id = self.requested_by.buffer_id();
316
317 // This reduces the occurrence of UI thrash from replacing edits
318 //
319 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
320 if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
321 && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
322 && old_edits.len() == 1
323 && new_edits.len() == 1
324 {
325 let (old_range, old_text) = &old_edits[0];
326 let (new_range, new_text) = &new_edits[0];
327 new_range == old_range && new_text.starts_with(old_text.as_ref())
328 } else {
329 true
330 }
331 }
332}
333
334#[derive(Debug, Clone)]
335enum PredictionRequestedBy {
336 DiagnosticsUpdate,
337 Buffer(EntityId),
338}
339
340impl PredictionRequestedBy {
341 pub fn buffer_id(&self) -> Option<EntityId> {
342 match self {
343 PredictionRequestedBy::DiagnosticsUpdate => None,
344 PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
345 }
346 }
347}
348
349#[derive(Debug)]
350struct PendingPrediction {
351 id: usize,
352 task: Task<Option<EditPredictionId>>,
353}
354
355/// A prediction from the perspective of a buffer.
356#[derive(Debug)]
357enum BufferEditPrediction<'a> {
358 Local { prediction: &'a EditPrediction },
359 Jump { prediction: &'a EditPrediction },
360}
361
362#[cfg(test)]
363impl std::ops::Deref for BufferEditPrediction<'_> {
364 type Target = EditPrediction;
365
366 fn deref(&self) -> &Self::Target {
367 match self {
368 BufferEditPrediction::Local { prediction } => prediction,
369 BufferEditPrediction::Jump { prediction } => prediction,
370 }
371 }
372}
373
374struct RegisteredBuffer {
375 snapshot: BufferSnapshot,
376 _subscriptions: [gpui::Subscription; 2],
377}
378
379struct LastEvent {
380 old_snapshot: BufferSnapshot,
381 new_snapshot: BufferSnapshot,
382 end_edit_anchor: Option<Anchor>,
383}
384
385impl LastEvent {
386 pub fn finalize(
387 &self,
388 license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
389 cx: &App,
390 ) -> Option<Arc<zeta_prompt::Event>> {
391 let path = buffer_path_with_id_fallback(&self.new_snapshot, cx);
392 let old_path = buffer_path_with_id_fallback(&self.old_snapshot, cx);
393
394 let file = self.new_snapshot.file();
395 let old_file = self.old_snapshot.file();
396
397 let in_open_source_repo = [file, old_file].iter().all(|file| {
398 file.is_some_and(|file| {
399 license_detection_watchers
400 .get(&file.worktree_id(cx))
401 .is_some_and(|watcher| watcher.is_project_open_source())
402 })
403 });
404
405 let diff = language::unified_diff(&self.old_snapshot.text(), &self.new_snapshot.text());
406
407 if path == old_path && diff.is_empty() {
408 None
409 } else {
410 Some(Arc::new(zeta_prompt::Event::BufferChange {
411 old_path,
412 path,
413 diff,
414 in_open_source_repo,
415 // TODO: Actually detect if this edit was predicted or not
416 predicted: false,
417 }))
418 }
419 }
420}
421
422fn buffer_path_with_id_fallback(snapshot: &BufferSnapshot, cx: &App) -> Arc<Path> {
423 if let Some(file) = snapshot.file() {
424 file.full_path(cx).into()
425 } else {
426 Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
427 }
428}
429
430impl EditPredictionStore {
431 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
432 cx.try_global::<EditPredictionStoreGlobal>()
433 .map(|global| global.0.clone())
434 }
435
436 pub fn global(
437 client: &Arc<Client>,
438 user_store: &Entity<UserStore>,
439 cx: &mut App,
440 ) -> Entity<Self> {
441 cx.try_global::<EditPredictionStoreGlobal>()
442 .map(|global| global.0.clone())
443 .unwrap_or_else(|| {
444 let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
445 cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
446 ep_store
447 })
448 }
449
450 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
451 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
452 let data_collection_choice = Self::load_data_collection_choice();
453
454 let llm_token = LlmApiToken::default();
455
456 let (reject_tx, reject_rx) = mpsc::unbounded();
457 cx.background_spawn({
458 let client = client.clone();
459 let llm_token = llm_token.clone();
460 let app_version = AppVersion::global(cx);
461 let background_executor = cx.background_executor().clone();
462 async move {
463 Self::handle_rejected_predictions(
464 reject_rx,
465 client,
466 llm_token,
467 app_version,
468 background_executor,
469 )
470 .await
471 }
472 })
473 .detach();
474
475 let mut this = Self {
476 projects: HashMap::default(),
477 client,
478 user_store,
479 options: DEFAULT_OPTIONS,
480 use_context: false,
481 llm_token,
482 _llm_token_subscription: cx.subscribe(
483 &refresh_llm_token_listener,
484 |this, _listener, _event, cx| {
485 let client = this.client.clone();
486 let llm_token = this.llm_token.clone();
487 cx.spawn(async move |_this, _cx| {
488 llm_token.refresh(&client).await?;
489 anyhow::Ok(())
490 })
491 .detach_and_log_err(cx);
492 },
493 ),
494 update_required: false,
495 #[cfg(feature = "eval-support")]
496 eval_cache: None,
497 edit_prediction_model: EditPredictionModel::Zeta2,
498 sweep_ai: SweepAi::new(cx),
499 mercury: Mercury::new(cx),
500 data_collection_choice,
501 reject_predictions_tx: reject_tx,
502 rated_predictions: Default::default(),
503 shown_predictions: Default::default(),
504 };
505
506 this.configure_context_retrieval(cx);
507 let weak_this = cx.weak_entity();
508 cx.on_flags_ready(move |_, cx| {
509 weak_this
510 .update(cx, |this, cx| this.configure_context_retrieval(cx))
511 .ok();
512 })
513 .detach();
514 cx.observe_global::<SettingsStore>(|this, cx| {
515 this.configure_context_retrieval(cx);
516 })
517 .detach();
518
519 this
520 }
521
522 pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
523 self.edit_prediction_model = model;
524 }
525
526 pub fn has_sweep_api_token(&self) -> bool {
527 self.sweep_ai
528 .api_token
529 .clone()
530 .now_or_never()
531 .flatten()
532 .is_some()
533 }
534
535 pub fn has_mercury_api_token(&self) -> bool {
536 self.mercury
537 .api_token
538 .clone()
539 .now_or_never()
540 .flatten()
541 .is_some()
542 }
543
544 #[cfg(feature = "eval-support")]
545 pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
546 self.eval_cache = Some(cache);
547 }
548
549 pub fn options(&self) -> &ZetaOptions {
550 &self.options
551 }
552
553 pub fn set_options(&mut self, options: ZetaOptions) {
554 self.options = options;
555 }
556
557 pub fn set_use_context(&mut self, use_context: bool) {
558 self.use_context = use_context;
559 }
560
561 pub fn clear_history(&mut self) {
562 for project_state in self.projects.values_mut() {
563 project_state.events.clear();
564 }
565 }
566
567 pub fn edit_history_for_project(
568 &self,
569 project: &Entity<Project>,
570 ) -> Vec<Arc<zeta_prompt::Event>> {
571 self.projects
572 .get(&project.entity_id())
573 .map(|project_state| project_state.events.iter().cloned().collect())
574 .unwrap_or_default()
575 }
576
577 pub fn context_for_project<'a>(
578 &'a self,
579 project: &Entity<Project>,
580 cx: &'a App,
581 ) -> Arc<[RelatedFile]> {
582 self.projects
583 .get(&project.entity_id())
584 .map(|project| project.context.read(cx).related_files())
585 .unwrap_or_else(|| vec![].into())
586 }
587
588 pub fn context_for_project_with_buffers<'a>(
589 &'a self,
590 project: &Entity<Project>,
591 cx: &'a App,
592 ) -> Option<impl 'a + Iterator<Item = (RelatedFile, Entity<Buffer>)>> {
593 self.projects
594 .get(&project.entity_id())
595 .map(|project| project.context.read(cx).related_files_with_buffers())
596 }
597
598 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
599 if self.edit_prediction_model == EditPredictionModel::Zeta2 {
600 self.user_store.read(cx).edit_prediction_usage()
601 } else {
602 None
603 }
604 }
605
606 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
607 self.get_or_init_project(project, cx);
608 }
609
610 pub fn register_buffer(
611 &mut self,
612 buffer: &Entity<Buffer>,
613 project: &Entity<Project>,
614 cx: &mut Context<Self>,
615 ) {
616 let project_state = self.get_or_init_project(project, cx);
617 Self::register_buffer_impl(project_state, buffer, project, cx);
618 }
619
620 fn get_or_init_project(
621 &mut self,
622 project: &Entity<Project>,
623 cx: &mut Context<Self>,
624 ) -> &mut ProjectState {
625 let entity_id = project.entity_id();
626 self.projects
627 .entry(entity_id)
628 .or_insert_with(|| ProjectState {
629 context: {
630 let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
631 cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
632 this.handle_excerpt_store_event(entity_id, event);
633 })
634 .detach();
635 related_excerpt_store
636 },
637 events: VecDeque::new(),
638 last_event: None,
639 recent_paths: VecDeque::new(),
640 debug_tx: None,
641 registered_buffers: HashMap::default(),
642 current_prediction: None,
643 cancelled_predictions: HashSet::default(),
644 pending_predictions: ArrayVec::new(),
645 next_pending_prediction_id: 0,
646 last_prediction_refresh: None,
647 license_detection_watchers: HashMap::default(),
648 _subscription: cx.subscribe(&project, Self::handle_project_event),
649 })
650 }
651
652 pub fn remove_project(&mut self, project: &Entity<Project>) {
653 self.projects.remove(&project.entity_id());
654 }
655
656 fn handle_excerpt_store_event(
657 &mut self,
658 project_entity_id: EntityId,
659 event: &RelatedExcerptStoreEvent,
660 ) {
661 if let Some(project_state) = self.projects.get(&project_entity_id) {
662 if let Some(debug_tx) = project_state.debug_tx.clone() {
663 match event {
664 RelatedExcerptStoreEvent::StartedRefresh => {
665 debug_tx
666 .unbounded_send(DebugEvent::ContextRetrievalStarted(
667 ContextRetrievalStartedDebugEvent {
668 project_entity_id: project_entity_id,
669 timestamp: Instant::now(),
670 search_prompt: String::new(),
671 },
672 ))
673 .ok();
674 }
675 RelatedExcerptStoreEvent::FinishedRefresh {
676 cache_hit_count,
677 cache_miss_count,
678 mean_definition_latency,
679 max_definition_latency,
680 } => {
681 debug_tx
682 .unbounded_send(DebugEvent::ContextRetrievalFinished(
683 ContextRetrievalFinishedDebugEvent {
684 project_entity_id: project_entity_id,
685 timestamp: Instant::now(),
686 metadata: vec![
687 (
688 "Cache Hits",
689 format!(
690 "{}/{}",
691 cache_hit_count,
692 cache_hit_count + cache_miss_count
693 )
694 .into(),
695 ),
696 (
697 "Max LSP Time",
698 format!("{} ms", max_definition_latency.as_millis())
699 .into(),
700 ),
701 (
702 "Mean LSP Time",
703 format!("{} ms", mean_definition_latency.as_millis())
704 .into(),
705 ),
706 ],
707 },
708 ))
709 .ok();
710 }
711 }
712 }
713 }
714 }
715
716 pub fn debug_info(
717 &mut self,
718 project: &Entity<Project>,
719 cx: &mut Context<Self>,
720 ) -> mpsc::UnboundedReceiver<DebugEvent> {
721 let project_state = self.get_or_init_project(project, cx);
722 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
723 project_state.debug_tx = Some(debug_watch_tx);
724 debug_watch_rx
725 }
726
727 fn handle_project_event(
728 &mut self,
729 project: Entity<Project>,
730 event: &project::Event,
731 cx: &mut Context<Self>,
732 ) {
733 // TODO [zeta2] init with recent paths
734 match event {
735 project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
736 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
737 return;
738 };
739 let path = project.read(cx).path_for_entry(*active_entry_id, cx);
740 if let Some(path) = path {
741 if let Some(ix) = project_state
742 .recent_paths
743 .iter()
744 .position(|probe| probe == &path)
745 {
746 project_state.recent_paths.remove(ix);
747 }
748 project_state.recent_paths.push_front(path);
749 }
750 }
751 project::Event::DiagnosticsUpdated { .. } => {
752 if cx.has_flag::<Zeta2FeatureFlag>() {
753 self.refresh_prediction_from_diagnostics(project, cx);
754 }
755 }
756 _ => (),
757 }
758 }
759
760 fn register_buffer_impl<'a>(
761 project_state: &'a mut ProjectState,
762 buffer: &Entity<Buffer>,
763 project: &Entity<Project>,
764 cx: &mut Context<Self>,
765 ) -> &'a mut RegisteredBuffer {
766 let buffer_id = buffer.entity_id();
767
768 if let Some(file) = buffer.read(cx).file() {
769 let worktree_id = file.worktree_id(cx);
770 if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
771 project_state
772 .license_detection_watchers
773 .entry(worktree_id)
774 .or_insert_with(|| {
775 let project_entity_id = project.entity_id();
776 cx.observe_release(&worktree, move |this, _worktree, _cx| {
777 let Some(project_state) = this.projects.get_mut(&project_entity_id)
778 else {
779 return;
780 };
781 project_state
782 .license_detection_watchers
783 .remove(&worktree_id);
784 })
785 .detach();
786 Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
787 });
788 }
789 }
790
791 match project_state.registered_buffers.entry(buffer_id) {
792 hash_map::Entry::Occupied(entry) => entry.into_mut(),
793 hash_map::Entry::Vacant(entry) => {
794 let snapshot = buffer.read(cx).snapshot();
795 let project_entity_id = project.entity_id();
796 entry.insert(RegisteredBuffer {
797 snapshot,
798 _subscriptions: [
799 cx.subscribe(buffer, {
800 let project = project.downgrade();
801 move |this, buffer, event, cx| {
802 if let language::BufferEvent::Edited = event
803 && let Some(project) = project.upgrade()
804 {
805 this.report_changes_for_buffer(&buffer, &project, cx);
806 }
807 }
808 }),
809 cx.observe_release(buffer, move |this, _buffer, _cx| {
810 let Some(project_state) = this.projects.get_mut(&project_entity_id)
811 else {
812 return;
813 };
814 project_state.registered_buffers.remove(&buffer_id);
815 }),
816 ],
817 })
818 }
819 }
820 }
821
822 fn report_changes_for_buffer(
823 &mut self,
824 buffer: &Entity<Buffer>,
825 project: &Entity<Project>,
826 cx: &mut Context<Self>,
827 ) {
828 let project_state = self.get_or_init_project(project, cx);
829 let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
830
831 let new_snapshot = buffer.read(cx).snapshot();
832 if new_snapshot.version == registered_buffer.snapshot.version {
833 return;
834 }
835
836 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
837 let end_edit_anchor = new_snapshot
838 .anchored_edits_since::<Point>(&old_snapshot.version)
839 .last()
840 .map(|(_, range)| range.end);
841 let events = &mut project_state.events;
842
843 if let Some(LastEvent {
844 new_snapshot: last_new_snapshot,
845 end_edit_anchor: last_end_edit_anchor,
846 ..
847 }) = project_state.last_event.as_mut()
848 {
849 let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
850 == last_new_snapshot.remote_id()
851 && old_snapshot.version == last_new_snapshot.version;
852
853 let should_coalesce = is_next_snapshot_of_same_buffer
854 && end_edit_anchor
855 .as_ref()
856 .zip(last_end_edit_anchor.as_ref())
857 .is_some_and(|(a, b)| {
858 let a = a.to_point(&new_snapshot);
859 let b = b.to_point(&new_snapshot);
860 a.row.abs_diff(b.row) <= CHANGE_GROUPING_LINE_SPAN
861 });
862
863 if should_coalesce {
864 *last_end_edit_anchor = end_edit_anchor;
865 *last_new_snapshot = new_snapshot;
866 return;
867 }
868 }
869
870 if events.len() + 1 >= EVENT_COUNT_MAX {
871 events.pop_front();
872 }
873
874 if let Some(event) = project_state.last_event.take() {
875 events.extend(event.finalize(&project_state.license_detection_watchers, cx));
876 }
877
878 project_state.last_event = Some(LastEvent {
879 old_snapshot,
880 new_snapshot,
881 end_edit_anchor,
882 });
883 }
884
885 fn current_prediction_for_buffer(
886 &self,
887 buffer: &Entity<Buffer>,
888 project: &Entity<Project>,
889 cx: &App,
890 ) -> Option<BufferEditPrediction<'_>> {
891 let project_state = self.projects.get(&project.entity_id())?;
892
893 let CurrentEditPrediction {
894 requested_by,
895 prediction,
896 ..
897 } = project_state.current_prediction.as_ref()?;
898
899 if prediction.targets_buffer(buffer.read(cx)) {
900 Some(BufferEditPrediction::Local { prediction })
901 } else {
902 let show_jump = match requested_by {
903 PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
904 requested_by_buffer_id == &buffer.entity_id()
905 }
906 PredictionRequestedBy::DiagnosticsUpdate => true,
907 };
908
909 if show_jump {
910 Some(BufferEditPrediction::Jump { prediction })
911 } else {
912 None
913 }
914 }
915 }
916
917 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
918 match self.edit_prediction_model {
919 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {}
920 EditPredictionModel::Sweep | EditPredictionModel::Mercury => return,
921 }
922
923 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
924 return;
925 };
926
927 let Some(prediction) = project_state.current_prediction.take() else {
928 return;
929 };
930 let request_id = prediction.prediction.id.to_string();
931 for pending_prediction in mem::take(&mut project_state.pending_predictions) {
932 project_state.cancel_pending_prediction(pending_prediction, cx);
933 }
934
935 let client = self.client.clone();
936 let llm_token = self.llm_token.clone();
937 let app_version = AppVersion::global(cx);
938 cx.spawn(async move |this, cx| {
939 let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") {
940 http_client::Url::parse(&predict_edits_url)?
941 } else {
942 client
943 .http_client()
944 .build_zed_llm_url("/predict_edits/accept", &[])?
945 };
946
947 let response = cx
948 .background_spawn(Self::send_api_request::<()>(
949 move |builder| {
950 let req = builder.uri(url.as_ref()).body(
951 serde_json::to_string(&AcceptEditPredictionBody {
952 request_id: request_id.clone(),
953 })?
954 .into(),
955 );
956 Ok(req?)
957 },
958 client,
959 llm_token,
960 app_version,
961 ))
962 .await;
963
964 Self::handle_api_response(&this, response, cx)?;
965 anyhow::Ok(())
966 })
967 .detach_and_log_err(cx);
968 }
969
970 async fn handle_rejected_predictions(
971 rx: UnboundedReceiver<EditPredictionRejection>,
972 client: Arc<Client>,
973 llm_token: LlmApiToken,
974 app_version: Version,
975 background_executor: BackgroundExecutor,
976 ) {
977 let mut rx = std::pin::pin!(rx.peekable());
978 let mut batched = Vec::new();
979
980 while let Some(rejection) = rx.next().await {
981 batched.push(rejection);
982
983 if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
984 select_biased! {
985 next = rx.as_mut().peek().fuse() => {
986 if next.is_some() {
987 continue;
988 }
989 }
990 () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
991 }
992 }
993
994 let url = client
995 .http_client()
996 .build_zed_llm_url("/predict_edits/reject", &[])
997 .unwrap();
998
999 let flush_count = batched
1000 .len()
1001 // in case items have accumulated after failure
1002 .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1003 let start = batched.len() - flush_count;
1004
1005 let body = RejectEditPredictionsBodyRef {
1006 rejections: &batched[start..],
1007 };
1008
1009 let result = Self::send_api_request::<()>(
1010 |builder| {
1011 let req = builder
1012 .uri(url.as_ref())
1013 .body(serde_json::to_string(&body)?.into());
1014 anyhow::Ok(req?)
1015 },
1016 client.clone(),
1017 llm_token.clone(),
1018 app_version.clone(),
1019 )
1020 .await;
1021
1022 if result.log_err().is_some() {
1023 batched.drain(start..);
1024 }
1025 }
1026 }
1027
1028 fn reject_current_prediction(
1029 &mut self,
1030 reason: EditPredictionRejectReason,
1031 project: &Entity<Project>,
1032 ) {
1033 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1034 project_state.pending_predictions.clear();
1035 if let Some(prediction) = project_state.current_prediction.take() {
1036 self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown);
1037 }
1038 };
1039 }
1040
1041 fn did_show_current_prediction(&mut self, project: &Entity<Project>, _cx: &mut Context<Self>) {
1042 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1043 if let Some(current_prediction) = project_state.current_prediction.as_mut() {
1044 if !current_prediction.was_shown {
1045 current_prediction.was_shown = true;
1046 self.shown_predictions
1047 .push_front(current_prediction.prediction.clone());
1048 if self.shown_predictions.len() > 50 {
1049 let completion = self.shown_predictions.pop_back().unwrap();
1050 self.rated_predictions.remove(&completion.id);
1051 }
1052 }
1053 }
1054 }
1055 }
1056
1057 fn reject_prediction(
1058 &mut self,
1059 prediction_id: EditPredictionId,
1060 reason: EditPredictionRejectReason,
1061 was_shown: bool,
1062 ) {
1063 match self.edit_prediction_model {
1064 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {}
1065 EditPredictionModel::Sweep | EditPredictionModel::Mercury => return,
1066 }
1067
1068 self.reject_predictions_tx
1069 .unbounded_send(EditPredictionRejection {
1070 request_id: prediction_id.to_string(),
1071 reason,
1072 was_shown,
1073 })
1074 .log_err();
1075 }
1076
1077 fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1078 self.projects
1079 .get(&project.entity_id())
1080 .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1081 }
1082
1083 pub fn refresh_prediction_from_buffer(
1084 &mut self,
1085 project: Entity<Project>,
1086 buffer: Entity<Buffer>,
1087 position: language::Anchor,
1088 cx: &mut Context<Self>,
1089 ) {
1090 self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
1091 let Some(request_task) = this
1092 .update(cx, |this, cx| {
1093 this.request_prediction(
1094 &project,
1095 &buffer,
1096 position,
1097 PredictEditsRequestTrigger::Other,
1098 cx,
1099 )
1100 })
1101 .log_err()
1102 else {
1103 return Task::ready(anyhow::Ok(None));
1104 };
1105
1106 cx.spawn(async move |_cx| {
1107 request_task.await.map(|prediction_result| {
1108 prediction_result.map(|prediction_result| {
1109 (
1110 prediction_result,
1111 PredictionRequestedBy::Buffer(buffer.entity_id()),
1112 )
1113 })
1114 })
1115 })
1116 })
1117 }
1118
1119 pub fn refresh_prediction_from_diagnostics(
1120 &mut self,
1121 project: Entity<Project>,
1122 cx: &mut Context<Self>,
1123 ) {
1124 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1125 return;
1126 };
1127
1128 // Prefer predictions from buffer
1129 if project_state.current_prediction.is_some() {
1130 return;
1131 };
1132
1133 self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
1134 let Some(open_buffer_task) = project
1135 .update(cx, |project, cx| {
1136 project
1137 .active_entry()
1138 .and_then(|entry| project.path_for_entry(entry, cx))
1139 .map(|path| project.open_buffer(path, cx))
1140 })
1141 .log_err()
1142 .flatten()
1143 else {
1144 return Task::ready(anyhow::Ok(None));
1145 };
1146
1147 cx.spawn(async move |cx| {
1148 let active_buffer = open_buffer_task.await?;
1149 let snapshot = active_buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
1150
1151 let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1152 active_buffer,
1153 &snapshot,
1154 Default::default(),
1155 Default::default(),
1156 &project,
1157 cx,
1158 )
1159 .await?
1160 else {
1161 return anyhow::Ok(None);
1162 };
1163
1164 let Some(prediction_result) = this
1165 .update(cx, |this, cx| {
1166 this.request_prediction(
1167 &project,
1168 &jump_buffer,
1169 jump_position,
1170 PredictEditsRequestTrigger::Diagnostics,
1171 cx,
1172 )
1173 })?
1174 .await?
1175 else {
1176 return anyhow::Ok(None);
1177 };
1178
1179 this.update(cx, |this, cx| {
1180 Some((
1181 if this
1182 .get_or_init_project(&project, cx)
1183 .current_prediction
1184 .is_none()
1185 {
1186 prediction_result
1187 } else {
1188 EditPredictionResult {
1189 id: prediction_result.id,
1190 prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1191 }
1192 },
1193 PredictionRequestedBy::DiagnosticsUpdate,
1194 ))
1195 })
1196 })
1197 });
1198 }
1199
1200 #[cfg(not(test))]
1201 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1202 #[cfg(test)]
1203 pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
1204
1205 fn queue_prediction_refresh(
1206 &mut self,
1207 project: Entity<Project>,
1208 throttle_entity: EntityId,
1209 cx: &mut Context<Self>,
1210 do_refresh: impl FnOnce(
1211 WeakEntity<Self>,
1212 &mut AsyncApp,
1213 )
1214 -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1215 + 'static,
1216 ) {
1217 let project_state = self.get_or_init_project(&project, cx);
1218 let pending_prediction_id = project_state.next_pending_prediction_id;
1219 project_state.next_pending_prediction_id += 1;
1220 let last_request = project_state.last_prediction_refresh;
1221
1222 let task = cx.spawn(async move |this, cx| {
1223 if let Some((last_entity, last_timestamp)) = last_request
1224 && throttle_entity == last_entity
1225 && let Some(timeout) =
1226 (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
1227 {
1228 cx.background_executor().timer(timeout).await;
1229 }
1230
1231 // If this task was cancelled before the throttle timeout expired,
1232 // do not perform a request.
1233 let mut is_cancelled = true;
1234 this.update(cx, |this, cx| {
1235 let project_state = this.get_or_init_project(&project, cx);
1236 if !project_state
1237 .cancelled_predictions
1238 .remove(&pending_prediction_id)
1239 {
1240 project_state.last_prediction_refresh = Some((throttle_entity, Instant::now()));
1241 is_cancelled = false;
1242 }
1243 })
1244 .ok();
1245 if is_cancelled {
1246 return None;
1247 }
1248
1249 let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
1250 let new_prediction_id = new_prediction_result
1251 .as_ref()
1252 .map(|(prediction, _)| prediction.id.clone());
1253
1254 // When a prediction completes, remove it from the pending list, and cancel
1255 // any pending predictions that were enqueued before it.
1256 this.update(cx, |this, cx| {
1257 let project_state = this.get_or_init_project(&project, cx);
1258
1259 let is_cancelled = project_state
1260 .cancelled_predictions
1261 .remove(&pending_prediction_id);
1262
1263 let new_current_prediction = if !is_cancelled
1264 && let Some((prediction_result, requested_by)) = new_prediction_result
1265 {
1266 match prediction_result.prediction {
1267 Ok(prediction) => {
1268 let new_prediction = CurrentEditPrediction {
1269 requested_by,
1270 prediction,
1271 was_shown: false,
1272 };
1273
1274 if let Some(current_prediction) =
1275 project_state.current_prediction.as_ref()
1276 {
1277 if new_prediction.should_replace_prediction(¤t_prediction, cx)
1278 {
1279 this.reject_current_prediction(
1280 EditPredictionRejectReason::Replaced,
1281 &project,
1282 );
1283
1284 Some(new_prediction)
1285 } else {
1286 this.reject_prediction(
1287 new_prediction.prediction.id,
1288 EditPredictionRejectReason::CurrentPreferred,
1289 false,
1290 );
1291 None
1292 }
1293 } else {
1294 Some(new_prediction)
1295 }
1296 }
1297 Err(reject_reason) => {
1298 this.reject_prediction(prediction_result.id, reject_reason, false);
1299 None
1300 }
1301 }
1302 } else {
1303 None
1304 };
1305
1306 let project_state = this.get_or_init_project(&project, cx);
1307
1308 if let Some(new_prediction) = new_current_prediction {
1309 project_state.current_prediction = Some(new_prediction);
1310 }
1311
1312 let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
1313 for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1314 if pending_prediction.id == pending_prediction_id {
1315 pending_predictions.remove(ix);
1316 for pending_prediction in pending_predictions.drain(0..ix) {
1317 project_state.cancel_pending_prediction(pending_prediction, cx)
1318 }
1319 break;
1320 }
1321 }
1322 this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
1323 cx.notify();
1324 })
1325 .ok();
1326
1327 new_prediction_id
1328 });
1329
1330 if project_state.pending_predictions.len() <= 1 {
1331 project_state.pending_predictions.push(PendingPrediction {
1332 id: pending_prediction_id,
1333 task,
1334 });
1335 } else if project_state.pending_predictions.len() == 2 {
1336 let pending_prediction = project_state.pending_predictions.pop().unwrap();
1337 project_state.pending_predictions.push(PendingPrediction {
1338 id: pending_prediction_id,
1339 task,
1340 });
1341 project_state.cancel_pending_prediction(pending_prediction, cx);
1342 }
1343 }
1344
1345 pub fn request_prediction(
1346 &mut self,
1347 project: &Entity<Project>,
1348 active_buffer: &Entity<Buffer>,
1349 position: language::Anchor,
1350 trigger: PredictEditsRequestTrigger,
1351 cx: &mut Context<Self>,
1352 ) -> Task<Result<Option<EditPredictionResult>>> {
1353 self.request_prediction_internal(
1354 project.clone(),
1355 active_buffer.clone(),
1356 position,
1357 trigger,
1358 cx.has_flag::<Zeta2FeatureFlag>(),
1359 cx,
1360 )
1361 }
1362
1363 fn request_prediction_internal(
1364 &mut self,
1365 project: Entity<Project>,
1366 active_buffer: Entity<Buffer>,
1367 position: language::Anchor,
1368 trigger: PredictEditsRequestTrigger,
1369 allow_jump: bool,
1370 cx: &mut Context<Self>,
1371 ) -> Task<Result<Option<EditPredictionResult>>> {
1372 const DIAGNOSTIC_LINES_RANGE: u32 = 20;
1373
1374 self.get_or_init_project(&project, cx);
1375 let project_state = self.projects.get(&project.entity_id()).unwrap();
1376 let events = project_state.events(cx);
1377 let has_events = !events.is_empty();
1378 let debug_tx = project_state.debug_tx.clone();
1379
1380 let snapshot = active_buffer.read(cx).snapshot();
1381 let cursor_point = position.to_point(&snapshot);
1382 let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1383 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1384 let diagnostic_search_range =
1385 Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1386
1387 let related_files = if self.use_context {
1388 self.context_for_project(&project, cx)
1389 } else {
1390 Vec::new().into()
1391 };
1392
1393 let inputs = EditPredictionModelInput {
1394 project: project.clone(),
1395 buffer: active_buffer.clone(),
1396 snapshot: snapshot.clone(),
1397 position,
1398 events,
1399 related_files,
1400 recent_paths: project_state.recent_paths.clone(),
1401 trigger,
1402 diagnostic_search_range: diagnostic_search_range.clone(),
1403 debug_tx,
1404 };
1405
1406 let task = match self.edit_prediction_model {
1407 EditPredictionModel::Zeta1 => zeta1::request_prediction_with_zeta1(self, inputs, cx),
1408 EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2(self, inputs, cx),
1409 EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
1410 EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
1411 };
1412
1413 cx.spawn(async move |this, cx| {
1414 let prediction = task.await?;
1415
1416 if prediction.is_none() && allow_jump {
1417 let cursor_point = position.to_point(&snapshot);
1418 if has_events
1419 && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1420 active_buffer.clone(),
1421 &snapshot,
1422 diagnostic_search_range,
1423 cursor_point,
1424 &project,
1425 cx,
1426 )
1427 .await?
1428 {
1429 return this
1430 .update(cx, |this, cx| {
1431 this.request_prediction_internal(
1432 project,
1433 jump_buffer,
1434 jump_position,
1435 trigger,
1436 false,
1437 cx,
1438 )
1439 })?
1440 .await;
1441 }
1442
1443 return anyhow::Ok(None);
1444 }
1445
1446 Ok(prediction)
1447 })
1448 }
1449
1450 async fn next_diagnostic_location(
1451 active_buffer: Entity<Buffer>,
1452 active_buffer_snapshot: &BufferSnapshot,
1453 active_buffer_diagnostic_search_range: Range<Point>,
1454 active_buffer_cursor_point: Point,
1455 project: &Entity<Project>,
1456 cx: &mut AsyncApp,
1457 ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1458 // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
1459 let mut jump_location = active_buffer_snapshot
1460 .diagnostic_groups(None)
1461 .into_iter()
1462 .filter_map(|(_, group)| {
1463 let range = &group.entries[group.primary_ix]
1464 .range
1465 .to_point(&active_buffer_snapshot);
1466 if range.overlaps(&active_buffer_diagnostic_search_range) {
1467 None
1468 } else {
1469 Some(range.start)
1470 }
1471 })
1472 .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
1473 .map(|position| {
1474 (
1475 active_buffer.clone(),
1476 active_buffer_snapshot.anchor_before(position),
1477 )
1478 });
1479
1480 if jump_location.is_none() {
1481 let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
1482 let file = buffer.file()?;
1483
1484 Some(ProjectPath {
1485 worktree_id: file.worktree_id(cx),
1486 path: file.path().clone(),
1487 })
1488 })?;
1489
1490 let buffer_task = project.update(cx, |project, cx| {
1491 let (path, _, _) = project
1492 .diagnostic_summaries(false, cx)
1493 .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
1494 .max_by_key(|(path, _, _)| {
1495 // find the buffer with errors that shares most parent directories
1496 path.path
1497 .components()
1498 .zip(
1499 active_buffer_path
1500 .as_ref()
1501 .map(|p| p.path.components())
1502 .unwrap_or_default(),
1503 )
1504 .take_while(|(a, b)| a == b)
1505 .count()
1506 })?;
1507
1508 Some(project.open_buffer(path, cx))
1509 })?;
1510
1511 if let Some(buffer_task) = buffer_task {
1512 let closest_buffer = buffer_task.await?;
1513
1514 jump_location = closest_buffer
1515 .read_with(cx, |buffer, _cx| {
1516 buffer
1517 .buffer_diagnostics(None)
1518 .into_iter()
1519 .min_by_key(|entry| entry.diagnostic.severity)
1520 .map(|entry| entry.range.start)
1521 })?
1522 .map(|position| (closest_buffer, position));
1523 }
1524 }
1525
1526 anyhow::Ok(jump_location)
1527 }
1528
1529 async fn send_raw_llm_request(
1530 request: open_ai::Request,
1531 client: Arc<Client>,
1532 llm_token: LlmApiToken,
1533 app_version: Version,
1534 #[cfg(feature = "eval-support")] eval_cache: Option<Arc<dyn EvalCache>>,
1535 #[cfg(feature = "eval-support")] eval_cache_kind: EvalCacheEntryKind,
1536 ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
1537 let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
1538 http_client::Url::parse(&predict_edits_url)?
1539 } else {
1540 client
1541 .http_client()
1542 .build_zed_llm_url("/predict_edits/raw", &[])?
1543 };
1544
1545 #[cfg(feature = "eval-support")]
1546 let cache_key = if let Some(cache) = eval_cache {
1547 use collections::FxHasher;
1548 use std::hash::{Hash, Hasher};
1549
1550 let mut hasher = FxHasher::default();
1551 url.hash(&mut hasher);
1552 let request_str = serde_json::to_string_pretty(&request)?;
1553 request_str.hash(&mut hasher);
1554 let hash = hasher.finish();
1555
1556 let key = (eval_cache_kind, hash);
1557 if let Some(response_str) = cache.read(key) {
1558 return Ok((serde_json::from_str(&response_str)?, None));
1559 }
1560
1561 Some((cache, request_str, key))
1562 } else {
1563 None
1564 };
1565
1566 let (response, usage) = Self::send_api_request(
1567 |builder| {
1568 let req = builder
1569 .uri(url.as_ref())
1570 .body(serde_json::to_string(&request)?.into());
1571 Ok(req?)
1572 },
1573 client,
1574 llm_token,
1575 app_version,
1576 )
1577 .await?;
1578
1579 #[cfg(feature = "eval-support")]
1580 if let Some((cache, request, key)) = cache_key {
1581 cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
1582 }
1583
1584 Ok((response, usage))
1585 }
1586
1587 fn handle_api_response<T>(
1588 this: &WeakEntity<Self>,
1589 response: Result<(T, Option<EditPredictionUsage>)>,
1590 cx: &mut gpui::AsyncApp,
1591 ) -> Result<T> {
1592 match response {
1593 Ok((data, usage)) => {
1594 if let Some(usage) = usage {
1595 this.update(cx, |this, cx| {
1596 this.user_store.update(cx, |user_store, cx| {
1597 user_store.update_edit_prediction_usage(usage, cx);
1598 });
1599 })
1600 .ok();
1601 }
1602 Ok(data)
1603 }
1604 Err(err) => {
1605 if err.is::<ZedUpdateRequiredError>() {
1606 cx.update(|cx| {
1607 this.update(cx, |this, _cx| {
1608 this.update_required = true;
1609 })
1610 .ok();
1611
1612 let error_message: SharedString = err.to_string().into();
1613 show_app_notification(
1614 NotificationId::unique::<ZedUpdateRequiredError>(),
1615 cx,
1616 move |cx| {
1617 cx.new(|cx| {
1618 ErrorMessagePrompt::new(error_message.clone(), cx)
1619 .with_link_button("Update Zed", "https://zed.dev/releases")
1620 })
1621 },
1622 );
1623 })
1624 .ok();
1625 }
1626 Err(err)
1627 }
1628 }
1629 }
1630
1631 async fn send_api_request<Res>(
1632 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1633 client: Arc<Client>,
1634 llm_token: LlmApiToken,
1635 app_version: Version,
1636 ) -> Result<(Res, Option<EditPredictionUsage>)>
1637 where
1638 Res: DeserializeOwned,
1639 {
1640 let http_client = client.http_client();
1641 let mut token = llm_token.acquire(&client).await?;
1642 let mut did_retry = false;
1643
1644 loop {
1645 let request_builder = http_client::Request::builder().method(Method::POST);
1646
1647 let request = build(
1648 request_builder
1649 .header("Content-Type", "application/json")
1650 .header("Authorization", format!("Bearer {}", token))
1651 .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
1652 )?;
1653
1654 let mut response = http_client.send(request).await?;
1655
1656 if let Some(minimum_required_version) = response
1657 .headers()
1658 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
1659 .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
1660 {
1661 anyhow::ensure!(
1662 app_version >= minimum_required_version,
1663 ZedUpdateRequiredError {
1664 minimum_version: minimum_required_version
1665 }
1666 );
1667 }
1668
1669 if response.status().is_success() {
1670 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
1671
1672 let mut body = Vec::new();
1673 response.body_mut().read_to_end(&mut body).await?;
1674 return Ok((serde_json::from_slice(&body)?, usage));
1675 } else if !did_retry
1676 && response
1677 .headers()
1678 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1679 .is_some()
1680 {
1681 did_retry = true;
1682 token = llm_token.refresh(&client).await?;
1683 } else {
1684 let mut body = String::new();
1685 response.body_mut().read_to_string(&mut body).await?;
1686 anyhow::bail!(
1687 "Request failed with status: {:?}\nBody: {}",
1688 response.status(),
1689 body
1690 );
1691 }
1692 }
1693 }
1694
1695 pub fn refresh_context(
1696 &mut self,
1697 project: &Entity<Project>,
1698 buffer: &Entity<language::Buffer>,
1699 cursor_position: language::Anchor,
1700 cx: &mut Context<Self>,
1701 ) {
1702 if self.use_context {
1703 self.get_or_init_project(project, cx)
1704 .context
1705 .update(cx, |store, cx| {
1706 store.refresh(buffer.clone(), cursor_position, cx);
1707 });
1708 }
1709 }
1710
1711 #[cfg(feature = "eval-support")]
1712 pub fn set_context_for_buffer(
1713 &mut self,
1714 project: &Entity<Project>,
1715 related_files: Vec<RelatedFile>,
1716 cx: &mut Context<Self>,
1717 ) {
1718 self.get_or_init_project(project, cx)
1719 .context
1720 .update(cx, |store, _| {
1721 store.set_related_files(related_files);
1722 });
1723 }
1724
1725 fn is_file_open_source(
1726 &self,
1727 project: &Entity<Project>,
1728 file: &Arc<dyn File>,
1729 cx: &App,
1730 ) -> bool {
1731 if !file.is_local() || file.is_private() {
1732 return false;
1733 }
1734 let Some(project_state) = self.projects.get(&project.entity_id()) else {
1735 return false;
1736 };
1737 project_state
1738 .license_detection_watchers
1739 .get(&file.worktree_id(cx))
1740 .as_ref()
1741 .is_some_and(|watcher| watcher.is_project_open_source())
1742 }
1743
1744 fn can_collect_file(&self, project: &Entity<Project>, file: &Arc<dyn File>, cx: &App) -> bool {
1745 self.data_collection_choice.is_enabled() && self.is_file_open_source(project, file, cx)
1746 }
1747
1748 fn can_collect_events(&self, events: &[Arc<zeta_prompt::Event>]) -> bool {
1749 if !self.data_collection_choice.is_enabled() {
1750 return false;
1751 }
1752 events.iter().all(|event| {
1753 matches!(
1754 event.as_ref(),
1755 zeta_prompt::Event::BufferChange {
1756 in_open_source_repo: true,
1757 ..
1758 }
1759 )
1760 })
1761 }
1762
1763 fn load_data_collection_choice() -> DataCollectionChoice {
1764 let choice = KEY_VALUE_STORE
1765 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
1766 .log_err()
1767 .flatten();
1768
1769 match choice.as_deref() {
1770 Some("true") => DataCollectionChoice::Enabled,
1771 Some("false") => DataCollectionChoice::Disabled,
1772 Some(_) => {
1773 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
1774 DataCollectionChoice::NotAnswered
1775 }
1776 None => DataCollectionChoice::NotAnswered,
1777 }
1778 }
1779
1780 fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
1781 self.data_collection_choice = self.data_collection_choice.toggle();
1782 let new_choice = self.data_collection_choice;
1783 db::write_and_log(cx, move || {
1784 KEY_VALUE_STORE.write_kvp(
1785 ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
1786 new_choice.is_enabled().to_string(),
1787 )
1788 });
1789 }
1790
1791 pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
1792 self.shown_predictions.iter()
1793 }
1794
1795 pub fn shown_completions_len(&self) -> usize {
1796 self.shown_predictions.len()
1797 }
1798
1799 pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
1800 self.rated_predictions.contains(id)
1801 }
1802
1803 pub fn rate_prediction(
1804 &mut self,
1805 prediction: &EditPrediction,
1806 rating: EditPredictionRating,
1807 feedback: String,
1808 cx: &mut Context<Self>,
1809 ) {
1810 self.rated_predictions.insert(prediction.id.clone());
1811 telemetry::event!(
1812 "Edit Prediction Rated",
1813 rating,
1814 inputs = prediction.inputs,
1815 output = prediction.edit_preview.as_unified_diff(&prediction.edits),
1816 feedback
1817 );
1818 self.client.telemetry().flush_events().detach();
1819 cx.notify();
1820 }
1821
1822 fn configure_context_retrieval(&mut self, cx: &mut Context<'_, EditPredictionStore>) {
1823 self.use_context = cx.has_flag::<Zeta2FeatureFlag>()
1824 && all_language_settings(None, cx).edit_predictions.use_context;
1825 }
1826}
1827
1828#[derive(Error, Debug)]
1829#[error(
1830 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1831)]
1832pub struct ZedUpdateRequiredError {
1833 minimum_version: Version,
1834}
1835
1836#[cfg(feature = "eval-support")]
1837pub type EvalCacheKey = (EvalCacheEntryKind, u64);
1838
1839#[cfg(feature = "eval-support")]
1840#[derive(Debug, Clone, Copy, PartialEq)]
1841pub enum EvalCacheEntryKind {
1842 Context,
1843 Search,
1844 Prediction,
1845}
1846
1847#[cfg(feature = "eval-support")]
1848impl std::fmt::Display for EvalCacheEntryKind {
1849 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1850 match self {
1851 EvalCacheEntryKind::Search => write!(f, "search"),
1852 EvalCacheEntryKind::Context => write!(f, "context"),
1853 EvalCacheEntryKind::Prediction => write!(f, "prediction"),
1854 }
1855 }
1856}
1857
1858#[cfg(feature = "eval-support")]
1859pub trait EvalCache: Send + Sync {
1860 fn read(&self, key: EvalCacheKey) -> Option<String>;
1861 fn write(&self, key: EvalCacheKey, input: &str, value: &str);
1862}
1863
1864#[derive(Debug, Clone, Copy)]
1865pub enum DataCollectionChoice {
1866 NotAnswered,
1867 Enabled,
1868 Disabled,
1869}
1870
1871impl DataCollectionChoice {
1872 pub fn is_enabled(self) -> bool {
1873 match self {
1874 Self::Enabled => true,
1875 Self::NotAnswered | Self::Disabled => false,
1876 }
1877 }
1878
1879 pub fn is_answered(self) -> bool {
1880 match self {
1881 Self::Enabled | Self::Disabled => true,
1882 Self::NotAnswered => false,
1883 }
1884 }
1885
1886 #[must_use]
1887 pub fn toggle(&self) -> DataCollectionChoice {
1888 match self {
1889 Self::Enabled => Self::Disabled,
1890 Self::Disabled => Self::Enabled,
1891 Self::NotAnswered => Self::Enabled,
1892 }
1893 }
1894}
1895
1896impl From<bool> for DataCollectionChoice {
1897 fn from(value: bool) -> Self {
1898 match value {
1899 true => DataCollectionChoice::Enabled,
1900 false => DataCollectionChoice::Disabled,
1901 }
1902 }
1903}
1904
1905struct ZedPredictUpsell;
1906
1907impl Dismissable for ZedPredictUpsell {
1908 const KEY: &'static str = "dismissed-edit-predict-upsell";
1909
1910 fn dismissed() -> bool {
1911 // To make this backwards compatible with older versions of Zed, we
1912 // check if the user has seen the previous Edit Prediction Onboarding
1913 // before, by checking the data collection choice which was written to
1914 // the database once the user clicked on "Accept and Enable"
1915 if KEY_VALUE_STORE
1916 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
1917 .log_err()
1918 .is_some_and(|s| s.is_some())
1919 {
1920 return true;
1921 }
1922
1923 KEY_VALUE_STORE
1924 .read_kvp(Self::KEY)
1925 .log_err()
1926 .is_some_and(|s| s.is_some())
1927 }
1928}
1929
1930pub fn should_show_upsell_modal() -> bool {
1931 !ZedPredictUpsell::dismissed()
1932}
1933
1934pub fn init(cx: &mut App) {
1935 cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
1936 workspace.register_action(
1937 move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
1938 ZedPredictModal::toggle(
1939 workspace,
1940 workspace.user_store().clone(),
1941 workspace.client().clone(),
1942 window,
1943 cx,
1944 )
1945 },
1946 );
1947
1948 workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
1949 update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
1950 settings
1951 .project
1952 .all_languages
1953 .features
1954 .get_or_insert_default()
1955 .edit_prediction_provider = Some(EditPredictionProvider::None)
1956 });
1957 });
1958 })
1959 .detach();
1960}