1use anyhow::Result;
2use arrayvec::ArrayVec;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{
5 self, PromptFormat, RawCompletionRequest, RawCompletionResponse,
6};
7use cloud_llm_client::{
8 EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason, EditPredictionRejection,
9 MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
10 PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME,
11};
12use collections::{HashMap, HashSet};
13use copilot::Copilot;
14use db::kvp::{Dismissable, KEY_VALUE_STORE};
15use edit_prediction_context::EditPredictionExcerptOptions;
16use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
17use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
18use futures::{
19 AsyncReadExt as _, FutureExt as _, StreamExt as _,
20 channel::mpsc::{self, UnboundedReceiver},
21 select_biased,
22};
23use gpui::BackgroundExecutor;
24use gpui::http_client::Url;
25use gpui::{
26 App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
27 http_client::{self, AsyncBody, Method},
28 prelude::*,
29};
30use language::language_settings::all_language_settings;
31use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
32use language::{BufferSnapshot, OffsetRangeExt};
33use language_model::{LlmApiToken, RefreshLlmTokenListener};
34use project::{Project, ProjectPath, WorktreeId};
35use release_channel::AppVersion;
36use semver::Version;
37use serde::de::DeserializeOwned;
38use settings::{EditPredictionProvider, SettingsStore, update_settings_file};
39use std::collections::{VecDeque, hash_map};
40use text::Edit;
41use workspace::Workspace;
42use zeta_prompt::ZetaVersion;
43
44use std::ops::Range;
45use std::path::Path;
46use std::rc::Rc;
47use std::str::FromStr as _;
48use std::sync::{Arc, LazyLock};
49use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
50use std::{env, mem};
51use thiserror::Error;
52use util::{RangeExt as _, ResultExt as _};
53use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
54
55pub mod cursor_excerpt;
56pub mod example_spec;
57mod license_detection;
58pub mod mercury;
59mod onboarding_modal;
60pub mod open_ai_response;
61mod prediction;
62pub mod sweep_ai;
63
64pub mod udiff;
65
66mod capture_example;
67mod zed_edit_prediction_delegate;
68pub mod zeta1;
69pub mod zeta2;
70
71#[cfg(test)]
72mod edit_prediction_tests;
73
74use crate::capture_example::should_sample_edit_prediction_example_capture;
75use crate::license_detection::LicenseDetectionWatcher;
76use crate::mercury::Mercury;
77use crate::onboarding_modal::ZedPredictModal;
78pub use crate::prediction::EditPrediction;
79pub use crate::prediction::EditPredictionId;
80use crate::prediction::EditPredictionResult;
81pub use crate::sweep_ai::SweepAi;
82pub use capture_example::capture_example;
83pub use language_model::ApiKeyState;
84pub use telemetry_events::EditPredictionRating;
85pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
86
87actions!(
88 edit_prediction,
89 [
90 /// Resets the edit prediction onboarding state.
91 ResetOnboarding,
92 /// Clears the edit prediction history.
93 ClearHistory,
94 ]
95);
96
97/// Maximum number of events to track.
98const EVENT_COUNT_MAX: usize = 6;
99const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
100const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
101const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
102const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
103
104pub struct SweepFeatureFlag;
105
106impl FeatureFlag for SweepFeatureFlag {
107 const NAME: &str = "sweep-ai";
108}
109
110pub struct MercuryFeatureFlag;
111
112impl FeatureFlag for MercuryFeatureFlag {
113 const NAME: &str = "mercury";
114}
115
116pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
117 context: EditPredictionExcerptOptions {
118 max_bytes: 512,
119 min_bytes: 128,
120 target_before_cursor_over_total_bytes: 0.5,
121 },
122 prompt_format: PromptFormat::DEFAULT,
123};
124
125static USE_OLLAMA: LazyLock<bool> =
126 LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
127
128static EDIT_PREDICTIONS_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
129 match env::var("ZED_ZETA2_MODEL").as_deref() {
130 Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten
131 Ok(model) => model,
132 Err(_) if *USE_OLLAMA => "qwen3-coder:30b",
133 Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten
134 }
135 .to_string()
136});
137
138pub struct Zeta2FeatureFlag;
139
140impl FeatureFlag for Zeta2FeatureFlag {
141 const NAME: &'static str = "zeta2";
142
143 fn enabled_for_staff() -> bool {
144 true
145 }
146}
147
148pub struct EditPredictionExampleCaptureFeatureFlag;
149
150impl FeatureFlag for EditPredictionExampleCaptureFeatureFlag {
151 const NAME: &'static str = "edit-prediction-example-capture";
152
153 fn enabled_for_staff() -> bool {
154 true
155 }
156}
157
158#[derive(Clone)]
159struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
160
161impl Global for EditPredictionStoreGlobal {}
162
163pub struct EditPredictionStore {
164 client: Arc<Client>,
165 user_store: Entity<UserStore>,
166 llm_token: LlmApiToken,
167 _llm_token_subscription: Subscription,
168 projects: HashMap<EntityId, ProjectState>,
169 use_context: bool,
170 options: ZetaOptions,
171 update_required: bool,
172 #[cfg(feature = "cli-support")]
173 eval_cache: Option<Arc<dyn EvalCache>>,
174 edit_prediction_model: EditPredictionModel,
175 pub sweep_ai: SweepAi,
176 pub mercury: Mercury,
177 data_collection_choice: DataCollectionChoice,
178 reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
179 shown_predictions: VecDeque<EditPrediction>,
180 rated_predictions: HashSet<EditPredictionId>,
181 custom_predict_edits_url: Option<Arc<Url>>,
182}
183
184#[derive(Copy, Clone, Default, PartialEq, Eq)]
185pub enum EditPredictionModel {
186 #[default]
187 Zeta1,
188 Zeta2 {
189 version: ZetaVersion,
190 },
191 Sweep,
192 Mercury,
193}
194
195pub struct EditPredictionModelInput {
196 project: Entity<Project>,
197 buffer: Entity<Buffer>,
198 snapshot: BufferSnapshot,
199 position: Anchor,
200 events: Vec<Arc<zeta_prompt::Event>>,
201 related_files: Vec<RelatedFile>,
202 recent_paths: VecDeque<ProjectPath>,
203 trigger: PredictEditsRequestTrigger,
204 diagnostic_search_range: Range<Point>,
205 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
206 pub user_actions: Vec<UserActionRecord>,
207}
208
209#[derive(Debug, Clone, PartialEq)]
210pub struct ZetaOptions {
211 pub context: EditPredictionExcerptOptions,
212 pub prompt_format: predict_edits_v3::PromptFormat,
213}
214
215#[derive(Debug)]
216pub enum DebugEvent {
217 ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
218 ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
219 EditPredictionStarted(EditPredictionStartedDebugEvent),
220 EditPredictionFinished(EditPredictionFinishedDebugEvent),
221}
222
223#[derive(Debug)]
224pub struct ContextRetrievalStartedDebugEvent {
225 pub project_entity_id: EntityId,
226 pub timestamp: Instant,
227 pub search_prompt: String,
228}
229
230#[derive(Debug)]
231pub struct ContextRetrievalFinishedDebugEvent {
232 pub project_entity_id: EntityId,
233 pub timestamp: Instant,
234 pub metadata: Vec<(&'static str, SharedString)>,
235}
236
237#[derive(Debug)]
238pub struct EditPredictionStartedDebugEvent {
239 pub buffer: WeakEntity<Buffer>,
240 pub position: Anchor,
241 pub prompt: Option<String>,
242}
243
244#[derive(Debug)]
245pub struct EditPredictionFinishedDebugEvent {
246 pub buffer: WeakEntity<Buffer>,
247 pub position: Anchor,
248 pub model_output: Option<String>,
249}
250
251pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
252
253const USER_ACTION_HISTORY_SIZE: usize = 16;
254
255#[derive(Clone, Debug)]
256pub struct UserActionRecord {
257 pub action_type: UserActionType,
258 pub buffer_id: EntityId,
259 pub line_number: u32,
260 pub offset: usize,
261 pub timestamp_epoch_ms: u64,
262}
263
264#[derive(Clone, Copy, Debug, PartialEq, Eq)]
265pub enum UserActionType {
266 InsertChar,
267 InsertSelection,
268 DeleteChar,
269 DeleteSelection,
270 CursorMovement,
271}
272
273/// An event with associated metadata for reconstructing buffer state.
274#[derive(Clone)]
275pub struct StoredEvent {
276 pub event: Arc<zeta_prompt::Event>,
277 pub old_snapshot: TextBufferSnapshot,
278}
279
280struct ProjectState {
281 events: VecDeque<StoredEvent>,
282 last_event: Option<LastEvent>,
283 recent_paths: VecDeque<ProjectPath>,
284 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
285 current_prediction: Option<CurrentEditPrediction>,
286 next_pending_prediction_id: usize,
287 pending_predictions: ArrayVec<PendingPrediction, 2>,
288 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
289 last_prediction_refresh: Option<(EntityId, Instant)>,
290 cancelled_predictions: HashSet<usize>,
291 context: Entity<RelatedExcerptStore>,
292 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
293 user_actions: VecDeque<UserActionRecord>,
294 _subscription: gpui::Subscription,
295 copilot: Option<Entity<Copilot>>,
296}
297
298impl ProjectState {
299 fn record_user_action(&mut self, action: UserActionRecord) {
300 if self.user_actions.len() >= USER_ACTION_HISTORY_SIZE {
301 self.user_actions.pop_front();
302 }
303 self.user_actions.push_back(action);
304 }
305
306 pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
307 self.events
308 .iter()
309 .cloned()
310 .chain(
311 self.last_event
312 .as_ref()
313 .and_then(|event| event.finalize(&self.license_detection_watchers, cx)),
314 )
315 .collect()
316 }
317
318 pub fn events_split_by_pause(&self, cx: &App) -> Vec<StoredEvent> {
319 self.events
320 .iter()
321 .cloned()
322 .chain(self.last_event.as_ref().iter().flat_map(|event| {
323 let (one, two) = event.split_by_pause();
324 let one = one.finalize(&self.license_detection_watchers, cx);
325 let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
326 one.into_iter().chain(two)
327 }))
328 .collect()
329 }
330
331 fn cancel_pending_prediction(
332 &mut self,
333 pending_prediction: PendingPrediction,
334 cx: &mut Context<EditPredictionStore>,
335 ) {
336 self.cancelled_predictions.insert(pending_prediction.id);
337
338 cx.spawn(async move |this, cx| {
339 let Some(prediction_id) = pending_prediction.task.await else {
340 return;
341 };
342
343 this.update(cx, |this, _cx| {
344 this.reject_prediction(prediction_id, EditPredictionRejectReason::Canceled, false);
345 })
346 .ok();
347 })
348 .detach()
349 }
350
351 fn active_buffer(
352 &self,
353 project: &Entity<Project>,
354 cx: &App,
355 ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
356 let project = project.read(cx);
357 let active_path = project.path_for_entry(project.active_entry()?, cx)?;
358 let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
359 let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
360 Some((active_buffer, registered_buffer.last_position))
361 }
362}
363
364#[derive(Debug, Clone)]
365struct CurrentEditPrediction {
366 pub requested_by: PredictionRequestedBy,
367 pub prediction: EditPrediction,
368 pub was_shown: bool,
369 pub shown_with: Option<edit_prediction_types::SuggestionDisplayType>,
370}
371
372impl CurrentEditPrediction {
373 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
374 let Some(new_edits) = self
375 .prediction
376 .interpolate(&self.prediction.buffer.read(cx))
377 else {
378 return false;
379 };
380
381 if self.prediction.buffer != old_prediction.prediction.buffer {
382 return true;
383 }
384
385 let Some(old_edits) = old_prediction
386 .prediction
387 .interpolate(&old_prediction.prediction.buffer.read(cx))
388 else {
389 return true;
390 };
391
392 let requested_by_buffer_id = self.requested_by.buffer_id();
393
394 // This reduces the occurrence of UI thrash from replacing edits
395 //
396 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
397 if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
398 && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
399 && old_edits.len() == 1
400 && new_edits.len() == 1
401 {
402 let (old_range, old_text) = &old_edits[0];
403 let (new_range, new_text) = &new_edits[0];
404 new_range == old_range && new_text.starts_with(old_text.as_ref())
405 } else {
406 true
407 }
408 }
409}
410
411#[derive(Debug, Clone)]
412enum PredictionRequestedBy {
413 DiagnosticsUpdate,
414 Buffer(EntityId),
415}
416
417impl PredictionRequestedBy {
418 pub fn buffer_id(&self) -> Option<EntityId> {
419 match self {
420 PredictionRequestedBy::DiagnosticsUpdate => None,
421 PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
422 }
423 }
424}
425
426#[derive(Debug)]
427struct PendingPrediction {
428 id: usize,
429 task: Task<Option<EditPredictionId>>,
430}
431
432/// A prediction from the perspective of a buffer.
433#[derive(Debug)]
434enum BufferEditPrediction<'a> {
435 Local { prediction: &'a EditPrediction },
436 Jump { prediction: &'a EditPrediction },
437}
438
439#[cfg(test)]
440impl std::ops::Deref for BufferEditPrediction<'_> {
441 type Target = EditPrediction;
442
443 fn deref(&self) -> &Self::Target {
444 match self {
445 BufferEditPrediction::Local { prediction } => prediction,
446 BufferEditPrediction::Jump { prediction } => prediction,
447 }
448 }
449}
450
451struct RegisteredBuffer {
452 file: Option<Arc<dyn File>>,
453 snapshot: TextBufferSnapshot,
454 last_position: Option<Anchor>,
455 _subscriptions: [gpui::Subscription; 2],
456}
457
458#[derive(Clone)]
459struct LastEvent {
460 old_snapshot: TextBufferSnapshot,
461 new_snapshot: TextBufferSnapshot,
462 old_file: Option<Arc<dyn File>>,
463 new_file: Option<Arc<dyn File>>,
464 edit_range: Option<Range<Anchor>>,
465 snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
466 last_edit_time: Option<Instant>,
467}
468
469impl LastEvent {
470 pub fn finalize(
471 &self,
472 license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
473 cx: &App,
474 ) -> Option<StoredEvent> {
475 let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
476 let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
477
478 let in_open_source_repo =
479 [self.new_file.as_ref(), self.old_file.as_ref()]
480 .iter()
481 .all(|file| {
482 file.is_some_and(|file| {
483 license_detection_watchers
484 .get(&file.worktree_id(cx))
485 .is_some_and(|watcher| watcher.is_project_open_source())
486 })
487 });
488
489 let diff = compute_diff_between_snapshots(&self.old_snapshot, &self.new_snapshot)?;
490
491 if path == old_path && diff.is_empty() {
492 None
493 } else {
494 Some(StoredEvent {
495 event: Arc::new(zeta_prompt::Event::BufferChange {
496 old_path,
497 path,
498 diff,
499 in_open_source_repo,
500 // TODO: Actually detect if this edit was predicted or not
501 predicted: false,
502 }),
503 old_snapshot: self.old_snapshot.clone(),
504 })
505 }
506 }
507
508 pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
509 let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
510 return (self.clone(), None);
511 };
512
513 let before = LastEvent {
514 old_snapshot: self.old_snapshot.clone(),
515 new_snapshot: boundary_snapshot.clone(),
516 old_file: self.old_file.clone(),
517 new_file: self.new_file.clone(),
518 edit_range: None,
519 snapshot_after_last_editing_pause: None,
520 last_edit_time: self.last_edit_time,
521 };
522
523 let after = LastEvent {
524 old_snapshot: boundary_snapshot.clone(),
525 new_snapshot: self.new_snapshot.clone(),
526 old_file: self.old_file.clone(),
527 new_file: self.new_file.clone(),
528 edit_range: None,
529 snapshot_after_last_editing_pause: None,
530 last_edit_time: self.last_edit_time,
531 };
532
533 (before, Some(after))
534 }
535}
536
537pub(crate) fn compute_diff_between_snapshots(
538 old_snapshot: &TextBufferSnapshot,
539 new_snapshot: &TextBufferSnapshot,
540) -> Option<String> {
541 let edits: Vec<Edit<usize>> = new_snapshot
542 .edits_since::<usize>(&old_snapshot.version)
543 .collect();
544
545 let (first_edit, last_edit) = edits.first().zip(edits.last())?;
546
547 let old_start_point = old_snapshot.offset_to_point(first_edit.old.start);
548 let old_end_point = old_snapshot.offset_to_point(last_edit.old.end);
549 let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
550 let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
551
552 const CONTEXT_LINES: u32 = 3;
553
554 let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
555 let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
556 let old_context_end_row =
557 (old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
558 let new_context_end_row =
559 (new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
560
561 let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
562 let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
563 let old_end_line_offset = old_snapshot
564 .point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
565 let new_end_line_offset = new_snapshot
566 .point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
567 let old_edit_range = old_start_line_offset..old_end_line_offset;
568 let new_edit_range = new_start_line_offset..new_end_line_offset;
569
570 let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
571 let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
572
573 let diff = language::unified_diff_with_offsets(
574 &old_region_text,
575 &new_region_text,
576 old_context_start_row,
577 new_context_start_row,
578 );
579
580 Some(diff)
581}
582
583fn buffer_path_with_id_fallback(
584 file: Option<&Arc<dyn File>>,
585 snapshot: &TextBufferSnapshot,
586 cx: &App,
587) -> Arc<Path> {
588 if let Some(file) = file {
589 file.full_path(cx).into()
590 } else {
591 Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
592 }
593}
594
595impl EditPredictionStore {
596 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
597 cx.try_global::<EditPredictionStoreGlobal>()
598 .map(|global| global.0.clone())
599 }
600
601 pub fn global(
602 client: &Arc<Client>,
603 user_store: &Entity<UserStore>,
604 cx: &mut App,
605 ) -> Entity<Self> {
606 cx.try_global::<EditPredictionStoreGlobal>()
607 .map(|global| global.0.clone())
608 .unwrap_or_else(|| {
609 let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
610 cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
611 ep_store
612 })
613 }
614
615 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
616 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
617 let data_collection_choice = Self::load_data_collection_choice();
618
619 let llm_token = LlmApiToken::default();
620
621 let (reject_tx, reject_rx) = mpsc::unbounded();
622 cx.background_spawn({
623 let client = client.clone();
624 let llm_token = llm_token.clone();
625 let app_version = AppVersion::global(cx);
626 let background_executor = cx.background_executor().clone();
627 async move {
628 Self::handle_rejected_predictions(
629 reject_rx,
630 client,
631 llm_token,
632 app_version,
633 background_executor,
634 )
635 .await
636 }
637 })
638 .detach();
639
640 let mut this = Self {
641 projects: HashMap::default(),
642 client,
643 user_store,
644 options: DEFAULT_OPTIONS,
645 use_context: false,
646 llm_token,
647 _llm_token_subscription: cx.subscribe(
648 &refresh_llm_token_listener,
649 |this, _listener, _event, cx| {
650 let client = this.client.clone();
651 let llm_token = this.llm_token.clone();
652 cx.spawn(async move |_this, _cx| {
653 llm_token.refresh(&client).await?;
654 anyhow::Ok(())
655 })
656 .detach_and_log_err(cx);
657 },
658 ),
659 update_required: false,
660 #[cfg(feature = "cli-support")]
661 eval_cache: None,
662 edit_prediction_model: EditPredictionModel::Zeta2 {
663 version: Default::default(),
664 },
665 sweep_ai: SweepAi::new(cx),
666 mercury: Mercury::new(cx),
667
668 data_collection_choice,
669 reject_predictions_tx: reject_tx,
670 rated_predictions: Default::default(),
671 shown_predictions: Default::default(),
672 custom_predict_edits_url: match env::var("ZED_PREDICT_EDITS_URL") {
673 Ok(custom_url) => Url::parse(&custom_url).log_err().map(Into::into),
674 Err(_) => {
675 if *USE_OLLAMA {
676 Some(
677 Url::parse("http://localhost:11434/v1/chat/completions")
678 .unwrap()
679 .into(),
680 )
681 } else {
682 None
683 }
684 }
685 },
686 };
687
688 this.configure_context_retrieval(cx);
689 let weak_this = cx.weak_entity();
690 cx.on_flags_ready(move |_, cx| {
691 weak_this
692 .update(cx, |this, cx| this.configure_context_retrieval(cx))
693 .ok();
694 })
695 .detach();
696 cx.observe_global::<SettingsStore>(|this, cx| {
697 this.configure_context_retrieval(cx);
698 })
699 .detach();
700
701 this
702 }
703
704 #[cfg(test)]
705 pub fn set_custom_predict_edits_url(&mut self, url: Url) {
706 self.custom_predict_edits_url = Some(url.into());
707 }
708
709 pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
710 self.edit_prediction_model = model;
711 }
712
713 pub fn has_sweep_api_token(&self, cx: &App) -> bool {
714 self.sweep_ai.api_token.read(cx).has_key()
715 }
716
717 pub fn has_mercury_api_token(&self, cx: &App) -> bool {
718 self.mercury.api_token.read(cx).has_key()
719 }
720
721 #[cfg(feature = "cli-support")]
722 pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
723 self.eval_cache = Some(cache);
724 }
725
726 pub fn options(&self) -> &ZetaOptions {
727 &self.options
728 }
729
730 pub fn set_options(&mut self, options: ZetaOptions) {
731 self.options = options;
732 }
733
734 pub fn set_use_context(&mut self, use_context: bool) {
735 self.use_context = use_context;
736 }
737
738 pub fn clear_history(&mut self) {
739 for project_state in self.projects.values_mut() {
740 project_state.events.clear();
741 project_state.last_event.take();
742 }
743 }
744
745 pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
746 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
747 project_state.events.clear();
748 project_state.last_event.take();
749 }
750 }
751
752 pub fn edit_history_for_project(
753 &self,
754 project: &Entity<Project>,
755 cx: &App,
756 ) -> Vec<StoredEvent> {
757 self.projects
758 .get(&project.entity_id())
759 .map(|project_state| project_state.events(cx))
760 .unwrap_or_default()
761 }
762
763 pub fn edit_history_for_project_with_pause_split_last_event(
764 &self,
765 project: &Entity<Project>,
766 cx: &App,
767 ) -> Vec<StoredEvent> {
768 self.projects
769 .get(&project.entity_id())
770 .map(|project_state| project_state.events_split_by_pause(cx))
771 .unwrap_or_default()
772 }
773
774 pub fn context_for_project<'a>(
775 &'a self,
776 project: &Entity<Project>,
777 cx: &'a mut App,
778 ) -> Vec<RelatedFile> {
779 self.projects
780 .get(&project.entity_id())
781 .map(|project| {
782 project
783 .context
784 .update(cx, |context, cx| context.related_files(cx))
785 })
786 .unwrap_or_default()
787 }
788
789 pub fn copilot_for_project(&self, project: &Entity<Project>) -> Option<Entity<Copilot>> {
790 self.projects
791 .get(&project.entity_id())
792 .and_then(|project| project.copilot.clone())
793 }
794
795 pub fn start_copilot_for_project(
796 &mut self,
797 project: &Entity<Project>,
798 cx: &mut Context<Self>,
799 ) -> Option<Entity<Copilot>> {
800 let state = self.get_or_init_project(project, cx);
801
802 if state.copilot.is_some() {
803 return state.copilot.clone();
804 }
805 let _project = project.clone();
806 let project = project.read(cx);
807
808 let node = project.node_runtime().cloned();
809 if let Some(node) = node {
810 let next_id = project.languages().next_language_server_id();
811 let fs = project.fs().clone();
812
813 let copilot = cx.new(|cx| Copilot::new(_project, next_id, fs, node, cx));
814 state.copilot = Some(copilot.clone());
815 Some(copilot)
816 } else {
817 None
818 }
819 }
820
821 pub fn context_for_project_with_buffers<'a>(
822 &'a self,
823 project: &Entity<Project>,
824 cx: &'a mut App,
825 ) -> Vec<(RelatedFile, Entity<Buffer>)> {
826 self.projects
827 .get(&project.entity_id())
828 .map(|project| {
829 project
830 .context
831 .update(cx, |context, cx| context.related_files_with_buffers(cx))
832 })
833 .unwrap_or_default()
834 }
835
836 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
837 if matches!(
838 self.edit_prediction_model,
839 EditPredictionModel::Zeta2 { .. }
840 ) {
841 self.user_store.read(cx).edit_prediction_usage()
842 } else {
843 None
844 }
845 }
846
847 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
848 self.get_or_init_project(project, cx);
849 }
850
851 pub fn register_buffer(
852 &mut self,
853 buffer: &Entity<Buffer>,
854 project: &Entity<Project>,
855 cx: &mut Context<Self>,
856 ) {
857 let project_state = self.get_or_init_project(project, cx);
858 Self::register_buffer_impl(project_state, buffer, project, cx);
859 }
860
861 fn get_or_init_project(
862 &mut self,
863 project: &Entity<Project>,
864 cx: &mut Context<Self>,
865 ) -> &mut ProjectState {
866 let entity_id = project.entity_id();
867 self.projects
868 .entry(entity_id)
869 .or_insert_with(|| ProjectState {
870 context: {
871 let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
872 cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
873 this.handle_excerpt_store_event(entity_id, event);
874 })
875 .detach();
876 related_excerpt_store
877 },
878 events: VecDeque::new(),
879 last_event: None,
880 recent_paths: VecDeque::new(),
881 debug_tx: None,
882 registered_buffers: HashMap::default(),
883 current_prediction: None,
884 cancelled_predictions: HashSet::default(),
885 pending_predictions: ArrayVec::new(),
886 next_pending_prediction_id: 0,
887 last_prediction_refresh: None,
888 license_detection_watchers: HashMap::default(),
889 user_actions: VecDeque::with_capacity(USER_ACTION_HISTORY_SIZE),
890 _subscription: cx.subscribe(&project, Self::handle_project_event),
891 copilot: None,
892 })
893 }
894
895 pub fn remove_project(&mut self, project: &Entity<Project>) {
896 self.projects.remove(&project.entity_id());
897 }
898
899 fn handle_excerpt_store_event(
900 &mut self,
901 project_entity_id: EntityId,
902 event: &RelatedExcerptStoreEvent,
903 ) {
904 if let Some(project_state) = self.projects.get(&project_entity_id) {
905 if let Some(debug_tx) = project_state.debug_tx.clone() {
906 match event {
907 RelatedExcerptStoreEvent::StartedRefresh => {
908 debug_tx
909 .unbounded_send(DebugEvent::ContextRetrievalStarted(
910 ContextRetrievalStartedDebugEvent {
911 project_entity_id: project_entity_id,
912 timestamp: Instant::now(),
913 search_prompt: String::new(),
914 },
915 ))
916 .ok();
917 }
918 RelatedExcerptStoreEvent::FinishedRefresh {
919 cache_hit_count,
920 cache_miss_count,
921 mean_definition_latency,
922 max_definition_latency,
923 } => {
924 debug_tx
925 .unbounded_send(DebugEvent::ContextRetrievalFinished(
926 ContextRetrievalFinishedDebugEvent {
927 project_entity_id: project_entity_id,
928 timestamp: Instant::now(),
929 metadata: vec![
930 (
931 "Cache Hits",
932 format!(
933 "{}/{}",
934 cache_hit_count,
935 cache_hit_count + cache_miss_count
936 )
937 .into(),
938 ),
939 (
940 "Max LSP Time",
941 format!("{} ms", max_definition_latency.as_millis())
942 .into(),
943 ),
944 (
945 "Mean LSP Time",
946 format!("{} ms", mean_definition_latency.as_millis())
947 .into(),
948 ),
949 ],
950 },
951 ))
952 .ok();
953 }
954 }
955 }
956 }
957 }
958
959 pub fn debug_info(
960 &mut self,
961 project: &Entity<Project>,
962 cx: &mut Context<Self>,
963 ) -> mpsc::UnboundedReceiver<DebugEvent> {
964 let project_state = self.get_or_init_project(project, cx);
965 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
966 project_state.debug_tx = Some(debug_watch_tx);
967 debug_watch_rx
968 }
969
970 fn handle_project_event(
971 &mut self,
972 project: Entity<Project>,
973 event: &project::Event,
974 cx: &mut Context<Self>,
975 ) {
976 // TODO [zeta2] init with recent paths
977 match event {
978 project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
979 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
980 return;
981 };
982 let path = project.read(cx).path_for_entry(*active_entry_id, cx);
983 if let Some(path) = path {
984 if let Some(ix) = project_state
985 .recent_paths
986 .iter()
987 .position(|probe| probe == &path)
988 {
989 project_state.recent_paths.remove(ix);
990 }
991 project_state.recent_paths.push_front(path);
992 }
993 }
994 project::Event::DiagnosticsUpdated { .. } => {
995 if cx.has_flag::<Zeta2FeatureFlag>() {
996 self.refresh_prediction_from_diagnostics(project, cx);
997 }
998 }
999 _ => (),
1000 }
1001 }
1002
1003 fn register_buffer_impl<'a>(
1004 project_state: &'a mut ProjectState,
1005 buffer: &Entity<Buffer>,
1006 project: &Entity<Project>,
1007 cx: &mut Context<Self>,
1008 ) -> &'a mut RegisteredBuffer {
1009 let buffer_id = buffer.entity_id();
1010
1011 if let Some(file) = buffer.read(cx).file() {
1012 let worktree_id = file.worktree_id(cx);
1013 if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
1014 project_state
1015 .license_detection_watchers
1016 .entry(worktree_id)
1017 .or_insert_with(|| {
1018 let project_entity_id = project.entity_id();
1019 cx.observe_release(&worktree, move |this, _worktree, _cx| {
1020 let Some(project_state) = this.projects.get_mut(&project_entity_id)
1021 else {
1022 return;
1023 };
1024 project_state
1025 .license_detection_watchers
1026 .remove(&worktree_id);
1027 })
1028 .detach();
1029 Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
1030 });
1031 }
1032 }
1033
1034 match project_state.registered_buffers.entry(buffer_id) {
1035 hash_map::Entry::Occupied(entry) => entry.into_mut(),
1036 hash_map::Entry::Vacant(entry) => {
1037 let buf = buffer.read(cx);
1038 let snapshot = buf.text_snapshot();
1039 let file = buf.file().cloned();
1040 let project_entity_id = project.entity_id();
1041 entry.insert(RegisteredBuffer {
1042 snapshot,
1043 file,
1044 last_position: None,
1045 _subscriptions: [
1046 cx.subscribe(buffer, {
1047 let project = project.downgrade();
1048 move |this, buffer, event, cx| {
1049 if let language::BufferEvent::Edited = event
1050 && let Some(project) = project.upgrade()
1051 {
1052 this.report_changes_for_buffer(&buffer, &project, cx);
1053 }
1054 }
1055 }),
1056 cx.observe_release(buffer, move |this, _buffer, _cx| {
1057 let Some(project_state) = this.projects.get_mut(&project_entity_id)
1058 else {
1059 return;
1060 };
1061 project_state.registered_buffers.remove(&buffer_id);
1062 }),
1063 ],
1064 })
1065 }
1066 }
1067 }
1068
1069 fn report_changes_for_buffer(
1070 &mut self,
1071 buffer: &Entity<Buffer>,
1072 project: &Entity<Project>,
1073 cx: &mut Context<Self>,
1074 ) {
1075 let project_state = self.get_or_init_project(project, cx);
1076 let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
1077
1078 let buf = buffer.read(cx);
1079 let new_file = buf.file().cloned();
1080 let new_snapshot = buf.text_snapshot();
1081 if new_snapshot.version == registered_buffer.snapshot.version {
1082 return;
1083 }
1084
1085 let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
1086 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1087 let mut num_edits = 0usize;
1088 let mut total_deleted = 0usize;
1089 let mut total_inserted = 0usize;
1090 let mut edit_range: Option<Range<Anchor>> = None;
1091 let mut last_offset: Option<usize> = None;
1092
1093 for (edit, anchor_range) in
1094 new_snapshot.anchored_edits_since::<usize>(&old_snapshot.version)
1095 {
1096 num_edits += 1;
1097 total_deleted += edit.old.len();
1098 total_inserted += edit.new.len();
1099 edit_range = Some(match edit_range {
1100 None => anchor_range,
1101 Some(acc) => acc.start..anchor_range.end,
1102 });
1103 last_offset = Some(edit.new.end);
1104 }
1105
1106 if num_edits > 0 {
1107 let action_type = match (total_deleted, total_inserted, num_edits) {
1108 (0, ins, n) if ins == n => UserActionType::InsertChar,
1109 (0, _, _) => UserActionType::InsertSelection,
1110 (del, 0, n) if del == n => UserActionType::DeleteChar,
1111 (_, 0, _) => UserActionType::DeleteSelection,
1112 (_, ins, n) if ins == n => UserActionType::InsertChar,
1113 (_, _, _) => UserActionType::InsertSelection,
1114 };
1115
1116 if let Some(offset) = last_offset {
1117 let point = new_snapshot.offset_to_point(offset);
1118 let timestamp_epoch_ms = SystemTime::now()
1119 .duration_since(UNIX_EPOCH)
1120 .map(|d| d.as_millis() as u64)
1121 .unwrap_or(0);
1122 project_state.record_user_action(UserActionRecord {
1123 action_type,
1124 buffer_id: buffer.entity_id(),
1125 line_number: point.row,
1126 offset,
1127 timestamp_epoch_ms,
1128 });
1129 }
1130 }
1131
1132 let events = &mut project_state.events;
1133
1134 let now = cx.background_executor().now();
1135 if let Some(last_event) = project_state.last_event.as_mut() {
1136 let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
1137 == last_event.new_snapshot.remote_id()
1138 && old_snapshot.version == last_event.new_snapshot.version;
1139
1140 let should_coalesce = is_next_snapshot_of_same_buffer
1141 && edit_range
1142 .as_ref()
1143 .zip(last_event.edit_range.as_ref())
1144 .is_some_and(|(a, b)| {
1145 let a = a.to_point(&new_snapshot);
1146 let b = b.to_point(&new_snapshot);
1147 if a.start > b.end {
1148 a.start.row.abs_diff(b.end.row) <= CHANGE_GROUPING_LINE_SPAN
1149 } else if b.start > a.end {
1150 b.start.row.abs_diff(a.end.row) <= CHANGE_GROUPING_LINE_SPAN
1151 } else {
1152 true
1153 }
1154 });
1155
1156 if should_coalesce {
1157 let pause_elapsed = last_event
1158 .last_edit_time
1159 .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
1160 .unwrap_or(false);
1161 if pause_elapsed {
1162 last_event.snapshot_after_last_editing_pause =
1163 Some(last_event.new_snapshot.clone());
1164 }
1165
1166 last_event.edit_range = edit_range;
1167 last_event.new_snapshot = new_snapshot;
1168 last_event.last_edit_time = Some(now);
1169 return;
1170 }
1171 }
1172
1173 if let Some(event) = project_state.last_event.take() {
1174 if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1175 if events.len() + 1 >= EVENT_COUNT_MAX {
1176 events.pop_front();
1177 }
1178 events.push_back(event);
1179 }
1180 }
1181
1182 project_state.last_event = Some(LastEvent {
1183 old_file,
1184 new_file,
1185 old_snapshot,
1186 new_snapshot,
1187 edit_range,
1188 snapshot_after_last_editing_pause: None,
1189 last_edit_time: Some(now),
1190 });
1191 }
1192
1193 fn prediction_at(
1194 &mut self,
1195 buffer: &Entity<Buffer>,
1196 position: Option<language::Anchor>,
1197 project: &Entity<Project>,
1198 cx: &App,
1199 ) -> Option<BufferEditPrediction<'_>> {
1200 let project_state = self.projects.get_mut(&project.entity_id())?;
1201 if let Some(position) = position
1202 && let Some(buffer) = project_state
1203 .registered_buffers
1204 .get_mut(&buffer.entity_id())
1205 {
1206 buffer.last_position = Some(position);
1207 }
1208
1209 let CurrentEditPrediction {
1210 requested_by,
1211 prediction,
1212 ..
1213 } = project_state.current_prediction.as_ref()?;
1214
1215 if prediction.targets_buffer(buffer.read(cx)) {
1216 Some(BufferEditPrediction::Local { prediction })
1217 } else {
1218 let show_jump = match requested_by {
1219 PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1220 requested_by_buffer_id == &buffer.entity_id()
1221 }
1222 PredictionRequestedBy::DiagnosticsUpdate => true,
1223 };
1224
1225 if show_jump {
1226 Some(BufferEditPrediction::Jump { prediction })
1227 } else {
1228 None
1229 }
1230 }
1231 }
1232
1233 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1234 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1235 return;
1236 };
1237
1238 let Some(current_prediction) = project_state.current_prediction.take() else {
1239 return;
1240 };
1241
1242 for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1243 project_state.cancel_pending_prediction(pending_prediction, cx);
1244 }
1245
1246 match self.edit_prediction_model {
1247 EditPredictionModel::Sweep => {
1248 sweep_ai::edit_prediction_accepted(self, current_prediction, cx)
1249 }
1250 EditPredictionModel::Mercury => {}
1251 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 { .. } => {
1252 zeta2::edit_prediction_accepted(self, current_prediction, cx)
1253 }
1254 }
1255 }
1256
1257 async fn handle_rejected_predictions(
1258 rx: UnboundedReceiver<EditPredictionRejection>,
1259 client: Arc<Client>,
1260 llm_token: LlmApiToken,
1261 app_version: Version,
1262 background_executor: BackgroundExecutor,
1263 ) {
1264 let mut rx = std::pin::pin!(rx.peekable());
1265 let mut batched = Vec::new();
1266
1267 while let Some(rejection) = rx.next().await {
1268 batched.push(rejection);
1269
1270 if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1271 select_biased! {
1272 next = rx.as_mut().peek().fuse() => {
1273 if next.is_some() {
1274 continue;
1275 }
1276 }
1277 () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1278 }
1279 }
1280
1281 let url = client
1282 .http_client()
1283 .build_zed_llm_url("/predict_edits/reject", &[])
1284 .unwrap();
1285
1286 let flush_count = batched
1287 .len()
1288 // in case items have accumulated after failure
1289 .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1290 let start = batched.len() - flush_count;
1291
1292 let body = RejectEditPredictionsBodyRef {
1293 rejections: &batched[start..],
1294 };
1295
1296 let result = Self::send_api_request::<()>(
1297 |builder| {
1298 let req = builder
1299 .uri(url.as_ref())
1300 .body(serde_json::to_string(&body)?.into());
1301 anyhow::Ok(req?)
1302 },
1303 client.clone(),
1304 llm_token.clone(),
1305 app_version.clone(),
1306 true,
1307 )
1308 .await;
1309
1310 if result.log_err().is_some() {
1311 batched.drain(start..);
1312 }
1313 }
1314 }
1315
1316 fn reject_current_prediction(
1317 &mut self,
1318 reason: EditPredictionRejectReason,
1319 project: &Entity<Project>,
1320 ) {
1321 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1322 project_state.pending_predictions.clear();
1323 if let Some(prediction) = project_state.current_prediction.take() {
1324 self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown);
1325 }
1326 };
1327 }
1328
1329 fn did_show_current_prediction(
1330 &mut self,
1331 project: &Entity<Project>,
1332 display_type: edit_prediction_types::SuggestionDisplayType,
1333 cx: &mut Context<Self>,
1334 ) {
1335 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1336 return;
1337 };
1338
1339 let Some(current_prediction) = project_state.current_prediction.as_mut() else {
1340 return;
1341 };
1342
1343 let is_jump = display_type == edit_prediction_types::SuggestionDisplayType::Jump;
1344 let previous_shown_with = current_prediction.shown_with;
1345
1346 if previous_shown_with.is_none() || !is_jump {
1347 current_prediction.shown_with = Some(display_type);
1348 }
1349
1350 let is_first_non_jump_show = !current_prediction.was_shown && !is_jump;
1351
1352 if is_first_non_jump_show {
1353 current_prediction.was_shown = true;
1354 }
1355
1356 let display_type_changed = previous_shown_with != Some(display_type);
1357
1358 if self.edit_prediction_model == EditPredictionModel::Sweep && display_type_changed {
1359 sweep_ai::edit_prediction_shown(
1360 &self.sweep_ai,
1361 self.client.clone(),
1362 ¤t_prediction.prediction,
1363 display_type,
1364 cx,
1365 );
1366 }
1367
1368 if is_first_non_jump_show {
1369 self.shown_predictions
1370 .push_front(current_prediction.prediction.clone());
1371 if self.shown_predictions.len() > 50 {
1372 let completion = self.shown_predictions.pop_back().unwrap();
1373 self.rated_predictions.remove(&completion.id);
1374 }
1375 }
1376 }
1377
1378 fn reject_prediction(
1379 &mut self,
1380 prediction_id: EditPredictionId,
1381 reason: EditPredictionRejectReason,
1382 was_shown: bool,
1383 ) {
1384 match self.edit_prediction_model {
1385 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 { .. } => {
1386 if self.custom_predict_edits_url.is_some() {
1387 return;
1388 }
1389 }
1390 EditPredictionModel::Sweep | EditPredictionModel::Mercury => return,
1391 }
1392
1393 self.reject_predictions_tx
1394 .unbounded_send(EditPredictionRejection {
1395 request_id: prediction_id.to_string(),
1396 reason,
1397 was_shown,
1398 })
1399 .log_err();
1400 }
1401
1402 fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1403 self.projects
1404 .get(&project.entity_id())
1405 .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1406 }
1407
1408 pub fn refresh_prediction_from_buffer(
1409 &mut self,
1410 project: Entity<Project>,
1411 buffer: Entity<Buffer>,
1412 position: language::Anchor,
1413 cx: &mut Context<Self>,
1414 ) {
1415 self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
1416 let Some(request_task) = this
1417 .update(cx, |this, cx| {
1418 this.request_prediction(
1419 &project,
1420 &buffer,
1421 position,
1422 PredictEditsRequestTrigger::Other,
1423 cx,
1424 )
1425 })
1426 .log_err()
1427 else {
1428 return Task::ready(anyhow::Ok(None));
1429 };
1430
1431 cx.spawn(async move |_cx| {
1432 request_task.await.map(|prediction_result| {
1433 prediction_result.map(|prediction_result| {
1434 (
1435 prediction_result,
1436 PredictionRequestedBy::Buffer(buffer.entity_id()),
1437 )
1438 })
1439 })
1440 })
1441 })
1442 }
1443
1444 pub fn refresh_prediction_from_diagnostics(
1445 &mut self,
1446 project: Entity<Project>,
1447 cx: &mut Context<Self>,
1448 ) {
1449 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1450 return;
1451 };
1452
1453 // Prefer predictions from buffer
1454 if project_state.current_prediction.is_some() {
1455 return;
1456 };
1457
1458 self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
1459 let Some((active_buffer, snapshot, cursor_point)) = this
1460 .read_with(cx, |this, cx| {
1461 let project_state = this.projects.get(&project.entity_id())?;
1462 let (buffer, position) = project_state.active_buffer(&project, cx)?;
1463 let snapshot = buffer.read(cx).snapshot();
1464
1465 if !Self::predictions_enabled_at(&snapshot, position, cx) {
1466 return None;
1467 }
1468
1469 let cursor_point = position
1470 .map(|pos| pos.to_point(&snapshot))
1471 .unwrap_or_default();
1472
1473 Some((buffer, snapshot, cursor_point))
1474 })
1475 .log_err()
1476 .flatten()
1477 else {
1478 return Task::ready(anyhow::Ok(None));
1479 };
1480
1481 cx.spawn(async move |cx| {
1482 let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1483 active_buffer,
1484 &snapshot,
1485 Default::default(),
1486 cursor_point,
1487 &project,
1488 cx,
1489 )
1490 .await?
1491 else {
1492 return anyhow::Ok(None);
1493 };
1494
1495 let Some(prediction_result) = this
1496 .update(cx, |this, cx| {
1497 this.request_prediction(
1498 &project,
1499 &jump_buffer,
1500 jump_position,
1501 PredictEditsRequestTrigger::Diagnostics,
1502 cx,
1503 )
1504 })?
1505 .await?
1506 else {
1507 return anyhow::Ok(None);
1508 };
1509
1510 this.update(cx, |this, cx| {
1511 Some((
1512 if this
1513 .get_or_init_project(&project, cx)
1514 .current_prediction
1515 .is_none()
1516 {
1517 prediction_result
1518 } else {
1519 EditPredictionResult {
1520 id: prediction_result.id,
1521 prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1522 }
1523 },
1524 PredictionRequestedBy::DiagnosticsUpdate,
1525 ))
1526 })
1527 })
1528 });
1529 }
1530
1531 fn predictions_enabled_at(
1532 snapshot: &BufferSnapshot,
1533 position: Option<language::Anchor>,
1534 cx: &App,
1535 ) -> bool {
1536 let file = snapshot.file();
1537 let all_settings = all_language_settings(file, cx);
1538 if !all_settings.show_edit_predictions(snapshot.language(), cx)
1539 || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
1540 {
1541 return false;
1542 }
1543
1544 if let Some(last_position) = position {
1545 let settings = snapshot.settings_at(last_position, cx);
1546
1547 if !settings.edit_predictions_disabled_in.is_empty()
1548 && let Some(scope) = snapshot.language_scope_at(last_position)
1549 && let Some(scope_name) = scope.override_name()
1550 && settings
1551 .edit_predictions_disabled_in
1552 .iter()
1553 .any(|s| s == scope_name)
1554 {
1555 return false;
1556 }
1557 }
1558
1559 true
1560 }
1561
1562 #[cfg(not(test))]
1563 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1564 #[cfg(test)]
1565 pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
1566
1567 fn queue_prediction_refresh(
1568 &mut self,
1569 project: Entity<Project>,
1570 throttle_entity: EntityId,
1571 cx: &mut Context<Self>,
1572 do_refresh: impl FnOnce(
1573 WeakEntity<Self>,
1574 &mut AsyncApp,
1575 )
1576 -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1577 + 'static,
1578 ) {
1579 let project_state = self.get_or_init_project(&project, cx);
1580 let pending_prediction_id = project_state.next_pending_prediction_id;
1581 project_state.next_pending_prediction_id += 1;
1582 let last_request = project_state.last_prediction_refresh;
1583
1584 let task = cx.spawn(async move |this, cx| {
1585 if let Some((last_entity, last_timestamp)) = last_request
1586 && throttle_entity == last_entity
1587 && let Some(timeout) =
1588 (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
1589 {
1590 cx.background_executor().timer(timeout).await;
1591 }
1592
1593 // If this task was cancelled before the throttle timeout expired,
1594 // do not perform a request.
1595 let mut is_cancelled = true;
1596 this.update(cx, |this, cx| {
1597 let project_state = this.get_or_init_project(&project, cx);
1598 if !project_state
1599 .cancelled_predictions
1600 .remove(&pending_prediction_id)
1601 {
1602 project_state.last_prediction_refresh = Some((throttle_entity, Instant::now()));
1603 is_cancelled = false;
1604 }
1605 })
1606 .ok();
1607 if is_cancelled {
1608 return None;
1609 }
1610
1611 let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
1612 let new_prediction_id = new_prediction_result
1613 .as_ref()
1614 .map(|(prediction, _)| prediction.id.clone());
1615
1616 // When a prediction completes, remove it from the pending list, and cancel
1617 // any pending predictions that were enqueued before it.
1618 this.update(cx, |this, cx| {
1619 let project_state = this.get_or_init_project(&project, cx);
1620
1621 let is_cancelled = project_state
1622 .cancelled_predictions
1623 .remove(&pending_prediction_id);
1624
1625 let new_current_prediction = if !is_cancelled
1626 && let Some((prediction_result, requested_by)) = new_prediction_result
1627 {
1628 match prediction_result.prediction {
1629 Ok(prediction) => {
1630 let new_prediction = CurrentEditPrediction {
1631 requested_by,
1632 prediction,
1633 was_shown: false,
1634 shown_with: None,
1635 };
1636
1637 if let Some(current_prediction) =
1638 project_state.current_prediction.as_ref()
1639 {
1640 if new_prediction.should_replace_prediction(¤t_prediction, cx)
1641 {
1642 this.reject_current_prediction(
1643 EditPredictionRejectReason::Replaced,
1644 &project,
1645 );
1646
1647 Some(new_prediction)
1648 } else {
1649 this.reject_prediction(
1650 new_prediction.prediction.id,
1651 EditPredictionRejectReason::CurrentPreferred,
1652 false,
1653 );
1654 None
1655 }
1656 } else {
1657 Some(new_prediction)
1658 }
1659 }
1660 Err(reject_reason) => {
1661 this.reject_prediction(prediction_result.id, reject_reason, false);
1662 None
1663 }
1664 }
1665 } else {
1666 None
1667 };
1668
1669 let project_state = this.get_or_init_project(&project, cx);
1670
1671 if let Some(new_prediction) = new_current_prediction {
1672 project_state.current_prediction = Some(new_prediction);
1673 }
1674
1675 let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
1676 for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1677 if pending_prediction.id == pending_prediction_id {
1678 pending_predictions.remove(ix);
1679 for pending_prediction in pending_predictions.drain(0..ix) {
1680 project_state.cancel_pending_prediction(pending_prediction, cx)
1681 }
1682 break;
1683 }
1684 }
1685 this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
1686 cx.notify();
1687 })
1688 .ok();
1689
1690 new_prediction_id
1691 });
1692
1693 if project_state.pending_predictions.len() <= 1 {
1694 project_state.pending_predictions.push(PendingPrediction {
1695 id: pending_prediction_id,
1696 task,
1697 });
1698 } else if project_state.pending_predictions.len() == 2 {
1699 let pending_prediction = project_state.pending_predictions.pop().unwrap();
1700 project_state.pending_predictions.push(PendingPrediction {
1701 id: pending_prediction_id,
1702 task,
1703 });
1704 project_state.cancel_pending_prediction(pending_prediction, cx);
1705 }
1706 }
1707
1708 pub fn request_prediction(
1709 &mut self,
1710 project: &Entity<Project>,
1711 active_buffer: &Entity<Buffer>,
1712 position: language::Anchor,
1713 trigger: PredictEditsRequestTrigger,
1714 cx: &mut Context<Self>,
1715 ) -> Task<Result<Option<EditPredictionResult>>> {
1716 self.request_prediction_internal(
1717 project.clone(),
1718 active_buffer.clone(),
1719 position,
1720 trigger,
1721 cx.has_flag::<Zeta2FeatureFlag>(),
1722 cx,
1723 )
1724 }
1725
1726 fn request_prediction_internal(
1727 &mut self,
1728 project: Entity<Project>,
1729 active_buffer: Entity<Buffer>,
1730 position: language::Anchor,
1731 trigger: PredictEditsRequestTrigger,
1732 allow_jump: bool,
1733 cx: &mut Context<Self>,
1734 ) -> Task<Result<Option<EditPredictionResult>>> {
1735 const DIAGNOSTIC_LINES_RANGE: u32 = 20;
1736
1737 self.get_or_init_project(&project, cx);
1738 let project_state = self.projects.get(&project.entity_id()).unwrap();
1739 let stored_events = project_state.events(cx);
1740 let has_events = !stored_events.is_empty();
1741 let events: Vec<Arc<zeta_prompt::Event>> =
1742 stored_events.into_iter().map(|e| e.event).collect();
1743 let debug_tx = project_state.debug_tx.clone();
1744
1745 let snapshot = active_buffer.read(cx).snapshot();
1746 let cursor_point = position.to_point(&snapshot);
1747 let current_offset = position.to_offset(&snapshot);
1748
1749 let mut user_actions: Vec<UserActionRecord> =
1750 project_state.user_actions.iter().cloned().collect();
1751
1752 if let Some(last_action) = user_actions.last() {
1753 if last_action.buffer_id == active_buffer.entity_id()
1754 && current_offset != last_action.offset
1755 {
1756 let timestamp_epoch_ms = SystemTime::now()
1757 .duration_since(UNIX_EPOCH)
1758 .map(|d| d.as_millis() as u64)
1759 .unwrap_or(0);
1760 user_actions.push(UserActionRecord {
1761 action_type: UserActionType::CursorMovement,
1762 buffer_id: active_buffer.entity_id(),
1763 line_number: cursor_point.row,
1764 offset: current_offset,
1765 timestamp_epoch_ms,
1766 });
1767 }
1768 }
1769 let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1770 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1771 let diagnostic_search_range =
1772 Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1773
1774 let related_files = if self.use_context {
1775 self.context_for_project(&project, cx)
1776 } else {
1777 Vec::new()
1778 };
1779
1780 let inputs = EditPredictionModelInput {
1781 project: project.clone(),
1782 buffer: active_buffer.clone(),
1783 snapshot: snapshot.clone(),
1784 position,
1785 events,
1786 related_files,
1787 recent_paths: project_state.recent_paths.clone(),
1788 trigger,
1789 diagnostic_search_range: diagnostic_search_range.clone(),
1790 debug_tx,
1791 user_actions,
1792 };
1793
1794 let can_collect_example = snapshot
1795 .file()
1796 .is_some_and(|file| self.can_collect_file(&project, file, cx))
1797 && self.can_collect_events(&inputs.events, cx);
1798
1799 if can_collect_example && should_sample_edit_prediction_example_capture(cx) {
1800 let events_for_capture =
1801 self.edit_history_for_project_with_pause_split_last_event(&project, cx);
1802 if let Some(example_task) = capture_example::capture_example(
1803 project.clone(),
1804 active_buffer.clone(),
1805 position,
1806 events_for_capture,
1807 false,
1808 cx,
1809 ) {
1810 cx.spawn(async move |_this, _cx| {
1811 let example = example_task.await?;
1812 telemetry::event!("Edit Prediction Example Captured", example = example);
1813 anyhow::Ok(())
1814 })
1815 .detach_and_log_err(cx);
1816 }
1817 }
1818 let task = match self.edit_prediction_model {
1819 EditPredictionModel::Zeta1 => zeta1::request_prediction_with_zeta1(self, inputs, cx),
1820 EditPredictionModel::Zeta2 { version } => {
1821 zeta2::request_prediction_with_zeta2(self, inputs, version, cx)
1822 }
1823 EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
1824 EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
1825 };
1826
1827 cx.spawn(async move |this, cx| {
1828 let prediction = task.await?;
1829
1830 if prediction.is_none() && allow_jump {
1831 let cursor_point = position.to_point(&snapshot);
1832 if has_events
1833 && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1834 active_buffer.clone(),
1835 &snapshot,
1836 diagnostic_search_range,
1837 cursor_point,
1838 &project,
1839 cx,
1840 )
1841 .await?
1842 {
1843 return this
1844 .update(cx, |this, cx| {
1845 this.request_prediction_internal(
1846 project,
1847 jump_buffer,
1848 jump_position,
1849 trigger,
1850 false,
1851 cx,
1852 )
1853 })?
1854 .await;
1855 }
1856
1857 return anyhow::Ok(None);
1858 }
1859
1860 Ok(prediction)
1861 })
1862 }
1863
1864 async fn next_diagnostic_location(
1865 active_buffer: Entity<Buffer>,
1866 active_buffer_snapshot: &BufferSnapshot,
1867 active_buffer_diagnostic_search_range: Range<Point>,
1868 active_buffer_cursor_point: Point,
1869 project: &Entity<Project>,
1870 cx: &mut AsyncApp,
1871 ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1872 // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
1873 let mut jump_location = active_buffer_snapshot
1874 .diagnostic_groups(None)
1875 .into_iter()
1876 .filter_map(|(_, group)| {
1877 let range = &group.entries[group.primary_ix]
1878 .range
1879 .to_point(&active_buffer_snapshot);
1880 if range.overlaps(&active_buffer_diagnostic_search_range) {
1881 None
1882 } else {
1883 Some(range.start)
1884 }
1885 })
1886 .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
1887 .map(|position| {
1888 (
1889 active_buffer.clone(),
1890 active_buffer_snapshot.anchor_before(position),
1891 )
1892 });
1893
1894 if jump_location.is_none() {
1895 let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
1896 let file = buffer.file()?;
1897
1898 Some(ProjectPath {
1899 worktree_id: file.worktree_id(cx),
1900 path: file.path().clone(),
1901 })
1902 });
1903
1904 let buffer_task = project.update(cx, |project, cx| {
1905 let (path, _, _) = project
1906 .diagnostic_summaries(false, cx)
1907 .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
1908 .max_by_key(|(path, _, _)| {
1909 // find the buffer with errors that shares most parent directories
1910 path.path
1911 .components()
1912 .zip(
1913 active_buffer_path
1914 .as_ref()
1915 .map(|p| p.path.components())
1916 .unwrap_or_default(),
1917 )
1918 .take_while(|(a, b)| a == b)
1919 .count()
1920 })?;
1921
1922 Some(project.open_buffer(path, cx))
1923 });
1924
1925 if let Some(buffer_task) = buffer_task {
1926 let closest_buffer = buffer_task.await?;
1927
1928 jump_location = closest_buffer
1929 .read_with(cx, |buffer, _cx| {
1930 buffer
1931 .buffer_diagnostics(None)
1932 .into_iter()
1933 .min_by_key(|entry| entry.diagnostic.severity)
1934 .map(|entry| entry.range.start)
1935 })
1936 .map(|position| (closest_buffer, position));
1937 }
1938 }
1939
1940 anyhow::Ok(jump_location)
1941 }
1942
1943 async fn send_raw_llm_request(
1944 request: RawCompletionRequest,
1945 client: Arc<Client>,
1946 custom_url: Option<Arc<Url>>,
1947 llm_token: LlmApiToken,
1948 app_version: Version,
1949 #[cfg(feature = "cli-support")] eval_cache: Option<Arc<dyn EvalCache>>,
1950 #[cfg(feature = "cli-support")] eval_cache_kind: EvalCacheEntryKind,
1951 ) -> Result<(RawCompletionResponse, Option<EditPredictionUsage>)> {
1952 let url = if let Some(custom_url) = custom_url {
1953 custom_url.as_ref().clone()
1954 } else {
1955 client
1956 .http_client()
1957 .build_zed_llm_url("/predict_edits/raw", &[])?
1958 };
1959
1960 #[cfg(feature = "cli-support")]
1961 let cache_key = if let Some(cache) = eval_cache {
1962 use collections::FxHasher;
1963 use std::hash::{Hash, Hasher};
1964
1965 let mut hasher = FxHasher::default();
1966 url.hash(&mut hasher);
1967 let request_str = serde_json::to_string_pretty(&request)?;
1968 request_str.hash(&mut hasher);
1969 let hash = hasher.finish();
1970
1971 let key = (eval_cache_kind, hash);
1972 if let Some(response_str) = cache.read(key) {
1973 return Ok((serde_json::from_str(&response_str)?, None));
1974 }
1975
1976 Some((cache, request_str, key))
1977 } else {
1978 None
1979 };
1980
1981 let (response, usage) = Self::send_api_request(
1982 |builder| {
1983 let req = builder
1984 .uri(url.as_ref())
1985 .body(serde_json::to_string(&request)?.into());
1986 Ok(req?)
1987 },
1988 client,
1989 llm_token,
1990 app_version,
1991 true,
1992 )
1993 .await?;
1994
1995 #[cfg(feature = "cli-support")]
1996 if let Some((cache, request, key)) = cache_key {
1997 cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
1998 }
1999
2000 Ok((response, usage))
2001 }
2002
2003 fn handle_api_response<T>(
2004 this: &WeakEntity<Self>,
2005 response: Result<(T, Option<EditPredictionUsage>)>,
2006 cx: &mut gpui::AsyncApp,
2007 ) -> Result<T> {
2008 match response {
2009 Ok((data, usage)) => {
2010 if let Some(usage) = usage {
2011 this.update(cx, |this, cx| {
2012 this.user_store.update(cx, |user_store, cx| {
2013 user_store.update_edit_prediction_usage(usage, cx);
2014 });
2015 })
2016 .ok();
2017 }
2018 Ok(data)
2019 }
2020 Err(err) => {
2021 if err.is::<ZedUpdateRequiredError>() {
2022 cx.update(|cx| {
2023 this.update(cx, |this, _cx| {
2024 this.update_required = true;
2025 })
2026 .ok();
2027
2028 let error_message: SharedString = err.to_string().into();
2029 show_app_notification(
2030 NotificationId::unique::<ZedUpdateRequiredError>(),
2031 cx,
2032 move |cx| {
2033 cx.new(|cx| {
2034 ErrorMessagePrompt::new(error_message.clone(), cx)
2035 .with_link_button("Update Zed", "https://zed.dev/releases")
2036 })
2037 },
2038 );
2039 });
2040 }
2041 Err(err)
2042 }
2043 }
2044 }
2045
2046 async fn send_api_request<Res>(
2047 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
2048 client: Arc<Client>,
2049 llm_token: LlmApiToken,
2050 app_version: Version,
2051 require_auth: bool,
2052 ) -> Result<(Res, Option<EditPredictionUsage>)>
2053 where
2054 Res: DeserializeOwned,
2055 {
2056 let http_client = client.http_client();
2057
2058 let mut token = if let Ok(custom_token) = std::env::var("ZED_PREDICT_EDITS_TOKEN") {
2059 Some(custom_token)
2060 } else if require_auth {
2061 Some(llm_token.acquire(&client).await?)
2062 } else {
2063 llm_token.acquire(&client).await.ok()
2064 };
2065 let mut did_retry = false;
2066
2067 loop {
2068 let request_builder = http_client::Request::builder().method(Method::POST);
2069
2070 let mut request_builder = request_builder
2071 .header("Content-Type", "application/json")
2072 .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
2073
2074 // Only add Authorization header if we have a token
2075 if let Some(ref token_value) = token {
2076 request_builder =
2077 request_builder.header("Authorization", format!("Bearer {}", token_value));
2078 }
2079
2080 let request = build(request_builder)?;
2081
2082 let mut response = http_client.send(request).await?;
2083
2084 if let Some(minimum_required_version) = response
2085 .headers()
2086 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2087 .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2088 {
2089 anyhow::ensure!(
2090 app_version >= minimum_required_version,
2091 ZedUpdateRequiredError {
2092 minimum_version: minimum_required_version
2093 }
2094 );
2095 }
2096
2097 if response.status().is_success() {
2098 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2099
2100 let mut body = Vec::new();
2101 response.body_mut().read_to_end(&mut body).await?;
2102 return Ok((serde_json::from_slice(&body)?, usage));
2103 } else if !did_retry
2104 && token.is_some()
2105 && response
2106 .headers()
2107 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
2108 .is_some()
2109 {
2110 did_retry = true;
2111 token = Some(llm_token.refresh(&client).await?);
2112 } else {
2113 let mut body = String::new();
2114 response.body_mut().read_to_string(&mut body).await?;
2115 anyhow::bail!(
2116 "Request failed with status: {:?}\nBody: {}",
2117 response.status(),
2118 body
2119 );
2120 }
2121 }
2122 }
2123
2124 pub fn refresh_context(
2125 &mut self,
2126 project: &Entity<Project>,
2127 buffer: &Entity<language::Buffer>,
2128 cursor_position: language::Anchor,
2129 cx: &mut Context<Self>,
2130 ) {
2131 if self.use_context {
2132 self.get_or_init_project(project, cx)
2133 .context
2134 .update(cx, |store, cx| {
2135 store.refresh(buffer.clone(), cursor_position, cx);
2136 });
2137 }
2138 }
2139
2140 #[cfg(feature = "cli-support")]
2141 pub fn set_context_for_buffer(
2142 &mut self,
2143 project: &Entity<Project>,
2144 related_files: Vec<RelatedFile>,
2145 cx: &mut Context<Self>,
2146 ) {
2147 self.get_or_init_project(project, cx)
2148 .context
2149 .update(cx, |store, cx| {
2150 store.set_related_files(related_files, cx);
2151 });
2152 }
2153
2154 fn is_file_open_source(
2155 &self,
2156 project: &Entity<Project>,
2157 file: &Arc<dyn File>,
2158 cx: &App,
2159 ) -> bool {
2160 if !file.is_local() || file.is_private() {
2161 return false;
2162 }
2163 let Some(project_state) = self.projects.get(&project.entity_id()) else {
2164 return false;
2165 };
2166 project_state
2167 .license_detection_watchers
2168 .get(&file.worktree_id(cx))
2169 .as_ref()
2170 .is_some_and(|watcher| watcher.is_project_open_source())
2171 }
2172
2173 fn can_collect_file(&self, project: &Entity<Project>, file: &Arc<dyn File>, cx: &App) -> bool {
2174 self.data_collection_choice.is_enabled(cx) && self.is_file_open_source(project, file, cx)
2175 }
2176
2177 fn can_collect_events(&self, events: &[Arc<zeta_prompt::Event>], cx: &App) -> bool {
2178 if !self.data_collection_choice.is_enabled(cx) {
2179 return false;
2180 }
2181 events.iter().all(|event| {
2182 matches!(
2183 event.as_ref(),
2184 zeta_prompt::Event::BufferChange {
2185 in_open_source_repo: true,
2186 ..
2187 }
2188 )
2189 })
2190 }
2191
2192 fn load_data_collection_choice() -> DataCollectionChoice {
2193 let choice = KEY_VALUE_STORE
2194 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2195 .log_err()
2196 .flatten();
2197
2198 match choice.as_deref() {
2199 Some("true") => DataCollectionChoice::Enabled,
2200 Some("false") => DataCollectionChoice::Disabled,
2201 Some(_) => {
2202 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2203 DataCollectionChoice::NotAnswered
2204 }
2205 None => DataCollectionChoice::NotAnswered,
2206 }
2207 }
2208
2209 fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
2210 self.data_collection_choice = self.data_collection_choice.toggle();
2211 let new_choice = self.data_collection_choice;
2212 let is_enabled = new_choice.is_enabled(cx);
2213 db::write_and_log(cx, move || {
2214 KEY_VALUE_STORE.write_kvp(
2215 ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
2216 is_enabled.to_string(),
2217 )
2218 });
2219 }
2220
2221 pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2222 self.shown_predictions.iter()
2223 }
2224
2225 pub fn shown_completions_len(&self) -> usize {
2226 self.shown_predictions.len()
2227 }
2228
2229 pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2230 self.rated_predictions.contains(id)
2231 }
2232
2233 pub fn rate_prediction(
2234 &mut self,
2235 prediction: &EditPrediction,
2236 rating: EditPredictionRating,
2237 feedback: String,
2238 cx: &mut Context<Self>,
2239 ) {
2240 self.rated_predictions.insert(prediction.id.clone());
2241 telemetry::event!(
2242 "Edit Prediction Rated",
2243 rating,
2244 inputs = prediction.inputs,
2245 output = prediction
2246 .edit_preview
2247 .as_unified_diff(prediction.snapshot.file(), &prediction.edits),
2248 feedback
2249 );
2250 self.client.telemetry().flush_events().detach();
2251 cx.notify();
2252 }
2253
2254 fn configure_context_retrieval(&mut self, cx: &mut Context<'_, EditPredictionStore>) {
2255 self.use_context = cx.has_flag::<Zeta2FeatureFlag>()
2256 && all_language_settings(None, cx).edit_predictions.use_context;
2257 }
2258}
2259
2260pub(crate) fn filter_redundant_excerpts(
2261 mut related_files: Vec<RelatedFile>,
2262 cursor_path: &Path,
2263 cursor_row_range: Range<u32>,
2264) -> Vec<RelatedFile> {
2265 for file in &mut related_files {
2266 if file.path.as_ref() == cursor_path {
2267 file.excerpts.retain(|excerpt| {
2268 excerpt.row_range.start < cursor_row_range.start
2269 || excerpt.row_range.end > cursor_row_range.end
2270 });
2271 }
2272 }
2273 related_files.retain(|file| !file.excerpts.is_empty());
2274 related_files
2275}
2276
2277#[derive(Error, Debug)]
2278#[error(
2279 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2280)]
2281pub struct ZedUpdateRequiredError {
2282 minimum_version: Version,
2283}
2284
2285#[cfg(feature = "cli-support")]
2286pub type EvalCacheKey = (EvalCacheEntryKind, u64);
2287
2288#[cfg(feature = "cli-support")]
2289#[derive(Debug, Clone, Copy, PartialEq)]
2290pub enum EvalCacheEntryKind {
2291 Context,
2292 Search,
2293 Prediction,
2294}
2295
2296#[cfg(feature = "cli-support")]
2297impl std::fmt::Display for EvalCacheEntryKind {
2298 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2299 match self {
2300 EvalCacheEntryKind::Search => write!(f, "search"),
2301 EvalCacheEntryKind::Context => write!(f, "context"),
2302 EvalCacheEntryKind::Prediction => write!(f, "prediction"),
2303 }
2304 }
2305}
2306
2307#[cfg(feature = "cli-support")]
2308pub trait EvalCache: Send + Sync {
2309 fn read(&self, key: EvalCacheKey) -> Option<String>;
2310 fn write(&self, key: EvalCacheKey, input: &str, value: &str);
2311}
2312
2313#[derive(Debug, Clone, Copy)]
2314pub enum DataCollectionChoice {
2315 NotAnswered,
2316 Enabled,
2317 Disabled,
2318}
2319
2320impl DataCollectionChoice {
2321 pub fn is_enabled(self, cx: &App) -> bool {
2322 if cx.is_staff() {
2323 return true;
2324 }
2325 match self {
2326 Self::Enabled => true,
2327 Self::NotAnswered | Self::Disabled => false,
2328 }
2329 }
2330
2331 #[must_use]
2332 pub fn toggle(&self) -> DataCollectionChoice {
2333 match self {
2334 Self::Enabled => Self::Disabled,
2335 Self::Disabled => Self::Enabled,
2336 Self::NotAnswered => Self::Enabled,
2337 }
2338 }
2339}
2340
2341impl From<bool> for DataCollectionChoice {
2342 fn from(value: bool) -> Self {
2343 match value {
2344 true => DataCollectionChoice::Enabled,
2345 false => DataCollectionChoice::Disabled,
2346 }
2347 }
2348}
2349
2350struct ZedPredictUpsell;
2351
2352impl Dismissable for ZedPredictUpsell {
2353 const KEY: &'static str = "dismissed-edit-predict-upsell";
2354
2355 fn dismissed() -> bool {
2356 // To make this backwards compatible with older versions of Zed, we
2357 // check if the user has seen the previous Edit Prediction Onboarding
2358 // before, by checking the data collection choice which was written to
2359 // the database once the user clicked on "Accept and Enable"
2360 if KEY_VALUE_STORE
2361 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2362 .log_err()
2363 .is_some_and(|s| s.is_some())
2364 {
2365 return true;
2366 }
2367
2368 KEY_VALUE_STORE
2369 .read_kvp(Self::KEY)
2370 .log_err()
2371 .is_some_and(|s| s.is_some())
2372 }
2373}
2374
2375pub fn should_show_upsell_modal() -> bool {
2376 !ZedPredictUpsell::dismissed()
2377}
2378
2379pub fn init(cx: &mut App) {
2380 cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2381 workspace.register_action(
2382 move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2383 ZedPredictModal::toggle(
2384 workspace,
2385 workspace.user_store().clone(),
2386 workspace.client().clone(),
2387 window,
2388 cx,
2389 )
2390 },
2391 );
2392
2393 workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2394 update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2395 settings
2396 .project
2397 .all_languages
2398 .features
2399 .get_or_insert_default()
2400 .edit_prediction_provider = Some(EditPredictionProvider::None)
2401 });
2402 });
2403 })
2404 .detach();
2405}