1use anyhow::Result;
2use arrayvec::ArrayVec;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat};
5use cloud_llm_client::{
6 AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason,
7 EditPredictionRejection, MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST,
8 MINIMUM_REQUIRED_VERSION_HEADER_NAME, PredictEditsRequestTrigger, RejectEditPredictionsBodyRef,
9 ZED_VERSION_HEADER_NAME,
10};
11use cloud_zeta2_prompt::DEFAULT_MAX_PROMPT_BYTES;
12use collections::{HashMap, HashSet};
13use db::kvp::{Dismissable, KEY_VALUE_STORE};
14use edit_prediction_context::EditPredictionExcerptOptions;
15use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
16use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
17use futures::{
18 AsyncReadExt as _, FutureExt as _, StreamExt as _,
19 channel::{
20 mpsc::{self, UnboundedReceiver},
21 oneshot,
22 },
23 select_biased,
24};
25use gpui::BackgroundExecutor;
26use gpui::{
27 App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
28 http_client::{self, AsyncBody, Method},
29 prelude::*,
30};
31use language::language_settings::all_language_settings;
32use language::{Anchor, Buffer, File, Point, ToPoint};
33use language::{BufferSnapshot, OffsetRangeExt};
34use language_model::{LlmApiToken, RefreshLlmTokenListener};
35use project::{Project, ProjectPath, WorktreeId};
36use release_channel::AppVersion;
37use semver::Version;
38use serde::de::DeserializeOwned;
39use settings::{EditPredictionProvider, SettingsStore, update_settings_file};
40use std::collections::{VecDeque, hash_map};
41use workspace::Workspace;
42
43use std::ops::Range;
44use std::path::Path;
45use std::rc::Rc;
46use std::str::FromStr as _;
47use std::sync::{Arc, LazyLock};
48use std::time::{Duration, Instant};
49use std::{env, mem};
50use thiserror::Error;
51use util::{RangeExt as _, ResultExt as _};
52use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
53
54mod cursor_excerpt;
55mod license_detection;
56pub mod mercury;
57mod onboarding_modal;
58pub mod open_ai_response;
59mod prediction;
60pub mod sweep_ai;
61pub mod udiff;
62mod xml_edits;
63mod zed_edit_prediction_delegate;
64pub mod zeta1;
65pub mod zeta2;
66
67#[cfg(test)]
68mod edit_prediction_tests;
69
70use crate::license_detection::LicenseDetectionWatcher;
71use crate::mercury::Mercury;
72use crate::onboarding_modal::ZedPredictModal;
73pub use crate::prediction::EditPrediction;
74pub use crate::prediction::EditPredictionId;
75pub use crate::prediction::EditPredictionInputs;
76use crate::prediction::EditPredictionResult;
77pub use crate::sweep_ai::SweepAi;
78pub use telemetry_events::EditPredictionRating;
79pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
80
81actions!(
82 edit_prediction,
83 [
84 /// Resets the edit prediction onboarding state.
85 ResetOnboarding,
86 /// Clears the edit prediction history.
87 ClearHistory,
88 ]
89);
90
91/// Maximum number of events to track.
92const EVENT_COUNT_MAX: usize = 6;
93const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
94const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
95const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
96
97pub struct SweepFeatureFlag;
98
99impl FeatureFlag for SweepFeatureFlag {
100 const NAME: &str = "sweep-ai";
101}
102
103pub struct MercuryFeatureFlag;
104
105impl FeatureFlag for MercuryFeatureFlag {
106 const NAME: &str = "mercury";
107}
108
109pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
110 context: EditPredictionExcerptOptions {
111 max_bytes: 512,
112 min_bytes: 128,
113 target_before_cursor_over_total_bytes: 0.5,
114 },
115 max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
116 prompt_format: PromptFormat::DEFAULT,
117};
118
119static USE_OLLAMA: LazyLock<bool> =
120 LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
121
122static EDIT_PREDICTIONS_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
123 match env::var("ZED_ZETA2_MODEL").as_deref() {
124 Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten
125 Ok(model) => model,
126 Err(_) if *USE_OLLAMA => "qwen3-coder:30b",
127 Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten
128 }
129 .to_string()
130});
131static PREDICT_EDITS_URL: LazyLock<Option<String>> = LazyLock::new(|| {
132 env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| {
133 if *USE_OLLAMA {
134 Some("http://localhost:11434/v1/chat/completions".into())
135 } else {
136 None
137 }
138 })
139});
140
141pub struct Zeta2FeatureFlag;
142
143impl FeatureFlag for Zeta2FeatureFlag {
144 const NAME: &'static str = "zeta2";
145
146 fn enabled_for_staff() -> bool {
147 true
148 }
149}
150
151#[derive(Clone)]
152struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
153
154impl Global for EditPredictionStoreGlobal {}
155
156pub struct EditPredictionStore {
157 client: Arc<Client>,
158 user_store: Entity<UserStore>,
159 llm_token: LlmApiToken,
160 _llm_token_subscription: Subscription,
161 projects: HashMap<EntityId, ProjectState>,
162 use_context: bool,
163 options: ZetaOptions,
164 update_required: bool,
165 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
166 #[cfg(feature = "eval-support")]
167 eval_cache: Option<Arc<dyn EvalCache>>,
168 edit_prediction_model: EditPredictionModel,
169 pub sweep_ai: SweepAi,
170 pub mercury: Mercury,
171 data_collection_choice: DataCollectionChoice,
172 reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
173 shown_predictions: VecDeque<EditPrediction>,
174 rated_predictions: HashSet<EditPredictionId>,
175}
176
177#[derive(Copy, Clone, Default, PartialEq, Eq)]
178pub enum EditPredictionModel {
179 #[default]
180 Zeta1,
181 Zeta2,
182 Sweep,
183 Mercury,
184}
185
186#[derive(Debug, Clone, PartialEq)]
187pub struct ZetaOptions {
188 pub context: EditPredictionExcerptOptions,
189 pub max_prompt_bytes: usize,
190 pub prompt_format: predict_edits_v3::PromptFormat,
191}
192
193#[derive(Debug)]
194pub enum DebugEvent {
195 ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
196 ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
197 EditPredictionRequested(EditPredictionRequestedDebugEvent),
198}
199
200#[derive(Debug)]
201pub struct ContextRetrievalStartedDebugEvent {
202 pub project_entity_id: EntityId,
203 pub timestamp: Instant,
204 pub search_prompt: String,
205}
206
207#[derive(Debug)]
208pub struct ContextRetrievalFinishedDebugEvent {
209 pub project_entity_id: EntityId,
210 pub timestamp: Instant,
211 pub metadata: Vec<(&'static str, SharedString)>,
212}
213
214#[derive(Debug)]
215pub struct EditPredictionRequestedDebugEvent {
216 pub inputs: EditPredictionInputs,
217 pub retrieval_time: Duration,
218 pub buffer: WeakEntity<Buffer>,
219 pub position: Anchor,
220 pub local_prompt: Result<String, String>,
221 pub response_rx: oneshot::Receiver<(Result<open_ai::Response, String>, Duration)>,
222}
223
224pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
225
226struct ProjectState {
227 events: VecDeque<Arc<cloud_llm_client::predict_edits_v3::Event>>,
228 last_event: Option<LastEvent>,
229 recent_paths: VecDeque<ProjectPath>,
230 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
231 current_prediction: Option<CurrentEditPrediction>,
232 next_pending_prediction_id: usize,
233 pending_predictions: ArrayVec<PendingPrediction, 2>,
234 context_updates_tx: smol::channel::Sender<()>,
235 context_updates_rx: smol::channel::Receiver<()>,
236 last_prediction_refresh: Option<(EntityId, Instant)>,
237 cancelled_predictions: HashSet<usize>,
238 context: Entity<RelatedExcerptStore>,
239 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
240 _subscription: gpui::Subscription,
241}
242
243impl ProjectState {
244 pub fn events(&self, cx: &App) -> Vec<Arc<cloud_llm_client::predict_edits_v3::Event>> {
245 self.events
246 .iter()
247 .cloned()
248 .chain(
249 self.last_event
250 .as_ref()
251 .and_then(|event| event.finalize(&self.license_detection_watchers, cx)),
252 )
253 .collect()
254 }
255
256 fn cancel_pending_prediction(
257 &mut self,
258 pending_prediction: PendingPrediction,
259 cx: &mut Context<EditPredictionStore>,
260 ) {
261 self.cancelled_predictions.insert(pending_prediction.id);
262
263 cx.spawn(async move |this, cx| {
264 let Some(prediction_id) = pending_prediction.task.await else {
265 return;
266 };
267
268 this.update(cx, |this, _cx| {
269 this.reject_prediction(prediction_id, EditPredictionRejectReason::Canceled, false);
270 })
271 .ok();
272 })
273 .detach()
274 }
275}
276
277#[derive(Debug, Clone)]
278struct CurrentEditPrediction {
279 pub requested_by: PredictionRequestedBy,
280 pub prediction: EditPrediction,
281 pub was_shown: bool,
282}
283
284impl CurrentEditPrediction {
285 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
286 let Some(new_edits) = self
287 .prediction
288 .interpolate(&self.prediction.buffer.read(cx))
289 else {
290 return false;
291 };
292
293 if self.prediction.buffer != old_prediction.prediction.buffer {
294 return true;
295 }
296
297 let Some(old_edits) = old_prediction
298 .prediction
299 .interpolate(&old_prediction.prediction.buffer.read(cx))
300 else {
301 return true;
302 };
303
304 let requested_by_buffer_id = self.requested_by.buffer_id();
305
306 // This reduces the occurrence of UI thrash from replacing edits
307 //
308 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
309 if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
310 && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
311 && old_edits.len() == 1
312 && new_edits.len() == 1
313 {
314 let (old_range, old_text) = &old_edits[0];
315 let (new_range, new_text) = &new_edits[0];
316 new_range == old_range && new_text.starts_with(old_text.as_ref())
317 } else {
318 true
319 }
320 }
321}
322
323#[derive(Debug, Clone)]
324enum PredictionRequestedBy {
325 DiagnosticsUpdate,
326 Buffer(EntityId),
327}
328
329impl PredictionRequestedBy {
330 pub fn buffer_id(&self) -> Option<EntityId> {
331 match self {
332 PredictionRequestedBy::DiagnosticsUpdate => None,
333 PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
334 }
335 }
336}
337
338#[derive(Debug)]
339struct PendingPrediction {
340 id: usize,
341 task: Task<Option<EditPredictionId>>,
342}
343
344/// A prediction from the perspective of a buffer.
345#[derive(Debug)]
346enum BufferEditPrediction<'a> {
347 Local { prediction: &'a EditPrediction },
348 Jump { prediction: &'a EditPrediction },
349}
350
351#[cfg(test)]
352impl std::ops::Deref for BufferEditPrediction<'_> {
353 type Target = EditPrediction;
354
355 fn deref(&self) -> &Self::Target {
356 match self {
357 BufferEditPrediction::Local { prediction } => prediction,
358 BufferEditPrediction::Jump { prediction } => prediction,
359 }
360 }
361}
362
363struct RegisteredBuffer {
364 snapshot: BufferSnapshot,
365 _subscriptions: [gpui::Subscription; 2],
366}
367
368struct LastEvent {
369 old_snapshot: BufferSnapshot,
370 new_snapshot: BufferSnapshot,
371 end_edit_anchor: Option<Anchor>,
372}
373
374impl LastEvent {
375 pub fn finalize(
376 &self,
377 license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
378 cx: &App,
379 ) -> Option<Arc<predict_edits_v3::Event>> {
380 let path = buffer_path_with_id_fallback(&self.new_snapshot, cx);
381 let old_path = buffer_path_with_id_fallback(&self.old_snapshot, cx);
382
383 let file = self.new_snapshot.file();
384 let old_file = self.old_snapshot.file();
385
386 let in_open_source_repo = [file, old_file].iter().all(|file| {
387 file.is_some_and(|file| {
388 license_detection_watchers
389 .get(&file.worktree_id(cx))
390 .is_some_and(|watcher| watcher.is_project_open_source())
391 })
392 });
393
394 let diff = language::unified_diff(&self.old_snapshot.text(), &self.new_snapshot.text());
395
396 if path == old_path && diff.is_empty() {
397 None
398 } else {
399 Some(Arc::new(predict_edits_v3::Event::BufferChange {
400 old_path,
401 path,
402 diff,
403 in_open_source_repo,
404 // TODO: Actually detect if this edit was predicted or not
405 predicted: false,
406 }))
407 }
408 }
409}
410
411fn buffer_path_with_id_fallback(snapshot: &BufferSnapshot, cx: &App) -> Arc<Path> {
412 if let Some(file) = snapshot.file() {
413 file.full_path(cx).into()
414 } else {
415 Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
416 }
417}
418
419impl EditPredictionStore {
420 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
421 cx.try_global::<EditPredictionStoreGlobal>()
422 .map(|global| global.0.clone())
423 }
424
425 pub fn global(
426 client: &Arc<Client>,
427 user_store: &Entity<UserStore>,
428 cx: &mut App,
429 ) -> Entity<Self> {
430 cx.try_global::<EditPredictionStoreGlobal>()
431 .map(|global| global.0.clone())
432 .unwrap_or_else(|| {
433 let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
434 cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
435 ep_store
436 })
437 }
438
439 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
440 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
441 let data_collection_choice = Self::load_data_collection_choice();
442
443 let llm_token = LlmApiToken::default();
444
445 let (reject_tx, reject_rx) = mpsc::unbounded();
446 cx.background_spawn({
447 let client = client.clone();
448 let llm_token = llm_token.clone();
449 let app_version = AppVersion::global(cx);
450 let background_executor = cx.background_executor().clone();
451 async move {
452 Self::handle_rejected_predictions(
453 reject_rx,
454 client,
455 llm_token,
456 app_version,
457 background_executor,
458 )
459 .await
460 }
461 })
462 .detach();
463
464 let mut this = Self {
465 projects: HashMap::default(),
466 client,
467 user_store,
468 options: DEFAULT_OPTIONS,
469 use_context: false,
470 llm_token,
471 _llm_token_subscription: cx.subscribe(
472 &refresh_llm_token_listener,
473 |this, _listener, _event, cx| {
474 let client = this.client.clone();
475 let llm_token = this.llm_token.clone();
476 cx.spawn(async move |_this, _cx| {
477 llm_token.refresh(&client).await?;
478 anyhow::Ok(())
479 })
480 .detach_and_log_err(cx);
481 },
482 ),
483 update_required: false,
484 debug_tx: None,
485 #[cfg(feature = "eval-support")]
486 eval_cache: None,
487 edit_prediction_model: EditPredictionModel::Zeta2,
488 sweep_ai: SweepAi::new(cx),
489 mercury: Mercury::new(cx),
490 data_collection_choice,
491 reject_predictions_tx: reject_tx,
492 rated_predictions: Default::default(),
493 shown_predictions: Default::default(),
494 };
495
496 this.configure_context_retrieval(cx);
497 let weak_this = cx.weak_entity();
498 cx.on_flags_ready(move |_, cx| {
499 weak_this
500 .update(cx, |this, cx| this.configure_context_retrieval(cx))
501 .ok();
502 })
503 .detach();
504 cx.observe_global::<SettingsStore>(|this, cx| {
505 this.configure_context_retrieval(cx);
506 })
507 .detach();
508
509 this
510 }
511
512 pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
513 self.edit_prediction_model = model;
514 }
515
516 pub fn has_sweep_api_token(&self) -> bool {
517 self.sweep_ai
518 .api_token
519 .clone()
520 .now_or_never()
521 .flatten()
522 .is_some()
523 }
524
525 pub fn has_mercury_api_token(&self) -> bool {
526 self.mercury
527 .api_token
528 .clone()
529 .now_or_never()
530 .flatten()
531 .is_some()
532 }
533
534 #[cfg(feature = "eval-support")]
535 pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
536 self.eval_cache = Some(cache);
537 }
538
539 pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<DebugEvent> {
540 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
541 self.debug_tx = Some(debug_watch_tx);
542 debug_watch_rx
543 }
544
545 pub fn options(&self) -> &ZetaOptions {
546 &self.options
547 }
548
549 pub fn set_options(&mut self, options: ZetaOptions) {
550 self.options = options;
551 }
552
553 pub fn set_use_context(&mut self, use_context: bool) {
554 self.use_context = use_context;
555 }
556
557 pub fn clear_history(&mut self) {
558 for project_state in self.projects.values_mut() {
559 project_state.events.clear();
560 }
561 }
562
563 pub fn context_for_project<'a>(
564 &'a self,
565 project: &Entity<Project>,
566 cx: &'a App,
567 ) -> &'a [RelatedFile] {
568 self.projects
569 .get(&project.entity_id())
570 .map(|project| project.context.read(cx).related_files())
571 .unwrap_or(&[])
572 }
573
574 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
575 if self.edit_prediction_model == EditPredictionModel::Zeta2 {
576 self.user_store.read(cx).edit_prediction_usage()
577 } else {
578 None
579 }
580 }
581
582 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
583 self.get_or_init_project(project, cx);
584 }
585
586 pub fn register_buffer(
587 &mut self,
588 buffer: &Entity<Buffer>,
589 project: &Entity<Project>,
590 cx: &mut Context<Self>,
591 ) {
592 let project_state = self.get_or_init_project(project, cx);
593 Self::register_buffer_impl(project_state, buffer, project, cx);
594 }
595
596 fn get_or_init_project(
597 &mut self,
598 project: &Entity<Project>,
599 cx: &mut Context<Self>,
600 ) -> &mut ProjectState {
601 let entity_id = project.entity_id();
602 let (context_updates_tx, context_updates_rx) = smol::channel::unbounded();
603 self.projects
604 .entry(entity_id)
605 .or_insert_with(|| ProjectState {
606 context: {
607 let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
608 cx.subscribe(
609 &related_excerpt_store,
610 move |this, _, event, _| match event {
611 RelatedExcerptStoreEvent::StartedRefresh => {
612 if let Some(debug_tx) = this.debug_tx.clone() {
613 debug_tx
614 .unbounded_send(DebugEvent::ContextRetrievalStarted(
615 ContextRetrievalStartedDebugEvent {
616 project_entity_id: entity_id,
617 timestamp: Instant::now(),
618 search_prompt: String::new(),
619 },
620 ))
621 .ok();
622 }
623 }
624 RelatedExcerptStoreEvent::FinishedRefresh {
625 cache_hit_count,
626 cache_miss_count,
627 mean_definition_latency,
628 max_definition_latency,
629 } => {
630 if let Some(debug_tx) = this.debug_tx.clone() {
631 debug_tx
632 .unbounded_send(DebugEvent::ContextRetrievalFinished(
633 ContextRetrievalFinishedDebugEvent {
634 project_entity_id: entity_id,
635 timestamp: Instant::now(),
636 metadata: vec![
637 (
638 "Cache Hits",
639 format!(
640 "{}/{}",
641 cache_hit_count,
642 cache_hit_count + cache_miss_count
643 )
644 .into(),
645 ),
646 (
647 "Max LSP Time",
648 format!(
649 "{} ms",
650 max_definition_latency.as_millis()
651 )
652 .into(),
653 ),
654 (
655 "Mean LSP Time",
656 format!(
657 "{} ms",
658 mean_definition_latency.as_millis()
659 )
660 .into(),
661 ),
662 ],
663 },
664 ))
665 .ok();
666 }
667 if let Some(project_state) = this.projects.get(&entity_id) {
668 project_state.context_updates_tx.send_blocking(()).ok();
669 }
670 }
671 },
672 )
673 .detach();
674 related_excerpt_store
675 },
676 events: VecDeque::new(),
677 last_event: None,
678 recent_paths: VecDeque::new(),
679 context_updates_rx,
680 context_updates_tx,
681 registered_buffers: HashMap::default(),
682 current_prediction: None,
683 cancelled_predictions: HashSet::default(),
684 pending_predictions: ArrayVec::new(),
685 next_pending_prediction_id: 0,
686 last_prediction_refresh: None,
687 license_detection_watchers: HashMap::default(),
688 _subscription: cx.subscribe(&project, Self::handle_project_event),
689 })
690 }
691
692 pub fn project_context_updates(
693 &self,
694 project: &Entity<Project>,
695 ) -> Option<smol::channel::Receiver<()>> {
696 let project_state = self.projects.get(&project.entity_id())?;
697 Some(project_state.context_updates_rx.clone())
698 }
699
700 fn handle_project_event(
701 &mut self,
702 project: Entity<Project>,
703 event: &project::Event,
704 cx: &mut Context<Self>,
705 ) {
706 // TODO [zeta2] init with recent paths
707 match event {
708 project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
709 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
710 return;
711 };
712 let path = project.read(cx).path_for_entry(*active_entry_id, cx);
713 if let Some(path) = path {
714 if let Some(ix) = project_state
715 .recent_paths
716 .iter()
717 .position(|probe| probe == &path)
718 {
719 project_state.recent_paths.remove(ix);
720 }
721 project_state.recent_paths.push_front(path);
722 }
723 }
724 project::Event::DiagnosticsUpdated { .. } => {
725 if cx.has_flag::<Zeta2FeatureFlag>() {
726 self.refresh_prediction_from_diagnostics(project, cx);
727 }
728 }
729 _ => (),
730 }
731 }
732
733 fn register_buffer_impl<'a>(
734 project_state: &'a mut ProjectState,
735 buffer: &Entity<Buffer>,
736 project: &Entity<Project>,
737 cx: &mut Context<Self>,
738 ) -> &'a mut RegisteredBuffer {
739 let buffer_id = buffer.entity_id();
740
741 if let Some(file) = buffer.read(cx).file() {
742 let worktree_id = file.worktree_id(cx);
743 if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
744 project_state
745 .license_detection_watchers
746 .entry(worktree_id)
747 .or_insert_with(|| {
748 let project_entity_id = project.entity_id();
749 cx.observe_release(&worktree, move |this, _worktree, _cx| {
750 let Some(project_state) = this.projects.get_mut(&project_entity_id)
751 else {
752 return;
753 };
754 project_state
755 .license_detection_watchers
756 .remove(&worktree_id);
757 })
758 .detach();
759 Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
760 });
761 }
762 }
763
764 match project_state.registered_buffers.entry(buffer_id) {
765 hash_map::Entry::Occupied(entry) => entry.into_mut(),
766 hash_map::Entry::Vacant(entry) => {
767 let snapshot = buffer.read(cx).snapshot();
768 let project_entity_id = project.entity_id();
769 entry.insert(RegisteredBuffer {
770 snapshot,
771 _subscriptions: [
772 cx.subscribe(buffer, {
773 let project = project.downgrade();
774 move |this, buffer, event, cx| {
775 if let language::BufferEvent::Edited = event
776 && let Some(project) = project.upgrade()
777 {
778 this.report_changes_for_buffer(&buffer, &project, cx);
779 }
780 }
781 }),
782 cx.observe_release(buffer, move |this, _buffer, _cx| {
783 let Some(project_state) = this.projects.get_mut(&project_entity_id)
784 else {
785 return;
786 };
787 project_state.registered_buffers.remove(&buffer_id);
788 }),
789 ],
790 })
791 }
792 }
793 }
794
795 fn report_changes_for_buffer(
796 &mut self,
797 buffer: &Entity<Buffer>,
798 project: &Entity<Project>,
799 cx: &mut Context<Self>,
800 ) {
801 let project_state = self.get_or_init_project(project, cx);
802 let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
803
804 let new_snapshot = buffer.read(cx).snapshot();
805 if new_snapshot.version == registered_buffer.snapshot.version {
806 return;
807 }
808
809 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
810 let end_edit_anchor = new_snapshot
811 .anchored_edits_since::<Point>(&old_snapshot.version)
812 .last()
813 .map(|(_, range)| range.end);
814 let events = &mut project_state.events;
815
816 if let Some(LastEvent {
817 new_snapshot: last_new_snapshot,
818 end_edit_anchor: last_end_edit_anchor,
819 ..
820 }) = project_state.last_event.as_mut()
821 {
822 let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
823 == last_new_snapshot.remote_id()
824 && old_snapshot.version == last_new_snapshot.version;
825
826 let should_coalesce = is_next_snapshot_of_same_buffer
827 && end_edit_anchor
828 .as_ref()
829 .zip(last_end_edit_anchor.as_ref())
830 .is_some_and(|(a, b)| {
831 let a = a.to_point(&new_snapshot);
832 let b = b.to_point(&new_snapshot);
833 a.row.abs_diff(b.row) <= CHANGE_GROUPING_LINE_SPAN
834 });
835
836 if should_coalesce {
837 *last_end_edit_anchor = end_edit_anchor;
838 *last_new_snapshot = new_snapshot;
839 return;
840 }
841 }
842
843 if events.len() + 1 >= EVENT_COUNT_MAX {
844 events.pop_front();
845 }
846
847 if let Some(event) = project_state.last_event.take() {
848 events.extend(event.finalize(&project_state.license_detection_watchers, cx));
849 }
850
851 project_state.last_event = Some(LastEvent {
852 old_snapshot,
853 new_snapshot,
854 end_edit_anchor,
855 });
856 }
857
858 fn current_prediction_for_buffer(
859 &self,
860 buffer: &Entity<Buffer>,
861 project: &Entity<Project>,
862 cx: &App,
863 ) -> Option<BufferEditPrediction<'_>> {
864 let project_state = self.projects.get(&project.entity_id())?;
865
866 let CurrentEditPrediction {
867 requested_by,
868 prediction,
869 ..
870 } = project_state.current_prediction.as_ref()?;
871
872 if prediction.targets_buffer(buffer.read(cx)) {
873 Some(BufferEditPrediction::Local { prediction })
874 } else {
875 let show_jump = match requested_by {
876 PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
877 requested_by_buffer_id == &buffer.entity_id()
878 }
879 PredictionRequestedBy::DiagnosticsUpdate => true,
880 };
881
882 if show_jump {
883 Some(BufferEditPrediction::Jump { prediction })
884 } else {
885 None
886 }
887 }
888 }
889
890 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
891 match self.edit_prediction_model {
892 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {}
893 EditPredictionModel::Sweep | EditPredictionModel::Mercury => return,
894 }
895
896 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
897 return;
898 };
899
900 let Some(prediction) = project_state.current_prediction.take() else {
901 return;
902 };
903 let request_id = prediction.prediction.id.to_string();
904 for pending_prediction in mem::take(&mut project_state.pending_predictions) {
905 project_state.cancel_pending_prediction(pending_prediction, cx);
906 }
907
908 let client = self.client.clone();
909 let llm_token = self.llm_token.clone();
910 let app_version = AppVersion::global(cx);
911 cx.spawn(async move |this, cx| {
912 let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") {
913 http_client::Url::parse(&predict_edits_url)?
914 } else {
915 client
916 .http_client()
917 .build_zed_llm_url("/predict_edits/accept", &[])?
918 };
919
920 let response = cx
921 .background_spawn(Self::send_api_request::<()>(
922 move |builder| {
923 let req = builder.uri(url.as_ref()).body(
924 serde_json::to_string(&AcceptEditPredictionBody {
925 request_id: request_id.clone(),
926 })?
927 .into(),
928 );
929 Ok(req?)
930 },
931 client,
932 llm_token,
933 app_version,
934 ))
935 .await;
936
937 Self::handle_api_response(&this, response, cx)?;
938 anyhow::Ok(())
939 })
940 .detach_and_log_err(cx);
941 }
942
943 async fn handle_rejected_predictions(
944 rx: UnboundedReceiver<EditPredictionRejection>,
945 client: Arc<Client>,
946 llm_token: LlmApiToken,
947 app_version: Version,
948 background_executor: BackgroundExecutor,
949 ) {
950 let mut rx = std::pin::pin!(rx.peekable());
951 let mut batched = Vec::new();
952
953 while let Some(rejection) = rx.next().await {
954 batched.push(rejection);
955
956 if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
957 select_biased! {
958 next = rx.as_mut().peek().fuse() => {
959 if next.is_some() {
960 continue;
961 }
962 }
963 () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
964 }
965 }
966
967 let url = client
968 .http_client()
969 .build_zed_llm_url("/predict_edits/reject", &[])
970 .unwrap();
971
972 let flush_count = batched
973 .len()
974 // in case items have accumulated after failure
975 .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
976 let start = batched.len() - flush_count;
977
978 let body = RejectEditPredictionsBodyRef {
979 rejections: &batched[start..],
980 };
981
982 let result = Self::send_api_request::<()>(
983 |builder| {
984 let req = builder
985 .uri(url.as_ref())
986 .body(serde_json::to_string(&body)?.into());
987 anyhow::Ok(req?)
988 },
989 client.clone(),
990 llm_token.clone(),
991 app_version.clone(),
992 )
993 .await;
994
995 if result.log_err().is_some() {
996 batched.drain(start..);
997 }
998 }
999 }
1000
1001 fn reject_current_prediction(
1002 &mut self,
1003 reason: EditPredictionRejectReason,
1004 project: &Entity<Project>,
1005 ) {
1006 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1007 project_state.pending_predictions.clear();
1008 if let Some(prediction) = project_state.current_prediction.take() {
1009 self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown);
1010 }
1011 };
1012 }
1013
1014 fn did_show_current_prediction(&mut self, project: &Entity<Project>, _cx: &mut Context<Self>) {
1015 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1016 if let Some(current_prediction) = project_state.current_prediction.as_mut() {
1017 if !current_prediction.was_shown {
1018 current_prediction.was_shown = true;
1019 self.shown_predictions
1020 .push_front(current_prediction.prediction.clone());
1021 if self.shown_predictions.len() > 50 {
1022 let completion = self.shown_predictions.pop_back().unwrap();
1023 self.rated_predictions.remove(&completion.id);
1024 }
1025 }
1026 }
1027 }
1028 }
1029
1030 fn reject_prediction(
1031 &mut self,
1032 prediction_id: EditPredictionId,
1033 reason: EditPredictionRejectReason,
1034 was_shown: bool,
1035 ) {
1036 match self.edit_prediction_model {
1037 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {}
1038 EditPredictionModel::Sweep | EditPredictionModel::Mercury => return,
1039 }
1040
1041 self.reject_predictions_tx
1042 .unbounded_send(EditPredictionRejection {
1043 request_id: prediction_id.to_string(),
1044 reason,
1045 was_shown,
1046 })
1047 .log_err();
1048 }
1049
1050 fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1051 self.projects
1052 .get(&project.entity_id())
1053 .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1054 }
1055
1056 pub fn refresh_prediction_from_buffer(
1057 &mut self,
1058 project: Entity<Project>,
1059 buffer: Entity<Buffer>,
1060 position: language::Anchor,
1061 cx: &mut Context<Self>,
1062 ) {
1063 self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
1064 let Some(request_task) = this
1065 .update(cx, |this, cx| {
1066 this.request_prediction(
1067 &project,
1068 &buffer,
1069 position,
1070 PredictEditsRequestTrigger::Other,
1071 cx,
1072 )
1073 })
1074 .log_err()
1075 else {
1076 return Task::ready(anyhow::Ok(None));
1077 };
1078
1079 cx.spawn(async move |_cx| {
1080 request_task.await.map(|prediction_result| {
1081 prediction_result.map(|prediction_result| {
1082 (
1083 prediction_result,
1084 PredictionRequestedBy::Buffer(buffer.entity_id()),
1085 )
1086 })
1087 })
1088 })
1089 })
1090 }
1091
1092 pub fn refresh_prediction_from_diagnostics(
1093 &mut self,
1094 project: Entity<Project>,
1095 cx: &mut Context<Self>,
1096 ) {
1097 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1098 return;
1099 };
1100
1101 // Prefer predictions from buffer
1102 if project_state.current_prediction.is_some() {
1103 return;
1104 };
1105
1106 self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
1107 let Some(open_buffer_task) = project
1108 .update(cx, |project, cx| {
1109 project
1110 .active_entry()
1111 .and_then(|entry| project.path_for_entry(entry, cx))
1112 .map(|path| project.open_buffer(path, cx))
1113 })
1114 .log_err()
1115 .flatten()
1116 else {
1117 return Task::ready(anyhow::Ok(None));
1118 };
1119
1120 cx.spawn(async move |cx| {
1121 let active_buffer = open_buffer_task.await?;
1122 let snapshot = active_buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
1123
1124 let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1125 active_buffer,
1126 &snapshot,
1127 Default::default(),
1128 Default::default(),
1129 &project,
1130 cx,
1131 )
1132 .await?
1133 else {
1134 return anyhow::Ok(None);
1135 };
1136
1137 let Some(prediction_result) = this
1138 .update(cx, |this, cx| {
1139 this.request_prediction(
1140 &project,
1141 &jump_buffer,
1142 jump_position,
1143 PredictEditsRequestTrigger::Diagnostics,
1144 cx,
1145 )
1146 })?
1147 .await?
1148 else {
1149 return anyhow::Ok(None);
1150 };
1151
1152 this.update(cx, |this, cx| {
1153 Some((
1154 if this
1155 .get_or_init_project(&project, cx)
1156 .current_prediction
1157 .is_none()
1158 {
1159 prediction_result
1160 } else {
1161 EditPredictionResult {
1162 id: prediction_result.id,
1163 prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1164 }
1165 },
1166 PredictionRequestedBy::DiagnosticsUpdate,
1167 ))
1168 })
1169 })
1170 });
1171 }
1172
1173 #[cfg(not(test))]
1174 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1175 #[cfg(test)]
1176 pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
1177
1178 fn queue_prediction_refresh(
1179 &mut self,
1180 project: Entity<Project>,
1181 throttle_entity: EntityId,
1182 cx: &mut Context<Self>,
1183 do_refresh: impl FnOnce(
1184 WeakEntity<Self>,
1185 &mut AsyncApp,
1186 )
1187 -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1188 + 'static,
1189 ) {
1190 let project_state = self.get_or_init_project(&project, cx);
1191 let pending_prediction_id = project_state.next_pending_prediction_id;
1192 project_state.next_pending_prediction_id += 1;
1193 let last_request = project_state.last_prediction_refresh;
1194
1195 let task = cx.spawn(async move |this, cx| {
1196 if let Some((last_entity, last_timestamp)) = last_request
1197 && throttle_entity == last_entity
1198 && let Some(timeout) =
1199 (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
1200 {
1201 cx.background_executor().timer(timeout).await;
1202 }
1203
1204 // If this task was cancelled before the throttle timeout expired,
1205 // do not perform a request.
1206 let mut is_cancelled = true;
1207 this.update(cx, |this, cx| {
1208 let project_state = this.get_or_init_project(&project, cx);
1209 if !project_state
1210 .cancelled_predictions
1211 .remove(&pending_prediction_id)
1212 {
1213 project_state.last_prediction_refresh = Some((throttle_entity, Instant::now()));
1214 is_cancelled = false;
1215 }
1216 })
1217 .ok();
1218 if is_cancelled {
1219 return None;
1220 }
1221
1222 let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
1223 let new_prediction_id = new_prediction_result
1224 .as_ref()
1225 .map(|(prediction, _)| prediction.id.clone());
1226
1227 // When a prediction completes, remove it from the pending list, and cancel
1228 // any pending predictions that were enqueued before it.
1229 this.update(cx, |this, cx| {
1230 let project_state = this.get_or_init_project(&project, cx);
1231
1232 let is_cancelled = project_state
1233 .cancelled_predictions
1234 .remove(&pending_prediction_id);
1235
1236 let new_current_prediction = if !is_cancelled
1237 && let Some((prediction_result, requested_by)) = new_prediction_result
1238 {
1239 match prediction_result.prediction {
1240 Ok(prediction) => {
1241 let new_prediction = CurrentEditPrediction {
1242 requested_by,
1243 prediction,
1244 was_shown: false,
1245 };
1246
1247 if let Some(current_prediction) =
1248 project_state.current_prediction.as_ref()
1249 {
1250 if new_prediction.should_replace_prediction(¤t_prediction, cx)
1251 {
1252 this.reject_current_prediction(
1253 EditPredictionRejectReason::Replaced,
1254 &project,
1255 );
1256
1257 Some(new_prediction)
1258 } else {
1259 this.reject_prediction(
1260 new_prediction.prediction.id,
1261 EditPredictionRejectReason::CurrentPreferred,
1262 false,
1263 );
1264 None
1265 }
1266 } else {
1267 Some(new_prediction)
1268 }
1269 }
1270 Err(reject_reason) => {
1271 this.reject_prediction(prediction_result.id, reject_reason, false);
1272 None
1273 }
1274 }
1275 } else {
1276 None
1277 };
1278
1279 let project_state = this.get_or_init_project(&project, cx);
1280
1281 if let Some(new_prediction) = new_current_prediction {
1282 project_state.current_prediction = Some(new_prediction);
1283 }
1284
1285 let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
1286 for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1287 if pending_prediction.id == pending_prediction_id {
1288 pending_predictions.remove(ix);
1289 for pending_prediction in pending_predictions.drain(0..ix) {
1290 project_state.cancel_pending_prediction(pending_prediction, cx)
1291 }
1292 break;
1293 }
1294 }
1295 this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
1296 cx.notify();
1297 })
1298 .ok();
1299
1300 new_prediction_id
1301 });
1302
1303 if project_state.pending_predictions.len() <= 1 {
1304 project_state.pending_predictions.push(PendingPrediction {
1305 id: pending_prediction_id,
1306 task,
1307 });
1308 } else if project_state.pending_predictions.len() == 2 {
1309 let pending_prediction = project_state.pending_predictions.pop().unwrap();
1310 project_state.pending_predictions.push(PendingPrediction {
1311 id: pending_prediction_id,
1312 task,
1313 });
1314 project_state.cancel_pending_prediction(pending_prediction, cx);
1315 }
1316 }
1317
1318 pub fn request_prediction(
1319 &mut self,
1320 project: &Entity<Project>,
1321 active_buffer: &Entity<Buffer>,
1322 position: language::Anchor,
1323 trigger: PredictEditsRequestTrigger,
1324 cx: &mut Context<Self>,
1325 ) -> Task<Result<Option<EditPredictionResult>>> {
1326 self.request_prediction_internal(
1327 project.clone(),
1328 active_buffer.clone(),
1329 position,
1330 trigger,
1331 cx.has_flag::<Zeta2FeatureFlag>(),
1332 cx,
1333 )
1334 }
1335
1336 fn request_prediction_internal(
1337 &mut self,
1338 project: Entity<Project>,
1339 active_buffer: Entity<Buffer>,
1340 position: language::Anchor,
1341 trigger: PredictEditsRequestTrigger,
1342 allow_jump: bool,
1343 cx: &mut Context<Self>,
1344 ) -> Task<Result<Option<EditPredictionResult>>> {
1345 const DIAGNOSTIC_LINES_RANGE: u32 = 20;
1346
1347 self.get_or_init_project(&project, cx);
1348 let project_state = self.projects.get(&project.entity_id()).unwrap();
1349 let events = project_state.events(cx);
1350 let has_events = !events.is_empty();
1351
1352 let snapshot = active_buffer.read(cx).snapshot();
1353 let cursor_point = position.to_point(&snapshot);
1354 let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1355 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1356 let diagnostic_search_range =
1357 Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1358
1359 let related_files = if self.use_context {
1360 self.context_for_project(&project, cx).to_vec()
1361 } else {
1362 Vec::new()
1363 };
1364
1365 let task = match self.edit_prediction_model {
1366 EditPredictionModel::Zeta1 => zeta1::request_prediction_with_zeta1(
1367 self,
1368 &project,
1369 &active_buffer,
1370 snapshot.clone(),
1371 position,
1372 events,
1373 trigger,
1374 cx,
1375 ),
1376 EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2(
1377 self,
1378 &project,
1379 &active_buffer,
1380 snapshot.clone(),
1381 position,
1382 events,
1383 related_files,
1384 trigger,
1385 cx,
1386 ),
1387 EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(
1388 &project,
1389 &active_buffer,
1390 snapshot.clone(),
1391 position,
1392 events,
1393 &project_state.recent_paths,
1394 related_files,
1395 diagnostic_search_range.clone(),
1396 cx,
1397 ),
1398 EditPredictionModel::Mercury => self.mercury.request_prediction(
1399 &project,
1400 &active_buffer,
1401 snapshot.clone(),
1402 position,
1403 events,
1404 &project_state.recent_paths,
1405 related_files,
1406 diagnostic_search_range.clone(),
1407 cx,
1408 ),
1409 };
1410
1411 cx.spawn(async move |this, cx| {
1412 let prediction = task.await?;
1413
1414 if prediction.is_none() && allow_jump {
1415 let cursor_point = position.to_point(&snapshot);
1416 if has_events
1417 && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1418 active_buffer.clone(),
1419 &snapshot,
1420 diagnostic_search_range,
1421 cursor_point,
1422 &project,
1423 cx,
1424 )
1425 .await?
1426 {
1427 return this
1428 .update(cx, |this, cx| {
1429 this.request_prediction_internal(
1430 project,
1431 jump_buffer,
1432 jump_position,
1433 trigger,
1434 false,
1435 cx,
1436 )
1437 })?
1438 .await;
1439 }
1440
1441 return anyhow::Ok(None);
1442 }
1443
1444 Ok(prediction)
1445 })
1446 }
1447
1448 async fn next_diagnostic_location(
1449 active_buffer: Entity<Buffer>,
1450 active_buffer_snapshot: &BufferSnapshot,
1451 active_buffer_diagnostic_search_range: Range<Point>,
1452 active_buffer_cursor_point: Point,
1453 project: &Entity<Project>,
1454 cx: &mut AsyncApp,
1455 ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1456 // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
1457 let mut jump_location = active_buffer_snapshot
1458 .diagnostic_groups(None)
1459 .into_iter()
1460 .filter_map(|(_, group)| {
1461 let range = &group.entries[group.primary_ix]
1462 .range
1463 .to_point(&active_buffer_snapshot);
1464 if range.overlaps(&active_buffer_diagnostic_search_range) {
1465 None
1466 } else {
1467 Some(range.start)
1468 }
1469 })
1470 .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
1471 .map(|position| {
1472 (
1473 active_buffer.clone(),
1474 active_buffer_snapshot.anchor_before(position),
1475 )
1476 });
1477
1478 if jump_location.is_none() {
1479 let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
1480 let file = buffer.file()?;
1481
1482 Some(ProjectPath {
1483 worktree_id: file.worktree_id(cx),
1484 path: file.path().clone(),
1485 })
1486 })?;
1487
1488 let buffer_task = project.update(cx, |project, cx| {
1489 let (path, _, _) = project
1490 .diagnostic_summaries(false, cx)
1491 .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
1492 .max_by_key(|(path, _, _)| {
1493 // find the buffer with errors that shares most parent directories
1494 path.path
1495 .components()
1496 .zip(
1497 active_buffer_path
1498 .as_ref()
1499 .map(|p| p.path.components())
1500 .unwrap_or_default(),
1501 )
1502 .take_while(|(a, b)| a == b)
1503 .count()
1504 })?;
1505
1506 Some(project.open_buffer(path, cx))
1507 })?;
1508
1509 if let Some(buffer_task) = buffer_task {
1510 let closest_buffer = buffer_task.await?;
1511
1512 jump_location = closest_buffer
1513 .read_with(cx, |buffer, _cx| {
1514 buffer
1515 .buffer_diagnostics(None)
1516 .into_iter()
1517 .min_by_key(|entry| entry.diagnostic.severity)
1518 .map(|entry| entry.range.start)
1519 })?
1520 .map(|position| (closest_buffer, position));
1521 }
1522 }
1523
1524 anyhow::Ok(jump_location)
1525 }
1526
1527 async fn send_raw_llm_request(
1528 request: open_ai::Request,
1529 client: Arc<Client>,
1530 llm_token: LlmApiToken,
1531 app_version: Version,
1532 #[cfg(feature = "eval-support")] eval_cache: Option<Arc<dyn EvalCache>>,
1533 #[cfg(feature = "eval-support")] eval_cache_kind: EvalCacheEntryKind,
1534 ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
1535 let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
1536 http_client::Url::parse(&predict_edits_url)?
1537 } else {
1538 client
1539 .http_client()
1540 .build_zed_llm_url("/predict_edits/raw", &[])?
1541 };
1542
1543 #[cfg(feature = "eval-support")]
1544 let cache_key = if let Some(cache) = eval_cache {
1545 use collections::FxHasher;
1546 use std::hash::{Hash, Hasher};
1547
1548 let mut hasher = FxHasher::default();
1549 url.hash(&mut hasher);
1550 let request_str = serde_json::to_string_pretty(&request)?;
1551 request_str.hash(&mut hasher);
1552 let hash = hasher.finish();
1553
1554 let key = (eval_cache_kind, hash);
1555 if let Some(response_str) = cache.read(key) {
1556 return Ok((serde_json::from_str(&response_str)?, None));
1557 }
1558
1559 Some((cache, request_str, key))
1560 } else {
1561 None
1562 };
1563
1564 let (response, usage) = Self::send_api_request(
1565 |builder| {
1566 let req = builder
1567 .uri(url.as_ref())
1568 .body(serde_json::to_string(&request)?.into());
1569 Ok(req?)
1570 },
1571 client,
1572 llm_token,
1573 app_version,
1574 )
1575 .await?;
1576
1577 #[cfg(feature = "eval-support")]
1578 if let Some((cache, request, key)) = cache_key {
1579 cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
1580 }
1581
1582 Ok((response, usage))
1583 }
1584
1585 fn handle_api_response<T>(
1586 this: &WeakEntity<Self>,
1587 response: Result<(T, Option<EditPredictionUsage>)>,
1588 cx: &mut gpui::AsyncApp,
1589 ) -> Result<T> {
1590 match response {
1591 Ok((data, usage)) => {
1592 if let Some(usage) = usage {
1593 this.update(cx, |this, cx| {
1594 this.user_store.update(cx, |user_store, cx| {
1595 user_store.update_edit_prediction_usage(usage, cx);
1596 });
1597 })
1598 .ok();
1599 }
1600 Ok(data)
1601 }
1602 Err(err) => {
1603 if err.is::<ZedUpdateRequiredError>() {
1604 cx.update(|cx| {
1605 this.update(cx, |this, _cx| {
1606 this.update_required = true;
1607 })
1608 .ok();
1609
1610 let error_message: SharedString = err.to_string().into();
1611 show_app_notification(
1612 NotificationId::unique::<ZedUpdateRequiredError>(),
1613 cx,
1614 move |cx| {
1615 cx.new(|cx| {
1616 ErrorMessagePrompt::new(error_message.clone(), cx)
1617 .with_link_button("Update Zed", "https://zed.dev/releases")
1618 })
1619 },
1620 );
1621 })
1622 .ok();
1623 }
1624 Err(err)
1625 }
1626 }
1627 }
1628
1629 async fn send_api_request<Res>(
1630 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1631 client: Arc<Client>,
1632 llm_token: LlmApiToken,
1633 app_version: Version,
1634 ) -> Result<(Res, Option<EditPredictionUsage>)>
1635 where
1636 Res: DeserializeOwned,
1637 {
1638 let http_client = client.http_client();
1639 let mut token = llm_token.acquire(&client).await?;
1640 let mut did_retry = false;
1641
1642 loop {
1643 let request_builder = http_client::Request::builder().method(Method::POST);
1644
1645 let request = build(
1646 request_builder
1647 .header("Content-Type", "application/json")
1648 .header("Authorization", format!("Bearer {}", token))
1649 .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
1650 )?;
1651
1652 let mut response = http_client.send(request).await?;
1653
1654 if let Some(minimum_required_version) = response
1655 .headers()
1656 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
1657 .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
1658 {
1659 anyhow::ensure!(
1660 app_version >= minimum_required_version,
1661 ZedUpdateRequiredError {
1662 minimum_version: minimum_required_version
1663 }
1664 );
1665 }
1666
1667 if response.status().is_success() {
1668 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
1669
1670 let mut body = Vec::new();
1671 response.body_mut().read_to_end(&mut body).await?;
1672 return Ok((serde_json::from_slice(&body)?, usage));
1673 } else if !did_retry
1674 && response
1675 .headers()
1676 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1677 .is_some()
1678 {
1679 did_retry = true;
1680 token = llm_token.refresh(&client).await?;
1681 } else {
1682 let mut body = String::new();
1683 response.body_mut().read_to_string(&mut body).await?;
1684 anyhow::bail!(
1685 "Request failed with status: {:?}\nBody: {}",
1686 response.status(),
1687 body
1688 );
1689 }
1690 }
1691 }
1692
1693 pub fn refresh_context(
1694 &mut self,
1695 project: &Entity<Project>,
1696 buffer: &Entity<language::Buffer>,
1697 cursor_position: language::Anchor,
1698 cx: &mut Context<Self>,
1699 ) {
1700 if self.use_context {
1701 self.get_or_init_project(project, cx)
1702 .context
1703 .update(cx, |store, cx| {
1704 store.refresh(buffer.clone(), cursor_position, cx);
1705 });
1706 }
1707 }
1708
1709 fn is_file_open_source(
1710 &self,
1711 project: &Entity<Project>,
1712 file: &Arc<dyn File>,
1713 cx: &App,
1714 ) -> bool {
1715 if !file.is_local() || file.is_private() {
1716 return false;
1717 }
1718 let Some(project_state) = self.projects.get(&project.entity_id()) else {
1719 return false;
1720 };
1721 project_state
1722 .license_detection_watchers
1723 .get(&file.worktree_id(cx))
1724 .as_ref()
1725 .is_some_and(|watcher| watcher.is_project_open_source())
1726 }
1727
1728 fn can_collect_file(&self, project: &Entity<Project>, file: &Arc<dyn File>, cx: &App) -> bool {
1729 self.data_collection_choice.is_enabled() && self.is_file_open_source(project, file, cx)
1730 }
1731
1732 fn can_collect_events(&self, events: &[Arc<Event>]) -> bool {
1733 if !self.data_collection_choice.is_enabled() {
1734 return false;
1735 }
1736 events.iter().all(|event| {
1737 matches!(
1738 event.as_ref(),
1739 Event::BufferChange {
1740 in_open_source_repo: true,
1741 ..
1742 }
1743 )
1744 })
1745 }
1746
1747 fn load_data_collection_choice() -> DataCollectionChoice {
1748 let choice = KEY_VALUE_STORE
1749 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
1750 .log_err()
1751 .flatten();
1752
1753 match choice.as_deref() {
1754 Some("true") => DataCollectionChoice::Enabled,
1755 Some("false") => DataCollectionChoice::Disabled,
1756 Some(_) => {
1757 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
1758 DataCollectionChoice::NotAnswered
1759 }
1760 None => DataCollectionChoice::NotAnswered,
1761 }
1762 }
1763
1764 fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
1765 self.data_collection_choice = self.data_collection_choice.toggle();
1766 let new_choice = self.data_collection_choice;
1767 db::write_and_log(cx, move || {
1768 KEY_VALUE_STORE.write_kvp(
1769 ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
1770 new_choice.is_enabled().to_string(),
1771 )
1772 });
1773 }
1774
1775 pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
1776 self.shown_predictions.iter()
1777 }
1778
1779 pub fn shown_completions_len(&self) -> usize {
1780 self.shown_predictions.len()
1781 }
1782
1783 pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
1784 self.rated_predictions.contains(id)
1785 }
1786
1787 pub fn rate_prediction(
1788 &mut self,
1789 prediction: &EditPrediction,
1790 rating: EditPredictionRating,
1791 feedback: String,
1792 cx: &mut Context<Self>,
1793 ) {
1794 self.rated_predictions.insert(prediction.id.clone());
1795 telemetry::event!(
1796 "Edit Prediction Rated",
1797 rating,
1798 inputs = prediction.inputs,
1799 output = prediction.edit_preview.as_unified_diff(&prediction.edits),
1800 feedback
1801 );
1802 self.client.telemetry().flush_events().detach();
1803 cx.notify();
1804 }
1805
1806 fn configure_context_retrieval(&mut self, cx: &mut Context<'_, EditPredictionStore>) {
1807 self.use_context = cx.has_flag::<Zeta2FeatureFlag>()
1808 && all_language_settings(None, cx).edit_predictions.use_context;
1809 }
1810}
1811
1812#[derive(Error, Debug)]
1813#[error(
1814 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1815)]
1816pub struct ZedUpdateRequiredError {
1817 minimum_version: Version,
1818}
1819
1820#[cfg(feature = "eval-support")]
1821pub type EvalCacheKey = (EvalCacheEntryKind, u64);
1822
1823#[cfg(feature = "eval-support")]
1824#[derive(Debug, Clone, Copy, PartialEq)]
1825pub enum EvalCacheEntryKind {
1826 Context,
1827 Search,
1828 Prediction,
1829}
1830
1831#[cfg(feature = "eval-support")]
1832impl std::fmt::Display for EvalCacheEntryKind {
1833 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1834 match self {
1835 EvalCacheEntryKind::Search => write!(f, "search"),
1836 EvalCacheEntryKind::Context => write!(f, "context"),
1837 EvalCacheEntryKind::Prediction => write!(f, "prediction"),
1838 }
1839 }
1840}
1841
1842#[cfg(feature = "eval-support")]
1843pub trait EvalCache: Send + Sync {
1844 fn read(&self, key: EvalCacheKey) -> Option<String>;
1845 fn write(&self, key: EvalCacheKey, input: &str, value: &str);
1846}
1847
1848#[derive(Debug, Clone, Copy)]
1849pub enum DataCollectionChoice {
1850 NotAnswered,
1851 Enabled,
1852 Disabled,
1853}
1854
1855impl DataCollectionChoice {
1856 pub fn is_enabled(self) -> bool {
1857 match self {
1858 Self::Enabled => true,
1859 Self::NotAnswered | Self::Disabled => false,
1860 }
1861 }
1862
1863 pub fn is_answered(self) -> bool {
1864 match self {
1865 Self::Enabled | Self::Disabled => true,
1866 Self::NotAnswered => false,
1867 }
1868 }
1869
1870 #[must_use]
1871 pub fn toggle(&self) -> DataCollectionChoice {
1872 match self {
1873 Self::Enabled => Self::Disabled,
1874 Self::Disabled => Self::Enabled,
1875 Self::NotAnswered => Self::Enabled,
1876 }
1877 }
1878}
1879
1880impl From<bool> for DataCollectionChoice {
1881 fn from(value: bool) -> Self {
1882 match value {
1883 true => DataCollectionChoice::Enabled,
1884 false => DataCollectionChoice::Disabled,
1885 }
1886 }
1887}
1888
1889struct ZedPredictUpsell;
1890
1891impl Dismissable for ZedPredictUpsell {
1892 const KEY: &'static str = "dismissed-edit-predict-upsell";
1893
1894 fn dismissed() -> bool {
1895 // To make this backwards compatible with older versions of Zed, we
1896 // check if the user has seen the previous Edit Prediction Onboarding
1897 // before, by checking the data collection choice which was written to
1898 // the database once the user clicked on "Accept and Enable"
1899 if KEY_VALUE_STORE
1900 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
1901 .log_err()
1902 .is_some_and(|s| s.is_some())
1903 {
1904 return true;
1905 }
1906
1907 KEY_VALUE_STORE
1908 .read_kvp(Self::KEY)
1909 .log_err()
1910 .is_some_and(|s| s.is_some())
1911 }
1912}
1913
1914pub fn should_show_upsell_modal() -> bool {
1915 !ZedPredictUpsell::dismissed()
1916}
1917
1918pub fn init(cx: &mut App) {
1919 cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
1920 workspace.register_action(
1921 move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
1922 ZedPredictModal::toggle(
1923 workspace,
1924 workspace.user_store().clone(),
1925 workspace.client().clone(),
1926 window,
1927 cx,
1928 )
1929 },
1930 );
1931
1932 workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
1933 update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
1934 settings
1935 .project
1936 .all_languages
1937 .features
1938 .get_or_insert_default()
1939 .edit_prediction_provider = Some(EditPredictionProvider::None)
1940 });
1941 });
1942 })
1943 .detach();
1944}