1use anyhow::Result;
2use arrayvec::ArrayVec;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat};
5use cloud_llm_client::{
6 AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason,
7 EditPredictionRejection, MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST,
8 MINIMUM_REQUIRED_VERSION_HEADER_NAME, PredictEditsRequestTrigger, RejectEditPredictionsBodyRef,
9 ZED_VERSION_HEADER_NAME,
10};
11use cloud_zeta2_prompt::DEFAULT_MAX_PROMPT_BYTES;
12use collections::{HashMap, HashSet};
13use db::kvp::{Dismissable, KEY_VALUE_STORE};
14use edit_prediction_context::EditPredictionExcerptOptions;
15use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
16use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
17use futures::{
18 AsyncReadExt as _, FutureExt as _, StreamExt as _,
19 channel::{
20 mpsc::{self, UnboundedReceiver},
21 oneshot,
22 },
23 select_biased,
24};
25use gpui::BackgroundExecutor;
26use gpui::{
27 App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
28 http_client::{self, AsyncBody, Method},
29 prelude::*,
30};
31use language::language_settings::all_language_settings;
32use language::{Anchor, Buffer, File, Point, ToPoint};
33use language::{BufferSnapshot, OffsetRangeExt};
34use language_model::{LlmApiToken, RefreshLlmTokenListener};
35use project::{Project, ProjectPath, WorktreeId};
36use release_channel::AppVersion;
37use semver::Version;
38use serde::de::DeserializeOwned;
39use settings::{EditPredictionProvider, SettingsStore, update_settings_file};
40use std::collections::{VecDeque, hash_map};
41use workspace::Workspace;
42
43use std::ops::Range;
44use std::path::Path;
45use std::rc::Rc;
46use std::str::FromStr as _;
47use std::sync::{Arc, LazyLock};
48use std::time::{Duration, Instant};
49use std::{env, mem};
50use thiserror::Error;
51use util::{RangeExt as _, ResultExt as _};
52use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
53
54mod license_detection;
55mod onboarding_modal;
56mod prediction;
57pub mod sweep_ai;
58pub mod udiff;
59mod xml_edits;
60mod zed_edit_prediction_delegate;
61pub mod zeta1;
62pub mod zeta2;
63
64#[cfg(test)]
65mod edit_prediction_tests;
66
67use crate::license_detection::LicenseDetectionWatcher;
68use crate::onboarding_modal::ZedPredictModal;
69pub use crate::prediction::EditPrediction;
70pub use crate::prediction::EditPredictionId;
71pub use crate::prediction::EditPredictionInputs;
72use crate::prediction::EditPredictionResult;
73pub use crate::sweep_ai::SweepAi;
74pub use telemetry_events::EditPredictionRating;
75pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
76
77actions!(
78 edit_prediction,
79 [
80 /// Resets the edit prediction onboarding state.
81 ResetOnboarding,
82 /// Clears the edit prediction history.
83 ClearHistory,
84 ]
85);
86
87/// Maximum number of events to track.
88const EVENT_COUNT_MAX: usize = 6;
89const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
90const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
91const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
92
93pub struct SweepFeatureFlag;
94
95impl FeatureFlag for SweepFeatureFlag {
96 const NAME: &str = "sweep-ai";
97}
98
99pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
100 context: EditPredictionExcerptOptions {
101 max_bytes: 512,
102 min_bytes: 128,
103 target_before_cursor_over_total_bytes: 0.5,
104 },
105 max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
106 prompt_format: PromptFormat::DEFAULT,
107};
108
109static USE_OLLAMA: LazyLock<bool> =
110 LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
111
112static EDIT_PREDICTIONS_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
113 match env::var("ZED_ZETA2_MODEL").as_deref() {
114 Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten
115 Ok(model) => model,
116 Err(_) if *USE_OLLAMA => "qwen3-coder:30b",
117 Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten
118 }
119 .to_string()
120});
121static PREDICT_EDITS_URL: LazyLock<Option<String>> = LazyLock::new(|| {
122 env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| {
123 if *USE_OLLAMA {
124 Some("http://localhost:11434/v1/chat/completions".into())
125 } else {
126 None
127 }
128 })
129});
130
131pub struct Zeta2FeatureFlag;
132
133impl FeatureFlag for Zeta2FeatureFlag {
134 const NAME: &'static str = "zeta2";
135
136 fn enabled_for_staff() -> bool {
137 true
138 }
139}
140
141#[derive(Clone)]
142struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
143
144impl Global for EditPredictionStoreGlobal {}
145
146pub struct EditPredictionStore {
147 client: Arc<Client>,
148 user_store: Entity<UserStore>,
149 llm_token: LlmApiToken,
150 _llm_token_subscription: Subscription,
151 projects: HashMap<EntityId, ProjectState>,
152 use_context: bool,
153 options: ZetaOptions,
154 update_required: bool,
155 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
156 #[cfg(feature = "eval-support")]
157 eval_cache: Option<Arc<dyn EvalCache>>,
158 edit_prediction_model: EditPredictionModel,
159 pub sweep_ai: SweepAi,
160 data_collection_choice: DataCollectionChoice,
161 reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
162 shown_predictions: VecDeque<EditPrediction>,
163 rated_predictions: HashSet<EditPredictionId>,
164}
165
166#[derive(Copy, Clone, Default, PartialEq, Eq)]
167pub enum EditPredictionModel {
168 #[default]
169 Zeta1,
170 Zeta2,
171 Sweep,
172}
173
174#[derive(Debug, Clone, PartialEq)]
175pub struct ZetaOptions {
176 pub context: EditPredictionExcerptOptions,
177 pub max_prompt_bytes: usize,
178 pub prompt_format: predict_edits_v3::PromptFormat,
179}
180
181#[derive(Debug)]
182pub enum DebugEvent {
183 ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
184 ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
185 EditPredictionRequested(EditPredictionRequestedDebugEvent),
186}
187
188#[derive(Debug)]
189pub struct ContextRetrievalStartedDebugEvent {
190 pub project_entity_id: EntityId,
191 pub timestamp: Instant,
192 pub search_prompt: String,
193}
194
195#[derive(Debug)]
196pub struct ContextRetrievalFinishedDebugEvent {
197 pub project_entity_id: EntityId,
198 pub timestamp: Instant,
199 pub metadata: Vec<(&'static str, SharedString)>,
200}
201
202#[derive(Debug)]
203pub struct EditPredictionRequestedDebugEvent {
204 pub inputs: EditPredictionInputs,
205 pub retrieval_time: Duration,
206 pub buffer: WeakEntity<Buffer>,
207 pub position: Anchor,
208 pub local_prompt: Result<String, String>,
209 pub response_rx: oneshot::Receiver<(Result<open_ai::Response, String>, Duration)>,
210}
211
212pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
213
214struct ProjectState {
215 events: VecDeque<Arc<cloud_llm_client::predict_edits_v3::Event>>,
216 last_event: Option<LastEvent>,
217 recent_paths: VecDeque<ProjectPath>,
218 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
219 current_prediction: Option<CurrentEditPrediction>,
220 next_pending_prediction_id: usize,
221 pending_predictions: ArrayVec<PendingPrediction, 2>,
222 context_updates_tx: smol::channel::Sender<()>,
223 context_updates_rx: smol::channel::Receiver<()>,
224 last_prediction_refresh: Option<(EntityId, Instant)>,
225 cancelled_predictions: HashSet<usize>,
226 context: Entity<RelatedExcerptStore>,
227 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
228 _subscription: gpui::Subscription,
229}
230
231impl ProjectState {
232 pub fn events(&self, cx: &App) -> Vec<Arc<cloud_llm_client::predict_edits_v3::Event>> {
233 self.events
234 .iter()
235 .cloned()
236 .chain(
237 self.last_event
238 .as_ref()
239 .and_then(|event| event.finalize(&self.license_detection_watchers, cx)),
240 )
241 .collect()
242 }
243
244 fn cancel_pending_prediction(
245 &mut self,
246 pending_prediction: PendingPrediction,
247 cx: &mut Context<EditPredictionStore>,
248 ) {
249 self.cancelled_predictions.insert(pending_prediction.id);
250
251 cx.spawn(async move |this, cx| {
252 let Some(prediction_id) = pending_prediction.task.await else {
253 return;
254 };
255
256 this.update(cx, |this, _cx| {
257 this.reject_prediction(prediction_id, EditPredictionRejectReason::Canceled, false);
258 })
259 .ok();
260 })
261 .detach()
262 }
263}
264
265#[derive(Debug, Clone)]
266struct CurrentEditPrediction {
267 pub requested_by: PredictionRequestedBy,
268 pub prediction: EditPrediction,
269 pub was_shown: bool,
270}
271
272impl CurrentEditPrediction {
273 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
274 let Some(new_edits) = self
275 .prediction
276 .interpolate(&self.prediction.buffer.read(cx))
277 else {
278 return false;
279 };
280
281 if self.prediction.buffer != old_prediction.prediction.buffer {
282 return true;
283 }
284
285 let Some(old_edits) = old_prediction
286 .prediction
287 .interpolate(&old_prediction.prediction.buffer.read(cx))
288 else {
289 return true;
290 };
291
292 let requested_by_buffer_id = self.requested_by.buffer_id();
293
294 // This reduces the occurrence of UI thrash from replacing edits
295 //
296 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
297 if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
298 && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
299 && old_edits.len() == 1
300 && new_edits.len() == 1
301 {
302 let (old_range, old_text) = &old_edits[0];
303 let (new_range, new_text) = &new_edits[0];
304 new_range == old_range && new_text.starts_with(old_text.as_ref())
305 } else {
306 true
307 }
308 }
309}
310
311#[derive(Debug, Clone)]
312enum PredictionRequestedBy {
313 DiagnosticsUpdate,
314 Buffer(EntityId),
315}
316
317impl PredictionRequestedBy {
318 pub fn buffer_id(&self) -> Option<EntityId> {
319 match self {
320 PredictionRequestedBy::DiagnosticsUpdate => None,
321 PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
322 }
323 }
324}
325
326#[derive(Debug)]
327struct PendingPrediction {
328 id: usize,
329 task: Task<Option<EditPredictionId>>,
330}
331
332/// A prediction from the perspective of a buffer.
333#[derive(Debug)]
334enum BufferEditPrediction<'a> {
335 Local { prediction: &'a EditPrediction },
336 Jump { prediction: &'a EditPrediction },
337}
338
339#[cfg(test)]
340impl std::ops::Deref for BufferEditPrediction<'_> {
341 type Target = EditPrediction;
342
343 fn deref(&self) -> &Self::Target {
344 match self {
345 BufferEditPrediction::Local { prediction } => prediction,
346 BufferEditPrediction::Jump { prediction } => prediction,
347 }
348 }
349}
350
351struct RegisteredBuffer {
352 snapshot: BufferSnapshot,
353 _subscriptions: [gpui::Subscription; 2],
354}
355
356struct LastEvent {
357 old_snapshot: BufferSnapshot,
358 new_snapshot: BufferSnapshot,
359 end_edit_anchor: Option<Anchor>,
360}
361
362impl LastEvent {
363 pub fn finalize(
364 &self,
365 license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
366 cx: &App,
367 ) -> Option<Arc<predict_edits_v3::Event>> {
368 let path = buffer_path_with_id_fallback(&self.new_snapshot, cx);
369 let old_path = buffer_path_with_id_fallback(&self.old_snapshot, cx);
370
371 let file = self.new_snapshot.file();
372 let old_file = self.old_snapshot.file();
373
374 let in_open_source_repo = [file, old_file].iter().all(|file| {
375 file.is_some_and(|file| {
376 license_detection_watchers
377 .get(&file.worktree_id(cx))
378 .is_some_and(|watcher| watcher.is_project_open_source())
379 })
380 });
381
382 let diff = language::unified_diff(&self.old_snapshot.text(), &self.new_snapshot.text());
383
384 if path == old_path && diff.is_empty() {
385 None
386 } else {
387 Some(Arc::new(predict_edits_v3::Event::BufferChange {
388 old_path,
389 path,
390 diff,
391 in_open_source_repo,
392 // TODO: Actually detect if this edit was predicted or not
393 predicted: false,
394 }))
395 }
396 }
397}
398
399fn buffer_path_with_id_fallback(snapshot: &BufferSnapshot, cx: &App) -> Arc<Path> {
400 if let Some(file) = snapshot.file() {
401 file.full_path(cx).into()
402 } else {
403 Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
404 }
405}
406
407impl EditPredictionStore {
408 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
409 cx.try_global::<EditPredictionStoreGlobal>()
410 .map(|global| global.0.clone())
411 }
412
413 pub fn global(
414 client: &Arc<Client>,
415 user_store: &Entity<UserStore>,
416 cx: &mut App,
417 ) -> Entity<Self> {
418 cx.try_global::<EditPredictionStoreGlobal>()
419 .map(|global| global.0.clone())
420 .unwrap_or_else(|| {
421 let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
422 cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
423 ep_store
424 })
425 }
426
427 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
428 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
429 let data_collection_choice = Self::load_data_collection_choice();
430
431 let llm_token = LlmApiToken::default();
432
433 let (reject_tx, reject_rx) = mpsc::unbounded();
434 cx.background_spawn({
435 let client = client.clone();
436 let llm_token = llm_token.clone();
437 let app_version = AppVersion::global(cx);
438 let background_executor = cx.background_executor().clone();
439 async move {
440 Self::handle_rejected_predictions(
441 reject_rx,
442 client,
443 llm_token,
444 app_version,
445 background_executor,
446 )
447 .await
448 }
449 })
450 .detach();
451
452 let mut this = Self {
453 projects: HashMap::default(),
454 client,
455 user_store,
456 options: DEFAULT_OPTIONS,
457 use_context: false,
458 llm_token,
459 _llm_token_subscription: cx.subscribe(
460 &refresh_llm_token_listener,
461 |this, _listener, _event, cx| {
462 let client = this.client.clone();
463 let llm_token = this.llm_token.clone();
464 cx.spawn(async move |_this, _cx| {
465 llm_token.refresh(&client).await?;
466 anyhow::Ok(())
467 })
468 .detach_and_log_err(cx);
469 },
470 ),
471 update_required: false,
472 debug_tx: None,
473 #[cfg(feature = "eval-support")]
474 eval_cache: None,
475 edit_prediction_model: EditPredictionModel::Zeta2,
476 sweep_ai: SweepAi::new(cx),
477 data_collection_choice,
478 reject_predictions_tx: reject_tx,
479 rated_predictions: Default::default(),
480 shown_predictions: Default::default(),
481 };
482
483 this.configure_context_retrieval(cx);
484 let weak_this = cx.weak_entity();
485 cx.on_flags_ready(move |_, cx| {
486 weak_this
487 .update(cx, |this, cx| this.configure_context_retrieval(cx))
488 .ok();
489 })
490 .detach();
491 cx.observe_global::<SettingsStore>(|this, cx| {
492 this.configure_context_retrieval(cx);
493 })
494 .detach();
495
496 this
497 }
498
499 pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
500 self.edit_prediction_model = model;
501 }
502
503 pub fn has_sweep_api_token(&self) -> bool {
504 self.sweep_ai
505 .api_token
506 .clone()
507 .now_or_never()
508 .flatten()
509 .is_some()
510 }
511
512 #[cfg(feature = "eval-support")]
513 pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
514 self.eval_cache = Some(cache);
515 }
516
517 pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<DebugEvent> {
518 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
519 self.debug_tx = Some(debug_watch_tx);
520 debug_watch_rx
521 }
522
523 pub fn options(&self) -> &ZetaOptions {
524 &self.options
525 }
526
527 pub fn set_options(&mut self, options: ZetaOptions) {
528 self.options = options;
529 }
530
531 pub fn set_use_context(&mut self, use_context: bool) {
532 self.use_context = use_context;
533 }
534
535 pub fn clear_history(&mut self) {
536 for project_state in self.projects.values_mut() {
537 project_state.events.clear();
538 }
539 }
540
541 pub fn context_for_project<'a>(
542 &'a self,
543 project: &Entity<Project>,
544 cx: &'a App,
545 ) -> &'a [RelatedFile] {
546 self.projects
547 .get(&project.entity_id())
548 .map(|project| project.context.read(cx).related_files())
549 .unwrap_or(&[])
550 }
551
552 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
553 if self.edit_prediction_model == EditPredictionModel::Zeta2 {
554 self.user_store.read(cx).edit_prediction_usage()
555 } else {
556 None
557 }
558 }
559
560 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
561 self.get_or_init_project(project, cx);
562 }
563
564 pub fn register_buffer(
565 &mut self,
566 buffer: &Entity<Buffer>,
567 project: &Entity<Project>,
568 cx: &mut Context<Self>,
569 ) {
570 let project_state = self.get_or_init_project(project, cx);
571 Self::register_buffer_impl(project_state, buffer, project, cx);
572 }
573
574 fn get_or_init_project(
575 &mut self,
576 project: &Entity<Project>,
577 cx: &mut Context<Self>,
578 ) -> &mut ProjectState {
579 let entity_id = project.entity_id();
580 let (context_updates_tx, context_updates_rx) = smol::channel::unbounded();
581 self.projects
582 .entry(entity_id)
583 .or_insert_with(|| ProjectState {
584 context: {
585 let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
586 cx.subscribe(
587 &related_excerpt_store,
588 move |this, _, event, _| match event {
589 RelatedExcerptStoreEvent::StartedRefresh => {
590 if let Some(debug_tx) = this.debug_tx.clone() {
591 debug_tx
592 .unbounded_send(DebugEvent::ContextRetrievalStarted(
593 ContextRetrievalStartedDebugEvent {
594 project_entity_id: entity_id,
595 timestamp: Instant::now(),
596 search_prompt: String::new(),
597 },
598 ))
599 .ok();
600 }
601 }
602 RelatedExcerptStoreEvent::FinishedRefresh {
603 cache_hit_count,
604 cache_miss_count,
605 mean_definition_latency,
606 max_definition_latency,
607 } => {
608 if let Some(debug_tx) = this.debug_tx.clone() {
609 debug_tx
610 .unbounded_send(DebugEvent::ContextRetrievalFinished(
611 ContextRetrievalFinishedDebugEvent {
612 project_entity_id: entity_id,
613 timestamp: Instant::now(),
614 metadata: vec![
615 (
616 "Cache Hits",
617 format!(
618 "{}/{}",
619 cache_hit_count,
620 cache_hit_count + cache_miss_count
621 )
622 .into(),
623 ),
624 (
625 "Max LSP Time",
626 format!(
627 "{} ms",
628 max_definition_latency.as_millis()
629 )
630 .into(),
631 ),
632 (
633 "Mean LSP Time",
634 format!(
635 "{} ms",
636 mean_definition_latency.as_millis()
637 )
638 .into(),
639 ),
640 ],
641 },
642 ))
643 .ok();
644 }
645 if let Some(project_state) = this.projects.get(&entity_id) {
646 project_state.context_updates_tx.send_blocking(()).ok();
647 }
648 }
649 },
650 )
651 .detach();
652 related_excerpt_store
653 },
654 events: VecDeque::new(),
655 last_event: None,
656 recent_paths: VecDeque::new(),
657 context_updates_rx,
658 context_updates_tx,
659 registered_buffers: HashMap::default(),
660 current_prediction: None,
661 cancelled_predictions: HashSet::default(),
662 pending_predictions: ArrayVec::new(),
663 next_pending_prediction_id: 0,
664 last_prediction_refresh: None,
665 license_detection_watchers: HashMap::default(),
666 _subscription: cx.subscribe(&project, Self::handle_project_event),
667 })
668 }
669
670 pub fn project_context_updates(
671 &self,
672 project: &Entity<Project>,
673 ) -> Option<smol::channel::Receiver<()>> {
674 let project_state = self.projects.get(&project.entity_id())?;
675 Some(project_state.context_updates_rx.clone())
676 }
677
678 fn handle_project_event(
679 &mut self,
680 project: Entity<Project>,
681 event: &project::Event,
682 cx: &mut Context<Self>,
683 ) {
684 // TODO [zeta2] init with recent paths
685 match event {
686 project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
687 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
688 return;
689 };
690 let path = project.read(cx).path_for_entry(*active_entry_id, cx);
691 if let Some(path) = path {
692 if let Some(ix) = project_state
693 .recent_paths
694 .iter()
695 .position(|probe| probe == &path)
696 {
697 project_state.recent_paths.remove(ix);
698 }
699 project_state.recent_paths.push_front(path);
700 }
701 }
702 project::Event::DiagnosticsUpdated { .. } => {
703 if cx.has_flag::<Zeta2FeatureFlag>() {
704 self.refresh_prediction_from_diagnostics(project, cx);
705 }
706 }
707 _ => (),
708 }
709 }
710
711 fn register_buffer_impl<'a>(
712 project_state: &'a mut ProjectState,
713 buffer: &Entity<Buffer>,
714 project: &Entity<Project>,
715 cx: &mut Context<Self>,
716 ) -> &'a mut RegisteredBuffer {
717 let buffer_id = buffer.entity_id();
718
719 if let Some(file) = buffer.read(cx).file() {
720 let worktree_id = file.worktree_id(cx);
721 if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
722 project_state
723 .license_detection_watchers
724 .entry(worktree_id)
725 .or_insert_with(|| {
726 let project_entity_id = project.entity_id();
727 cx.observe_release(&worktree, move |this, _worktree, _cx| {
728 let Some(project_state) = this.projects.get_mut(&project_entity_id)
729 else {
730 return;
731 };
732 project_state
733 .license_detection_watchers
734 .remove(&worktree_id);
735 })
736 .detach();
737 Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
738 });
739 }
740 }
741
742 match project_state.registered_buffers.entry(buffer_id) {
743 hash_map::Entry::Occupied(entry) => entry.into_mut(),
744 hash_map::Entry::Vacant(entry) => {
745 let snapshot = buffer.read(cx).snapshot();
746 let project_entity_id = project.entity_id();
747 entry.insert(RegisteredBuffer {
748 snapshot,
749 _subscriptions: [
750 cx.subscribe(buffer, {
751 let project = project.downgrade();
752 move |this, buffer, event, cx| {
753 if let language::BufferEvent::Edited = event
754 && let Some(project) = project.upgrade()
755 {
756 this.report_changes_for_buffer(&buffer, &project, cx);
757 }
758 }
759 }),
760 cx.observe_release(buffer, move |this, _buffer, _cx| {
761 let Some(project_state) = this.projects.get_mut(&project_entity_id)
762 else {
763 return;
764 };
765 project_state.registered_buffers.remove(&buffer_id);
766 }),
767 ],
768 })
769 }
770 }
771 }
772
773 fn report_changes_for_buffer(
774 &mut self,
775 buffer: &Entity<Buffer>,
776 project: &Entity<Project>,
777 cx: &mut Context<Self>,
778 ) {
779 let project_state = self.get_or_init_project(project, cx);
780 let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
781
782 let new_snapshot = buffer.read(cx).snapshot();
783 if new_snapshot.version == registered_buffer.snapshot.version {
784 return;
785 }
786
787 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
788 let end_edit_anchor = new_snapshot
789 .anchored_edits_since::<Point>(&old_snapshot.version)
790 .last()
791 .map(|(_, range)| range.end);
792 let events = &mut project_state.events;
793
794 if let Some(LastEvent {
795 new_snapshot: last_new_snapshot,
796 end_edit_anchor: last_end_edit_anchor,
797 ..
798 }) = project_state.last_event.as_mut()
799 {
800 let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
801 == last_new_snapshot.remote_id()
802 && old_snapshot.version == last_new_snapshot.version;
803
804 let should_coalesce = is_next_snapshot_of_same_buffer
805 && end_edit_anchor
806 .as_ref()
807 .zip(last_end_edit_anchor.as_ref())
808 .is_some_and(|(a, b)| {
809 let a = a.to_point(&new_snapshot);
810 let b = b.to_point(&new_snapshot);
811 a.row.abs_diff(b.row) <= CHANGE_GROUPING_LINE_SPAN
812 });
813
814 if should_coalesce {
815 *last_end_edit_anchor = end_edit_anchor;
816 *last_new_snapshot = new_snapshot;
817 return;
818 }
819 }
820
821 if events.len() + 1 >= EVENT_COUNT_MAX {
822 events.pop_front();
823 }
824
825 if let Some(event) = project_state.last_event.take() {
826 events.extend(event.finalize(&project_state.license_detection_watchers, cx));
827 }
828
829 project_state.last_event = Some(LastEvent {
830 old_snapshot,
831 new_snapshot,
832 end_edit_anchor,
833 });
834 }
835
836 fn current_prediction_for_buffer(
837 &self,
838 buffer: &Entity<Buffer>,
839 project: &Entity<Project>,
840 cx: &App,
841 ) -> Option<BufferEditPrediction<'_>> {
842 let project_state = self.projects.get(&project.entity_id())?;
843
844 let CurrentEditPrediction {
845 requested_by,
846 prediction,
847 ..
848 } = project_state.current_prediction.as_ref()?;
849
850 if prediction.targets_buffer(buffer.read(cx)) {
851 Some(BufferEditPrediction::Local { prediction })
852 } else {
853 let show_jump = match requested_by {
854 PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
855 requested_by_buffer_id == &buffer.entity_id()
856 }
857 PredictionRequestedBy::DiagnosticsUpdate => true,
858 };
859
860 if show_jump {
861 Some(BufferEditPrediction::Jump { prediction })
862 } else {
863 None
864 }
865 }
866 }
867
868 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
869 match self.edit_prediction_model {
870 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {}
871 EditPredictionModel::Sweep => return,
872 }
873
874 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
875 return;
876 };
877
878 let Some(prediction) = project_state.current_prediction.take() else {
879 return;
880 };
881 let request_id = prediction.prediction.id.to_string();
882 for pending_prediction in mem::take(&mut project_state.pending_predictions) {
883 project_state.cancel_pending_prediction(pending_prediction, cx);
884 }
885
886 let client = self.client.clone();
887 let llm_token = self.llm_token.clone();
888 let app_version = AppVersion::global(cx);
889 cx.spawn(async move |this, cx| {
890 let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") {
891 http_client::Url::parse(&predict_edits_url)?
892 } else {
893 client
894 .http_client()
895 .build_zed_llm_url("/predict_edits/accept", &[])?
896 };
897
898 let response = cx
899 .background_spawn(Self::send_api_request::<()>(
900 move |builder| {
901 let req = builder.uri(url.as_ref()).body(
902 serde_json::to_string(&AcceptEditPredictionBody {
903 request_id: request_id.clone(),
904 })?
905 .into(),
906 );
907 Ok(req?)
908 },
909 client,
910 llm_token,
911 app_version,
912 ))
913 .await;
914
915 Self::handle_api_response(&this, response, cx)?;
916 anyhow::Ok(())
917 })
918 .detach_and_log_err(cx);
919 }
920
921 async fn handle_rejected_predictions(
922 rx: UnboundedReceiver<EditPredictionRejection>,
923 client: Arc<Client>,
924 llm_token: LlmApiToken,
925 app_version: Version,
926 background_executor: BackgroundExecutor,
927 ) {
928 let mut rx = std::pin::pin!(rx.peekable());
929 let mut batched = Vec::new();
930
931 while let Some(rejection) = rx.next().await {
932 batched.push(rejection);
933
934 if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
935 select_biased! {
936 next = rx.as_mut().peek().fuse() => {
937 if next.is_some() {
938 continue;
939 }
940 }
941 () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
942 }
943 }
944
945 let url = client
946 .http_client()
947 .build_zed_llm_url("/predict_edits/reject", &[])
948 .unwrap();
949
950 let flush_count = batched
951 .len()
952 // in case items have accumulated after failure
953 .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
954 let start = batched.len() - flush_count;
955
956 let body = RejectEditPredictionsBodyRef {
957 rejections: &batched[start..],
958 };
959
960 let result = Self::send_api_request::<()>(
961 |builder| {
962 let req = builder
963 .uri(url.as_ref())
964 .body(serde_json::to_string(&body)?.into());
965 anyhow::Ok(req?)
966 },
967 client.clone(),
968 llm_token.clone(),
969 app_version.clone(),
970 )
971 .await;
972
973 if result.log_err().is_some() {
974 batched.drain(start..);
975 }
976 }
977 }
978
979 fn reject_current_prediction(
980 &mut self,
981 reason: EditPredictionRejectReason,
982 project: &Entity<Project>,
983 ) {
984 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
985 project_state.pending_predictions.clear();
986 if let Some(prediction) = project_state.current_prediction.take() {
987 self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown);
988 }
989 };
990 }
991
992 fn did_show_current_prediction(&mut self, project: &Entity<Project>, _cx: &mut Context<Self>) {
993 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
994 if let Some(current_prediction) = project_state.current_prediction.as_mut() {
995 if !current_prediction.was_shown {
996 current_prediction.was_shown = true;
997 self.shown_predictions
998 .push_front(current_prediction.prediction.clone());
999 if self.shown_predictions.len() > 50 {
1000 let completion = self.shown_predictions.pop_back().unwrap();
1001 self.rated_predictions.remove(&completion.id);
1002 }
1003 }
1004 }
1005 }
1006 }
1007
1008 fn reject_prediction(
1009 &mut self,
1010 prediction_id: EditPredictionId,
1011 reason: EditPredictionRejectReason,
1012 was_shown: bool,
1013 ) {
1014 match self.edit_prediction_model {
1015 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {}
1016 EditPredictionModel::Sweep => return,
1017 }
1018
1019 self.reject_predictions_tx
1020 .unbounded_send(EditPredictionRejection {
1021 request_id: prediction_id.to_string(),
1022 reason,
1023 was_shown,
1024 })
1025 .log_err();
1026 }
1027
1028 fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1029 self.projects
1030 .get(&project.entity_id())
1031 .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1032 }
1033
1034 pub fn refresh_prediction_from_buffer(
1035 &mut self,
1036 project: Entity<Project>,
1037 buffer: Entity<Buffer>,
1038 position: language::Anchor,
1039 cx: &mut Context<Self>,
1040 ) {
1041 self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
1042 let Some(request_task) = this
1043 .update(cx, |this, cx| {
1044 this.request_prediction(
1045 &project,
1046 &buffer,
1047 position,
1048 PredictEditsRequestTrigger::Other,
1049 cx,
1050 )
1051 })
1052 .log_err()
1053 else {
1054 return Task::ready(anyhow::Ok(None));
1055 };
1056
1057 cx.spawn(async move |_cx| {
1058 request_task.await.map(|prediction_result| {
1059 prediction_result.map(|prediction_result| {
1060 (
1061 prediction_result,
1062 PredictionRequestedBy::Buffer(buffer.entity_id()),
1063 )
1064 })
1065 })
1066 })
1067 })
1068 }
1069
1070 pub fn refresh_prediction_from_diagnostics(
1071 &mut self,
1072 project: Entity<Project>,
1073 cx: &mut Context<Self>,
1074 ) {
1075 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1076 return;
1077 };
1078
1079 // Prefer predictions from buffer
1080 if project_state.current_prediction.is_some() {
1081 return;
1082 };
1083
1084 self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
1085 let Some(open_buffer_task) = project
1086 .update(cx, |project, cx| {
1087 project
1088 .active_entry()
1089 .and_then(|entry| project.path_for_entry(entry, cx))
1090 .map(|path| project.open_buffer(path, cx))
1091 })
1092 .log_err()
1093 .flatten()
1094 else {
1095 return Task::ready(anyhow::Ok(None));
1096 };
1097
1098 cx.spawn(async move |cx| {
1099 let active_buffer = open_buffer_task.await?;
1100 let snapshot = active_buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
1101
1102 let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1103 active_buffer,
1104 &snapshot,
1105 Default::default(),
1106 Default::default(),
1107 &project,
1108 cx,
1109 )
1110 .await?
1111 else {
1112 return anyhow::Ok(None);
1113 };
1114
1115 let Some(prediction_result) = this
1116 .update(cx, |this, cx| {
1117 this.request_prediction(
1118 &project,
1119 &jump_buffer,
1120 jump_position,
1121 PredictEditsRequestTrigger::Diagnostics,
1122 cx,
1123 )
1124 })?
1125 .await?
1126 else {
1127 return anyhow::Ok(None);
1128 };
1129
1130 this.update(cx, |this, cx| {
1131 Some((
1132 if this
1133 .get_or_init_project(&project, cx)
1134 .current_prediction
1135 .is_none()
1136 {
1137 prediction_result
1138 } else {
1139 EditPredictionResult {
1140 id: prediction_result.id,
1141 prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1142 }
1143 },
1144 PredictionRequestedBy::DiagnosticsUpdate,
1145 ))
1146 })
1147 })
1148 });
1149 }
1150
1151 #[cfg(not(test))]
1152 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1153 #[cfg(test)]
1154 pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
1155
1156 fn queue_prediction_refresh(
1157 &mut self,
1158 project: Entity<Project>,
1159 throttle_entity: EntityId,
1160 cx: &mut Context<Self>,
1161 do_refresh: impl FnOnce(
1162 WeakEntity<Self>,
1163 &mut AsyncApp,
1164 )
1165 -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1166 + 'static,
1167 ) {
1168 let project_state = self.get_or_init_project(&project, cx);
1169 let pending_prediction_id = project_state.next_pending_prediction_id;
1170 project_state.next_pending_prediction_id += 1;
1171 let last_request = project_state.last_prediction_refresh;
1172
1173 let task = cx.spawn(async move |this, cx| {
1174 if let Some((last_entity, last_timestamp)) = last_request
1175 && throttle_entity == last_entity
1176 && let Some(timeout) =
1177 (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
1178 {
1179 cx.background_executor().timer(timeout).await;
1180 }
1181
1182 // If this task was cancelled before the throttle timeout expired,
1183 // do not perform a request.
1184 let mut is_cancelled = true;
1185 this.update(cx, |this, cx| {
1186 let project_state = this.get_or_init_project(&project, cx);
1187 if !project_state
1188 .cancelled_predictions
1189 .remove(&pending_prediction_id)
1190 {
1191 project_state.last_prediction_refresh = Some((throttle_entity, Instant::now()));
1192 is_cancelled = false;
1193 }
1194 })
1195 .ok();
1196 if is_cancelled {
1197 return None;
1198 }
1199
1200 let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
1201 let new_prediction_id = new_prediction_result
1202 .as_ref()
1203 .map(|(prediction, _)| prediction.id.clone());
1204
1205 // When a prediction completes, remove it from the pending list, and cancel
1206 // any pending predictions that were enqueued before it.
1207 this.update(cx, |this, cx| {
1208 let project_state = this.get_or_init_project(&project, cx);
1209
1210 let is_cancelled = project_state
1211 .cancelled_predictions
1212 .remove(&pending_prediction_id);
1213
1214 let new_current_prediction = if !is_cancelled
1215 && let Some((prediction_result, requested_by)) = new_prediction_result
1216 {
1217 match prediction_result.prediction {
1218 Ok(prediction) => {
1219 let new_prediction = CurrentEditPrediction {
1220 requested_by,
1221 prediction,
1222 was_shown: false,
1223 };
1224
1225 if let Some(current_prediction) =
1226 project_state.current_prediction.as_ref()
1227 {
1228 if new_prediction.should_replace_prediction(¤t_prediction, cx)
1229 {
1230 this.reject_current_prediction(
1231 EditPredictionRejectReason::Replaced,
1232 &project,
1233 );
1234
1235 Some(new_prediction)
1236 } else {
1237 this.reject_prediction(
1238 new_prediction.prediction.id,
1239 EditPredictionRejectReason::CurrentPreferred,
1240 false,
1241 );
1242 None
1243 }
1244 } else {
1245 Some(new_prediction)
1246 }
1247 }
1248 Err(reject_reason) => {
1249 this.reject_prediction(prediction_result.id, reject_reason, false);
1250 None
1251 }
1252 }
1253 } else {
1254 None
1255 };
1256
1257 let project_state = this.get_or_init_project(&project, cx);
1258
1259 if let Some(new_prediction) = new_current_prediction {
1260 project_state.current_prediction = Some(new_prediction);
1261 }
1262
1263 let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
1264 for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1265 if pending_prediction.id == pending_prediction_id {
1266 pending_predictions.remove(ix);
1267 for pending_prediction in pending_predictions.drain(0..ix) {
1268 project_state.cancel_pending_prediction(pending_prediction, cx)
1269 }
1270 break;
1271 }
1272 }
1273 this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
1274 cx.notify();
1275 })
1276 .ok();
1277
1278 new_prediction_id
1279 });
1280
1281 if project_state.pending_predictions.len() <= 1 {
1282 project_state.pending_predictions.push(PendingPrediction {
1283 id: pending_prediction_id,
1284 task,
1285 });
1286 } else if project_state.pending_predictions.len() == 2 {
1287 let pending_prediction = project_state.pending_predictions.pop().unwrap();
1288 project_state.pending_predictions.push(PendingPrediction {
1289 id: pending_prediction_id,
1290 task,
1291 });
1292 project_state.cancel_pending_prediction(pending_prediction, cx);
1293 }
1294 }
1295
1296 pub fn request_prediction(
1297 &mut self,
1298 project: &Entity<Project>,
1299 active_buffer: &Entity<Buffer>,
1300 position: language::Anchor,
1301 trigger: PredictEditsRequestTrigger,
1302 cx: &mut Context<Self>,
1303 ) -> Task<Result<Option<EditPredictionResult>>> {
1304 self.request_prediction_internal(
1305 project.clone(),
1306 active_buffer.clone(),
1307 position,
1308 trigger,
1309 cx.has_flag::<Zeta2FeatureFlag>(),
1310 cx,
1311 )
1312 }
1313
1314 fn request_prediction_internal(
1315 &mut self,
1316 project: Entity<Project>,
1317 active_buffer: Entity<Buffer>,
1318 position: language::Anchor,
1319 trigger: PredictEditsRequestTrigger,
1320 allow_jump: bool,
1321 cx: &mut Context<Self>,
1322 ) -> Task<Result<Option<EditPredictionResult>>> {
1323 const DIAGNOSTIC_LINES_RANGE: u32 = 20;
1324
1325 self.get_or_init_project(&project, cx);
1326 let project_state = self.projects.get(&project.entity_id()).unwrap();
1327 let events = project_state.events(cx);
1328 let has_events = !events.is_empty();
1329
1330 let snapshot = active_buffer.read(cx).snapshot();
1331 let cursor_point = position.to_point(&snapshot);
1332 let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1333 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1334 let diagnostic_search_range =
1335 Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1336
1337 let related_files = if self.use_context {
1338 self.context_for_project(&project, cx).to_vec()
1339 } else {
1340 Vec::new()
1341 };
1342
1343 let task = match self.edit_prediction_model {
1344 EditPredictionModel::Zeta1 => zeta1::request_prediction_with_zeta1(
1345 self,
1346 &project,
1347 &active_buffer,
1348 snapshot.clone(),
1349 position,
1350 events,
1351 trigger,
1352 cx,
1353 ),
1354 EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2(
1355 self,
1356 &project,
1357 &active_buffer,
1358 snapshot.clone(),
1359 position,
1360 events,
1361 related_files,
1362 trigger,
1363 cx,
1364 ),
1365 EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(
1366 &project,
1367 &active_buffer,
1368 snapshot.clone(),
1369 position,
1370 events,
1371 &project_state.recent_paths,
1372 related_files,
1373 diagnostic_search_range.clone(),
1374 cx,
1375 ),
1376 };
1377
1378 cx.spawn(async move |this, cx| {
1379 let prediction = task.await?;
1380
1381 if prediction.is_none() && allow_jump {
1382 let cursor_point = position.to_point(&snapshot);
1383 if has_events
1384 && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1385 active_buffer.clone(),
1386 &snapshot,
1387 diagnostic_search_range,
1388 cursor_point,
1389 &project,
1390 cx,
1391 )
1392 .await?
1393 {
1394 return this
1395 .update(cx, |this, cx| {
1396 this.request_prediction_internal(
1397 project,
1398 jump_buffer,
1399 jump_position,
1400 trigger,
1401 false,
1402 cx,
1403 )
1404 })?
1405 .await;
1406 }
1407
1408 return anyhow::Ok(None);
1409 }
1410
1411 Ok(prediction)
1412 })
1413 }
1414
1415 async fn next_diagnostic_location(
1416 active_buffer: Entity<Buffer>,
1417 active_buffer_snapshot: &BufferSnapshot,
1418 active_buffer_diagnostic_search_range: Range<Point>,
1419 active_buffer_cursor_point: Point,
1420 project: &Entity<Project>,
1421 cx: &mut AsyncApp,
1422 ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1423 // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
1424 let mut jump_location = active_buffer_snapshot
1425 .diagnostic_groups(None)
1426 .into_iter()
1427 .filter_map(|(_, group)| {
1428 let range = &group.entries[group.primary_ix]
1429 .range
1430 .to_point(&active_buffer_snapshot);
1431 if range.overlaps(&active_buffer_diagnostic_search_range) {
1432 None
1433 } else {
1434 Some(range.start)
1435 }
1436 })
1437 .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
1438 .map(|position| {
1439 (
1440 active_buffer.clone(),
1441 active_buffer_snapshot.anchor_before(position),
1442 )
1443 });
1444
1445 if jump_location.is_none() {
1446 let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
1447 let file = buffer.file()?;
1448
1449 Some(ProjectPath {
1450 worktree_id: file.worktree_id(cx),
1451 path: file.path().clone(),
1452 })
1453 })?;
1454
1455 let buffer_task = project.update(cx, |project, cx| {
1456 let (path, _, _) = project
1457 .diagnostic_summaries(false, cx)
1458 .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
1459 .max_by_key(|(path, _, _)| {
1460 // find the buffer with errors that shares most parent directories
1461 path.path
1462 .components()
1463 .zip(
1464 active_buffer_path
1465 .as_ref()
1466 .map(|p| p.path.components())
1467 .unwrap_or_default(),
1468 )
1469 .take_while(|(a, b)| a == b)
1470 .count()
1471 })?;
1472
1473 Some(project.open_buffer(path, cx))
1474 })?;
1475
1476 if let Some(buffer_task) = buffer_task {
1477 let closest_buffer = buffer_task.await?;
1478
1479 jump_location = closest_buffer
1480 .read_with(cx, |buffer, _cx| {
1481 buffer
1482 .buffer_diagnostics(None)
1483 .into_iter()
1484 .min_by_key(|entry| entry.diagnostic.severity)
1485 .map(|entry| entry.range.start)
1486 })?
1487 .map(|position| (closest_buffer, position));
1488 }
1489 }
1490
1491 anyhow::Ok(jump_location)
1492 }
1493
1494 async fn send_raw_llm_request(
1495 request: open_ai::Request,
1496 client: Arc<Client>,
1497 llm_token: LlmApiToken,
1498 app_version: Version,
1499 #[cfg(feature = "eval-support")] eval_cache: Option<Arc<dyn EvalCache>>,
1500 #[cfg(feature = "eval-support")] eval_cache_kind: EvalCacheEntryKind,
1501 ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
1502 let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
1503 http_client::Url::parse(&predict_edits_url)?
1504 } else {
1505 client
1506 .http_client()
1507 .build_zed_llm_url("/predict_edits/raw", &[])?
1508 };
1509
1510 #[cfg(feature = "eval-support")]
1511 let cache_key = if let Some(cache) = eval_cache {
1512 use collections::FxHasher;
1513 use std::hash::{Hash, Hasher};
1514
1515 let mut hasher = FxHasher::default();
1516 url.hash(&mut hasher);
1517 let request_str = serde_json::to_string_pretty(&request)?;
1518 request_str.hash(&mut hasher);
1519 let hash = hasher.finish();
1520
1521 let key = (eval_cache_kind, hash);
1522 if let Some(response_str) = cache.read(key) {
1523 return Ok((serde_json::from_str(&response_str)?, None));
1524 }
1525
1526 Some((cache, request_str, key))
1527 } else {
1528 None
1529 };
1530
1531 let (response, usage) = Self::send_api_request(
1532 |builder| {
1533 let req = builder
1534 .uri(url.as_ref())
1535 .body(serde_json::to_string(&request)?.into());
1536 Ok(req?)
1537 },
1538 client,
1539 llm_token,
1540 app_version,
1541 )
1542 .await?;
1543
1544 #[cfg(feature = "eval-support")]
1545 if let Some((cache, request, key)) = cache_key {
1546 cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
1547 }
1548
1549 Ok((response, usage))
1550 }
1551
1552 fn handle_api_response<T>(
1553 this: &WeakEntity<Self>,
1554 response: Result<(T, Option<EditPredictionUsage>)>,
1555 cx: &mut gpui::AsyncApp,
1556 ) -> Result<T> {
1557 match response {
1558 Ok((data, usage)) => {
1559 if let Some(usage) = usage {
1560 this.update(cx, |this, cx| {
1561 this.user_store.update(cx, |user_store, cx| {
1562 user_store.update_edit_prediction_usage(usage, cx);
1563 });
1564 })
1565 .ok();
1566 }
1567 Ok(data)
1568 }
1569 Err(err) => {
1570 if err.is::<ZedUpdateRequiredError>() {
1571 cx.update(|cx| {
1572 this.update(cx, |this, _cx| {
1573 this.update_required = true;
1574 })
1575 .ok();
1576
1577 let error_message: SharedString = err.to_string().into();
1578 show_app_notification(
1579 NotificationId::unique::<ZedUpdateRequiredError>(),
1580 cx,
1581 move |cx| {
1582 cx.new(|cx| {
1583 ErrorMessagePrompt::new(error_message.clone(), cx)
1584 .with_link_button("Update Zed", "https://zed.dev/releases")
1585 })
1586 },
1587 );
1588 })
1589 .ok();
1590 }
1591 Err(err)
1592 }
1593 }
1594 }
1595
1596 async fn send_api_request<Res>(
1597 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1598 client: Arc<Client>,
1599 llm_token: LlmApiToken,
1600 app_version: Version,
1601 ) -> Result<(Res, Option<EditPredictionUsage>)>
1602 where
1603 Res: DeserializeOwned,
1604 {
1605 let http_client = client.http_client();
1606 let mut token = llm_token.acquire(&client).await?;
1607 let mut did_retry = false;
1608
1609 loop {
1610 let request_builder = http_client::Request::builder().method(Method::POST);
1611
1612 let request = build(
1613 request_builder
1614 .header("Content-Type", "application/json")
1615 .header("Authorization", format!("Bearer {}", token))
1616 .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
1617 )?;
1618
1619 let mut response = http_client.send(request).await?;
1620
1621 if let Some(minimum_required_version) = response
1622 .headers()
1623 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
1624 .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
1625 {
1626 anyhow::ensure!(
1627 app_version >= minimum_required_version,
1628 ZedUpdateRequiredError {
1629 minimum_version: minimum_required_version
1630 }
1631 );
1632 }
1633
1634 if response.status().is_success() {
1635 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
1636
1637 let mut body = Vec::new();
1638 response.body_mut().read_to_end(&mut body).await?;
1639 return Ok((serde_json::from_slice(&body)?, usage));
1640 } else if !did_retry
1641 && response
1642 .headers()
1643 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1644 .is_some()
1645 {
1646 did_retry = true;
1647 token = llm_token.refresh(&client).await?;
1648 } else {
1649 let mut body = String::new();
1650 response.body_mut().read_to_string(&mut body).await?;
1651 anyhow::bail!(
1652 "Request failed with status: {:?}\nBody: {}",
1653 response.status(),
1654 body
1655 );
1656 }
1657 }
1658 }
1659
1660 pub fn refresh_context(
1661 &mut self,
1662 project: &Entity<Project>,
1663 buffer: &Entity<language::Buffer>,
1664 cursor_position: language::Anchor,
1665 cx: &mut Context<Self>,
1666 ) {
1667 if self.use_context {
1668 self.get_or_init_project(project, cx)
1669 .context
1670 .update(cx, |store, cx| {
1671 store.refresh(buffer.clone(), cursor_position, cx);
1672 });
1673 }
1674 }
1675
1676 fn is_file_open_source(
1677 &self,
1678 project: &Entity<Project>,
1679 file: &Arc<dyn File>,
1680 cx: &App,
1681 ) -> bool {
1682 if !file.is_local() || file.is_private() {
1683 return false;
1684 }
1685 let Some(project_state) = self.projects.get(&project.entity_id()) else {
1686 return false;
1687 };
1688 project_state
1689 .license_detection_watchers
1690 .get(&file.worktree_id(cx))
1691 .as_ref()
1692 .is_some_and(|watcher| watcher.is_project_open_source())
1693 }
1694
1695 fn can_collect_file(&self, project: &Entity<Project>, file: &Arc<dyn File>, cx: &App) -> bool {
1696 self.data_collection_choice.is_enabled() && self.is_file_open_source(project, file, cx)
1697 }
1698
1699 fn can_collect_events(&self, events: &[Arc<Event>]) -> bool {
1700 if !self.data_collection_choice.is_enabled() {
1701 return false;
1702 }
1703 events.iter().all(|event| {
1704 matches!(
1705 event.as_ref(),
1706 Event::BufferChange {
1707 in_open_source_repo: true,
1708 ..
1709 }
1710 )
1711 })
1712 }
1713
1714 fn load_data_collection_choice() -> DataCollectionChoice {
1715 let choice = KEY_VALUE_STORE
1716 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
1717 .log_err()
1718 .flatten();
1719
1720 match choice.as_deref() {
1721 Some("true") => DataCollectionChoice::Enabled,
1722 Some("false") => DataCollectionChoice::Disabled,
1723 Some(_) => {
1724 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
1725 DataCollectionChoice::NotAnswered
1726 }
1727 None => DataCollectionChoice::NotAnswered,
1728 }
1729 }
1730
1731 fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
1732 self.data_collection_choice = self.data_collection_choice.toggle();
1733 let new_choice = self.data_collection_choice;
1734 db::write_and_log(cx, move || {
1735 KEY_VALUE_STORE.write_kvp(
1736 ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
1737 new_choice.is_enabled().to_string(),
1738 )
1739 });
1740 }
1741
1742 pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
1743 self.shown_predictions.iter()
1744 }
1745
1746 pub fn shown_completions_len(&self) -> usize {
1747 self.shown_predictions.len()
1748 }
1749
1750 pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
1751 self.rated_predictions.contains(id)
1752 }
1753
1754 pub fn rate_prediction(
1755 &mut self,
1756 prediction: &EditPrediction,
1757 rating: EditPredictionRating,
1758 feedback: String,
1759 cx: &mut Context<Self>,
1760 ) {
1761 self.rated_predictions.insert(prediction.id.clone());
1762 telemetry::event!(
1763 "Edit Prediction Rated",
1764 rating,
1765 inputs = prediction.inputs,
1766 output = prediction.edit_preview.as_unified_diff(&prediction.edits),
1767 feedback
1768 );
1769 self.client.telemetry().flush_events().detach();
1770 cx.notify();
1771 }
1772
1773 fn configure_context_retrieval(&mut self, cx: &mut Context<'_, EditPredictionStore>) {
1774 self.use_context = cx.has_flag::<Zeta2FeatureFlag>()
1775 && all_language_settings(None, cx).edit_predictions.use_context;
1776 }
1777}
1778
1779#[derive(Error, Debug)]
1780#[error(
1781 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1782)]
1783pub struct ZedUpdateRequiredError {
1784 minimum_version: Version,
1785}
1786
1787#[cfg(feature = "eval-support")]
1788pub type EvalCacheKey = (EvalCacheEntryKind, u64);
1789
1790#[cfg(feature = "eval-support")]
1791#[derive(Debug, Clone, Copy, PartialEq)]
1792pub enum EvalCacheEntryKind {
1793 Context,
1794 Search,
1795 Prediction,
1796}
1797
1798#[cfg(feature = "eval-support")]
1799impl std::fmt::Display for EvalCacheEntryKind {
1800 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1801 match self {
1802 EvalCacheEntryKind::Search => write!(f, "search"),
1803 EvalCacheEntryKind::Context => write!(f, "context"),
1804 EvalCacheEntryKind::Prediction => write!(f, "prediction"),
1805 }
1806 }
1807}
1808
1809#[cfg(feature = "eval-support")]
1810pub trait EvalCache: Send + Sync {
1811 fn read(&self, key: EvalCacheKey) -> Option<String>;
1812 fn write(&self, key: EvalCacheKey, input: &str, value: &str);
1813}
1814
1815#[derive(Debug, Clone, Copy)]
1816pub enum DataCollectionChoice {
1817 NotAnswered,
1818 Enabled,
1819 Disabled,
1820}
1821
1822impl DataCollectionChoice {
1823 pub fn is_enabled(self) -> bool {
1824 match self {
1825 Self::Enabled => true,
1826 Self::NotAnswered | Self::Disabled => false,
1827 }
1828 }
1829
1830 pub fn is_answered(self) -> bool {
1831 match self {
1832 Self::Enabled | Self::Disabled => true,
1833 Self::NotAnswered => false,
1834 }
1835 }
1836
1837 #[must_use]
1838 pub fn toggle(&self) -> DataCollectionChoice {
1839 match self {
1840 Self::Enabled => Self::Disabled,
1841 Self::Disabled => Self::Enabled,
1842 Self::NotAnswered => Self::Enabled,
1843 }
1844 }
1845}
1846
1847impl From<bool> for DataCollectionChoice {
1848 fn from(value: bool) -> Self {
1849 match value {
1850 true => DataCollectionChoice::Enabled,
1851 false => DataCollectionChoice::Disabled,
1852 }
1853 }
1854}
1855
1856struct ZedPredictUpsell;
1857
1858impl Dismissable for ZedPredictUpsell {
1859 const KEY: &'static str = "dismissed-edit-predict-upsell";
1860
1861 fn dismissed() -> bool {
1862 // To make this backwards compatible with older versions of Zed, we
1863 // check if the user has seen the previous Edit Prediction Onboarding
1864 // before, by checking the data collection choice which was written to
1865 // the database once the user clicked on "Accept and Enable"
1866 if KEY_VALUE_STORE
1867 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
1868 .log_err()
1869 .is_some_and(|s| s.is_some())
1870 {
1871 return true;
1872 }
1873
1874 KEY_VALUE_STORE
1875 .read_kvp(Self::KEY)
1876 .log_err()
1877 .is_some_and(|s| s.is_some())
1878 }
1879}
1880
1881pub fn should_show_upsell_modal() -> bool {
1882 !ZedPredictUpsell::dismissed()
1883}
1884
1885pub fn init(cx: &mut App) {
1886 cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
1887 workspace.register_action(
1888 move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
1889 ZedPredictModal::toggle(
1890 workspace,
1891 workspace.user_store().clone(),
1892 workspace.client().clone(),
1893 window,
1894 cx,
1895 )
1896 },
1897 );
1898
1899 workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
1900 update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
1901 settings
1902 .project
1903 .all_languages
1904 .features
1905 .get_or_insert_default()
1906 .edit_prediction_provider = Some(EditPredictionProvider::None)
1907 });
1908 });
1909 })
1910 .detach();
1911}