1use anyhow::Result;
2use arrayvec::ArrayVec;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{
5 PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
6};
7use cloud_llm_client::{
8 EditPredictionRejectReason, EditPredictionRejection,
9 MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
10 PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME,
11};
12use collections::{HashMap, HashSet};
13use copilot::{Copilot, Reinstall, SignIn, SignOut};
14use db::kvp::{Dismissable, KEY_VALUE_STORE};
15use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
16use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
17use futures::{
18 AsyncReadExt as _, FutureExt as _, StreamExt as _,
19 channel::mpsc::{self, UnboundedReceiver},
20 select_biased,
21};
22use gpui::BackgroundExecutor;
23use gpui::http_client::Url;
24use gpui::{
25 App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
26 http_client::{self, AsyncBody, Method},
27 prelude::*,
28};
29use language::language_settings::all_language_settings;
30use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
31use language::{BufferSnapshot, OffsetRangeExt};
32use language_model::{LlmApiToken, NeedsLlmTokenRefresh, RefreshLlmTokenListener};
33use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
34use release_channel::AppVersion;
35use semver::Version;
36use serde::de::DeserializeOwned;
37use settings::{EditPredictionProvider, Settings as _, update_settings_file};
38use std::collections::{VecDeque, hash_map};
39use text::Edit;
40use workspace::Workspace;
41use zeta_prompt::ZetaPromptInput;
42use zeta_prompt::ZetaVersion;
43
44use std::ops::Range;
45use std::path::Path;
46use std::rc::Rc;
47use std::str::FromStr as _;
48use std::sync::{Arc, LazyLock};
49use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
50use std::{env, mem};
51use thiserror::Error;
52use util::{RangeExt as _, ResultExt as _};
53use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
54
55pub mod cursor_excerpt;
56pub mod example_spec;
57mod license_detection;
58pub mod mercury;
59pub mod ollama;
60mod onboarding_modal;
61pub mod open_ai_response;
62mod prediction;
63pub mod sweep_ai;
64
65pub mod udiff;
66
67mod capture_example;
68mod zed_edit_prediction_delegate;
69pub mod zeta1;
70pub mod zeta2;
71
72#[cfg(test)]
73mod edit_prediction_tests;
74
75use crate::capture_example::{
76 should_sample_edit_prediction_example_capture, should_send_testing_zeta2_request,
77};
78use crate::license_detection::LicenseDetectionWatcher;
79use crate::mercury::Mercury;
80use crate::ollama::Ollama;
81use crate::onboarding_modal::ZedPredictModal;
82pub use crate::prediction::EditPrediction;
83pub use crate::prediction::EditPredictionId;
84use crate::prediction::EditPredictionResult;
85pub use crate::sweep_ai::SweepAi;
86pub use capture_example::capture_example;
87pub use language_model::ApiKeyState;
88pub use telemetry_events::EditPredictionRating;
89pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
90
91actions!(
92 edit_prediction,
93 [
94 /// Resets the edit prediction onboarding state.
95 ResetOnboarding,
96 /// Clears the edit prediction history.
97 ClearHistory,
98 ]
99);
100
101/// Maximum number of events to track.
102const EVENT_COUNT_MAX: usize = 6;
103const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
104const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
105const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
106const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
107
108static EDIT_PREDICTIONS_MODEL_ID: LazyLock<Option<String>> =
109 LazyLock::new(|| env::var("ZED_ZETA_MODEL").ok());
110
111pub struct Zeta2FeatureFlag;
112
113impl FeatureFlag for Zeta2FeatureFlag {
114 const NAME: &'static str = "zeta2";
115
116 fn enabled_for_staff() -> bool {
117 true
118 }
119}
120
121pub struct EditPredictionExampleCaptureFeatureFlag;
122
123impl FeatureFlag for EditPredictionExampleCaptureFeatureFlag {
124 const NAME: &'static str = "edit-prediction-example-capture";
125
126 fn enabled_for_staff() -> bool {
127 true
128 }
129}
130
131#[derive(Clone)]
132struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
133
134impl Global for EditPredictionStoreGlobal {}
135
136pub struct EditPredictionStore {
137 client: Arc<Client>,
138 user_store: Entity<UserStore>,
139 llm_token: LlmApiToken,
140 _llm_token_subscription: Subscription,
141 projects: HashMap<EntityId, ProjectState>,
142 update_required: bool,
143 edit_prediction_model: EditPredictionModel,
144 pub sweep_ai: SweepAi,
145 pub mercury: Mercury,
146 pub ollama: Ollama,
147 data_collection_choice: DataCollectionChoice,
148 reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
149 shown_predictions: VecDeque<EditPrediction>,
150 rated_predictions: HashSet<EditPredictionId>,
151 custom_predict_edits_url: Option<Arc<Url>>,
152}
153
154#[derive(Copy, Clone, Default, PartialEq, Eq)]
155pub enum EditPredictionModel {
156 #[default]
157 Zeta1,
158 Zeta2 {
159 version: ZetaVersion,
160 },
161 Sweep,
162 Mercury,
163 Ollama,
164}
165
166#[derive(Clone)]
167pub struct EditPredictionModelInput {
168 project: Entity<Project>,
169 buffer: Entity<Buffer>,
170 snapshot: BufferSnapshot,
171 position: Anchor,
172 events: Vec<Arc<zeta_prompt::Event>>,
173 related_files: Vec<RelatedFile>,
174 recent_paths: VecDeque<ProjectPath>,
175 trigger: PredictEditsRequestTrigger,
176 diagnostic_search_range: Range<Point>,
177 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
178 pub user_actions: Vec<UserActionRecord>,
179}
180
181#[derive(Debug)]
182pub enum DebugEvent {
183 ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
184 ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
185 EditPredictionStarted(EditPredictionStartedDebugEvent),
186 EditPredictionFinished(EditPredictionFinishedDebugEvent),
187}
188
189#[derive(Debug)]
190pub struct ContextRetrievalStartedDebugEvent {
191 pub project_entity_id: EntityId,
192 pub timestamp: Instant,
193 pub search_prompt: String,
194}
195
196#[derive(Debug)]
197pub struct ContextRetrievalFinishedDebugEvent {
198 pub project_entity_id: EntityId,
199 pub timestamp: Instant,
200 pub metadata: Vec<(&'static str, SharedString)>,
201}
202
203#[derive(Debug)]
204pub struct EditPredictionStartedDebugEvent {
205 pub buffer: WeakEntity<Buffer>,
206 pub position: Anchor,
207 pub prompt: Option<String>,
208}
209
210#[derive(Debug)]
211pub struct EditPredictionFinishedDebugEvent {
212 pub buffer: WeakEntity<Buffer>,
213 pub position: Anchor,
214 pub model_output: Option<String>,
215}
216
217const USER_ACTION_HISTORY_SIZE: usize = 16;
218
219#[derive(Clone, Debug)]
220pub struct UserActionRecord {
221 pub action_type: UserActionType,
222 pub buffer_id: EntityId,
223 pub line_number: u32,
224 pub offset: usize,
225 pub timestamp_epoch_ms: u64,
226}
227
228#[derive(Clone, Copy, Debug, PartialEq, Eq)]
229pub enum UserActionType {
230 InsertChar,
231 InsertSelection,
232 DeleteChar,
233 DeleteSelection,
234 CursorMovement,
235}
236
237/// An event with associated metadata for reconstructing buffer state.
238#[derive(Clone)]
239pub struct StoredEvent {
240 pub event: Arc<zeta_prompt::Event>,
241 pub old_snapshot: TextBufferSnapshot,
242}
243
244struct ProjectState {
245 events: VecDeque<StoredEvent>,
246 last_event: Option<LastEvent>,
247 recent_paths: VecDeque<ProjectPath>,
248 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
249 current_prediction: Option<CurrentEditPrediction>,
250 next_pending_prediction_id: usize,
251 pending_predictions: ArrayVec<PendingPrediction, 2>,
252 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
253 last_prediction_refresh: Option<(EntityId, Instant)>,
254 cancelled_predictions: HashSet<usize>,
255 context: Entity<RelatedExcerptStore>,
256 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
257 user_actions: VecDeque<UserActionRecord>,
258 _subscriptions: [gpui::Subscription; 2],
259 copilot: Option<Entity<Copilot>>,
260}
261
262impl ProjectState {
263 fn record_user_action(&mut self, action: UserActionRecord) {
264 if self.user_actions.len() >= USER_ACTION_HISTORY_SIZE {
265 self.user_actions.pop_front();
266 }
267 self.user_actions.push_back(action);
268 }
269
270 pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
271 self.events
272 .iter()
273 .cloned()
274 .chain(
275 self.last_event
276 .as_ref()
277 .and_then(|event| event.finalize(&self.license_detection_watchers, cx)),
278 )
279 .collect()
280 }
281
282 pub fn events_split_by_pause(&self, cx: &App) -> Vec<StoredEvent> {
283 self.events
284 .iter()
285 .cloned()
286 .chain(self.last_event.as_ref().iter().flat_map(|event| {
287 let (one, two) = event.split_by_pause();
288 let one = one.finalize(&self.license_detection_watchers, cx);
289 let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
290 one.into_iter().chain(two)
291 }))
292 .collect()
293 }
294
295 fn cancel_pending_prediction(
296 &mut self,
297 pending_prediction: PendingPrediction,
298 cx: &mut Context<EditPredictionStore>,
299 ) {
300 self.cancelled_predictions.insert(pending_prediction.id);
301
302 if pending_prediction.drop_on_cancel {
303 drop(pending_prediction.task);
304 } else {
305 cx.spawn(async move |this, cx| {
306 let Some(prediction_id) = pending_prediction.task.await else {
307 return;
308 };
309
310 this.update(cx, |this, _cx| {
311 this.reject_prediction(
312 prediction_id,
313 EditPredictionRejectReason::Canceled,
314 false,
315 );
316 })
317 .ok();
318 })
319 .detach()
320 }
321 }
322
323 fn active_buffer(
324 &self,
325 project: &Entity<Project>,
326 cx: &App,
327 ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
328 let project = project.read(cx);
329 let active_path = project.path_for_entry(project.active_entry()?, cx)?;
330 let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
331 let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
332 Some((active_buffer, registered_buffer.last_position))
333 }
334}
335
336#[derive(Debug, Clone)]
337struct CurrentEditPrediction {
338 pub requested_by: PredictionRequestedBy,
339 pub prediction: EditPrediction,
340 pub was_shown: bool,
341 pub shown_with: Option<edit_prediction_types::SuggestionDisplayType>,
342}
343
344impl CurrentEditPrediction {
345 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
346 let Some(new_edits) = self
347 .prediction
348 .interpolate(&self.prediction.buffer.read(cx))
349 else {
350 return false;
351 };
352
353 if self.prediction.buffer != old_prediction.prediction.buffer {
354 return true;
355 }
356
357 let Some(old_edits) = old_prediction
358 .prediction
359 .interpolate(&old_prediction.prediction.buffer.read(cx))
360 else {
361 return true;
362 };
363
364 let requested_by_buffer_id = self.requested_by.buffer_id();
365
366 // This reduces the occurrence of UI thrash from replacing edits
367 //
368 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
369 if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
370 && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
371 && old_edits.len() == 1
372 && new_edits.len() == 1
373 {
374 let (old_range, old_text) = &old_edits[0];
375 let (new_range, new_text) = &new_edits[0];
376 new_range == old_range && new_text.starts_with(old_text.as_ref())
377 } else {
378 true
379 }
380 }
381}
382
383#[derive(Debug, Clone)]
384enum PredictionRequestedBy {
385 DiagnosticsUpdate,
386 Buffer(EntityId),
387}
388
389impl PredictionRequestedBy {
390 pub fn buffer_id(&self) -> Option<EntityId> {
391 match self {
392 PredictionRequestedBy::DiagnosticsUpdate => None,
393 PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
394 }
395 }
396}
397
398#[derive(Debug)]
399struct PendingPrediction {
400 id: usize,
401 task: Task<Option<EditPredictionId>>,
402 /// If true, the task is dropped immediately on cancel (cancelling the HTTP request).
403 /// If false, the task is awaited to completion so rejection can be reported.
404 drop_on_cancel: bool,
405}
406
407/// A prediction from the perspective of a buffer.
408#[derive(Debug)]
409enum BufferEditPrediction<'a> {
410 Local { prediction: &'a EditPrediction },
411 Jump { prediction: &'a EditPrediction },
412}
413
414#[cfg(test)]
415impl std::ops::Deref for BufferEditPrediction<'_> {
416 type Target = EditPrediction;
417
418 fn deref(&self) -> &Self::Target {
419 match self {
420 BufferEditPrediction::Local { prediction } => prediction,
421 BufferEditPrediction::Jump { prediction } => prediction,
422 }
423 }
424}
425
426struct RegisteredBuffer {
427 file: Option<Arc<dyn File>>,
428 snapshot: TextBufferSnapshot,
429 last_position: Option<Anchor>,
430 _subscriptions: [gpui::Subscription; 2],
431}
432
433#[derive(Clone)]
434struct LastEvent {
435 old_snapshot: TextBufferSnapshot,
436 new_snapshot: TextBufferSnapshot,
437 old_file: Option<Arc<dyn File>>,
438 new_file: Option<Arc<dyn File>>,
439 edit_range: Option<Range<Anchor>>,
440 snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
441 last_edit_time: Option<Instant>,
442}
443
444impl LastEvent {
445 pub fn finalize(
446 &self,
447 license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
448 cx: &App,
449 ) -> Option<StoredEvent> {
450 let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
451 let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
452
453 let in_open_source_repo =
454 [self.new_file.as_ref(), self.old_file.as_ref()]
455 .iter()
456 .all(|file| {
457 file.is_some_and(|file| {
458 license_detection_watchers
459 .get(&file.worktree_id(cx))
460 .is_some_and(|watcher| watcher.is_project_open_source())
461 })
462 });
463
464 let diff = compute_diff_between_snapshots(&self.old_snapshot, &self.new_snapshot)?;
465
466 if path == old_path && diff.is_empty() {
467 None
468 } else {
469 Some(StoredEvent {
470 event: Arc::new(zeta_prompt::Event::BufferChange {
471 old_path,
472 path,
473 diff,
474 in_open_source_repo,
475 // TODO: Actually detect if this edit was predicted or not
476 predicted: false,
477 }),
478 old_snapshot: self.old_snapshot.clone(),
479 })
480 }
481 }
482
483 pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
484 let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
485 return (self.clone(), None);
486 };
487
488 let before = LastEvent {
489 old_snapshot: self.old_snapshot.clone(),
490 new_snapshot: boundary_snapshot.clone(),
491 old_file: self.old_file.clone(),
492 new_file: self.new_file.clone(),
493 edit_range: None,
494 snapshot_after_last_editing_pause: None,
495 last_edit_time: self.last_edit_time,
496 };
497
498 let after = LastEvent {
499 old_snapshot: boundary_snapshot.clone(),
500 new_snapshot: self.new_snapshot.clone(),
501 old_file: self.old_file.clone(),
502 new_file: self.new_file.clone(),
503 edit_range: None,
504 snapshot_after_last_editing_pause: None,
505 last_edit_time: self.last_edit_time,
506 };
507
508 (before, Some(after))
509 }
510}
511
512pub(crate) fn compute_diff_between_snapshots(
513 old_snapshot: &TextBufferSnapshot,
514 new_snapshot: &TextBufferSnapshot,
515) -> Option<String> {
516 let edits: Vec<Edit<usize>> = new_snapshot
517 .edits_since::<usize>(&old_snapshot.version)
518 .collect();
519
520 let (first_edit, last_edit) = edits.first().zip(edits.last())?;
521
522 let old_start_point = old_snapshot.offset_to_point(first_edit.old.start);
523 let old_end_point = old_snapshot.offset_to_point(last_edit.old.end);
524 let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
525 let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
526
527 const CONTEXT_LINES: u32 = 3;
528
529 let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
530 let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
531 let old_context_end_row =
532 (old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
533 let new_context_end_row =
534 (new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
535
536 let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
537 let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
538 let old_end_line_offset = old_snapshot
539 .point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
540 let new_end_line_offset = new_snapshot
541 .point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
542 let old_edit_range = old_start_line_offset..old_end_line_offset;
543 let new_edit_range = new_start_line_offset..new_end_line_offset;
544
545 let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
546 let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
547
548 let diff = language::unified_diff_with_offsets(
549 &old_region_text,
550 &new_region_text,
551 old_context_start_row,
552 new_context_start_row,
553 );
554
555 Some(diff)
556}
557
558fn buffer_path_with_id_fallback(
559 file: Option<&Arc<dyn File>>,
560 snapshot: &TextBufferSnapshot,
561 cx: &App,
562) -> Arc<Path> {
563 if let Some(file) = file {
564 file.full_path(cx).into()
565 } else {
566 Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
567 }
568}
569
570impl EditPredictionStore {
571 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
572 cx.try_global::<EditPredictionStoreGlobal>()
573 .map(|global| global.0.clone())
574 }
575
576 pub fn global(
577 client: &Arc<Client>,
578 user_store: &Entity<UserStore>,
579 cx: &mut App,
580 ) -> Entity<Self> {
581 cx.try_global::<EditPredictionStoreGlobal>()
582 .map(|global| global.0.clone())
583 .unwrap_or_else(|| {
584 let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
585 cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
586 ep_store
587 })
588 }
589
590 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
591 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
592 let data_collection_choice = Self::load_data_collection_choice();
593
594 let llm_token = LlmApiToken::default();
595
596 let (reject_tx, reject_rx) = mpsc::unbounded();
597 cx.background_spawn({
598 let client = client.clone();
599 let llm_token = llm_token.clone();
600 let app_version = AppVersion::global(cx);
601 let background_executor = cx.background_executor().clone();
602 async move {
603 Self::handle_rejected_predictions(
604 reject_rx,
605 client,
606 llm_token,
607 app_version,
608 background_executor,
609 )
610 .await
611 }
612 })
613 .detach();
614
615 let this = Self {
616 projects: HashMap::default(),
617 client,
618 user_store,
619 llm_token,
620 _llm_token_subscription: cx.subscribe(
621 &refresh_llm_token_listener,
622 |this, _listener, _event, cx| {
623 let client = this.client.clone();
624 let llm_token = this.llm_token.clone();
625 cx.spawn(async move |_this, _cx| {
626 llm_token.refresh(&client).await?;
627 anyhow::Ok(())
628 })
629 .detach_and_log_err(cx);
630 },
631 ),
632 update_required: false,
633 edit_prediction_model: EditPredictionModel::Zeta2 {
634 version: Default::default(),
635 },
636 sweep_ai: SweepAi::new(cx),
637 mercury: Mercury::new(cx),
638 ollama: Ollama::new(),
639
640 data_collection_choice,
641 reject_predictions_tx: reject_tx,
642 rated_predictions: Default::default(),
643 shown_predictions: Default::default(),
644 custom_predict_edits_url: match env::var("ZED_PREDICT_EDITS_URL") {
645 Ok(custom_url) => Url::parse(&custom_url).log_err().map(Into::into),
646 Err(_) => None,
647 },
648 };
649
650 this
651 }
652
653 #[cfg(test)]
654 pub fn set_custom_predict_edits_url(&mut self, url: Url) {
655 self.custom_predict_edits_url = Some(url.into());
656 }
657
658 pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
659 self.edit_prediction_model = model;
660 }
661
662 pub fn icons(&self) -> edit_prediction_types::EditPredictionIconSet {
663 use ui::IconName;
664 match self.edit_prediction_model {
665 EditPredictionModel::Sweep => {
666 edit_prediction_types::EditPredictionIconSet::new(IconName::SweepAi)
667 .with_disabled(IconName::SweepAiDisabled)
668 .with_up(IconName::SweepAiUp)
669 .with_down(IconName::SweepAiDown)
670 .with_error(IconName::SweepAiError)
671 }
672 EditPredictionModel::Mercury => {
673 edit_prediction_types::EditPredictionIconSet::new(IconName::Inception)
674 }
675 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 { .. } => {
676 edit_prediction_types::EditPredictionIconSet::new(IconName::ZedPredict)
677 .with_disabled(IconName::ZedPredictDisabled)
678 .with_up(IconName::ZedPredictUp)
679 .with_down(IconName::ZedPredictDown)
680 .with_error(IconName::ZedPredictError)
681 }
682 EditPredictionModel::Ollama => {
683 edit_prediction_types::EditPredictionIconSet::new(IconName::AiOllama)
684 }
685 }
686 }
687
688 pub fn has_sweep_api_token(&self, cx: &App) -> bool {
689 self.sweep_ai.api_token.read(cx).has_key()
690 }
691
692 pub fn has_mercury_api_token(&self, cx: &App) -> bool {
693 self.mercury.api_token.read(cx).has_key()
694 }
695
696 pub fn clear_history(&mut self) {
697 for project_state in self.projects.values_mut() {
698 project_state.events.clear();
699 project_state.last_event.take();
700 }
701 }
702
703 pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
704 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
705 project_state.events.clear();
706 project_state.last_event.take();
707 }
708 }
709
710 pub fn edit_history_for_project(
711 &self,
712 project: &Entity<Project>,
713 cx: &App,
714 ) -> Vec<StoredEvent> {
715 self.projects
716 .get(&project.entity_id())
717 .map(|project_state| project_state.events(cx))
718 .unwrap_or_default()
719 }
720
721 pub fn edit_history_for_project_with_pause_split_last_event(
722 &self,
723 project: &Entity<Project>,
724 cx: &App,
725 ) -> Vec<StoredEvent> {
726 self.projects
727 .get(&project.entity_id())
728 .map(|project_state| project_state.events_split_by_pause(cx))
729 .unwrap_or_default()
730 }
731
732 pub fn context_for_project<'a>(
733 &'a self,
734 project: &Entity<Project>,
735 cx: &'a mut App,
736 ) -> Vec<RelatedFile> {
737 self.projects
738 .get(&project.entity_id())
739 .map(|project| {
740 project
741 .context
742 .update(cx, |context, cx| context.related_files(cx))
743 })
744 .unwrap_or_default()
745 }
746
747 pub fn copilot_for_project(&self, project: &Entity<Project>) -> Option<Entity<Copilot>> {
748 self.projects
749 .get(&project.entity_id())
750 .and_then(|project| project.copilot.clone())
751 }
752
753 pub fn start_copilot_for_project(
754 &mut self,
755 project: &Entity<Project>,
756 cx: &mut Context<Self>,
757 ) -> Option<Entity<Copilot>> {
758 if DisableAiSettings::get(None, cx).disable_ai {
759 return None;
760 }
761 let state = self.get_or_init_project(project, cx);
762
763 if state.copilot.is_some() {
764 return state.copilot.clone();
765 }
766 let _project = project.clone();
767 let project = project.read(cx);
768
769 let node = project.node_runtime().cloned();
770 if let Some(node) = node {
771 let next_id = project.languages().next_language_server_id();
772 let fs = project.fs().clone();
773
774 let copilot = cx.new(|cx| Copilot::new(Some(_project), next_id, fs, node, cx));
775 state.copilot = Some(copilot.clone());
776 Some(copilot)
777 } else {
778 None
779 }
780 }
781
782 pub fn context_for_project_with_buffers<'a>(
783 &'a self,
784 project: &Entity<Project>,
785 cx: &'a mut App,
786 ) -> Vec<(RelatedFile, Entity<Buffer>)> {
787 self.projects
788 .get(&project.entity_id())
789 .map(|project| {
790 project
791 .context
792 .update(cx, |context, cx| context.related_files_with_buffers(cx))
793 })
794 .unwrap_or_default()
795 }
796
797 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
798 if matches!(
799 self.edit_prediction_model,
800 EditPredictionModel::Zeta2 { .. }
801 ) {
802 self.user_store.read(cx).edit_prediction_usage()
803 } else {
804 None
805 }
806 }
807
808 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
809 self.get_or_init_project(project, cx);
810 }
811
812 pub fn register_buffer(
813 &mut self,
814 buffer: &Entity<Buffer>,
815 project: &Entity<Project>,
816 cx: &mut Context<Self>,
817 ) {
818 let project_state = self.get_or_init_project(project, cx);
819 Self::register_buffer_impl(project_state, buffer, project, cx);
820 }
821
822 fn get_or_init_project(
823 &mut self,
824 project: &Entity<Project>,
825 cx: &mut Context<Self>,
826 ) -> &mut ProjectState {
827 let entity_id = project.entity_id();
828 self.projects
829 .entry(entity_id)
830 .or_insert_with(|| ProjectState {
831 context: {
832 let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
833 cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
834 this.handle_excerpt_store_event(entity_id, event);
835 })
836 .detach();
837 related_excerpt_store
838 },
839 events: VecDeque::new(),
840 last_event: None,
841 recent_paths: VecDeque::new(),
842 debug_tx: None,
843 registered_buffers: HashMap::default(),
844 current_prediction: None,
845 cancelled_predictions: HashSet::default(),
846 pending_predictions: ArrayVec::new(),
847 next_pending_prediction_id: 0,
848 last_prediction_refresh: None,
849 license_detection_watchers: HashMap::default(),
850 user_actions: VecDeque::with_capacity(USER_ACTION_HISTORY_SIZE),
851 _subscriptions: [
852 cx.subscribe(&project, Self::handle_project_event),
853 cx.observe_release(&project, move |this, _, cx| {
854 this.projects.remove(&entity_id);
855 cx.notify();
856 }),
857 ],
858 copilot: None,
859 })
860 }
861
862 pub fn remove_project(&mut self, project: &Entity<Project>) {
863 self.projects.remove(&project.entity_id());
864 }
865
866 fn handle_excerpt_store_event(
867 &mut self,
868 project_entity_id: EntityId,
869 event: &RelatedExcerptStoreEvent,
870 ) {
871 if let Some(project_state) = self.projects.get(&project_entity_id) {
872 if let Some(debug_tx) = project_state.debug_tx.clone() {
873 match event {
874 RelatedExcerptStoreEvent::StartedRefresh => {
875 debug_tx
876 .unbounded_send(DebugEvent::ContextRetrievalStarted(
877 ContextRetrievalStartedDebugEvent {
878 project_entity_id: project_entity_id,
879 timestamp: Instant::now(),
880 search_prompt: String::new(),
881 },
882 ))
883 .ok();
884 }
885 RelatedExcerptStoreEvent::FinishedRefresh {
886 cache_hit_count,
887 cache_miss_count,
888 mean_definition_latency,
889 max_definition_latency,
890 } => {
891 debug_tx
892 .unbounded_send(DebugEvent::ContextRetrievalFinished(
893 ContextRetrievalFinishedDebugEvent {
894 project_entity_id: project_entity_id,
895 timestamp: Instant::now(),
896 metadata: vec![
897 (
898 "Cache Hits",
899 format!(
900 "{}/{}",
901 cache_hit_count,
902 cache_hit_count + cache_miss_count
903 )
904 .into(),
905 ),
906 (
907 "Max LSP Time",
908 format!("{} ms", max_definition_latency.as_millis())
909 .into(),
910 ),
911 (
912 "Mean LSP Time",
913 format!("{} ms", mean_definition_latency.as_millis())
914 .into(),
915 ),
916 ],
917 },
918 ))
919 .ok();
920 }
921 }
922 }
923 }
924 }
925
926 pub fn debug_info(
927 &mut self,
928 project: &Entity<Project>,
929 cx: &mut Context<Self>,
930 ) -> mpsc::UnboundedReceiver<DebugEvent> {
931 let project_state = self.get_or_init_project(project, cx);
932 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
933 project_state.debug_tx = Some(debug_watch_tx);
934 debug_watch_rx
935 }
936
937 fn handle_project_event(
938 &mut self,
939 project: Entity<Project>,
940 event: &project::Event,
941 cx: &mut Context<Self>,
942 ) {
943 // TODO [zeta2] init with recent paths
944 match event {
945 project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
946 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
947 return;
948 };
949 let path = project.read(cx).path_for_entry(*active_entry_id, cx);
950 if let Some(path) = path {
951 if let Some(ix) = project_state
952 .recent_paths
953 .iter()
954 .position(|probe| probe == &path)
955 {
956 project_state.recent_paths.remove(ix);
957 }
958 project_state.recent_paths.push_front(path);
959 }
960 }
961 project::Event::DiagnosticsUpdated { .. } => {
962 if cx.has_flag::<Zeta2FeatureFlag>() {
963 self.refresh_prediction_from_diagnostics(project, cx);
964 }
965 }
966 _ => (),
967 }
968 }
969
970 fn register_buffer_impl<'a>(
971 project_state: &'a mut ProjectState,
972 buffer: &Entity<Buffer>,
973 project: &Entity<Project>,
974 cx: &mut Context<Self>,
975 ) -> &'a mut RegisteredBuffer {
976 let buffer_id = buffer.entity_id();
977
978 if let Some(file) = buffer.read(cx).file() {
979 let worktree_id = file.worktree_id(cx);
980 if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
981 project_state
982 .license_detection_watchers
983 .entry(worktree_id)
984 .or_insert_with(|| {
985 let project_entity_id = project.entity_id();
986 cx.observe_release(&worktree, move |this, _worktree, _cx| {
987 let Some(project_state) = this.projects.get_mut(&project_entity_id)
988 else {
989 return;
990 };
991 project_state
992 .license_detection_watchers
993 .remove(&worktree_id);
994 })
995 .detach();
996 Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
997 });
998 }
999 }
1000
1001 match project_state.registered_buffers.entry(buffer_id) {
1002 hash_map::Entry::Occupied(entry) => entry.into_mut(),
1003 hash_map::Entry::Vacant(entry) => {
1004 let buf = buffer.read(cx);
1005 let snapshot = buf.text_snapshot();
1006 let file = buf.file().cloned();
1007 let project_entity_id = project.entity_id();
1008 entry.insert(RegisteredBuffer {
1009 snapshot,
1010 file,
1011 last_position: None,
1012 _subscriptions: [
1013 cx.subscribe(buffer, {
1014 let project = project.downgrade();
1015 move |this, buffer, event, cx| {
1016 if let language::BufferEvent::Edited = event
1017 && let Some(project) = project.upgrade()
1018 {
1019 this.report_changes_for_buffer(&buffer, &project, cx);
1020 }
1021 }
1022 }),
1023 cx.observe_release(buffer, move |this, _buffer, _cx| {
1024 let Some(project_state) = this.projects.get_mut(&project_entity_id)
1025 else {
1026 return;
1027 };
1028 project_state.registered_buffers.remove(&buffer_id);
1029 }),
1030 ],
1031 })
1032 }
1033 }
1034 }
1035
1036 fn report_changes_for_buffer(
1037 &mut self,
1038 buffer: &Entity<Buffer>,
1039 project: &Entity<Project>,
1040 cx: &mut Context<Self>,
1041 ) {
1042 let project_state = self.get_or_init_project(project, cx);
1043 let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
1044
1045 let buf = buffer.read(cx);
1046 let new_file = buf.file().cloned();
1047 let new_snapshot = buf.text_snapshot();
1048 if new_snapshot.version == registered_buffer.snapshot.version {
1049 return;
1050 }
1051
1052 let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
1053 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1054 let mut num_edits = 0usize;
1055 let mut total_deleted = 0usize;
1056 let mut total_inserted = 0usize;
1057 let mut edit_range: Option<Range<Anchor>> = None;
1058 let mut last_offset: Option<usize> = None;
1059
1060 for (edit, anchor_range) in
1061 new_snapshot.anchored_edits_since::<usize>(&old_snapshot.version)
1062 {
1063 num_edits += 1;
1064 total_deleted += edit.old.len();
1065 total_inserted += edit.new.len();
1066 edit_range = Some(match edit_range {
1067 None => anchor_range,
1068 Some(acc) => acc.start..anchor_range.end,
1069 });
1070 last_offset = Some(edit.new.end);
1071 }
1072
1073 if num_edits > 0 {
1074 let action_type = match (total_deleted, total_inserted, num_edits) {
1075 (0, ins, n) if ins == n => UserActionType::InsertChar,
1076 (0, _, _) => UserActionType::InsertSelection,
1077 (del, 0, n) if del == n => UserActionType::DeleteChar,
1078 (_, 0, _) => UserActionType::DeleteSelection,
1079 (_, ins, n) if ins == n => UserActionType::InsertChar,
1080 (_, _, _) => UserActionType::InsertSelection,
1081 };
1082
1083 if let Some(offset) = last_offset {
1084 let point = new_snapshot.offset_to_point(offset);
1085 let timestamp_epoch_ms = SystemTime::now()
1086 .duration_since(UNIX_EPOCH)
1087 .map(|d| d.as_millis() as u64)
1088 .unwrap_or(0);
1089 project_state.record_user_action(UserActionRecord {
1090 action_type,
1091 buffer_id: buffer.entity_id(),
1092 line_number: point.row,
1093 offset,
1094 timestamp_epoch_ms,
1095 });
1096 }
1097 }
1098
1099 let events = &mut project_state.events;
1100
1101 let now = cx.background_executor().now();
1102 if let Some(last_event) = project_state.last_event.as_mut() {
1103 let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
1104 == last_event.new_snapshot.remote_id()
1105 && old_snapshot.version == last_event.new_snapshot.version;
1106
1107 let should_coalesce = is_next_snapshot_of_same_buffer
1108 && edit_range
1109 .as_ref()
1110 .zip(last_event.edit_range.as_ref())
1111 .is_some_and(|(a, b)| {
1112 let a = a.to_point(&new_snapshot);
1113 let b = b.to_point(&new_snapshot);
1114 if a.start > b.end {
1115 a.start.row.abs_diff(b.end.row) <= CHANGE_GROUPING_LINE_SPAN
1116 } else if b.start > a.end {
1117 b.start.row.abs_diff(a.end.row) <= CHANGE_GROUPING_LINE_SPAN
1118 } else {
1119 true
1120 }
1121 });
1122
1123 if should_coalesce {
1124 let pause_elapsed = last_event
1125 .last_edit_time
1126 .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
1127 .unwrap_or(false);
1128 if pause_elapsed {
1129 last_event.snapshot_after_last_editing_pause =
1130 Some(last_event.new_snapshot.clone());
1131 }
1132
1133 last_event.edit_range = edit_range;
1134 last_event.new_snapshot = new_snapshot;
1135 last_event.last_edit_time = Some(now);
1136 return;
1137 }
1138 }
1139
1140 if let Some(event) = project_state.last_event.take() {
1141 if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1142 if events.len() + 1 >= EVENT_COUNT_MAX {
1143 events.pop_front();
1144 }
1145 events.push_back(event);
1146 }
1147 }
1148
1149 project_state.last_event = Some(LastEvent {
1150 old_file,
1151 new_file,
1152 old_snapshot,
1153 new_snapshot,
1154 edit_range,
1155 snapshot_after_last_editing_pause: None,
1156 last_edit_time: Some(now),
1157 });
1158 }
1159
1160 fn prediction_at(
1161 &mut self,
1162 buffer: &Entity<Buffer>,
1163 position: Option<language::Anchor>,
1164 project: &Entity<Project>,
1165 cx: &App,
1166 ) -> Option<BufferEditPrediction<'_>> {
1167 let project_state = self.projects.get_mut(&project.entity_id())?;
1168 if let Some(position) = position
1169 && let Some(buffer) = project_state
1170 .registered_buffers
1171 .get_mut(&buffer.entity_id())
1172 {
1173 buffer.last_position = Some(position);
1174 }
1175
1176 let CurrentEditPrediction {
1177 requested_by,
1178 prediction,
1179 ..
1180 } = project_state.current_prediction.as_ref()?;
1181
1182 if prediction.targets_buffer(buffer.read(cx)) {
1183 Some(BufferEditPrediction::Local { prediction })
1184 } else {
1185 let show_jump = match requested_by {
1186 PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1187 requested_by_buffer_id == &buffer.entity_id()
1188 }
1189 PredictionRequestedBy::DiagnosticsUpdate => true,
1190 };
1191
1192 if show_jump {
1193 Some(BufferEditPrediction::Jump { prediction })
1194 } else {
1195 None
1196 }
1197 }
1198 }
1199
1200 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1201 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1202 return;
1203 };
1204
1205 let Some(current_prediction) = project_state.current_prediction.take() else {
1206 return;
1207 };
1208
1209 for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1210 project_state.cancel_pending_prediction(pending_prediction, cx);
1211 }
1212
1213 match self.edit_prediction_model {
1214 EditPredictionModel::Sweep => {
1215 sweep_ai::edit_prediction_accepted(self, current_prediction, cx)
1216 }
1217 EditPredictionModel::Mercury | EditPredictionModel::Ollama => {}
1218 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 { .. } => {
1219 zeta2::edit_prediction_accepted(self, current_prediction, cx)
1220 }
1221 }
1222 }
1223
1224 async fn handle_rejected_predictions(
1225 rx: UnboundedReceiver<EditPredictionRejection>,
1226 client: Arc<Client>,
1227 llm_token: LlmApiToken,
1228 app_version: Version,
1229 background_executor: BackgroundExecutor,
1230 ) {
1231 let mut rx = std::pin::pin!(rx.peekable());
1232 let mut batched = Vec::new();
1233
1234 while let Some(rejection) = rx.next().await {
1235 batched.push(rejection);
1236
1237 if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1238 select_biased! {
1239 next = rx.as_mut().peek().fuse() => {
1240 if next.is_some() {
1241 continue;
1242 }
1243 }
1244 () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1245 }
1246 }
1247
1248 let url = client
1249 .http_client()
1250 .build_zed_llm_url("/predict_edits/reject", &[])
1251 .unwrap();
1252
1253 let flush_count = batched
1254 .len()
1255 // in case items have accumulated after failure
1256 .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1257 let start = batched.len() - flush_count;
1258
1259 let body = RejectEditPredictionsBodyRef {
1260 rejections: &batched[start..],
1261 };
1262
1263 let result = Self::send_api_request::<()>(
1264 |builder| {
1265 let req = builder
1266 .uri(url.as_ref())
1267 .body(serde_json::to_string(&body)?.into());
1268 anyhow::Ok(req?)
1269 },
1270 client.clone(),
1271 llm_token.clone(),
1272 app_version.clone(),
1273 true,
1274 )
1275 .await;
1276
1277 if result.log_err().is_some() {
1278 batched.drain(start..);
1279 }
1280 }
1281 }
1282
1283 fn reject_current_prediction(
1284 &mut self,
1285 reason: EditPredictionRejectReason,
1286 project: &Entity<Project>,
1287 ) {
1288 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1289 project_state.pending_predictions.clear();
1290 if let Some(prediction) = project_state.current_prediction.take() {
1291 self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown);
1292 }
1293 };
1294 }
1295
1296 fn did_show_current_prediction(
1297 &mut self,
1298 project: &Entity<Project>,
1299 display_type: edit_prediction_types::SuggestionDisplayType,
1300 cx: &mut Context<Self>,
1301 ) {
1302 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1303 return;
1304 };
1305
1306 let Some(current_prediction) = project_state.current_prediction.as_mut() else {
1307 return;
1308 };
1309
1310 let is_jump = display_type == edit_prediction_types::SuggestionDisplayType::Jump;
1311 let previous_shown_with = current_prediction.shown_with;
1312
1313 if previous_shown_with.is_none() || !is_jump {
1314 current_prediction.shown_with = Some(display_type);
1315 }
1316
1317 let is_first_non_jump_show = !current_prediction.was_shown && !is_jump;
1318
1319 if is_first_non_jump_show {
1320 current_prediction.was_shown = true;
1321 }
1322
1323 let display_type_changed = previous_shown_with != Some(display_type);
1324
1325 if self.edit_prediction_model == EditPredictionModel::Sweep && display_type_changed {
1326 sweep_ai::edit_prediction_shown(
1327 &self.sweep_ai,
1328 self.client.clone(),
1329 ¤t_prediction.prediction,
1330 display_type,
1331 cx,
1332 );
1333 }
1334
1335 if is_first_non_jump_show {
1336 self.shown_predictions
1337 .push_front(current_prediction.prediction.clone());
1338 if self.shown_predictions.len() > 50 {
1339 let completion = self.shown_predictions.pop_back().unwrap();
1340 self.rated_predictions.remove(&completion.id);
1341 }
1342 }
1343 }
1344
1345 fn reject_prediction(
1346 &mut self,
1347 prediction_id: EditPredictionId,
1348 reason: EditPredictionRejectReason,
1349 was_shown: bool,
1350 ) {
1351 match self.edit_prediction_model {
1352 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 { .. } => {
1353 if self.custom_predict_edits_url.is_some() {
1354 return;
1355 }
1356 }
1357 EditPredictionModel::Sweep
1358 | EditPredictionModel::Mercury
1359 | EditPredictionModel::Ollama => return,
1360 }
1361
1362 self.reject_predictions_tx
1363 .unbounded_send(EditPredictionRejection {
1364 request_id: prediction_id.to_string(),
1365 reason,
1366 was_shown,
1367 })
1368 .log_err();
1369 }
1370
1371 fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1372 self.projects
1373 .get(&project.entity_id())
1374 .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1375 }
1376
1377 pub fn refresh_prediction_from_buffer(
1378 &mut self,
1379 project: Entity<Project>,
1380 buffer: Entity<Buffer>,
1381 position: language::Anchor,
1382 cx: &mut Context<Self>,
1383 ) {
1384 self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
1385 let Some(request_task) = this
1386 .update(cx, |this, cx| {
1387 this.request_prediction(
1388 &project,
1389 &buffer,
1390 position,
1391 PredictEditsRequestTrigger::Other,
1392 cx,
1393 )
1394 })
1395 .log_err()
1396 else {
1397 return Task::ready(anyhow::Ok(None));
1398 };
1399
1400 cx.spawn(async move |_cx| {
1401 request_task.await.map(|prediction_result| {
1402 prediction_result.map(|prediction_result| {
1403 (
1404 prediction_result,
1405 PredictionRequestedBy::Buffer(buffer.entity_id()),
1406 )
1407 })
1408 })
1409 })
1410 })
1411 }
1412
1413 pub fn refresh_prediction_from_diagnostics(
1414 &mut self,
1415 project: Entity<Project>,
1416 cx: &mut Context<Self>,
1417 ) {
1418 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1419 return;
1420 };
1421
1422 // Prefer predictions from buffer
1423 if project_state.current_prediction.is_some() {
1424 return;
1425 };
1426
1427 self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
1428 let Some((active_buffer, snapshot, cursor_point)) = this
1429 .read_with(cx, |this, cx| {
1430 let project_state = this.projects.get(&project.entity_id())?;
1431 let (buffer, position) = project_state.active_buffer(&project, cx)?;
1432 let snapshot = buffer.read(cx).snapshot();
1433
1434 if !Self::predictions_enabled_at(&snapshot, position, cx) {
1435 return None;
1436 }
1437
1438 let cursor_point = position
1439 .map(|pos| pos.to_point(&snapshot))
1440 .unwrap_or_default();
1441
1442 Some((buffer, snapshot, cursor_point))
1443 })
1444 .log_err()
1445 .flatten()
1446 else {
1447 return Task::ready(anyhow::Ok(None));
1448 };
1449
1450 cx.spawn(async move |cx| {
1451 let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1452 active_buffer,
1453 &snapshot,
1454 Default::default(),
1455 cursor_point,
1456 &project,
1457 cx,
1458 )
1459 .await?
1460 else {
1461 return anyhow::Ok(None);
1462 };
1463
1464 let Some(prediction_result) = this
1465 .update(cx, |this, cx| {
1466 this.request_prediction(
1467 &project,
1468 &jump_buffer,
1469 jump_position,
1470 PredictEditsRequestTrigger::Diagnostics,
1471 cx,
1472 )
1473 })?
1474 .await?
1475 else {
1476 return anyhow::Ok(None);
1477 };
1478
1479 this.update(cx, |this, cx| {
1480 Some((
1481 if this
1482 .get_or_init_project(&project, cx)
1483 .current_prediction
1484 .is_none()
1485 {
1486 prediction_result
1487 } else {
1488 EditPredictionResult {
1489 id: prediction_result.id,
1490 prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1491 }
1492 },
1493 PredictionRequestedBy::DiagnosticsUpdate,
1494 ))
1495 })
1496 })
1497 });
1498 }
1499
1500 fn predictions_enabled_at(
1501 snapshot: &BufferSnapshot,
1502 position: Option<language::Anchor>,
1503 cx: &App,
1504 ) -> bool {
1505 let file = snapshot.file();
1506 let all_settings = all_language_settings(file, cx);
1507 if !all_settings.show_edit_predictions(snapshot.language(), cx)
1508 || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
1509 {
1510 return false;
1511 }
1512
1513 if let Some(last_position) = position {
1514 let settings = snapshot.settings_at(last_position, cx);
1515
1516 if !settings.edit_predictions_disabled_in.is_empty()
1517 && let Some(scope) = snapshot.language_scope_at(last_position)
1518 && let Some(scope_name) = scope.override_name()
1519 && settings
1520 .edit_predictions_disabled_in
1521 .iter()
1522 .any(|s| s == scope_name)
1523 {
1524 return false;
1525 }
1526 }
1527
1528 true
1529 }
1530
1531 #[cfg(not(test))]
1532 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1533 #[cfg(test)]
1534 pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
1535
1536 fn queue_prediction_refresh(
1537 &mut self,
1538 project: Entity<Project>,
1539 throttle_entity: EntityId,
1540 cx: &mut Context<Self>,
1541 do_refresh: impl FnOnce(
1542 WeakEntity<Self>,
1543 &mut AsyncApp,
1544 )
1545 -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1546 + 'static,
1547 ) {
1548 let is_ollama = self.edit_prediction_model == EditPredictionModel::Ollama;
1549 let drop_on_cancel = is_ollama;
1550 let max_pending_predictions = if is_ollama { 1 } else { 2 };
1551 let project_state = self.get_or_init_project(&project, cx);
1552 let pending_prediction_id = project_state.next_pending_prediction_id;
1553 project_state.next_pending_prediction_id += 1;
1554 let last_request = project_state.last_prediction_refresh;
1555
1556 let task = cx.spawn(async move |this, cx| {
1557 if let Some((last_entity, last_timestamp)) = last_request
1558 && throttle_entity == last_entity
1559 && let Some(timeout) =
1560 (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
1561 {
1562 cx.background_executor().timer(timeout).await;
1563 }
1564
1565 // If this task was cancelled before the throttle timeout expired,
1566 // do not perform a request.
1567 let mut is_cancelled = true;
1568 this.update(cx, |this, cx| {
1569 let project_state = this.get_or_init_project(&project, cx);
1570 if !project_state
1571 .cancelled_predictions
1572 .remove(&pending_prediction_id)
1573 {
1574 project_state.last_prediction_refresh = Some((throttle_entity, Instant::now()));
1575 is_cancelled = false;
1576 }
1577 })
1578 .ok();
1579 if is_cancelled {
1580 return None;
1581 }
1582
1583 let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
1584 let new_prediction_id = new_prediction_result
1585 .as_ref()
1586 .map(|(prediction, _)| prediction.id.clone());
1587
1588 // When a prediction completes, remove it from the pending list, and cancel
1589 // any pending predictions that were enqueued before it.
1590 this.update(cx, |this, cx| {
1591 let project_state = this.get_or_init_project(&project, cx);
1592
1593 let is_cancelled = project_state
1594 .cancelled_predictions
1595 .remove(&pending_prediction_id);
1596
1597 let new_current_prediction = if !is_cancelled
1598 && let Some((prediction_result, requested_by)) = new_prediction_result
1599 {
1600 match prediction_result.prediction {
1601 Ok(prediction) => {
1602 let new_prediction = CurrentEditPrediction {
1603 requested_by,
1604 prediction,
1605 was_shown: false,
1606 shown_with: None,
1607 };
1608
1609 if let Some(current_prediction) =
1610 project_state.current_prediction.as_ref()
1611 {
1612 if new_prediction.should_replace_prediction(¤t_prediction, cx)
1613 {
1614 this.reject_current_prediction(
1615 EditPredictionRejectReason::Replaced,
1616 &project,
1617 );
1618
1619 Some(new_prediction)
1620 } else {
1621 this.reject_prediction(
1622 new_prediction.prediction.id,
1623 EditPredictionRejectReason::CurrentPreferred,
1624 false,
1625 );
1626 None
1627 }
1628 } else {
1629 Some(new_prediction)
1630 }
1631 }
1632 Err(reject_reason) => {
1633 this.reject_prediction(prediction_result.id, reject_reason, false);
1634 None
1635 }
1636 }
1637 } else {
1638 None
1639 };
1640
1641 let project_state = this.get_or_init_project(&project, cx);
1642
1643 if let Some(new_prediction) = new_current_prediction {
1644 project_state.current_prediction = Some(new_prediction);
1645 }
1646
1647 let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
1648 for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1649 if pending_prediction.id == pending_prediction_id {
1650 pending_predictions.remove(ix);
1651 for pending_prediction in pending_predictions.drain(0..ix) {
1652 project_state.cancel_pending_prediction(pending_prediction, cx)
1653 }
1654 break;
1655 }
1656 }
1657 this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
1658 cx.notify();
1659 })
1660 .ok();
1661
1662 new_prediction_id
1663 });
1664
1665 if project_state.pending_predictions.len() < max_pending_predictions {
1666 project_state.pending_predictions.push(PendingPrediction {
1667 id: pending_prediction_id,
1668 task,
1669 drop_on_cancel,
1670 });
1671 } else {
1672 let pending_prediction = project_state.pending_predictions.pop().unwrap();
1673 project_state.pending_predictions.push(PendingPrediction {
1674 id: pending_prediction_id,
1675 task,
1676 drop_on_cancel,
1677 });
1678 project_state.cancel_pending_prediction(pending_prediction, cx);
1679 }
1680 }
1681
1682 pub fn request_prediction(
1683 &mut self,
1684 project: &Entity<Project>,
1685 active_buffer: &Entity<Buffer>,
1686 position: language::Anchor,
1687 trigger: PredictEditsRequestTrigger,
1688 cx: &mut Context<Self>,
1689 ) -> Task<Result<Option<EditPredictionResult>>> {
1690 self.request_prediction_internal(
1691 project.clone(),
1692 active_buffer.clone(),
1693 position,
1694 trigger,
1695 cx.has_flag::<Zeta2FeatureFlag>(),
1696 cx,
1697 )
1698 }
1699
1700 fn request_prediction_internal(
1701 &mut self,
1702 project: Entity<Project>,
1703 active_buffer: Entity<Buffer>,
1704 position: language::Anchor,
1705 trigger: PredictEditsRequestTrigger,
1706 allow_jump: bool,
1707 cx: &mut Context<Self>,
1708 ) -> Task<Result<Option<EditPredictionResult>>> {
1709 const DIAGNOSTIC_LINES_RANGE: u32 = 20;
1710
1711 self.get_or_init_project(&project, cx);
1712 let project_state = self.projects.get(&project.entity_id()).unwrap();
1713 let stored_events = project_state.events_split_by_pause(cx);
1714 let has_events = !stored_events.is_empty();
1715 let events: Vec<Arc<zeta_prompt::Event>> =
1716 stored_events.into_iter().map(|e| e.event).collect();
1717 let debug_tx = project_state.debug_tx.clone();
1718
1719 let snapshot = active_buffer.read(cx).snapshot();
1720 let cursor_point = position.to_point(&snapshot);
1721 let current_offset = position.to_offset(&snapshot);
1722
1723 let mut user_actions: Vec<UserActionRecord> =
1724 project_state.user_actions.iter().cloned().collect();
1725
1726 if let Some(last_action) = user_actions.last() {
1727 if last_action.buffer_id == active_buffer.entity_id()
1728 && current_offset != last_action.offset
1729 {
1730 let timestamp_epoch_ms = SystemTime::now()
1731 .duration_since(UNIX_EPOCH)
1732 .map(|d| d.as_millis() as u64)
1733 .unwrap_or(0);
1734 user_actions.push(UserActionRecord {
1735 action_type: UserActionType::CursorMovement,
1736 buffer_id: active_buffer.entity_id(),
1737 line_number: cursor_point.row,
1738 offset: current_offset,
1739 timestamp_epoch_ms,
1740 });
1741 }
1742 }
1743 let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1744 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1745 let diagnostic_search_range =
1746 Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1747
1748 let related_files = self.context_for_project(&project, cx);
1749
1750 let inputs = EditPredictionModelInput {
1751 project: project.clone(),
1752 buffer: active_buffer.clone(),
1753 snapshot: snapshot.clone(),
1754 position,
1755 events,
1756 related_files,
1757 recent_paths: project_state.recent_paths.clone(),
1758 trigger,
1759 diagnostic_search_range: diagnostic_search_range.clone(),
1760 debug_tx,
1761 user_actions,
1762 };
1763
1764 let can_collect_example = snapshot
1765 .file()
1766 .is_some_and(|file| self.can_collect_file(&project, file, cx))
1767 && self.can_collect_events(&inputs.events, cx)
1768 && self.can_collect_related_files(&project, cx);
1769
1770 if can_collect_example && should_sample_edit_prediction_example_capture(cx) {
1771 let events_for_capture =
1772 self.edit_history_for_project_with_pause_split_last_event(&project, cx);
1773 let related_files_for_capture = inputs.related_files.clone();
1774 if let Some(example_task) = capture_example::capture_example(
1775 project.clone(),
1776 active_buffer.clone(),
1777 position,
1778 events_for_capture,
1779 related_files_for_capture,
1780 false,
1781 cx,
1782 ) {
1783 cx.spawn(async move |_this, _cx| {
1784 let example = example_task.await?;
1785 telemetry::event!("Edit Prediction Example Captured", example = example);
1786 anyhow::Ok(())
1787 })
1788 .detach_and_log_err(cx);
1789 }
1790 }
1791 let task = match self.edit_prediction_model {
1792 EditPredictionModel::Zeta1 => {
1793 if should_send_testing_zeta2_request() {
1794 let mut zeta2_inputs = inputs.clone();
1795 zeta2_inputs.trigger = PredictEditsRequestTrigger::Testing;
1796 zeta2::request_prediction_with_zeta2(
1797 self,
1798 zeta2_inputs,
1799 Default::default(),
1800 cx,
1801 )
1802 .detach();
1803 }
1804 zeta1::request_prediction_with_zeta1(self, inputs, cx)
1805 }
1806 EditPredictionModel::Zeta2 { version } => {
1807 zeta2::request_prediction_with_zeta2(self, inputs, version, cx)
1808 }
1809 EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
1810 EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
1811 EditPredictionModel::Ollama => self.ollama.request_prediction(inputs, cx),
1812 };
1813
1814 cx.spawn(async move |this, cx| {
1815 let prediction = task.await?;
1816
1817 if prediction.is_none() && allow_jump {
1818 let cursor_point = position.to_point(&snapshot);
1819 if has_events
1820 && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1821 active_buffer.clone(),
1822 &snapshot,
1823 diagnostic_search_range,
1824 cursor_point,
1825 &project,
1826 cx,
1827 )
1828 .await?
1829 {
1830 return this
1831 .update(cx, |this, cx| {
1832 this.request_prediction_internal(
1833 project,
1834 jump_buffer,
1835 jump_position,
1836 trigger,
1837 false,
1838 cx,
1839 )
1840 })?
1841 .await;
1842 }
1843
1844 return anyhow::Ok(None);
1845 }
1846
1847 Ok(prediction)
1848 })
1849 }
1850
1851 async fn next_diagnostic_location(
1852 active_buffer: Entity<Buffer>,
1853 active_buffer_snapshot: &BufferSnapshot,
1854 active_buffer_diagnostic_search_range: Range<Point>,
1855 active_buffer_cursor_point: Point,
1856 project: &Entity<Project>,
1857 cx: &mut AsyncApp,
1858 ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1859 // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
1860 let mut jump_location = active_buffer_snapshot
1861 .diagnostic_groups(None)
1862 .into_iter()
1863 .filter_map(|(_, group)| {
1864 let range = &group.entries[group.primary_ix]
1865 .range
1866 .to_point(&active_buffer_snapshot);
1867 if range.overlaps(&active_buffer_diagnostic_search_range) {
1868 None
1869 } else {
1870 Some(range.start)
1871 }
1872 })
1873 .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
1874 .map(|position| {
1875 (
1876 active_buffer.clone(),
1877 active_buffer_snapshot.anchor_before(position),
1878 )
1879 });
1880
1881 if jump_location.is_none() {
1882 let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
1883 let file = buffer.file()?;
1884
1885 Some(ProjectPath {
1886 worktree_id: file.worktree_id(cx),
1887 path: file.path().clone(),
1888 })
1889 });
1890
1891 let buffer_task = project.update(cx, |project, cx| {
1892 let (path, _, _) = project
1893 .diagnostic_summaries(false, cx)
1894 .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
1895 .max_by_key(|(path, _, _)| {
1896 // find the buffer with errors that shares most parent directories
1897 path.path
1898 .components()
1899 .zip(
1900 active_buffer_path
1901 .as_ref()
1902 .map(|p| p.path.components())
1903 .unwrap_or_default(),
1904 )
1905 .take_while(|(a, b)| a == b)
1906 .count()
1907 })?;
1908
1909 Some(project.open_buffer(path, cx))
1910 });
1911
1912 if let Some(buffer_task) = buffer_task {
1913 let closest_buffer = buffer_task.await?;
1914
1915 jump_location = closest_buffer
1916 .read_with(cx, |buffer, _cx| {
1917 buffer
1918 .buffer_diagnostics(None)
1919 .into_iter()
1920 .min_by_key(|entry| entry.diagnostic.severity)
1921 .map(|entry| entry.range.start)
1922 })
1923 .map(|position| (closest_buffer, position));
1924 }
1925 }
1926
1927 anyhow::Ok(jump_location)
1928 }
1929
1930 async fn send_raw_llm_request(
1931 request: RawCompletionRequest,
1932 client: Arc<Client>,
1933 custom_url: Option<Arc<Url>>,
1934 llm_token: LlmApiToken,
1935 app_version: Version,
1936 ) -> Result<(RawCompletionResponse, Option<EditPredictionUsage>)> {
1937 let url = if let Some(custom_url) = custom_url {
1938 custom_url.as_ref().clone()
1939 } else {
1940 client
1941 .http_client()
1942 .build_zed_llm_url("/predict_edits/raw", &[])?
1943 };
1944
1945 Self::send_api_request(
1946 |builder| {
1947 let req = builder
1948 .uri(url.as_ref())
1949 .body(serde_json::to_string(&request)?.into());
1950 Ok(req?)
1951 },
1952 client,
1953 llm_token,
1954 app_version,
1955 true,
1956 )
1957 .await
1958 }
1959
1960 pub(crate) async fn send_v3_request(
1961 input: ZetaPromptInput,
1962 prompt_version: ZetaVersion,
1963 client: Arc<Client>,
1964 llm_token: LlmApiToken,
1965 app_version: Version,
1966 trigger: PredictEditsRequestTrigger,
1967 ) -> Result<(PredictEditsV3Response, Option<EditPredictionUsage>)> {
1968 let url = client
1969 .http_client()
1970 .build_zed_llm_url("/predict_edits/v3", &[])?;
1971
1972 let request = PredictEditsV3Request {
1973 input,
1974 model: EDIT_PREDICTIONS_MODEL_ID.clone(),
1975 prompt_version,
1976 trigger,
1977 };
1978
1979 Self::send_api_request(
1980 |builder| {
1981 let req = builder
1982 .uri(url.as_ref())
1983 .body(serde_json::to_string(&request)?.into());
1984 Ok(req?)
1985 },
1986 client,
1987 llm_token,
1988 app_version,
1989 true,
1990 )
1991 .await
1992 }
1993
1994 fn handle_api_response<T>(
1995 this: &WeakEntity<Self>,
1996 response: Result<(T, Option<EditPredictionUsage>)>,
1997 cx: &mut gpui::AsyncApp,
1998 ) -> Result<T> {
1999 match response {
2000 Ok((data, usage)) => {
2001 if let Some(usage) = usage {
2002 this.update(cx, |this, cx| {
2003 this.user_store.update(cx, |user_store, cx| {
2004 user_store.update_edit_prediction_usage(usage, cx);
2005 });
2006 })
2007 .ok();
2008 }
2009 Ok(data)
2010 }
2011 Err(err) => {
2012 if err.is::<ZedUpdateRequiredError>() {
2013 cx.update(|cx| {
2014 this.update(cx, |this, _cx| {
2015 this.update_required = true;
2016 })
2017 .ok();
2018
2019 let error_message: SharedString = err.to_string().into();
2020 show_app_notification(
2021 NotificationId::unique::<ZedUpdateRequiredError>(),
2022 cx,
2023 move |cx| {
2024 cx.new(|cx| {
2025 ErrorMessagePrompt::new(error_message.clone(), cx)
2026 .with_link_button("Update Zed", "https://zed.dev/releases")
2027 })
2028 },
2029 );
2030 });
2031 }
2032 Err(err)
2033 }
2034 }
2035 }
2036
2037 async fn send_api_request<Res>(
2038 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
2039 client: Arc<Client>,
2040 llm_token: LlmApiToken,
2041 app_version: Version,
2042 require_auth: bool,
2043 ) -> Result<(Res, Option<EditPredictionUsage>)>
2044 where
2045 Res: DeserializeOwned,
2046 {
2047 let http_client = client.http_client();
2048
2049 let mut token = if let Ok(custom_token) = std::env::var("ZED_PREDICT_EDITS_TOKEN") {
2050 Some(custom_token)
2051 } else if require_auth {
2052 Some(llm_token.acquire(&client).await?)
2053 } else {
2054 llm_token.acquire(&client).await.ok()
2055 };
2056 let mut did_retry = false;
2057
2058 loop {
2059 let request_builder = http_client::Request::builder().method(Method::POST);
2060
2061 let mut request_builder = request_builder
2062 .header("Content-Type", "application/json")
2063 .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
2064
2065 // Only add Authorization header if we have a token
2066 if let Some(ref token_value) = token {
2067 request_builder =
2068 request_builder.header("Authorization", format!("Bearer {}", token_value));
2069 }
2070
2071 let request = build(request_builder)?;
2072
2073 let mut response = http_client.send(request).await?;
2074
2075 if let Some(minimum_required_version) = response
2076 .headers()
2077 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2078 .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2079 {
2080 anyhow::ensure!(
2081 app_version >= minimum_required_version,
2082 ZedUpdateRequiredError {
2083 minimum_version: minimum_required_version
2084 }
2085 );
2086 }
2087
2088 if response.status().is_success() {
2089 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2090
2091 let mut body = Vec::new();
2092 response.body_mut().read_to_end(&mut body).await?;
2093 return Ok((serde_json::from_slice(&body)?, usage));
2094 } else if !did_retry && token.is_some() && response.needs_llm_token_refresh() {
2095 did_retry = true;
2096 token = Some(llm_token.refresh(&client).await?);
2097 } else {
2098 let mut body = String::new();
2099 response.body_mut().read_to_string(&mut body).await?;
2100 anyhow::bail!(
2101 "Request failed with status: {:?}\nBody: {}",
2102 response.status(),
2103 body
2104 );
2105 }
2106 }
2107 }
2108
2109 pub fn refresh_context(
2110 &mut self,
2111 project: &Entity<Project>,
2112 buffer: &Entity<language::Buffer>,
2113 cursor_position: language::Anchor,
2114 cx: &mut Context<Self>,
2115 ) {
2116 self.get_or_init_project(project, cx)
2117 .context
2118 .update(cx, |store, cx| {
2119 store.refresh(buffer.clone(), cursor_position, cx);
2120 });
2121 }
2122
2123 #[cfg(feature = "cli-support")]
2124 pub fn set_context_for_buffer(
2125 &mut self,
2126 project: &Entity<Project>,
2127 related_files: Vec<RelatedFile>,
2128 cx: &mut Context<Self>,
2129 ) {
2130 self.get_or_init_project(project, cx)
2131 .context
2132 .update(cx, |store, cx| {
2133 store.set_related_files(related_files, cx);
2134 });
2135 }
2136
2137 #[cfg(feature = "cli-support")]
2138 pub fn set_recent_paths_for_project(
2139 &mut self,
2140 project: &Entity<Project>,
2141 paths: impl IntoIterator<Item = project::ProjectPath>,
2142 cx: &mut Context<Self>,
2143 ) {
2144 let project_state = self.get_or_init_project(project, cx);
2145 project_state.recent_paths = paths.into_iter().collect();
2146 }
2147
2148 fn is_file_open_source(
2149 &self,
2150 project: &Entity<Project>,
2151 file: &Arc<dyn File>,
2152 cx: &App,
2153 ) -> bool {
2154 if !file.is_local() || file.is_private() {
2155 return false;
2156 }
2157 let Some(project_state) = self.projects.get(&project.entity_id()) else {
2158 return false;
2159 };
2160 project_state
2161 .license_detection_watchers
2162 .get(&file.worktree_id(cx))
2163 .as_ref()
2164 .is_some_and(|watcher| watcher.is_project_open_source())
2165 }
2166
2167 fn can_collect_file(&self, project: &Entity<Project>, file: &Arc<dyn File>, cx: &App) -> bool {
2168 self.data_collection_choice.is_enabled(cx) && self.is_file_open_source(project, file, cx)
2169 }
2170
2171 fn can_collect_events(&self, events: &[Arc<zeta_prompt::Event>], cx: &App) -> bool {
2172 if !self.data_collection_choice.is_enabled(cx) {
2173 return false;
2174 }
2175 events.iter().all(|event| {
2176 matches!(
2177 event.as_ref(),
2178 zeta_prompt::Event::BufferChange {
2179 in_open_source_repo: true,
2180 ..
2181 }
2182 )
2183 })
2184 }
2185
2186 fn can_collect_related_files(&self, project: &Entity<Project>, cx: &mut App) -> bool {
2187 if !self.data_collection_choice.is_enabled(cx) {
2188 return false;
2189 }
2190
2191 let related_with_buffers = self.context_for_project_with_buffers(project, cx);
2192
2193 related_with_buffers.iter().all(|(_, buffer)| {
2194 buffer
2195 .read(cx)
2196 .file()
2197 .is_some_and(|file| self.is_file_open_source(project, &file, cx))
2198 })
2199 }
2200
2201 fn load_data_collection_choice() -> DataCollectionChoice {
2202 let choice = KEY_VALUE_STORE
2203 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2204 .log_err()
2205 .flatten();
2206
2207 match choice.as_deref() {
2208 Some("true") => DataCollectionChoice::Enabled,
2209 Some("false") => DataCollectionChoice::Disabled,
2210 Some(_) => {
2211 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2212 DataCollectionChoice::NotAnswered
2213 }
2214 None => DataCollectionChoice::NotAnswered,
2215 }
2216 }
2217
2218 fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
2219 self.data_collection_choice = self.data_collection_choice.toggle();
2220 let new_choice = self.data_collection_choice;
2221 let is_enabled = new_choice.is_enabled(cx);
2222 db::write_and_log(cx, move || {
2223 KEY_VALUE_STORE.write_kvp(
2224 ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
2225 is_enabled.to_string(),
2226 )
2227 });
2228 }
2229
2230 pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2231 self.shown_predictions.iter()
2232 }
2233
2234 pub fn shown_completions_len(&self) -> usize {
2235 self.shown_predictions.len()
2236 }
2237
2238 pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2239 self.rated_predictions.contains(id)
2240 }
2241
2242 pub fn rate_prediction(
2243 &mut self,
2244 prediction: &EditPrediction,
2245 rating: EditPredictionRating,
2246 feedback: String,
2247 cx: &mut Context<Self>,
2248 ) {
2249 self.rated_predictions.insert(prediction.id.clone());
2250 telemetry::event!(
2251 "Edit Prediction Rated",
2252 rating,
2253 inputs = prediction.inputs,
2254 output = prediction
2255 .edit_preview
2256 .as_unified_diff(prediction.snapshot.file(), &prediction.edits),
2257 feedback
2258 );
2259 self.client.telemetry().flush_events().detach();
2260 cx.notify();
2261 }
2262}
2263
2264pub(crate) fn filter_redundant_excerpts(
2265 mut related_files: Vec<RelatedFile>,
2266 cursor_path: &Path,
2267 cursor_row_range: Range<u32>,
2268) -> Vec<RelatedFile> {
2269 for file in &mut related_files {
2270 if file.path.as_ref() == cursor_path {
2271 file.excerpts.retain(|excerpt| {
2272 excerpt.row_range.start < cursor_row_range.start
2273 || excerpt.row_range.end > cursor_row_range.end
2274 });
2275 }
2276 }
2277 related_files.retain(|file| !file.excerpts.is_empty());
2278 related_files
2279}
2280
2281#[derive(Error, Debug)]
2282#[error(
2283 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2284)]
2285pub struct ZedUpdateRequiredError {
2286 minimum_version: Version,
2287}
2288
2289#[derive(Debug, Clone, Copy)]
2290pub enum DataCollectionChoice {
2291 NotAnswered,
2292 Enabled,
2293 Disabled,
2294}
2295
2296impl DataCollectionChoice {
2297 pub fn is_enabled(self, cx: &App) -> bool {
2298 if cx.is_staff() {
2299 return true;
2300 }
2301 match self {
2302 Self::Enabled => true,
2303 Self::NotAnswered | Self::Disabled => false,
2304 }
2305 }
2306
2307 #[must_use]
2308 pub fn toggle(&self) -> DataCollectionChoice {
2309 match self {
2310 Self::Enabled => Self::Disabled,
2311 Self::Disabled => Self::Enabled,
2312 Self::NotAnswered => Self::Enabled,
2313 }
2314 }
2315}
2316
2317impl From<bool> for DataCollectionChoice {
2318 fn from(value: bool) -> Self {
2319 match value {
2320 true => DataCollectionChoice::Enabled,
2321 false => DataCollectionChoice::Disabled,
2322 }
2323 }
2324}
2325
2326struct ZedPredictUpsell;
2327
2328impl Dismissable for ZedPredictUpsell {
2329 const KEY: &'static str = "dismissed-edit-predict-upsell";
2330
2331 fn dismissed() -> bool {
2332 // To make this backwards compatible with older versions of Zed, we
2333 // check if the user has seen the previous Edit Prediction Onboarding
2334 // before, by checking the data collection choice which was written to
2335 // the database once the user clicked on "Accept and Enable"
2336 if KEY_VALUE_STORE
2337 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2338 .log_err()
2339 .is_some_and(|s| s.is_some())
2340 {
2341 return true;
2342 }
2343
2344 KEY_VALUE_STORE
2345 .read_kvp(Self::KEY)
2346 .log_err()
2347 .is_some_and(|s| s.is_some())
2348 }
2349}
2350
2351pub fn should_show_upsell_modal() -> bool {
2352 !ZedPredictUpsell::dismissed()
2353}
2354
2355pub fn init(cx: &mut App) {
2356 cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2357 workspace.register_action(
2358 move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2359 ZedPredictModal::toggle(
2360 workspace,
2361 workspace.user_store().clone(),
2362 workspace.client().clone(),
2363 window,
2364 cx,
2365 )
2366 },
2367 );
2368
2369 workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2370 update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2371 settings
2372 .project
2373 .all_languages
2374 .edit_predictions
2375 .get_or_insert_default()
2376 .provider = Some(EditPredictionProvider::None)
2377 });
2378 });
2379 fn copilot_for_project(project: &Entity<Project>, cx: &mut App) -> Option<Entity<Copilot>> {
2380 EditPredictionStore::try_global(cx).and_then(|store| {
2381 store.update(cx, |this, cx| this.start_copilot_for_project(project, cx))
2382 })
2383 }
2384
2385 workspace.register_action(|workspace, _: &SignIn, window, cx| {
2386 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2387 copilot_ui::initiate_sign_in(copilot, window, cx);
2388 }
2389 });
2390 workspace.register_action(|workspace, _: &Reinstall, window, cx| {
2391 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2392 copilot_ui::reinstall_and_sign_in(copilot, window, cx);
2393 }
2394 });
2395 workspace.register_action(|workspace, _: &SignOut, window, cx| {
2396 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2397 copilot_ui::initiate_sign_out(copilot, window, cx);
2398 }
2399 });
2400 })
2401 .detach();
2402}