1use anyhow::Result;
2use arrayvec::ArrayVec;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{
5 self, PromptFormat, RawCompletionRequest, RawCompletionResponse,
6};
7use cloud_llm_client::{
8 EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason, EditPredictionRejection,
9 MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
10 PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME,
11};
12use collections::{HashMap, HashSet};
13use copilot::{Copilot, Reinstall, SignIn, SignOut};
14use db::kvp::{Dismissable, KEY_VALUE_STORE};
15use edit_prediction_context::EditPredictionExcerptOptions;
16use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
17use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
18use futures::{
19 AsyncReadExt as _, FutureExt as _, StreamExt as _,
20 channel::mpsc::{self, UnboundedReceiver},
21 select_biased,
22};
23use gpui::BackgroundExecutor;
24use gpui::http_client::Url;
25use gpui::{
26 App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
27 http_client::{self, AsyncBody, Method},
28 prelude::*,
29};
30use language::language_settings::all_language_settings;
31use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
32use language::{BufferSnapshot, OffsetRangeExt};
33use language_model::{LlmApiToken, RefreshLlmTokenListener};
34use project::{Project, ProjectPath, WorktreeId};
35use release_channel::AppVersion;
36use semver::Version;
37use serde::de::DeserializeOwned;
38use settings::{EditPredictionProvider, update_settings_file};
39use std::collections::{VecDeque, hash_map};
40use text::Edit;
41use workspace::Workspace;
42use zeta_prompt::ZetaVersion;
43
44use std::ops::Range;
45use std::path::Path;
46use std::rc::Rc;
47use std::str::FromStr as _;
48use std::sync::{Arc, LazyLock};
49use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
50use std::{env, mem};
51use thiserror::Error;
52use util::{RangeExt as _, ResultExt as _};
53use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
54
55pub mod cursor_excerpt;
56pub mod example_spec;
57mod license_detection;
58pub mod mercury;
59mod onboarding_modal;
60pub mod open_ai_response;
61mod prediction;
62pub mod sweep_ai;
63
64pub mod udiff;
65
66mod capture_example;
67mod zed_edit_prediction_delegate;
68pub mod zeta1;
69pub mod zeta2;
70
71#[cfg(test)]
72mod edit_prediction_tests;
73
74use crate::capture_example::{
75 should_sample_edit_prediction_example_capture, should_send_testing_zeta2_request,
76};
77use crate::license_detection::LicenseDetectionWatcher;
78use crate::mercury::Mercury;
79use crate::onboarding_modal::ZedPredictModal;
80pub use crate::prediction::EditPrediction;
81pub use crate::prediction::EditPredictionId;
82use crate::prediction::EditPredictionResult;
83pub use crate::sweep_ai::SweepAi;
84pub use capture_example::capture_example;
85pub use language_model::ApiKeyState;
86pub use telemetry_events::EditPredictionRating;
87pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
88
89actions!(
90 edit_prediction,
91 [
92 /// Resets the edit prediction onboarding state.
93 ResetOnboarding,
94 /// Clears the edit prediction history.
95 ClearHistory,
96 ]
97);
98
99/// Maximum number of events to track.
100const EVENT_COUNT_MAX: usize = 6;
101const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
102const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
103const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
104const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
105
106pub struct SweepFeatureFlag;
107
108impl FeatureFlag for SweepFeatureFlag {
109 const NAME: &str = "sweep-ai";
110}
111
112pub struct MercuryFeatureFlag;
113
114impl FeatureFlag for MercuryFeatureFlag {
115 const NAME: &str = "mercury";
116}
117
118pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
119 context: EditPredictionExcerptOptions {
120 max_bytes: 512,
121 min_bytes: 128,
122 target_before_cursor_over_total_bytes: 0.5,
123 },
124 prompt_format: PromptFormat::DEFAULT,
125};
126
127static USE_OLLAMA: LazyLock<bool> =
128 LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
129
130static EDIT_PREDICTIONS_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
131 match env::var("ZED_ZETA2_MODEL").as_deref() {
132 Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten
133 Ok(model) => model,
134 Err(_) if *USE_OLLAMA => "qwen3-coder:30b",
135 Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten
136 }
137 .to_string()
138});
139
140pub struct Zeta2FeatureFlag;
141
142impl FeatureFlag for Zeta2FeatureFlag {
143 const NAME: &'static str = "zeta2";
144
145 fn enabled_for_staff() -> bool {
146 true
147 }
148}
149
150pub struct EditPredictionExampleCaptureFeatureFlag;
151
152impl FeatureFlag for EditPredictionExampleCaptureFeatureFlag {
153 const NAME: &'static str = "edit-prediction-example-capture";
154
155 fn enabled_for_staff() -> bool {
156 true
157 }
158}
159
160#[derive(Clone)]
161struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
162
163impl Global for EditPredictionStoreGlobal {}
164
165pub struct EditPredictionStore {
166 client: Arc<Client>,
167 user_store: Entity<UserStore>,
168 llm_token: LlmApiToken,
169 _llm_token_subscription: Subscription,
170 projects: HashMap<EntityId, ProjectState>,
171 options: ZetaOptions,
172 update_required: bool,
173 #[cfg(feature = "cli-support")]
174 eval_cache: Option<Arc<dyn EvalCache>>,
175 edit_prediction_model: EditPredictionModel,
176 pub sweep_ai: SweepAi,
177 pub mercury: Mercury,
178 data_collection_choice: DataCollectionChoice,
179 reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
180 shown_predictions: VecDeque<EditPrediction>,
181 rated_predictions: HashSet<EditPredictionId>,
182 custom_predict_edits_url: Option<Arc<Url>>,
183}
184
185#[derive(Copy, Clone, Default, PartialEq, Eq)]
186pub enum EditPredictionModel {
187 #[default]
188 Zeta1,
189 Zeta2 {
190 version: ZetaVersion,
191 },
192 Sweep,
193 Mercury,
194}
195
196#[derive(Clone)]
197pub struct EditPredictionModelInput {
198 project: Entity<Project>,
199 buffer: Entity<Buffer>,
200 snapshot: BufferSnapshot,
201 position: Anchor,
202 events: Vec<Arc<zeta_prompt::Event>>,
203 related_files: Vec<RelatedFile>,
204 recent_paths: VecDeque<ProjectPath>,
205 trigger: PredictEditsRequestTrigger,
206 diagnostic_search_range: Range<Point>,
207 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
208 pub user_actions: Vec<UserActionRecord>,
209}
210
211#[derive(Debug, Clone, PartialEq)]
212pub struct ZetaOptions {
213 pub context: EditPredictionExcerptOptions,
214 pub prompt_format: predict_edits_v3::PromptFormat,
215}
216
217#[derive(Debug)]
218pub enum DebugEvent {
219 ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
220 ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
221 EditPredictionStarted(EditPredictionStartedDebugEvent),
222 EditPredictionFinished(EditPredictionFinishedDebugEvent),
223}
224
225#[derive(Debug)]
226pub struct ContextRetrievalStartedDebugEvent {
227 pub project_entity_id: EntityId,
228 pub timestamp: Instant,
229 pub search_prompt: String,
230}
231
232#[derive(Debug)]
233pub struct ContextRetrievalFinishedDebugEvent {
234 pub project_entity_id: EntityId,
235 pub timestamp: Instant,
236 pub metadata: Vec<(&'static str, SharedString)>,
237}
238
239#[derive(Debug)]
240pub struct EditPredictionStartedDebugEvent {
241 pub buffer: WeakEntity<Buffer>,
242 pub position: Anchor,
243 pub prompt: Option<String>,
244}
245
246#[derive(Debug)]
247pub struct EditPredictionFinishedDebugEvent {
248 pub buffer: WeakEntity<Buffer>,
249 pub position: Anchor,
250 pub model_output: Option<String>,
251}
252
253pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
254
255const USER_ACTION_HISTORY_SIZE: usize = 16;
256
257#[derive(Clone, Debug)]
258pub struct UserActionRecord {
259 pub action_type: UserActionType,
260 pub buffer_id: EntityId,
261 pub line_number: u32,
262 pub offset: usize,
263 pub timestamp_epoch_ms: u64,
264}
265
266#[derive(Clone, Copy, Debug, PartialEq, Eq)]
267pub enum UserActionType {
268 InsertChar,
269 InsertSelection,
270 DeleteChar,
271 DeleteSelection,
272 CursorMovement,
273}
274
275/// An event with associated metadata for reconstructing buffer state.
276#[derive(Clone)]
277pub struct StoredEvent {
278 pub event: Arc<zeta_prompt::Event>,
279 pub old_snapshot: TextBufferSnapshot,
280}
281
282struct ProjectState {
283 events: VecDeque<StoredEvent>,
284 last_event: Option<LastEvent>,
285 recent_paths: VecDeque<ProjectPath>,
286 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
287 current_prediction: Option<CurrentEditPrediction>,
288 next_pending_prediction_id: usize,
289 pending_predictions: ArrayVec<PendingPrediction, 2>,
290 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
291 last_prediction_refresh: Option<(EntityId, Instant)>,
292 cancelled_predictions: HashSet<usize>,
293 context: Entity<RelatedExcerptStore>,
294 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
295 user_actions: VecDeque<UserActionRecord>,
296 _subscription: gpui::Subscription,
297 copilot: Option<Entity<Copilot>>,
298}
299
300impl ProjectState {
301 fn record_user_action(&mut self, action: UserActionRecord) {
302 if self.user_actions.len() >= USER_ACTION_HISTORY_SIZE {
303 self.user_actions.pop_front();
304 }
305 self.user_actions.push_back(action);
306 }
307
308 pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
309 self.events
310 .iter()
311 .cloned()
312 .chain(
313 self.last_event
314 .as_ref()
315 .and_then(|event| event.finalize(&self.license_detection_watchers, cx)),
316 )
317 .collect()
318 }
319
320 pub fn events_split_by_pause(&self, cx: &App) -> Vec<StoredEvent> {
321 self.events
322 .iter()
323 .cloned()
324 .chain(self.last_event.as_ref().iter().flat_map(|event| {
325 let (one, two) = event.split_by_pause();
326 let one = one.finalize(&self.license_detection_watchers, cx);
327 let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
328 one.into_iter().chain(two)
329 }))
330 .collect()
331 }
332
333 fn cancel_pending_prediction(
334 &mut self,
335 pending_prediction: PendingPrediction,
336 cx: &mut Context<EditPredictionStore>,
337 ) {
338 self.cancelled_predictions.insert(pending_prediction.id);
339
340 cx.spawn(async move |this, cx| {
341 let Some(prediction_id) = pending_prediction.task.await else {
342 return;
343 };
344
345 this.update(cx, |this, _cx| {
346 this.reject_prediction(prediction_id, EditPredictionRejectReason::Canceled, false);
347 })
348 .ok();
349 })
350 .detach()
351 }
352
353 fn active_buffer(
354 &self,
355 project: &Entity<Project>,
356 cx: &App,
357 ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
358 let project = project.read(cx);
359 let active_path = project.path_for_entry(project.active_entry()?, cx)?;
360 let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
361 let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
362 Some((active_buffer, registered_buffer.last_position))
363 }
364}
365
366#[derive(Debug, Clone)]
367struct CurrentEditPrediction {
368 pub requested_by: PredictionRequestedBy,
369 pub prediction: EditPrediction,
370 pub was_shown: bool,
371 pub shown_with: Option<edit_prediction_types::SuggestionDisplayType>,
372}
373
374impl CurrentEditPrediction {
375 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
376 let Some(new_edits) = self
377 .prediction
378 .interpolate(&self.prediction.buffer.read(cx))
379 else {
380 return false;
381 };
382
383 if self.prediction.buffer != old_prediction.prediction.buffer {
384 return true;
385 }
386
387 let Some(old_edits) = old_prediction
388 .prediction
389 .interpolate(&old_prediction.prediction.buffer.read(cx))
390 else {
391 return true;
392 };
393
394 let requested_by_buffer_id = self.requested_by.buffer_id();
395
396 // This reduces the occurrence of UI thrash from replacing edits
397 //
398 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
399 if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
400 && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
401 && old_edits.len() == 1
402 && new_edits.len() == 1
403 {
404 let (old_range, old_text) = &old_edits[0];
405 let (new_range, new_text) = &new_edits[0];
406 new_range == old_range && new_text.starts_with(old_text.as_ref())
407 } else {
408 true
409 }
410 }
411}
412
413#[derive(Debug, Clone)]
414enum PredictionRequestedBy {
415 DiagnosticsUpdate,
416 Buffer(EntityId),
417}
418
419impl PredictionRequestedBy {
420 pub fn buffer_id(&self) -> Option<EntityId> {
421 match self {
422 PredictionRequestedBy::DiagnosticsUpdate => None,
423 PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
424 }
425 }
426}
427
428#[derive(Debug)]
429struct PendingPrediction {
430 id: usize,
431 task: Task<Option<EditPredictionId>>,
432}
433
434/// A prediction from the perspective of a buffer.
435#[derive(Debug)]
436enum BufferEditPrediction<'a> {
437 Local { prediction: &'a EditPrediction },
438 Jump { prediction: &'a EditPrediction },
439}
440
441#[cfg(test)]
442impl std::ops::Deref for BufferEditPrediction<'_> {
443 type Target = EditPrediction;
444
445 fn deref(&self) -> &Self::Target {
446 match self {
447 BufferEditPrediction::Local { prediction } => prediction,
448 BufferEditPrediction::Jump { prediction } => prediction,
449 }
450 }
451}
452
453struct RegisteredBuffer {
454 file: Option<Arc<dyn File>>,
455 snapshot: TextBufferSnapshot,
456 last_position: Option<Anchor>,
457 _subscriptions: [gpui::Subscription; 2],
458}
459
460#[derive(Clone)]
461struct LastEvent {
462 old_snapshot: TextBufferSnapshot,
463 new_snapshot: TextBufferSnapshot,
464 old_file: Option<Arc<dyn File>>,
465 new_file: Option<Arc<dyn File>>,
466 edit_range: Option<Range<Anchor>>,
467 snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
468 last_edit_time: Option<Instant>,
469}
470
471impl LastEvent {
472 pub fn finalize(
473 &self,
474 license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
475 cx: &App,
476 ) -> Option<StoredEvent> {
477 let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
478 let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
479
480 let in_open_source_repo =
481 [self.new_file.as_ref(), self.old_file.as_ref()]
482 .iter()
483 .all(|file| {
484 file.is_some_and(|file| {
485 license_detection_watchers
486 .get(&file.worktree_id(cx))
487 .is_some_and(|watcher| watcher.is_project_open_source())
488 })
489 });
490
491 let diff = compute_diff_between_snapshots(&self.old_snapshot, &self.new_snapshot)?;
492
493 if path == old_path && diff.is_empty() {
494 None
495 } else {
496 Some(StoredEvent {
497 event: Arc::new(zeta_prompt::Event::BufferChange {
498 old_path,
499 path,
500 diff,
501 in_open_source_repo,
502 // TODO: Actually detect if this edit was predicted or not
503 predicted: false,
504 }),
505 old_snapshot: self.old_snapshot.clone(),
506 })
507 }
508 }
509
510 pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
511 let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
512 return (self.clone(), None);
513 };
514
515 let before = LastEvent {
516 old_snapshot: self.old_snapshot.clone(),
517 new_snapshot: boundary_snapshot.clone(),
518 old_file: self.old_file.clone(),
519 new_file: self.new_file.clone(),
520 edit_range: None,
521 snapshot_after_last_editing_pause: None,
522 last_edit_time: self.last_edit_time,
523 };
524
525 let after = LastEvent {
526 old_snapshot: boundary_snapshot.clone(),
527 new_snapshot: self.new_snapshot.clone(),
528 old_file: self.old_file.clone(),
529 new_file: self.new_file.clone(),
530 edit_range: None,
531 snapshot_after_last_editing_pause: None,
532 last_edit_time: self.last_edit_time,
533 };
534
535 (before, Some(after))
536 }
537}
538
539pub(crate) fn compute_diff_between_snapshots(
540 old_snapshot: &TextBufferSnapshot,
541 new_snapshot: &TextBufferSnapshot,
542) -> Option<String> {
543 let edits: Vec<Edit<usize>> = new_snapshot
544 .edits_since::<usize>(&old_snapshot.version)
545 .collect();
546
547 let (first_edit, last_edit) = edits.first().zip(edits.last())?;
548
549 let old_start_point = old_snapshot.offset_to_point(first_edit.old.start);
550 let old_end_point = old_snapshot.offset_to_point(last_edit.old.end);
551 let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
552 let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
553
554 const CONTEXT_LINES: u32 = 3;
555
556 let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
557 let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
558 let old_context_end_row =
559 (old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
560 let new_context_end_row =
561 (new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
562
563 let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
564 let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
565 let old_end_line_offset = old_snapshot
566 .point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
567 let new_end_line_offset = new_snapshot
568 .point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
569 let old_edit_range = old_start_line_offset..old_end_line_offset;
570 let new_edit_range = new_start_line_offset..new_end_line_offset;
571
572 let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
573 let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
574
575 let diff = language::unified_diff_with_offsets(
576 &old_region_text,
577 &new_region_text,
578 old_context_start_row,
579 new_context_start_row,
580 );
581
582 Some(diff)
583}
584
585fn buffer_path_with_id_fallback(
586 file: Option<&Arc<dyn File>>,
587 snapshot: &TextBufferSnapshot,
588 cx: &App,
589) -> Arc<Path> {
590 if let Some(file) = file {
591 file.full_path(cx).into()
592 } else {
593 Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
594 }
595}
596
597impl EditPredictionStore {
598 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
599 cx.try_global::<EditPredictionStoreGlobal>()
600 .map(|global| global.0.clone())
601 }
602
603 pub fn global(
604 client: &Arc<Client>,
605 user_store: &Entity<UserStore>,
606 cx: &mut App,
607 ) -> Entity<Self> {
608 cx.try_global::<EditPredictionStoreGlobal>()
609 .map(|global| global.0.clone())
610 .unwrap_or_else(|| {
611 let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
612 cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
613 ep_store
614 })
615 }
616
617 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
618 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
619 let data_collection_choice = Self::load_data_collection_choice();
620
621 let llm_token = LlmApiToken::default();
622
623 let (reject_tx, reject_rx) = mpsc::unbounded();
624 cx.background_spawn({
625 let client = client.clone();
626 let llm_token = llm_token.clone();
627 let app_version = AppVersion::global(cx);
628 let background_executor = cx.background_executor().clone();
629 async move {
630 Self::handle_rejected_predictions(
631 reject_rx,
632 client,
633 llm_token,
634 app_version,
635 background_executor,
636 )
637 .await
638 }
639 })
640 .detach();
641
642 let this = Self {
643 projects: HashMap::default(),
644 client,
645 user_store,
646 options: DEFAULT_OPTIONS,
647 llm_token,
648 _llm_token_subscription: cx.subscribe(
649 &refresh_llm_token_listener,
650 |this, _listener, _event, cx| {
651 let client = this.client.clone();
652 let llm_token = this.llm_token.clone();
653 cx.spawn(async move |_this, _cx| {
654 llm_token.refresh(&client).await?;
655 anyhow::Ok(())
656 })
657 .detach_and_log_err(cx);
658 },
659 ),
660 update_required: false,
661 #[cfg(feature = "cli-support")]
662 eval_cache: None,
663 edit_prediction_model: EditPredictionModel::Zeta2 {
664 version: Default::default(),
665 },
666 sweep_ai: SweepAi::new(cx),
667 mercury: Mercury::new(cx),
668
669 data_collection_choice,
670 reject_predictions_tx: reject_tx,
671 rated_predictions: Default::default(),
672 shown_predictions: Default::default(),
673 custom_predict_edits_url: match env::var("ZED_PREDICT_EDITS_URL") {
674 Ok(custom_url) => Url::parse(&custom_url).log_err().map(Into::into),
675 Err(_) => {
676 if *USE_OLLAMA {
677 Some(
678 Url::parse("http://localhost:11434/v1/chat/completions")
679 .unwrap()
680 .into(),
681 )
682 } else {
683 None
684 }
685 }
686 },
687 };
688
689 this
690 }
691
692 #[cfg(test)]
693 pub fn set_custom_predict_edits_url(&mut self, url: Url) {
694 self.custom_predict_edits_url = Some(url.into());
695 }
696
697 pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
698 self.edit_prediction_model = model;
699 }
700
701 pub fn has_sweep_api_token(&self, cx: &App) -> bool {
702 self.sweep_ai.api_token.read(cx).has_key()
703 }
704
705 pub fn has_mercury_api_token(&self, cx: &App) -> bool {
706 self.mercury.api_token.read(cx).has_key()
707 }
708
709 #[cfg(feature = "cli-support")]
710 pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
711 self.eval_cache = Some(cache);
712 }
713
714 pub fn options(&self) -> &ZetaOptions {
715 &self.options
716 }
717
718 pub fn set_options(&mut self, options: ZetaOptions) {
719 self.options = options;
720 }
721
722 pub fn clear_history(&mut self) {
723 for project_state in self.projects.values_mut() {
724 project_state.events.clear();
725 project_state.last_event.take();
726 }
727 }
728
729 pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
730 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
731 project_state.events.clear();
732 project_state.last_event.take();
733 }
734 }
735
736 pub fn edit_history_for_project(
737 &self,
738 project: &Entity<Project>,
739 cx: &App,
740 ) -> Vec<StoredEvent> {
741 self.projects
742 .get(&project.entity_id())
743 .map(|project_state| project_state.events(cx))
744 .unwrap_or_default()
745 }
746
747 pub fn edit_history_for_project_with_pause_split_last_event(
748 &self,
749 project: &Entity<Project>,
750 cx: &App,
751 ) -> Vec<StoredEvent> {
752 self.projects
753 .get(&project.entity_id())
754 .map(|project_state| project_state.events_split_by_pause(cx))
755 .unwrap_or_default()
756 }
757
758 pub fn context_for_project<'a>(
759 &'a self,
760 project: &Entity<Project>,
761 cx: &'a mut App,
762 ) -> Vec<RelatedFile> {
763 self.projects
764 .get(&project.entity_id())
765 .map(|project| {
766 project
767 .context
768 .update(cx, |context, cx| context.related_files(cx))
769 })
770 .unwrap_or_default()
771 }
772
773 pub fn copilot_for_project(&self, project: &Entity<Project>) -> Option<Entity<Copilot>> {
774 self.projects
775 .get(&project.entity_id())
776 .and_then(|project| project.copilot.clone())
777 }
778
779 pub fn start_copilot_for_project(
780 &mut self,
781 project: &Entity<Project>,
782 cx: &mut Context<Self>,
783 ) -> Option<Entity<Copilot>> {
784 let state = self.get_or_init_project(project, cx);
785
786 if state.copilot.is_some() {
787 return state.copilot.clone();
788 }
789 let _project = project.clone();
790 let project = project.read(cx);
791
792 let node = project.node_runtime().cloned();
793 if let Some(node) = node {
794 let next_id = project.languages().next_language_server_id();
795 let fs = project.fs().clone();
796
797 let copilot = cx.new(|cx| Copilot::new(Some(_project), next_id, fs, node, cx));
798 state.copilot = Some(copilot.clone());
799 Some(copilot)
800 } else {
801 None
802 }
803 }
804
805 pub fn context_for_project_with_buffers<'a>(
806 &'a self,
807 project: &Entity<Project>,
808 cx: &'a mut App,
809 ) -> Vec<(RelatedFile, Entity<Buffer>)> {
810 self.projects
811 .get(&project.entity_id())
812 .map(|project| {
813 project
814 .context
815 .update(cx, |context, cx| context.related_files_with_buffers(cx))
816 })
817 .unwrap_or_default()
818 }
819
820 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
821 if matches!(
822 self.edit_prediction_model,
823 EditPredictionModel::Zeta2 { .. }
824 ) {
825 self.user_store.read(cx).edit_prediction_usage()
826 } else {
827 None
828 }
829 }
830
831 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
832 self.get_or_init_project(project, cx);
833 }
834
835 pub fn register_buffer(
836 &mut self,
837 buffer: &Entity<Buffer>,
838 project: &Entity<Project>,
839 cx: &mut Context<Self>,
840 ) {
841 let project_state = self.get_or_init_project(project, cx);
842 Self::register_buffer_impl(project_state, buffer, project, cx);
843 }
844
845 fn get_or_init_project(
846 &mut self,
847 project: &Entity<Project>,
848 cx: &mut Context<Self>,
849 ) -> &mut ProjectState {
850 let entity_id = project.entity_id();
851 self.projects
852 .entry(entity_id)
853 .or_insert_with(|| ProjectState {
854 context: {
855 let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
856 cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
857 this.handle_excerpt_store_event(entity_id, event);
858 })
859 .detach();
860 related_excerpt_store
861 },
862 events: VecDeque::new(),
863 last_event: None,
864 recent_paths: VecDeque::new(),
865 debug_tx: None,
866 registered_buffers: HashMap::default(),
867 current_prediction: None,
868 cancelled_predictions: HashSet::default(),
869 pending_predictions: ArrayVec::new(),
870 next_pending_prediction_id: 0,
871 last_prediction_refresh: None,
872 license_detection_watchers: HashMap::default(),
873 user_actions: VecDeque::with_capacity(USER_ACTION_HISTORY_SIZE),
874 _subscription: cx.subscribe(&project, Self::handle_project_event),
875 copilot: None,
876 })
877 }
878
879 pub fn remove_project(&mut self, project: &Entity<Project>) {
880 self.projects.remove(&project.entity_id());
881 }
882
883 fn handle_excerpt_store_event(
884 &mut self,
885 project_entity_id: EntityId,
886 event: &RelatedExcerptStoreEvent,
887 ) {
888 if let Some(project_state) = self.projects.get(&project_entity_id) {
889 if let Some(debug_tx) = project_state.debug_tx.clone() {
890 match event {
891 RelatedExcerptStoreEvent::StartedRefresh => {
892 debug_tx
893 .unbounded_send(DebugEvent::ContextRetrievalStarted(
894 ContextRetrievalStartedDebugEvent {
895 project_entity_id: project_entity_id,
896 timestamp: Instant::now(),
897 search_prompt: String::new(),
898 },
899 ))
900 .ok();
901 }
902 RelatedExcerptStoreEvent::FinishedRefresh {
903 cache_hit_count,
904 cache_miss_count,
905 mean_definition_latency,
906 max_definition_latency,
907 } => {
908 debug_tx
909 .unbounded_send(DebugEvent::ContextRetrievalFinished(
910 ContextRetrievalFinishedDebugEvent {
911 project_entity_id: project_entity_id,
912 timestamp: Instant::now(),
913 metadata: vec![
914 (
915 "Cache Hits",
916 format!(
917 "{}/{}",
918 cache_hit_count,
919 cache_hit_count + cache_miss_count
920 )
921 .into(),
922 ),
923 (
924 "Max LSP Time",
925 format!("{} ms", max_definition_latency.as_millis())
926 .into(),
927 ),
928 (
929 "Mean LSP Time",
930 format!("{} ms", mean_definition_latency.as_millis())
931 .into(),
932 ),
933 ],
934 },
935 ))
936 .ok();
937 }
938 }
939 }
940 }
941 }
942
943 pub fn debug_info(
944 &mut self,
945 project: &Entity<Project>,
946 cx: &mut Context<Self>,
947 ) -> mpsc::UnboundedReceiver<DebugEvent> {
948 let project_state = self.get_or_init_project(project, cx);
949 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
950 project_state.debug_tx = Some(debug_watch_tx);
951 debug_watch_rx
952 }
953
954 fn handle_project_event(
955 &mut self,
956 project: Entity<Project>,
957 event: &project::Event,
958 cx: &mut Context<Self>,
959 ) {
960 // TODO [zeta2] init with recent paths
961 match event {
962 project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
963 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
964 return;
965 };
966 let path = project.read(cx).path_for_entry(*active_entry_id, cx);
967 if let Some(path) = path {
968 if let Some(ix) = project_state
969 .recent_paths
970 .iter()
971 .position(|probe| probe == &path)
972 {
973 project_state.recent_paths.remove(ix);
974 }
975 project_state.recent_paths.push_front(path);
976 }
977 }
978 project::Event::DiagnosticsUpdated { .. } => {
979 if cx.has_flag::<Zeta2FeatureFlag>() {
980 self.refresh_prediction_from_diagnostics(project, cx);
981 }
982 }
983 _ => (),
984 }
985 }
986
987 fn register_buffer_impl<'a>(
988 project_state: &'a mut ProjectState,
989 buffer: &Entity<Buffer>,
990 project: &Entity<Project>,
991 cx: &mut Context<Self>,
992 ) -> &'a mut RegisteredBuffer {
993 let buffer_id = buffer.entity_id();
994
995 if let Some(file) = buffer.read(cx).file() {
996 let worktree_id = file.worktree_id(cx);
997 if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
998 project_state
999 .license_detection_watchers
1000 .entry(worktree_id)
1001 .or_insert_with(|| {
1002 let project_entity_id = project.entity_id();
1003 cx.observe_release(&worktree, move |this, _worktree, _cx| {
1004 let Some(project_state) = this.projects.get_mut(&project_entity_id)
1005 else {
1006 return;
1007 };
1008 project_state
1009 .license_detection_watchers
1010 .remove(&worktree_id);
1011 })
1012 .detach();
1013 Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
1014 });
1015 }
1016 }
1017
1018 match project_state.registered_buffers.entry(buffer_id) {
1019 hash_map::Entry::Occupied(entry) => entry.into_mut(),
1020 hash_map::Entry::Vacant(entry) => {
1021 let buf = buffer.read(cx);
1022 let snapshot = buf.text_snapshot();
1023 let file = buf.file().cloned();
1024 let project_entity_id = project.entity_id();
1025 entry.insert(RegisteredBuffer {
1026 snapshot,
1027 file,
1028 last_position: None,
1029 _subscriptions: [
1030 cx.subscribe(buffer, {
1031 let project = project.downgrade();
1032 move |this, buffer, event, cx| {
1033 if let language::BufferEvent::Edited = event
1034 && let Some(project) = project.upgrade()
1035 {
1036 this.report_changes_for_buffer(&buffer, &project, cx);
1037 }
1038 }
1039 }),
1040 cx.observe_release(buffer, move |this, _buffer, _cx| {
1041 let Some(project_state) = this.projects.get_mut(&project_entity_id)
1042 else {
1043 return;
1044 };
1045 project_state.registered_buffers.remove(&buffer_id);
1046 }),
1047 ],
1048 })
1049 }
1050 }
1051 }
1052
1053 fn report_changes_for_buffer(
1054 &mut self,
1055 buffer: &Entity<Buffer>,
1056 project: &Entity<Project>,
1057 cx: &mut Context<Self>,
1058 ) {
1059 let project_state = self.get_or_init_project(project, cx);
1060 let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
1061
1062 let buf = buffer.read(cx);
1063 let new_file = buf.file().cloned();
1064 let new_snapshot = buf.text_snapshot();
1065 if new_snapshot.version == registered_buffer.snapshot.version {
1066 return;
1067 }
1068
1069 let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
1070 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1071 let mut num_edits = 0usize;
1072 let mut total_deleted = 0usize;
1073 let mut total_inserted = 0usize;
1074 let mut edit_range: Option<Range<Anchor>> = None;
1075 let mut last_offset: Option<usize> = None;
1076
1077 for (edit, anchor_range) in
1078 new_snapshot.anchored_edits_since::<usize>(&old_snapshot.version)
1079 {
1080 num_edits += 1;
1081 total_deleted += edit.old.len();
1082 total_inserted += edit.new.len();
1083 edit_range = Some(match edit_range {
1084 None => anchor_range,
1085 Some(acc) => acc.start..anchor_range.end,
1086 });
1087 last_offset = Some(edit.new.end);
1088 }
1089
1090 if num_edits > 0 {
1091 let action_type = match (total_deleted, total_inserted, num_edits) {
1092 (0, ins, n) if ins == n => UserActionType::InsertChar,
1093 (0, _, _) => UserActionType::InsertSelection,
1094 (del, 0, n) if del == n => UserActionType::DeleteChar,
1095 (_, 0, _) => UserActionType::DeleteSelection,
1096 (_, ins, n) if ins == n => UserActionType::InsertChar,
1097 (_, _, _) => UserActionType::InsertSelection,
1098 };
1099
1100 if let Some(offset) = last_offset {
1101 let point = new_snapshot.offset_to_point(offset);
1102 let timestamp_epoch_ms = SystemTime::now()
1103 .duration_since(UNIX_EPOCH)
1104 .map(|d| d.as_millis() as u64)
1105 .unwrap_or(0);
1106 project_state.record_user_action(UserActionRecord {
1107 action_type,
1108 buffer_id: buffer.entity_id(),
1109 line_number: point.row,
1110 offset,
1111 timestamp_epoch_ms,
1112 });
1113 }
1114 }
1115
1116 let events = &mut project_state.events;
1117
1118 let now = cx.background_executor().now();
1119 if let Some(last_event) = project_state.last_event.as_mut() {
1120 let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
1121 == last_event.new_snapshot.remote_id()
1122 && old_snapshot.version == last_event.new_snapshot.version;
1123
1124 let should_coalesce = is_next_snapshot_of_same_buffer
1125 && edit_range
1126 .as_ref()
1127 .zip(last_event.edit_range.as_ref())
1128 .is_some_and(|(a, b)| {
1129 let a = a.to_point(&new_snapshot);
1130 let b = b.to_point(&new_snapshot);
1131 if a.start > b.end {
1132 a.start.row.abs_diff(b.end.row) <= CHANGE_GROUPING_LINE_SPAN
1133 } else if b.start > a.end {
1134 b.start.row.abs_diff(a.end.row) <= CHANGE_GROUPING_LINE_SPAN
1135 } else {
1136 true
1137 }
1138 });
1139
1140 if should_coalesce {
1141 let pause_elapsed = last_event
1142 .last_edit_time
1143 .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
1144 .unwrap_or(false);
1145 if pause_elapsed {
1146 last_event.snapshot_after_last_editing_pause =
1147 Some(last_event.new_snapshot.clone());
1148 }
1149
1150 last_event.edit_range = edit_range;
1151 last_event.new_snapshot = new_snapshot;
1152 last_event.last_edit_time = Some(now);
1153 return;
1154 }
1155 }
1156
1157 if let Some(event) = project_state.last_event.take() {
1158 if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1159 if events.len() + 1 >= EVENT_COUNT_MAX {
1160 events.pop_front();
1161 }
1162 events.push_back(event);
1163 }
1164 }
1165
1166 project_state.last_event = Some(LastEvent {
1167 old_file,
1168 new_file,
1169 old_snapshot,
1170 new_snapshot,
1171 edit_range,
1172 snapshot_after_last_editing_pause: None,
1173 last_edit_time: Some(now),
1174 });
1175 }
1176
1177 fn prediction_at(
1178 &mut self,
1179 buffer: &Entity<Buffer>,
1180 position: Option<language::Anchor>,
1181 project: &Entity<Project>,
1182 cx: &App,
1183 ) -> Option<BufferEditPrediction<'_>> {
1184 let project_state = self.projects.get_mut(&project.entity_id())?;
1185 if let Some(position) = position
1186 && let Some(buffer) = project_state
1187 .registered_buffers
1188 .get_mut(&buffer.entity_id())
1189 {
1190 buffer.last_position = Some(position);
1191 }
1192
1193 let CurrentEditPrediction {
1194 requested_by,
1195 prediction,
1196 ..
1197 } = project_state.current_prediction.as_ref()?;
1198
1199 if prediction.targets_buffer(buffer.read(cx)) {
1200 Some(BufferEditPrediction::Local { prediction })
1201 } else {
1202 let show_jump = match requested_by {
1203 PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1204 requested_by_buffer_id == &buffer.entity_id()
1205 }
1206 PredictionRequestedBy::DiagnosticsUpdate => true,
1207 };
1208
1209 if show_jump {
1210 Some(BufferEditPrediction::Jump { prediction })
1211 } else {
1212 None
1213 }
1214 }
1215 }
1216
1217 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1218 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1219 return;
1220 };
1221
1222 let Some(current_prediction) = project_state.current_prediction.take() else {
1223 return;
1224 };
1225
1226 for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1227 project_state.cancel_pending_prediction(pending_prediction, cx);
1228 }
1229
1230 match self.edit_prediction_model {
1231 EditPredictionModel::Sweep => {
1232 sweep_ai::edit_prediction_accepted(self, current_prediction, cx)
1233 }
1234 EditPredictionModel::Mercury => {}
1235 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 { .. } => {
1236 zeta2::edit_prediction_accepted(self, current_prediction, cx)
1237 }
1238 }
1239 }
1240
1241 async fn handle_rejected_predictions(
1242 rx: UnboundedReceiver<EditPredictionRejection>,
1243 client: Arc<Client>,
1244 llm_token: LlmApiToken,
1245 app_version: Version,
1246 background_executor: BackgroundExecutor,
1247 ) {
1248 let mut rx = std::pin::pin!(rx.peekable());
1249 let mut batched = Vec::new();
1250
1251 while let Some(rejection) = rx.next().await {
1252 batched.push(rejection);
1253
1254 if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1255 select_biased! {
1256 next = rx.as_mut().peek().fuse() => {
1257 if next.is_some() {
1258 continue;
1259 }
1260 }
1261 () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1262 }
1263 }
1264
1265 let url = client
1266 .http_client()
1267 .build_zed_llm_url("/predict_edits/reject", &[])
1268 .unwrap();
1269
1270 let flush_count = batched
1271 .len()
1272 // in case items have accumulated after failure
1273 .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1274 let start = batched.len() - flush_count;
1275
1276 let body = RejectEditPredictionsBodyRef {
1277 rejections: &batched[start..],
1278 };
1279
1280 let result = Self::send_api_request::<()>(
1281 |builder| {
1282 let req = builder
1283 .uri(url.as_ref())
1284 .body(serde_json::to_string(&body)?.into());
1285 anyhow::Ok(req?)
1286 },
1287 client.clone(),
1288 llm_token.clone(),
1289 app_version.clone(),
1290 true,
1291 )
1292 .await;
1293
1294 if result.log_err().is_some() {
1295 batched.drain(start..);
1296 }
1297 }
1298 }
1299
1300 fn reject_current_prediction(
1301 &mut self,
1302 reason: EditPredictionRejectReason,
1303 project: &Entity<Project>,
1304 ) {
1305 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1306 project_state.pending_predictions.clear();
1307 if let Some(prediction) = project_state.current_prediction.take() {
1308 self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown);
1309 }
1310 };
1311 }
1312
1313 fn did_show_current_prediction(
1314 &mut self,
1315 project: &Entity<Project>,
1316 display_type: edit_prediction_types::SuggestionDisplayType,
1317 cx: &mut Context<Self>,
1318 ) {
1319 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1320 return;
1321 };
1322
1323 let Some(current_prediction) = project_state.current_prediction.as_mut() else {
1324 return;
1325 };
1326
1327 let is_jump = display_type == edit_prediction_types::SuggestionDisplayType::Jump;
1328 let previous_shown_with = current_prediction.shown_with;
1329
1330 if previous_shown_with.is_none() || !is_jump {
1331 current_prediction.shown_with = Some(display_type);
1332 }
1333
1334 let is_first_non_jump_show = !current_prediction.was_shown && !is_jump;
1335
1336 if is_first_non_jump_show {
1337 current_prediction.was_shown = true;
1338 }
1339
1340 let display_type_changed = previous_shown_with != Some(display_type);
1341
1342 if self.edit_prediction_model == EditPredictionModel::Sweep && display_type_changed {
1343 sweep_ai::edit_prediction_shown(
1344 &self.sweep_ai,
1345 self.client.clone(),
1346 ¤t_prediction.prediction,
1347 display_type,
1348 cx,
1349 );
1350 }
1351
1352 if is_first_non_jump_show {
1353 self.shown_predictions
1354 .push_front(current_prediction.prediction.clone());
1355 if self.shown_predictions.len() > 50 {
1356 let completion = self.shown_predictions.pop_back().unwrap();
1357 self.rated_predictions.remove(&completion.id);
1358 }
1359 }
1360 }
1361
1362 fn reject_prediction(
1363 &mut self,
1364 prediction_id: EditPredictionId,
1365 reason: EditPredictionRejectReason,
1366 was_shown: bool,
1367 ) {
1368 match self.edit_prediction_model {
1369 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 { .. } => {
1370 if self.custom_predict_edits_url.is_some() {
1371 return;
1372 }
1373 }
1374 EditPredictionModel::Sweep | EditPredictionModel::Mercury => return,
1375 }
1376
1377 self.reject_predictions_tx
1378 .unbounded_send(EditPredictionRejection {
1379 request_id: prediction_id.to_string(),
1380 reason,
1381 was_shown,
1382 })
1383 .log_err();
1384 }
1385
1386 fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1387 self.projects
1388 .get(&project.entity_id())
1389 .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1390 }
1391
1392 pub fn refresh_prediction_from_buffer(
1393 &mut self,
1394 project: Entity<Project>,
1395 buffer: Entity<Buffer>,
1396 position: language::Anchor,
1397 cx: &mut Context<Self>,
1398 ) {
1399 self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
1400 let Some(request_task) = this
1401 .update(cx, |this, cx| {
1402 this.request_prediction(
1403 &project,
1404 &buffer,
1405 position,
1406 PredictEditsRequestTrigger::Other,
1407 cx,
1408 )
1409 })
1410 .log_err()
1411 else {
1412 return Task::ready(anyhow::Ok(None));
1413 };
1414
1415 cx.spawn(async move |_cx| {
1416 request_task.await.map(|prediction_result| {
1417 prediction_result.map(|prediction_result| {
1418 (
1419 prediction_result,
1420 PredictionRequestedBy::Buffer(buffer.entity_id()),
1421 )
1422 })
1423 })
1424 })
1425 })
1426 }
1427
1428 pub fn refresh_prediction_from_diagnostics(
1429 &mut self,
1430 project: Entity<Project>,
1431 cx: &mut Context<Self>,
1432 ) {
1433 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1434 return;
1435 };
1436
1437 // Prefer predictions from buffer
1438 if project_state.current_prediction.is_some() {
1439 return;
1440 };
1441
1442 self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
1443 let Some((active_buffer, snapshot, cursor_point)) = this
1444 .read_with(cx, |this, cx| {
1445 let project_state = this.projects.get(&project.entity_id())?;
1446 let (buffer, position) = project_state.active_buffer(&project, cx)?;
1447 let snapshot = buffer.read(cx).snapshot();
1448
1449 if !Self::predictions_enabled_at(&snapshot, position, cx) {
1450 return None;
1451 }
1452
1453 let cursor_point = position
1454 .map(|pos| pos.to_point(&snapshot))
1455 .unwrap_or_default();
1456
1457 Some((buffer, snapshot, cursor_point))
1458 })
1459 .log_err()
1460 .flatten()
1461 else {
1462 return Task::ready(anyhow::Ok(None));
1463 };
1464
1465 cx.spawn(async move |cx| {
1466 let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1467 active_buffer,
1468 &snapshot,
1469 Default::default(),
1470 cursor_point,
1471 &project,
1472 cx,
1473 )
1474 .await?
1475 else {
1476 return anyhow::Ok(None);
1477 };
1478
1479 let Some(prediction_result) = this
1480 .update(cx, |this, cx| {
1481 this.request_prediction(
1482 &project,
1483 &jump_buffer,
1484 jump_position,
1485 PredictEditsRequestTrigger::Diagnostics,
1486 cx,
1487 )
1488 })?
1489 .await?
1490 else {
1491 return anyhow::Ok(None);
1492 };
1493
1494 this.update(cx, |this, cx| {
1495 Some((
1496 if this
1497 .get_or_init_project(&project, cx)
1498 .current_prediction
1499 .is_none()
1500 {
1501 prediction_result
1502 } else {
1503 EditPredictionResult {
1504 id: prediction_result.id,
1505 prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1506 }
1507 },
1508 PredictionRequestedBy::DiagnosticsUpdate,
1509 ))
1510 })
1511 })
1512 });
1513 }
1514
1515 fn predictions_enabled_at(
1516 snapshot: &BufferSnapshot,
1517 position: Option<language::Anchor>,
1518 cx: &App,
1519 ) -> bool {
1520 let file = snapshot.file();
1521 let all_settings = all_language_settings(file, cx);
1522 if !all_settings.show_edit_predictions(snapshot.language(), cx)
1523 || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
1524 {
1525 return false;
1526 }
1527
1528 if let Some(last_position) = position {
1529 let settings = snapshot.settings_at(last_position, cx);
1530
1531 if !settings.edit_predictions_disabled_in.is_empty()
1532 && let Some(scope) = snapshot.language_scope_at(last_position)
1533 && let Some(scope_name) = scope.override_name()
1534 && settings
1535 .edit_predictions_disabled_in
1536 .iter()
1537 .any(|s| s == scope_name)
1538 {
1539 return false;
1540 }
1541 }
1542
1543 true
1544 }
1545
1546 #[cfg(not(test))]
1547 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1548 #[cfg(test)]
1549 pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
1550
1551 fn queue_prediction_refresh(
1552 &mut self,
1553 project: Entity<Project>,
1554 throttle_entity: EntityId,
1555 cx: &mut Context<Self>,
1556 do_refresh: impl FnOnce(
1557 WeakEntity<Self>,
1558 &mut AsyncApp,
1559 )
1560 -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1561 + 'static,
1562 ) {
1563 let project_state = self.get_or_init_project(&project, cx);
1564 let pending_prediction_id = project_state.next_pending_prediction_id;
1565 project_state.next_pending_prediction_id += 1;
1566 let last_request = project_state.last_prediction_refresh;
1567
1568 let task = cx.spawn(async move |this, cx| {
1569 if let Some((last_entity, last_timestamp)) = last_request
1570 && throttle_entity == last_entity
1571 && let Some(timeout) =
1572 (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
1573 {
1574 cx.background_executor().timer(timeout).await;
1575 }
1576
1577 // If this task was cancelled before the throttle timeout expired,
1578 // do not perform a request.
1579 let mut is_cancelled = true;
1580 this.update(cx, |this, cx| {
1581 let project_state = this.get_or_init_project(&project, cx);
1582 if !project_state
1583 .cancelled_predictions
1584 .remove(&pending_prediction_id)
1585 {
1586 project_state.last_prediction_refresh = Some((throttle_entity, Instant::now()));
1587 is_cancelled = false;
1588 }
1589 })
1590 .ok();
1591 if is_cancelled {
1592 return None;
1593 }
1594
1595 let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
1596 let new_prediction_id = new_prediction_result
1597 .as_ref()
1598 .map(|(prediction, _)| prediction.id.clone());
1599
1600 // When a prediction completes, remove it from the pending list, and cancel
1601 // any pending predictions that were enqueued before it.
1602 this.update(cx, |this, cx| {
1603 let project_state = this.get_or_init_project(&project, cx);
1604
1605 let is_cancelled = project_state
1606 .cancelled_predictions
1607 .remove(&pending_prediction_id);
1608
1609 let new_current_prediction = if !is_cancelled
1610 && let Some((prediction_result, requested_by)) = new_prediction_result
1611 {
1612 match prediction_result.prediction {
1613 Ok(prediction) => {
1614 let new_prediction = CurrentEditPrediction {
1615 requested_by,
1616 prediction,
1617 was_shown: false,
1618 shown_with: None,
1619 };
1620
1621 if let Some(current_prediction) =
1622 project_state.current_prediction.as_ref()
1623 {
1624 if new_prediction.should_replace_prediction(¤t_prediction, cx)
1625 {
1626 this.reject_current_prediction(
1627 EditPredictionRejectReason::Replaced,
1628 &project,
1629 );
1630
1631 Some(new_prediction)
1632 } else {
1633 this.reject_prediction(
1634 new_prediction.prediction.id,
1635 EditPredictionRejectReason::CurrentPreferred,
1636 false,
1637 );
1638 None
1639 }
1640 } else {
1641 Some(new_prediction)
1642 }
1643 }
1644 Err(reject_reason) => {
1645 this.reject_prediction(prediction_result.id, reject_reason, false);
1646 None
1647 }
1648 }
1649 } else {
1650 None
1651 };
1652
1653 let project_state = this.get_or_init_project(&project, cx);
1654
1655 if let Some(new_prediction) = new_current_prediction {
1656 project_state.current_prediction = Some(new_prediction);
1657 }
1658
1659 let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
1660 for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1661 if pending_prediction.id == pending_prediction_id {
1662 pending_predictions.remove(ix);
1663 for pending_prediction in pending_predictions.drain(0..ix) {
1664 project_state.cancel_pending_prediction(pending_prediction, cx)
1665 }
1666 break;
1667 }
1668 }
1669 this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
1670 cx.notify();
1671 })
1672 .ok();
1673
1674 new_prediction_id
1675 });
1676
1677 if project_state.pending_predictions.len() <= 1 {
1678 project_state.pending_predictions.push(PendingPrediction {
1679 id: pending_prediction_id,
1680 task,
1681 });
1682 } else if project_state.pending_predictions.len() == 2 {
1683 let pending_prediction = project_state.pending_predictions.pop().unwrap();
1684 project_state.pending_predictions.push(PendingPrediction {
1685 id: pending_prediction_id,
1686 task,
1687 });
1688 project_state.cancel_pending_prediction(pending_prediction, cx);
1689 }
1690 }
1691
1692 pub fn request_prediction(
1693 &mut self,
1694 project: &Entity<Project>,
1695 active_buffer: &Entity<Buffer>,
1696 position: language::Anchor,
1697 trigger: PredictEditsRequestTrigger,
1698 cx: &mut Context<Self>,
1699 ) -> Task<Result<Option<EditPredictionResult>>> {
1700 self.request_prediction_internal(
1701 project.clone(),
1702 active_buffer.clone(),
1703 position,
1704 trigger,
1705 cx.has_flag::<Zeta2FeatureFlag>(),
1706 cx,
1707 )
1708 }
1709
1710 fn request_prediction_internal(
1711 &mut self,
1712 project: Entity<Project>,
1713 active_buffer: Entity<Buffer>,
1714 position: language::Anchor,
1715 trigger: PredictEditsRequestTrigger,
1716 allow_jump: bool,
1717 cx: &mut Context<Self>,
1718 ) -> Task<Result<Option<EditPredictionResult>>> {
1719 const DIAGNOSTIC_LINES_RANGE: u32 = 20;
1720
1721 self.get_or_init_project(&project, cx);
1722 let project_state = self.projects.get(&project.entity_id()).unwrap();
1723 let stored_events = project_state.events(cx);
1724 let has_events = !stored_events.is_empty();
1725 let events: Vec<Arc<zeta_prompt::Event>> =
1726 stored_events.into_iter().map(|e| e.event).collect();
1727 let debug_tx = project_state.debug_tx.clone();
1728
1729 let snapshot = active_buffer.read(cx).snapshot();
1730 let cursor_point = position.to_point(&snapshot);
1731 let current_offset = position.to_offset(&snapshot);
1732
1733 let mut user_actions: Vec<UserActionRecord> =
1734 project_state.user_actions.iter().cloned().collect();
1735
1736 if let Some(last_action) = user_actions.last() {
1737 if last_action.buffer_id == active_buffer.entity_id()
1738 && current_offset != last_action.offset
1739 {
1740 let timestamp_epoch_ms = SystemTime::now()
1741 .duration_since(UNIX_EPOCH)
1742 .map(|d| d.as_millis() as u64)
1743 .unwrap_or(0);
1744 user_actions.push(UserActionRecord {
1745 action_type: UserActionType::CursorMovement,
1746 buffer_id: active_buffer.entity_id(),
1747 line_number: cursor_point.row,
1748 offset: current_offset,
1749 timestamp_epoch_ms,
1750 });
1751 }
1752 }
1753 let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1754 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1755 let diagnostic_search_range =
1756 Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1757
1758 let related_files = self.context_for_project(&project, cx);
1759
1760 let inputs = EditPredictionModelInput {
1761 project: project.clone(),
1762 buffer: active_buffer.clone(),
1763 snapshot: snapshot.clone(),
1764 position,
1765 events,
1766 related_files,
1767 recent_paths: project_state.recent_paths.clone(),
1768 trigger,
1769 diagnostic_search_range: diagnostic_search_range.clone(),
1770 debug_tx,
1771 user_actions,
1772 };
1773
1774 let can_collect_example = snapshot
1775 .file()
1776 .is_some_and(|file| self.can_collect_file(&project, file, cx))
1777 && self.can_collect_events(&inputs.events, cx);
1778
1779 if can_collect_example && should_sample_edit_prediction_example_capture(cx) {
1780 let events_for_capture =
1781 self.edit_history_for_project_with_pause_split_last_event(&project, cx);
1782 if let Some(example_task) = capture_example::capture_example(
1783 project.clone(),
1784 active_buffer.clone(),
1785 position,
1786 events_for_capture,
1787 false,
1788 cx,
1789 ) {
1790 cx.spawn(async move |_this, _cx| {
1791 let example = example_task.await?;
1792 telemetry::event!("Edit Prediction Example Captured", example = example);
1793 anyhow::Ok(())
1794 })
1795 .detach_and_log_err(cx);
1796 }
1797 }
1798 let task = match self.edit_prediction_model {
1799 EditPredictionModel::Zeta1 => {
1800 if should_send_testing_zeta2_request() {
1801 let mut zeta2_inputs = inputs.clone();
1802 zeta2_inputs.trigger = PredictEditsRequestTrigger::Testing;
1803 zeta2::request_prediction_with_zeta2(
1804 self,
1805 zeta2_inputs,
1806 Default::default(),
1807 cx,
1808 )
1809 .detach();
1810 }
1811 zeta1::request_prediction_with_zeta1(self, inputs, cx)
1812 }
1813 EditPredictionModel::Zeta2 { version } => {
1814 zeta2::request_prediction_with_zeta2(self, inputs, version, cx)
1815 }
1816 EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
1817 EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
1818 };
1819
1820 cx.spawn(async move |this, cx| {
1821 let prediction = task.await?;
1822
1823 if prediction.is_none() && allow_jump {
1824 let cursor_point = position.to_point(&snapshot);
1825 if has_events
1826 && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1827 active_buffer.clone(),
1828 &snapshot,
1829 diagnostic_search_range,
1830 cursor_point,
1831 &project,
1832 cx,
1833 )
1834 .await?
1835 {
1836 return this
1837 .update(cx, |this, cx| {
1838 this.request_prediction_internal(
1839 project,
1840 jump_buffer,
1841 jump_position,
1842 trigger,
1843 false,
1844 cx,
1845 )
1846 })?
1847 .await;
1848 }
1849
1850 return anyhow::Ok(None);
1851 }
1852
1853 Ok(prediction)
1854 })
1855 }
1856
1857 async fn next_diagnostic_location(
1858 active_buffer: Entity<Buffer>,
1859 active_buffer_snapshot: &BufferSnapshot,
1860 active_buffer_diagnostic_search_range: Range<Point>,
1861 active_buffer_cursor_point: Point,
1862 project: &Entity<Project>,
1863 cx: &mut AsyncApp,
1864 ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1865 // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
1866 let mut jump_location = active_buffer_snapshot
1867 .diagnostic_groups(None)
1868 .into_iter()
1869 .filter_map(|(_, group)| {
1870 let range = &group.entries[group.primary_ix]
1871 .range
1872 .to_point(&active_buffer_snapshot);
1873 if range.overlaps(&active_buffer_diagnostic_search_range) {
1874 None
1875 } else {
1876 Some(range.start)
1877 }
1878 })
1879 .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
1880 .map(|position| {
1881 (
1882 active_buffer.clone(),
1883 active_buffer_snapshot.anchor_before(position),
1884 )
1885 });
1886
1887 if jump_location.is_none() {
1888 let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
1889 let file = buffer.file()?;
1890
1891 Some(ProjectPath {
1892 worktree_id: file.worktree_id(cx),
1893 path: file.path().clone(),
1894 })
1895 });
1896
1897 let buffer_task = project.update(cx, |project, cx| {
1898 let (path, _, _) = project
1899 .diagnostic_summaries(false, cx)
1900 .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
1901 .max_by_key(|(path, _, _)| {
1902 // find the buffer with errors that shares most parent directories
1903 path.path
1904 .components()
1905 .zip(
1906 active_buffer_path
1907 .as_ref()
1908 .map(|p| p.path.components())
1909 .unwrap_or_default(),
1910 )
1911 .take_while(|(a, b)| a == b)
1912 .count()
1913 })?;
1914
1915 Some(project.open_buffer(path, cx))
1916 });
1917
1918 if let Some(buffer_task) = buffer_task {
1919 let closest_buffer = buffer_task.await?;
1920
1921 jump_location = closest_buffer
1922 .read_with(cx, |buffer, _cx| {
1923 buffer
1924 .buffer_diagnostics(None)
1925 .into_iter()
1926 .min_by_key(|entry| entry.diagnostic.severity)
1927 .map(|entry| entry.range.start)
1928 })
1929 .map(|position| (closest_buffer, position));
1930 }
1931 }
1932
1933 anyhow::Ok(jump_location)
1934 }
1935
1936 async fn send_raw_llm_request(
1937 request: RawCompletionRequest,
1938 client: Arc<Client>,
1939 custom_url: Option<Arc<Url>>,
1940 llm_token: LlmApiToken,
1941 app_version: Version,
1942 #[cfg(feature = "cli-support")] eval_cache: Option<Arc<dyn EvalCache>>,
1943 #[cfg(feature = "cli-support")] eval_cache_kind: EvalCacheEntryKind,
1944 ) -> Result<(RawCompletionResponse, Option<EditPredictionUsage>)> {
1945 let url = if let Some(custom_url) = custom_url {
1946 custom_url.as_ref().clone()
1947 } else {
1948 client
1949 .http_client()
1950 .build_zed_llm_url("/predict_edits/raw", &[])?
1951 };
1952
1953 #[cfg(feature = "cli-support")]
1954 let cache_key = if let Some(cache) = eval_cache {
1955 use collections::FxHasher;
1956 use std::hash::{Hash, Hasher};
1957
1958 let mut hasher = FxHasher::default();
1959 url.hash(&mut hasher);
1960 let request_str = serde_json::to_string_pretty(&request)?;
1961 request_str.hash(&mut hasher);
1962 let hash = hasher.finish();
1963
1964 let key = (eval_cache_kind, hash);
1965 if let Some(response_str) = cache.read(key) {
1966 return Ok((serde_json::from_str(&response_str)?, None));
1967 }
1968
1969 Some((cache, request_str, key))
1970 } else {
1971 None
1972 };
1973
1974 let (response, usage) = Self::send_api_request(
1975 |builder| {
1976 let req = builder
1977 .uri(url.as_ref())
1978 .body(serde_json::to_string(&request)?.into());
1979 Ok(req?)
1980 },
1981 client,
1982 llm_token,
1983 app_version,
1984 true,
1985 )
1986 .await?;
1987
1988 #[cfg(feature = "cli-support")]
1989 if let Some((cache, request, key)) = cache_key {
1990 cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
1991 }
1992
1993 Ok((response, usage))
1994 }
1995
1996 fn handle_api_response<T>(
1997 this: &WeakEntity<Self>,
1998 response: Result<(T, Option<EditPredictionUsage>)>,
1999 cx: &mut gpui::AsyncApp,
2000 ) -> Result<T> {
2001 match response {
2002 Ok((data, usage)) => {
2003 if let Some(usage) = usage {
2004 this.update(cx, |this, cx| {
2005 this.user_store.update(cx, |user_store, cx| {
2006 user_store.update_edit_prediction_usage(usage, cx);
2007 });
2008 })
2009 .ok();
2010 }
2011 Ok(data)
2012 }
2013 Err(err) => {
2014 if err.is::<ZedUpdateRequiredError>() {
2015 cx.update(|cx| {
2016 this.update(cx, |this, _cx| {
2017 this.update_required = true;
2018 })
2019 .ok();
2020
2021 let error_message: SharedString = err.to_string().into();
2022 show_app_notification(
2023 NotificationId::unique::<ZedUpdateRequiredError>(),
2024 cx,
2025 move |cx| {
2026 cx.new(|cx| {
2027 ErrorMessagePrompt::new(error_message.clone(), cx)
2028 .with_link_button("Update Zed", "https://zed.dev/releases")
2029 })
2030 },
2031 );
2032 });
2033 }
2034 Err(err)
2035 }
2036 }
2037 }
2038
2039 async fn send_api_request<Res>(
2040 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
2041 client: Arc<Client>,
2042 llm_token: LlmApiToken,
2043 app_version: Version,
2044 require_auth: bool,
2045 ) -> Result<(Res, Option<EditPredictionUsage>)>
2046 where
2047 Res: DeserializeOwned,
2048 {
2049 let http_client = client.http_client();
2050
2051 let mut token = if require_auth {
2052 Some(llm_token.acquire(&client).await?)
2053 } else {
2054 llm_token.acquire(&client).await.ok()
2055 };
2056 let mut did_retry = false;
2057
2058 loop {
2059 let request_builder = http_client::Request::builder().method(Method::POST);
2060
2061 let mut request_builder = request_builder
2062 .header("Content-Type", "application/json")
2063 .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
2064
2065 // Only add Authorization header if we have a token
2066 if let Some(ref token_value) = token {
2067 request_builder =
2068 request_builder.header("Authorization", format!("Bearer {}", token_value));
2069 }
2070
2071 let request = build(request_builder)?;
2072
2073 let mut response = http_client.send(request).await?;
2074
2075 if let Some(minimum_required_version) = response
2076 .headers()
2077 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2078 .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2079 {
2080 anyhow::ensure!(
2081 app_version >= minimum_required_version,
2082 ZedUpdateRequiredError {
2083 minimum_version: minimum_required_version
2084 }
2085 );
2086 }
2087
2088 if response.status().is_success() {
2089 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2090
2091 let mut body = Vec::new();
2092 response.body_mut().read_to_end(&mut body).await?;
2093 return Ok((serde_json::from_slice(&body)?, usage));
2094 } else if !did_retry
2095 && token.is_some()
2096 && response
2097 .headers()
2098 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
2099 .is_some()
2100 {
2101 did_retry = true;
2102 token = Some(llm_token.refresh(&client).await?);
2103 } else {
2104 let mut body = String::new();
2105 response.body_mut().read_to_string(&mut body).await?;
2106 anyhow::bail!(
2107 "Request failed with status: {:?}\nBody: {}",
2108 response.status(),
2109 body
2110 );
2111 }
2112 }
2113 }
2114
2115 pub fn refresh_context(
2116 &mut self,
2117 project: &Entity<Project>,
2118 buffer: &Entity<language::Buffer>,
2119 cursor_position: language::Anchor,
2120 cx: &mut Context<Self>,
2121 ) {
2122 self.get_or_init_project(project, cx)
2123 .context
2124 .update(cx, |store, cx| {
2125 store.refresh(buffer.clone(), cursor_position, cx);
2126 });
2127 }
2128
2129 #[cfg(feature = "cli-support")]
2130 pub fn set_context_for_buffer(
2131 &mut self,
2132 project: &Entity<Project>,
2133 related_files: Vec<RelatedFile>,
2134 cx: &mut Context<Self>,
2135 ) {
2136 self.get_or_init_project(project, cx)
2137 .context
2138 .update(cx, |store, cx| {
2139 store.set_related_files(related_files, cx);
2140 });
2141 }
2142
2143 fn is_file_open_source(
2144 &self,
2145 project: &Entity<Project>,
2146 file: &Arc<dyn File>,
2147 cx: &App,
2148 ) -> bool {
2149 if !file.is_local() || file.is_private() {
2150 return false;
2151 }
2152 let Some(project_state) = self.projects.get(&project.entity_id()) else {
2153 return false;
2154 };
2155 project_state
2156 .license_detection_watchers
2157 .get(&file.worktree_id(cx))
2158 .as_ref()
2159 .is_some_and(|watcher| watcher.is_project_open_source())
2160 }
2161
2162 fn can_collect_file(&self, project: &Entity<Project>, file: &Arc<dyn File>, cx: &App) -> bool {
2163 self.data_collection_choice.is_enabled(cx) && self.is_file_open_source(project, file, cx)
2164 }
2165
2166 fn can_collect_events(&self, events: &[Arc<zeta_prompt::Event>], cx: &App) -> bool {
2167 if !self.data_collection_choice.is_enabled(cx) {
2168 return false;
2169 }
2170 events.iter().all(|event| {
2171 matches!(
2172 event.as_ref(),
2173 zeta_prompt::Event::BufferChange {
2174 in_open_source_repo: true,
2175 ..
2176 }
2177 )
2178 })
2179 }
2180
2181 fn load_data_collection_choice() -> DataCollectionChoice {
2182 let choice = KEY_VALUE_STORE
2183 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2184 .log_err()
2185 .flatten();
2186
2187 match choice.as_deref() {
2188 Some("true") => DataCollectionChoice::Enabled,
2189 Some("false") => DataCollectionChoice::Disabled,
2190 Some(_) => {
2191 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2192 DataCollectionChoice::NotAnswered
2193 }
2194 None => DataCollectionChoice::NotAnswered,
2195 }
2196 }
2197
2198 fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
2199 self.data_collection_choice = self.data_collection_choice.toggle();
2200 let new_choice = self.data_collection_choice;
2201 let is_enabled = new_choice.is_enabled(cx);
2202 db::write_and_log(cx, move || {
2203 KEY_VALUE_STORE.write_kvp(
2204 ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
2205 is_enabled.to_string(),
2206 )
2207 });
2208 }
2209
2210 pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2211 self.shown_predictions.iter()
2212 }
2213
2214 pub fn shown_completions_len(&self) -> usize {
2215 self.shown_predictions.len()
2216 }
2217
2218 pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2219 self.rated_predictions.contains(id)
2220 }
2221
2222 pub fn rate_prediction(
2223 &mut self,
2224 prediction: &EditPrediction,
2225 rating: EditPredictionRating,
2226 feedback: String,
2227 cx: &mut Context<Self>,
2228 ) {
2229 self.rated_predictions.insert(prediction.id.clone());
2230 telemetry::event!(
2231 "Edit Prediction Rated",
2232 rating,
2233 inputs = prediction.inputs,
2234 output = prediction
2235 .edit_preview
2236 .as_unified_diff(prediction.snapshot.file(), &prediction.edits),
2237 feedback
2238 );
2239 self.client.telemetry().flush_events().detach();
2240 cx.notify();
2241 }
2242}
2243
2244pub(crate) fn filter_redundant_excerpts(
2245 mut related_files: Vec<RelatedFile>,
2246 cursor_path: &Path,
2247 cursor_row_range: Range<u32>,
2248) -> Vec<RelatedFile> {
2249 for file in &mut related_files {
2250 if file.path.as_ref() == cursor_path {
2251 file.excerpts.retain(|excerpt| {
2252 excerpt.row_range.start < cursor_row_range.start
2253 || excerpt.row_range.end > cursor_row_range.end
2254 });
2255 }
2256 }
2257 related_files.retain(|file| !file.excerpts.is_empty());
2258 related_files
2259}
2260
2261#[derive(Error, Debug)]
2262#[error(
2263 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2264)]
2265pub struct ZedUpdateRequiredError {
2266 minimum_version: Version,
2267}
2268
2269#[cfg(feature = "cli-support")]
2270pub type EvalCacheKey = (EvalCacheEntryKind, u64);
2271
2272#[cfg(feature = "cli-support")]
2273#[derive(Debug, Clone, Copy, PartialEq)]
2274pub enum EvalCacheEntryKind {
2275 Context,
2276 Search,
2277 Prediction,
2278}
2279
2280#[cfg(feature = "cli-support")]
2281impl std::fmt::Display for EvalCacheEntryKind {
2282 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2283 match self {
2284 EvalCacheEntryKind::Search => write!(f, "search"),
2285 EvalCacheEntryKind::Context => write!(f, "context"),
2286 EvalCacheEntryKind::Prediction => write!(f, "prediction"),
2287 }
2288 }
2289}
2290
2291#[cfg(feature = "cli-support")]
2292pub trait EvalCache: Send + Sync {
2293 fn read(&self, key: EvalCacheKey) -> Option<String>;
2294 fn write(&self, key: EvalCacheKey, input: &str, value: &str);
2295}
2296
2297#[derive(Debug, Clone, Copy)]
2298pub enum DataCollectionChoice {
2299 NotAnswered,
2300 Enabled,
2301 Disabled,
2302}
2303
2304impl DataCollectionChoice {
2305 pub fn is_enabled(self, cx: &App) -> bool {
2306 if cx.is_staff() {
2307 return true;
2308 }
2309 match self {
2310 Self::Enabled => true,
2311 Self::NotAnswered | Self::Disabled => false,
2312 }
2313 }
2314
2315 #[must_use]
2316 pub fn toggle(&self) -> DataCollectionChoice {
2317 match self {
2318 Self::Enabled => Self::Disabled,
2319 Self::Disabled => Self::Enabled,
2320 Self::NotAnswered => Self::Enabled,
2321 }
2322 }
2323}
2324
2325impl From<bool> for DataCollectionChoice {
2326 fn from(value: bool) -> Self {
2327 match value {
2328 true => DataCollectionChoice::Enabled,
2329 false => DataCollectionChoice::Disabled,
2330 }
2331 }
2332}
2333
2334struct ZedPredictUpsell;
2335
2336impl Dismissable for ZedPredictUpsell {
2337 const KEY: &'static str = "dismissed-edit-predict-upsell";
2338
2339 fn dismissed() -> bool {
2340 // To make this backwards compatible with older versions of Zed, we
2341 // check if the user has seen the previous Edit Prediction Onboarding
2342 // before, by checking the data collection choice which was written to
2343 // the database once the user clicked on "Accept and Enable"
2344 if KEY_VALUE_STORE
2345 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2346 .log_err()
2347 .is_some_and(|s| s.is_some())
2348 {
2349 return true;
2350 }
2351
2352 KEY_VALUE_STORE
2353 .read_kvp(Self::KEY)
2354 .log_err()
2355 .is_some_and(|s| s.is_some())
2356 }
2357}
2358
2359pub fn should_show_upsell_modal() -> bool {
2360 !ZedPredictUpsell::dismissed()
2361}
2362
2363pub fn init(cx: &mut App) {
2364 cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2365 workspace.register_action(
2366 move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2367 ZedPredictModal::toggle(
2368 workspace,
2369 workspace.user_store().clone(),
2370 workspace.client().clone(),
2371 window,
2372 cx,
2373 )
2374 },
2375 );
2376
2377 workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2378 update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2379 settings
2380 .project
2381 .all_languages
2382 .features
2383 .get_or_insert_default()
2384 .edit_prediction_provider = Some(EditPredictionProvider::None)
2385 });
2386 });
2387 fn copilot_for_project(project: &Entity<Project>, cx: &mut App) -> Option<Entity<Copilot>> {
2388 EditPredictionStore::try_global(cx).and_then(|store| {
2389 store.update(cx, |this, cx| this.start_copilot_for_project(project, cx))
2390 })
2391 }
2392
2393 workspace.register_action(|workspace, _: &SignIn, window, cx| {
2394 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2395 copilot_ui::initiate_sign_in(copilot, window, cx);
2396 }
2397 });
2398 workspace.register_action(|workspace, _: &Reinstall, window, cx| {
2399 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2400 copilot_ui::reinstall_and_sign_in(copilot, window, cx);
2401 }
2402 });
2403 workspace.register_action(|workspace, _: &SignOut, window, cx| {
2404 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2405 copilot_ui::initiate_sign_out(copilot, window, cx);
2406 }
2407 });
2408 })
2409 .detach();
2410}