1use anyhow::Result;
2use arrayvec::ArrayVec;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{
5 PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
6};
7use cloud_llm_client::{
8 EditPredictionRejectReason, EditPredictionRejection,
9 MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
10 PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME,
11};
12use collections::{HashMap, HashSet};
13use copilot::{Copilot, Reinstall, SignIn, SignOut};
14use db::kvp::{Dismissable, KEY_VALUE_STORE};
15use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
16use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
17use futures::{
18 AsyncReadExt as _, FutureExt as _, StreamExt as _,
19 channel::mpsc::{self, UnboundedReceiver},
20 select_biased,
21};
22use gpui::BackgroundExecutor;
23use gpui::http_client::Url;
24use gpui::{
25 App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
26 http_client::{self, AsyncBody, Method},
27 prelude::*,
28};
29use language::language_settings::all_language_settings;
30use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
31use language::{BufferSnapshot, OffsetRangeExt};
32use language_model::{LlmApiToken, NeedsLlmTokenRefresh, RefreshLlmTokenListener};
33use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
34use release_channel::AppVersion;
35use semver::Version;
36use serde::de::DeserializeOwned;
37use settings::{EditPredictionProvider, Settings as _, update_settings_file};
38use std::collections::{VecDeque, hash_map};
39use text::Edit;
40use workspace::Workspace;
41use zeta_prompt::ZetaPromptInput;
42use zeta_prompt::ZetaVersion;
43
44use std::ops::Range;
45use std::path::Path;
46use std::rc::Rc;
47use std::str::FromStr as _;
48use std::sync::{Arc, LazyLock};
49use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
50use std::{env, mem};
51use thiserror::Error;
52use util::{RangeExt as _, ResultExt as _};
53use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
54
55pub mod cursor_excerpt;
56pub mod example_spec;
57mod license_detection;
58pub mod mercury;
59pub mod ollama;
60mod onboarding_modal;
61pub mod open_ai_response;
62mod prediction;
63pub mod sweep_ai;
64
65pub mod udiff;
66
67mod capture_example;
68mod zed_edit_prediction_delegate;
69pub mod zeta1;
70pub mod zeta2;
71
72#[cfg(test)]
73mod edit_prediction_tests;
74
75use crate::capture_example::{
76 should_sample_edit_prediction_example_capture, should_send_testing_zeta2_request,
77};
78use crate::license_detection::LicenseDetectionWatcher;
79use crate::mercury::Mercury;
80use crate::ollama::Ollama;
81use crate::onboarding_modal::ZedPredictModal;
82pub use crate::prediction::EditPrediction;
83pub use crate::prediction::EditPredictionId;
84use crate::prediction::EditPredictionResult;
85pub use crate::sweep_ai::SweepAi;
86pub use capture_example::capture_example;
87pub use language_model::ApiKeyState;
88pub use telemetry_events::EditPredictionRating;
89pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
90
91actions!(
92 edit_prediction,
93 [
94 /// Resets the edit prediction onboarding state.
95 ResetOnboarding,
96 /// Clears the edit prediction history.
97 ClearHistory,
98 ]
99);
100
101/// Maximum number of events to track.
102const EVENT_COUNT_MAX: usize = 6;
103const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
104const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
105const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
106const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
107
108static EDIT_PREDICTIONS_MODEL_ID: LazyLock<Option<String>> =
109 LazyLock::new(|| env::var("ZED_ZETA_MODEL").ok());
110
111pub struct Zeta2FeatureFlag;
112
113impl FeatureFlag for Zeta2FeatureFlag {
114 const NAME: &'static str = "zeta2";
115
116 fn enabled_for_staff() -> bool {
117 true
118 }
119}
120
121pub struct EditPredictionExampleCaptureFeatureFlag;
122
123impl FeatureFlag for EditPredictionExampleCaptureFeatureFlag {
124 const NAME: &'static str = "edit-prediction-example-capture";
125
126 fn enabled_for_staff() -> bool {
127 true
128 }
129}
130
131#[derive(Clone)]
132struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
133
134impl Global for EditPredictionStoreGlobal {}
135
136pub struct EditPredictionStore {
137 client: Arc<Client>,
138 user_store: Entity<UserStore>,
139 llm_token: LlmApiToken,
140 _llm_token_subscription: Subscription,
141 projects: HashMap<EntityId, ProjectState>,
142 update_required: bool,
143 edit_prediction_model: EditPredictionModel,
144 pub sweep_ai: SweepAi,
145 pub mercury: Mercury,
146 pub ollama: Ollama,
147 data_collection_choice: DataCollectionChoice,
148 reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
149 shown_predictions: VecDeque<EditPrediction>,
150 rated_predictions: HashSet<EditPredictionId>,
151 custom_predict_edits_url: Option<Arc<Url>>,
152}
153
154#[derive(Copy, Clone, Default, PartialEq, Eq)]
155pub enum EditPredictionModel {
156 #[default]
157 Zeta1,
158 Zeta2 {
159 version: ZetaVersion,
160 },
161 Sweep,
162 Mercury,
163 Ollama,
164}
165
166#[derive(Clone)]
167pub struct EditPredictionModelInput {
168 project: Entity<Project>,
169 buffer: Entity<Buffer>,
170 snapshot: BufferSnapshot,
171 position: Anchor,
172 events: Vec<Arc<zeta_prompt::Event>>,
173 related_files: Vec<RelatedFile>,
174 recent_paths: VecDeque<ProjectPath>,
175 trigger: PredictEditsRequestTrigger,
176 diagnostic_search_range: Range<Point>,
177 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
178 pub user_actions: Vec<UserActionRecord>,
179}
180
181#[derive(Debug)]
182pub enum DebugEvent {
183 ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
184 ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
185 EditPredictionStarted(EditPredictionStartedDebugEvent),
186 EditPredictionFinished(EditPredictionFinishedDebugEvent),
187}
188
189#[derive(Debug)]
190pub struct ContextRetrievalStartedDebugEvent {
191 pub project_entity_id: EntityId,
192 pub timestamp: Instant,
193 pub search_prompt: String,
194}
195
196#[derive(Debug)]
197pub struct ContextRetrievalFinishedDebugEvent {
198 pub project_entity_id: EntityId,
199 pub timestamp: Instant,
200 pub metadata: Vec<(&'static str, SharedString)>,
201}
202
203#[derive(Debug)]
204pub struct EditPredictionStartedDebugEvent {
205 pub buffer: WeakEntity<Buffer>,
206 pub position: Anchor,
207 pub prompt: Option<String>,
208}
209
210#[derive(Debug)]
211pub struct EditPredictionFinishedDebugEvent {
212 pub buffer: WeakEntity<Buffer>,
213 pub position: Anchor,
214 pub model_output: Option<String>,
215}
216
217const USER_ACTION_HISTORY_SIZE: usize = 16;
218
219#[derive(Clone, Debug)]
220pub struct UserActionRecord {
221 pub action_type: UserActionType,
222 pub buffer_id: EntityId,
223 pub line_number: u32,
224 pub offset: usize,
225 pub timestamp_epoch_ms: u64,
226}
227
228#[derive(Clone, Copy, Debug, PartialEq, Eq)]
229pub enum UserActionType {
230 InsertChar,
231 InsertSelection,
232 DeleteChar,
233 DeleteSelection,
234 CursorMovement,
235}
236
237/// An event with associated metadata for reconstructing buffer state.
238#[derive(Clone)]
239pub struct StoredEvent {
240 pub event: Arc<zeta_prompt::Event>,
241 pub old_snapshot: TextBufferSnapshot,
242}
243
244struct ProjectState {
245 events: VecDeque<StoredEvent>,
246 last_event: Option<LastEvent>,
247 recent_paths: VecDeque<ProjectPath>,
248 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
249 current_prediction: Option<CurrentEditPrediction>,
250 next_pending_prediction_id: usize,
251 pending_predictions: ArrayVec<PendingPrediction, 2>,
252 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
253 last_prediction_refresh: Option<(EntityId, Instant)>,
254 cancelled_predictions: HashSet<usize>,
255 context: Entity<RelatedExcerptStore>,
256 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
257 user_actions: VecDeque<UserActionRecord>,
258 _subscriptions: [gpui::Subscription; 2],
259 copilot: Option<Entity<Copilot>>,
260}
261
262impl ProjectState {
263 fn record_user_action(&mut self, action: UserActionRecord) {
264 if self.user_actions.len() >= USER_ACTION_HISTORY_SIZE {
265 self.user_actions.pop_front();
266 }
267 self.user_actions.push_back(action);
268 }
269
270 pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
271 self.events
272 .iter()
273 .cloned()
274 .chain(
275 self.last_event
276 .as_ref()
277 .and_then(|event| event.finalize(&self.license_detection_watchers, cx)),
278 )
279 .collect()
280 }
281
282 pub fn events_split_by_pause(&self, cx: &App) -> Vec<StoredEvent> {
283 self.events
284 .iter()
285 .cloned()
286 .chain(self.last_event.as_ref().iter().flat_map(|event| {
287 let (one, two) = event.split_by_pause();
288 let one = one.finalize(&self.license_detection_watchers, cx);
289 let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
290 one.into_iter().chain(two)
291 }))
292 .collect()
293 }
294
295 fn cancel_pending_prediction(
296 &mut self,
297 pending_prediction: PendingPrediction,
298 cx: &mut Context<EditPredictionStore>,
299 ) {
300 self.cancelled_predictions.insert(pending_prediction.id);
301
302 if pending_prediction.drop_on_cancel {
303 drop(pending_prediction.task);
304 } else {
305 cx.spawn(async move |this, cx| {
306 let Some(prediction_id) = pending_prediction.task.await else {
307 return;
308 };
309
310 this.update(cx, |this, cx| {
311 this.reject_prediction(
312 prediction_id,
313 EditPredictionRejectReason::Canceled,
314 false,
315 edit_prediction_types::EditPredictionDismissReason::Ignored,
316 cx,
317 );
318 })
319 .ok();
320 })
321 .detach()
322 }
323 }
324
325 fn active_buffer(
326 &self,
327 project: &Entity<Project>,
328 cx: &App,
329 ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
330 let project = project.read(cx);
331 let active_path = project.path_for_entry(project.active_entry()?, cx)?;
332 let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
333 let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
334 Some((active_buffer, registered_buffer.last_position))
335 }
336}
337
338#[derive(Debug, Clone)]
339struct CurrentEditPrediction {
340 pub requested_by: PredictionRequestedBy,
341 pub prediction: EditPrediction,
342 pub was_shown: bool,
343 pub shown_with: Option<edit_prediction_types::SuggestionDisplayType>,
344}
345
346impl CurrentEditPrediction {
347 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
348 let Some(new_edits) = self
349 .prediction
350 .interpolate(&self.prediction.buffer.read(cx))
351 else {
352 return false;
353 };
354
355 if self.prediction.buffer != old_prediction.prediction.buffer {
356 return true;
357 }
358
359 let Some(old_edits) = old_prediction
360 .prediction
361 .interpolate(&old_prediction.prediction.buffer.read(cx))
362 else {
363 return true;
364 };
365
366 let requested_by_buffer_id = self.requested_by.buffer_id();
367
368 // This reduces the occurrence of UI thrash from replacing edits
369 //
370 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
371 if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
372 && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
373 && old_edits.len() == 1
374 && new_edits.len() == 1
375 {
376 let (old_range, old_text) = &old_edits[0];
377 let (new_range, new_text) = &new_edits[0];
378 new_range == old_range && new_text.starts_with(old_text.as_ref())
379 } else {
380 true
381 }
382 }
383}
384
385#[derive(Debug, Clone)]
386enum PredictionRequestedBy {
387 DiagnosticsUpdate,
388 Buffer(EntityId),
389}
390
391impl PredictionRequestedBy {
392 pub fn buffer_id(&self) -> Option<EntityId> {
393 match self {
394 PredictionRequestedBy::DiagnosticsUpdate => None,
395 PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
396 }
397 }
398}
399
400#[derive(Debug)]
401struct PendingPrediction {
402 id: usize,
403 task: Task<Option<EditPredictionId>>,
404 /// If true, the task is dropped immediately on cancel (cancelling the HTTP request).
405 /// If false, the task is awaited to completion so rejection can be reported.
406 drop_on_cancel: bool,
407}
408
409/// A prediction from the perspective of a buffer.
410#[derive(Debug)]
411enum BufferEditPrediction<'a> {
412 Local { prediction: &'a EditPrediction },
413 Jump { prediction: &'a EditPrediction },
414}
415
416#[cfg(test)]
417impl std::ops::Deref for BufferEditPrediction<'_> {
418 type Target = EditPrediction;
419
420 fn deref(&self) -> &Self::Target {
421 match self {
422 BufferEditPrediction::Local { prediction } => prediction,
423 BufferEditPrediction::Jump { prediction } => prediction,
424 }
425 }
426}
427
428struct RegisteredBuffer {
429 file: Option<Arc<dyn File>>,
430 snapshot: TextBufferSnapshot,
431 last_position: Option<Anchor>,
432 _subscriptions: [gpui::Subscription; 2],
433}
434
435#[derive(Clone)]
436struct LastEvent {
437 old_snapshot: TextBufferSnapshot,
438 new_snapshot: TextBufferSnapshot,
439 old_file: Option<Arc<dyn File>>,
440 new_file: Option<Arc<dyn File>>,
441 edit_range: Option<Range<Anchor>>,
442 snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
443 last_edit_time: Option<Instant>,
444}
445
446impl LastEvent {
447 pub fn finalize(
448 &self,
449 license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
450 cx: &App,
451 ) -> Option<StoredEvent> {
452 let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
453 let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
454
455 let in_open_source_repo =
456 [self.new_file.as_ref(), self.old_file.as_ref()]
457 .iter()
458 .all(|file| {
459 file.is_some_and(|file| {
460 license_detection_watchers
461 .get(&file.worktree_id(cx))
462 .is_some_and(|watcher| watcher.is_project_open_source())
463 })
464 });
465
466 let diff = compute_diff_between_snapshots(&self.old_snapshot, &self.new_snapshot)?;
467
468 if path == old_path && diff.is_empty() {
469 None
470 } else {
471 Some(StoredEvent {
472 event: Arc::new(zeta_prompt::Event::BufferChange {
473 old_path,
474 path,
475 diff,
476 in_open_source_repo,
477 // TODO: Actually detect if this edit was predicted or not
478 predicted: false,
479 }),
480 old_snapshot: self.old_snapshot.clone(),
481 })
482 }
483 }
484
485 pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
486 let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
487 return (self.clone(), None);
488 };
489
490 let before = LastEvent {
491 old_snapshot: self.old_snapshot.clone(),
492 new_snapshot: boundary_snapshot.clone(),
493 old_file: self.old_file.clone(),
494 new_file: self.new_file.clone(),
495 edit_range: None,
496 snapshot_after_last_editing_pause: None,
497 last_edit_time: self.last_edit_time,
498 };
499
500 let after = LastEvent {
501 old_snapshot: boundary_snapshot.clone(),
502 new_snapshot: self.new_snapshot.clone(),
503 old_file: self.old_file.clone(),
504 new_file: self.new_file.clone(),
505 edit_range: None,
506 snapshot_after_last_editing_pause: None,
507 last_edit_time: self.last_edit_time,
508 };
509
510 (before, Some(after))
511 }
512}
513
514pub(crate) fn compute_diff_between_snapshots(
515 old_snapshot: &TextBufferSnapshot,
516 new_snapshot: &TextBufferSnapshot,
517) -> Option<String> {
518 let edits: Vec<Edit<usize>> = new_snapshot
519 .edits_since::<usize>(&old_snapshot.version)
520 .collect();
521
522 let (first_edit, last_edit) = edits.first().zip(edits.last())?;
523
524 let old_start_point = old_snapshot.offset_to_point(first_edit.old.start);
525 let old_end_point = old_snapshot.offset_to_point(last_edit.old.end);
526 let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
527 let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
528
529 const CONTEXT_LINES: u32 = 3;
530
531 let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
532 let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
533 let old_context_end_row =
534 (old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
535 let new_context_end_row =
536 (new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
537
538 let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
539 let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
540 let old_end_line_offset = old_snapshot
541 .point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
542 let new_end_line_offset = new_snapshot
543 .point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
544 let old_edit_range = old_start_line_offset..old_end_line_offset;
545 let new_edit_range = new_start_line_offset..new_end_line_offset;
546
547 let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
548 let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
549
550 let diff = language::unified_diff_with_offsets(
551 &old_region_text,
552 &new_region_text,
553 old_context_start_row,
554 new_context_start_row,
555 );
556
557 Some(diff)
558}
559
560fn buffer_path_with_id_fallback(
561 file: Option<&Arc<dyn File>>,
562 snapshot: &TextBufferSnapshot,
563 cx: &App,
564) -> Arc<Path> {
565 if let Some(file) = file {
566 file.full_path(cx).into()
567 } else {
568 Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
569 }
570}
571
572impl EditPredictionStore {
573 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
574 cx.try_global::<EditPredictionStoreGlobal>()
575 .map(|global| global.0.clone())
576 }
577
578 pub fn global(
579 client: &Arc<Client>,
580 user_store: &Entity<UserStore>,
581 cx: &mut App,
582 ) -> Entity<Self> {
583 cx.try_global::<EditPredictionStoreGlobal>()
584 .map(|global| global.0.clone())
585 .unwrap_or_else(|| {
586 let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
587 cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
588 ep_store
589 })
590 }
591
592 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
593 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
594 let data_collection_choice = Self::load_data_collection_choice();
595
596 let llm_token = LlmApiToken::default();
597
598 let (reject_tx, reject_rx) = mpsc::unbounded();
599 cx.background_spawn({
600 let client = client.clone();
601 let llm_token = llm_token.clone();
602 let app_version = AppVersion::global(cx);
603 let background_executor = cx.background_executor().clone();
604 async move {
605 Self::handle_rejected_predictions(
606 reject_rx,
607 client,
608 llm_token,
609 app_version,
610 background_executor,
611 )
612 .await
613 }
614 })
615 .detach();
616
617 let this = Self {
618 projects: HashMap::default(),
619 client,
620 user_store,
621 llm_token,
622 _llm_token_subscription: cx.subscribe(
623 &refresh_llm_token_listener,
624 |this, _listener, _event, cx| {
625 let client = this.client.clone();
626 let llm_token = this.llm_token.clone();
627 cx.spawn(async move |_this, _cx| {
628 llm_token.refresh(&client).await?;
629 anyhow::Ok(())
630 })
631 .detach_and_log_err(cx);
632 },
633 ),
634 update_required: false,
635 edit_prediction_model: EditPredictionModel::Zeta2 {
636 version: Default::default(),
637 },
638 sweep_ai: SweepAi::new(cx),
639 mercury: Mercury::new(cx),
640 ollama: Ollama::new(),
641
642 data_collection_choice,
643 reject_predictions_tx: reject_tx,
644 rated_predictions: Default::default(),
645 shown_predictions: Default::default(),
646 custom_predict_edits_url: match env::var("ZED_PREDICT_EDITS_URL") {
647 Ok(custom_url) => Url::parse(&custom_url).log_err().map(Into::into),
648 Err(_) => None,
649 },
650 };
651
652 this
653 }
654
655 #[cfg(test)]
656 pub fn set_custom_predict_edits_url(&mut self, url: Url) {
657 self.custom_predict_edits_url = Some(url.into());
658 }
659
660 pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
661 self.edit_prediction_model = model;
662 }
663
664 pub fn icons(&self) -> edit_prediction_types::EditPredictionIconSet {
665 use ui::IconName;
666 match self.edit_prediction_model {
667 EditPredictionModel::Ollama => {
668 edit_prediction_types::EditPredictionIconSet::new(IconName::AiOllama)
669 }
670 EditPredictionModel::Sweep => {
671 edit_prediction_types::EditPredictionIconSet::new(IconName::SweepAi)
672 .with_disabled(IconName::SweepAiDisabled)
673 .with_up(IconName::SweepAiUp)
674 .with_down(IconName::SweepAiDown)
675 .with_error(IconName::SweepAiError)
676 }
677 EditPredictionModel::Mercury => {
678 edit_prediction_types::EditPredictionIconSet::new(IconName::Inception)
679 }
680 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 { .. } => {
681 edit_prediction_types::EditPredictionIconSet::new(IconName::ZedPredict)
682 .with_disabled(IconName::ZedPredictDisabled)
683 .with_up(IconName::ZedPredictUp)
684 .with_down(IconName::ZedPredictDown)
685 .with_error(IconName::ZedPredictError)
686 }
687 }
688 }
689
690 pub fn has_sweep_api_token(&self, cx: &App) -> bool {
691 self.sweep_ai.api_token.read(cx).has_key()
692 }
693
694 pub fn has_mercury_api_token(&self, cx: &App) -> bool {
695 self.mercury.api_token.read(cx).has_key()
696 }
697
698 pub fn clear_history(&mut self) {
699 for project_state in self.projects.values_mut() {
700 project_state.events.clear();
701 project_state.last_event.take();
702 }
703 }
704
705 pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
706 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
707 project_state.events.clear();
708 project_state.last_event.take();
709 }
710 }
711
712 pub fn edit_history_for_project(
713 &self,
714 project: &Entity<Project>,
715 cx: &App,
716 ) -> Vec<StoredEvent> {
717 self.projects
718 .get(&project.entity_id())
719 .map(|project_state| project_state.events(cx))
720 .unwrap_or_default()
721 }
722
723 pub fn edit_history_for_project_with_pause_split_last_event(
724 &self,
725 project: &Entity<Project>,
726 cx: &App,
727 ) -> Vec<StoredEvent> {
728 self.projects
729 .get(&project.entity_id())
730 .map(|project_state| project_state.events_split_by_pause(cx))
731 .unwrap_or_default()
732 }
733
734 pub fn context_for_project<'a>(
735 &'a self,
736 project: &Entity<Project>,
737 cx: &'a mut App,
738 ) -> Vec<RelatedFile> {
739 self.projects
740 .get(&project.entity_id())
741 .map(|project| {
742 project
743 .context
744 .update(cx, |context, cx| context.related_files(cx))
745 })
746 .unwrap_or_default()
747 }
748
749 pub fn copilot_for_project(&self, project: &Entity<Project>) -> Option<Entity<Copilot>> {
750 self.projects
751 .get(&project.entity_id())
752 .and_then(|project| project.copilot.clone())
753 }
754
755 pub fn start_copilot_for_project(
756 &mut self,
757 project: &Entity<Project>,
758 cx: &mut Context<Self>,
759 ) -> Option<Entity<Copilot>> {
760 if DisableAiSettings::get(None, cx).disable_ai {
761 return None;
762 }
763 let state = self.get_or_init_project(project, cx);
764
765 if state.copilot.is_some() {
766 return state.copilot.clone();
767 }
768 let _project = project.clone();
769 let project = project.read(cx);
770
771 let node = project.node_runtime().cloned();
772 if let Some(node) = node {
773 let next_id = project.languages().next_language_server_id();
774 let fs = project.fs().clone();
775
776 let copilot = cx.new(|cx| Copilot::new(Some(_project), next_id, fs, node, cx));
777 state.copilot = Some(copilot.clone());
778 Some(copilot)
779 } else {
780 None
781 }
782 }
783
784 pub fn context_for_project_with_buffers<'a>(
785 &'a self,
786 project: &Entity<Project>,
787 cx: &'a mut App,
788 ) -> Vec<(RelatedFile, Entity<Buffer>)> {
789 self.projects
790 .get(&project.entity_id())
791 .map(|project| {
792 project
793 .context
794 .update(cx, |context, cx| context.related_files_with_buffers(cx))
795 })
796 .unwrap_or_default()
797 }
798
799 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
800 if matches!(
801 self.edit_prediction_model,
802 EditPredictionModel::Zeta2 { .. }
803 ) {
804 self.user_store.read(cx).edit_prediction_usage()
805 } else {
806 None
807 }
808 }
809
810 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
811 self.get_or_init_project(project, cx);
812 }
813
814 pub fn register_buffer(
815 &mut self,
816 buffer: &Entity<Buffer>,
817 project: &Entity<Project>,
818 cx: &mut Context<Self>,
819 ) {
820 let project_state = self.get_or_init_project(project, cx);
821 Self::register_buffer_impl(project_state, buffer, project, cx);
822 }
823
824 fn get_or_init_project(
825 &mut self,
826 project: &Entity<Project>,
827 cx: &mut Context<Self>,
828 ) -> &mut ProjectState {
829 let entity_id = project.entity_id();
830 self.projects
831 .entry(entity_id)
832 .or_insert_with(|| ProjectState {
833 context: {
834 let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
835 cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
836 this.handle_excerpt_store_event(entity_id, event);
837 })
838 .detach();
839 related_excerpt_store
840 },
841 events: VecDeque::new(),
842 last_event: None,
843 recent_paths: VecDeque::new(),
844 debug_tx: None,
845 registered_buffers: HashMap::default(),
846 current_prediction: None,
847 cancelled_predictions: HashSet::default(),
848 pending_predictions: ArrayVec::new(),
849 next_pending_prediction_id: 0,
850 last_prediction_refresh: None,
851 license_detection_watchers: HashMap::default(),
852 user_actions: VecDeque::with_capacity(USER_ACTION_HISTORY_SIZE),
853 _subscriptions: [
854 cx.subscribe(&project, Self::handle_project_event),
855 cx.observe_release(&project, move |this, _, cx| {
856 this.projects.remove(&entity_id);
857 cx.notify();
858 }),
859 ],
860 copilot: None,
861 })
862 }
863
864 pub fn remove_project(&mut self, project: &Entity<Project>) {
865 self.projects.remove(&project.entity_id());
866 }
867
868 fn handle_excerpt_store_event(
869 &mut self,
870 project_entity_id: EntityId,
871 event: &RelatedExcerptStoreEvent,
872 ) {
873 if let Some(project_state) = self.projects.get(&project_entity_id) {
874 if let Some(debug_tx) = project_state.debug_tx.clone() {
875 match event {
876 RelatedExcerptStoreEvent::StartedRefresh => {
877 debug_tx
878 .unbounded_send(DebugEvent::ContextRetrievalStarted(
879 ContextRetrievalStartedDebugEvent {
880 project_entity_id: project_entity_id,
881 timestamp: Instant::now(),
882 search_prompt: String::new(),
883 },
884 ))
885 .ok();
886 }
887 RelatedExcerptStoreEvent::FinishedRefresh {
888 cache_hit_count,
889 cache_miss_count,
890 mean_definition_latency,
891 max_definition_latency,
892 } => {
893 debug_tx
894 .unbounded_send(DebugEvent::ContextRetrievalFinished(
895 ContextRetrievalFinishedDebugEvent {
896 project_entity_id: project_entity_id,
897 timestamp: Instant::now(),
898 metadata: vec![
899 (
900 "Cache Hits",
901 format!(
902 "{}/{}",
903 cache_hit_count,
904 cache_hit_count + cache_miss_count
905 )
906 .into(),
907 ),
908 (
909 "Max LSP Time",
910 format!("{} ms", max_definition_latency.as_millis())
911 .into(),
912 ),
913 (
914 "Mean LSP Time",
915 format!("{} ms", mean_definition_latency.as_millis())
916 .into(),
917 ),
918 ],
919 },
920 ))
921 .ok();
922 }
923 }
924 }
925 }
926 }
927
928 pub fn debug_info(
929 &mut self,
930 project: &Entity<Project>,
931 cx: &mut Context<Self>,
932 ) -> mpsc::UnboundedReceiver<DebugEvent> {
933 let project_state = self.get_or_init_project(project, cx);
934 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
935 project_state.debug_tx = Some(debug_watch_tx);
936 debug_watch_rx
937 }
938
939 fn handle_project_event(
940 &mut self,
941 project: Entity<Project>,
942 event: &project::Event,
943 cx: &mut Context<Self>,
944 ) {
945 // TODO [zeta2] init with recent paths
946 match event {
947 project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
948 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
949 return;
950 };
951 let path = project.read(cx).path_for_entry(*active_entry_id, cx);
952 if let Some(path) = path {
953 if let Some(ix) = project_state
954 .recent_paths
955 .iter()
956 .position(|probe| probe == &path)
957 {
958 project_state.recent_paths.remove(ix);
959 }
960 project_state.recent_paths.push_front(path);
961 }
962 }
963 project::Event::DiagnosticsUpdated { .. } => {
964 if cx.has_flag::<Zeta2FeatureFlag>() {
965 self.refresh_prediction_from_diagnostics(project, cx);
966 }
967 }
968 _ => (),
969 }
970 }
971
972 fn register_buffer_impl<'a>(
973 project_state: &'a mut ProjectState,
974 buffer: &Entity<Buffer>,
975 project: &Entity<Project>,
976 cx: &mut Context<Self>,
977 ) -> &'a mut RegisteredBuffer {
978 let buffer_id = buffer.entity_id();
979
980 if let Some(file) = buffer.read(cx).file() {
981 let worktree_id = file.worktree_id(cx);
982 if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
983 project_state
984 .license_detection_watchers
985 .entry(worktree_id)
986 .or_insert_with(|| {
987 let project_entity_id = project.entity_id();
988 cx.observe_release(&worktree, move |this, _worktree, _cx| {
989 let Some(project_state) = this.projects.get_mut(&project_entity_id)
990 else {
991 return;
992 };
993 project_state
994 .license_detection_watchers
995 .remove(&worktree_id);
996 })
997 .detach();
998 Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
999 });
1000 }
1001 }
1002
1003 match project_state.registered_buffers.entry(buffer_id) {
1004 hash_map::Entry::Occupied(entry) => entry.into_mut(),
1005 hash_map::Entry::Vacant(entry) => {
1006 let buf = buffer.read(cx);
1007 let snapshot = buf.text_snapshot();
1008 let file = buf.file().cloned();
1009 let project_entity_id = project.entity_id();
1010 entry.insert(RegisteredBuffer {
1011 snapshot,
1012 file,
1013 last_position: None,
1014 _subscriptions: [
1015 cx.subscribe(buffer, {
1016 let project = project.downgrade();
1017 move |this, buffer, event, cx| {
1018 if let language::BufferEvent::Edited = event
1019 && let Some(project) = project.upgrade()
1020 {
1021 this.report_changes_for_buffer(&buffer, &project, cx);
1022 }
1023 }
1024 }),
1025 cx.observe_release(buffer, move |this, _buffer, _cx| {
1026 let Some(project_state) = this.projects.get_mut(&project_entity_id)
1027 else {
1028 return;
1029 };
1030 project_state.registered_buffers.remove(&buffer_id);
1031 }),
1032 ],
1033 })
1034 }
1035 }
1036 }
1037
1038 fn report_changes_for_buffer(
1039 &mut self,
1040 buffer: &Entity<Buffer>,
1041 project: &Entity<Project>,
1042 cx: &mut Context<Self>,
1043 ) {
1044 let project_state = self.get_or_init_project(project, cx);
1045 let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
1046
1047 let buf = buffer.read(cx);
1048 let new_file = buf.file().cloned();
1049 let new_snapshot = buf.text_snapshot();
1050 if new_snapshot.version == registered_buffer.snapshot.version {
1051 return;
1052 }
1053
1054 let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
1055 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1056 let mut num_edits = 0usize;
1057 let mut total_deleted = 0usize;
1058 let mut total_inserted = 0usize;
1059 let mut edit_range: Option<Range<Anchor>> = None;
1060 let mut last_offset: Option<usize> = None;
1061
1062 for (edit, anchor_range) in
1063 new_snapshot.anchored_edits_since::<usize>(&old_snapshot.version)
1064 {
1065 num_edits += 1;
1066 total_deleted += edit.old.len();
1067 total_inserted += edit.new.len();
1068 edit_range = Some(match edit_range {
1069 None => anchor_range,
1070 Some(acc) => acc.start..anchor_range.end,
1071 });
1072 last_offset = Some(edit.new.end);
1073 }
1074
1075 if num_edits > 0 {
1076 let action_type = match (total_deleted, total_inserted, num_edits) {
1077 (0, ins, n) if ins == n => UserActionType::InsertChar,
1078 (0, _, _) => UserActionType::InsertSelection,
1079 (del, 0, n) if del == n => UserActionType::DeleteChar,
1080 (_, 0, _) => UserActionType::DeleteSelection,
1081 (_, ins, n) if ins == n => UserActionType::InsertChar,
1082 (_, _, _) => UserActionType::InsertSelection,
1083 };
1084
1085 if let Some(offset) = last_offset {
1086 let point = new_snapshot.offset_to_point(offset);
1087 let timestamp_epoch_ms = SystemTime::now()
1088 .duration_since(UNIX_EPOCH)
1089 .map(|d| d.as_millis() as u64)
1090 .unwrap_or(0);
1091 project_state.record_user_action(UserActionRecord {
1092 action_type,
1093 buffer_id: buffer.entity_id(),
1094 line_number: point.row,
1095 offset,
1096 timestamp_epoch_ms,
1097 });
1098 }
1099 }
1100
1101 let events = &mut project_state.events;
1102
1103 let now = cx.background_executor().now();
1104 if let Some(last_event) = project_state.last_event.as_mut() {
1105 let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
1106 == last_event.new_snapshot.remote_id()
1107 && old_snapshot.version == last_event.new_snapshot.version;
1108
1109 let should_coalesce = is_next_snapshot_of_same_buffer
1110 && edit_range
1111 .as_ref()
1112 .zip(last_event.edit_range.as_ref())
1113 .is_some_and(|(a, b)| {
1114 let a = a.to_point(&new_snapshot);
1115 let b = b.to_point(&new_snapshot);
1116 if a.start > b.end {
1117 a.start.row.abs_diff(b.end.row) <= CHANGE_GROUPING_LINE_SPAN
1118 } else if b.start > a.end {
1119 b.start.row.abs_diff(a.end.row) <= CHANGE_GROUPING_LINE_SPAN
1120 } else {
1121 true
1122 }
1123 });
1124
1125 if should_coalesce {
1126 let pause_elapsed = last_event
1127 .last_edit_time
1128 .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
1129 .unwrap_or(false);
1130 if pause_elapsed {
1131 last_event.snapshot_after_last_editing_pause =
1132 Some(last_event.new_snapshot.clone());
1133 }
1134
1135 last_event.edit_range = edit_range;
1136 last_event.new_snapshot = new_snapshot;
1137 last_event.last_edit_time = Some(now);
1138 return;
1139 }
1140 }
1141
1142 if let Some(event) = project_state.last_event.take() {
1143 if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1144 if events.len() + 1 >= EVENT_COUNT_MAX {
1145 events.pop_front();
1146 }
1147 events.push_back(event);
1148 }
1149 }
1150
1151 project_state.last_event = Some(LastEvent {
1152 old_file,
1153 new_file,
1154 old_snapshot,
1155 new_snapshot,
1156 edit_range,
1157 snapshot_after_last_editing_pause: None,
1158 last_edit_time: Some(now),
1159 });
1160 }
1161
1162 fn prediction_at(
1163 &mut self,
1164 buffer: &Entity<Buffer>,
1165 position: Option<language::Anchor>,
1166 project: &Entity<Project>,
1167 cx: &App,
1168 ) -> Option<BufferEditPrediction<'_>> {
1169 let project_state = self.projects.get_mut(&project.entity_id())?;
1170 if let Some(position) = position
1171 && let Some(buffer) = project_state
1172 .registered_buffers
1173 .get_mut(&buffer.entity_id())
1174 {
1175 buffer.last_position = Some(position);
1176 }
1177
1178 let CurrentEditPrediction {
1179 requested_by,
1180 prediction,
1181 ..
1182 } = project_state.current_prediction.as_ref()?;
1183
1184 if prediction.targets_buffer(buffer.read(cx)) {
1185 Some(BufferEditPrediction::Local { prediction })
1186 } else {
1187 let show_jump = match requested_by {
1188 PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1189 requested_by_buffer_id == &buffer.entity_id()
1190 }
1191 PredictionRequestedBy::DiagnosticsUpdate => true,
1192 };
1193
1194 if show_jump {
1195 Some(BufferEditPrediction::Jump { prediction })
1196 } else {
1197 None
1198 }
1199 }
1200 }
1201
1202 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1203 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1204 return;
1205 };
1206
1207 let Some(current_prediction) = project_state.current_prediction.take() else {
1208 return;
1209 };
1210
1211 for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1212 project_state.cancel_pending_prediction(pending_prediction, cx);
1213 }
1214
1215 match self.edit_prediction_model {
1216 EditPredictionModel::Sweep => {
1217 sweep_ai::edit_prediction_accepted(self, current_prediction, cx)
1218 }
1219 EditPredictionModel::Mercury => {
1220 mercury::edit_prediction_accepted(
1221 current_prediction.prediction.id,
1222 self.client.http_client(),
1223 cx,
1224 );
1225 }
1226 EditPredictionModel::Ollama => {}
1227 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 { .. } => {
1228 zeta2::edit_prediction_accepted(self, current_prediction, cx)
1229 }
1230 }
1231 }
1232
1233 async fn handle_rejected_predictions(
1234 rx: UnboundedReceiver<EditPredictionRejection>,
1235 client: Arc<Client>,
1236 llm_token: LlmApiToken,
1237 app_version: Version,
1238 background_executor: BackgroundExecutor,
1239 ) {
1240 let mut rx = std::pin::pin!(rx.peekable());
1241 let mut batched = Vec::new();
1242
1243 while let Some(rejection) = rx.next().await {
1244 batched.push(rejection);
1245
1246 if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1247 select_biased! {
1248 next = rx.as_mut().peek().fuse() => {
1249 if next.is_some() {
1250 continue;
1251 }
1252 }
1253 () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1254 }
1255 }
1256
1257 let url = client
1258 .http_client()
1259 .build_zed_llm_url("/predict_edits/reject", &[])
1260 .unwrap();
1261
1262 let flush_count = batched
1263 .len()
1264 // in case items have accumulated after failure
1265 .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1266 let start = batched.len() - flush_count;
1267
1268 let body = RejectEditPredictionsBodyRef {
1269 rejections: &batched[start..],
1270 };
1271
1272 let result = Self::send_api_request::<()>(
1273 |builder| {
1274 let req = builder
1275 .uri(url.as_ref())
1276 .body(serde_json::to_string(&body)?.into());
1277 anyhow::Ok(req?)
1278 },
1279 client.clone(),
1280 llm_token.clone(),
1281 app_version.clone(),
1282 true,
1283 )
1284 .await;
1285
1286 if result.log_err().is_some() {
1287 batched.drain(start..);
1288 }
1289 }
1290 }
1291
1292 fn reject_current_prediction(
1293 &mut self,
1294 reason: EditPredictionRejectReason,
1295 project: &Entity<Project>,
1296 dismiss_reason: edit_prediction_types::EditPredictionDismissReason,
1297 cx: &App,
1298 ) {
1299 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1300 project_state.pending_predictions.clear();
1301 if let Some(prediction) = project_state.current_prediction.take() {
1302 self.reject_prediction(
1303 prediction.prediction.id,
1304 reason,
1305 prediction.was_shown,
1306 dismiss_reason,
1307 cx,
1308 );
1309 }
1310 };
1311 }
1312
1313 fn did_show_current_prediction(
1314 &mut self,
1315 project: &Entity<Project>,
1316 display_type: edit_prediction_types::SuggestionDisplayType,
1317 cx: &mut Context<Self>,
1318 ) {
1319 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1320 return;
1321 };
1322
1323 let Some(current_prediction) = project_state.current_prediction.as_mut() else {
1324 return;
1325 };
1326
1327 let is_jump = display_type == edit_prediction_types::SuggestionDisplayType::Jump;
1328 let previous_shown_with = current_prediction.shown_with;
1329
1330 if previous_shown_with.is_none() || !is_jump {
1331 current_prediction.shown_with = Some(display_type);
1332 }
1333
1334 let is_first_non_jump_show = !current_prediction.was_shown && !is_jump;
1335
1336 if is_first_non_jump_show {
1337 current_prediction.was_shown = true;
1338 }
1339
1340 let display_type_changed = previous_shown_with != Some(display_type);
1341
1342 if self.edit_prediction_model == EditPredictionModel::Sweep && display_type_changed {
1343 sweep_ai::edit_prediction_shown(
1344 &self.sweep_ai,
1345 self.client.clone(),
1346 ¤t_prediction.prediction,
1347 display_type,
1348 cx,
1349 );
1350 }
1351
1352 if is_first_non_jump_show {
1353 self.shown_predictions
1354 .push_front(current_prediction.prediction.clone());
1355 if self.shown_predictions.len() > 50 {
1356 let completion = self.shown_predictions.pop_back().unwrap();
1357 self.rated_predictions.remove(&completion.id);
1358 }
1359 }
1360 }
1361
1362 fn reject_prediction(
1363 &mut self,
1364 prediction_id: EditPredictionId,
1365 reason: EditPredictionRejectReason,
1366 was_shown: bool,
1367 dismiss_reason: edit_prediction_types::EditPredictionDismissReason,
1368 cx: &App,
1369 ) {
1370 match self.edit_prediction_model {
1371 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 { .. } => {
1372 if self.custom_predict_edits_url.is_none() {
1373 self.reject_predictions_tx
1374 .unbounded_send(EditPredictionRejection {
1375 request_id: prediction_id.to_string(),
1376 reason,
1377 was_shown,
1378 })
1379 .log_err();
1380 }
1381 }
1382 EditPredictionModel::Sweep | EditPredictionModel::Ollama => {}
1383 EditPredictionModel::Mercury => {
1384 mercury::edit_prediction_rejected(
1385 prediction_id,
1386 was_shown,
1387 dismiss_reason,
1388 self.client.http_client(),
1389 cx,
1390 );
1391 }
1392 }
1393 }
1394
1395 fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1396 self.projects
1397 .get(&project.entity_id())
1398 .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1399 }
1400
1401 pub fn refresh_prediction_from_buffer(
1402 &mut self,
1403 project: Entity<Project>,
1404 buffer: Entity<Buffer>,
1405 position: language::Anchor,
1406 cx: &mut Context<Self>,
1407 ) {
1408 self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
1409 let Some(request_task) = this
1410 .update(cx, |this, cx| {
1411 this.request_prediction(
1412 &project,
1413 &buffer,
1414 position,
1415 PredictEditsRequestTrigger::Other,
1416 cx,
1417 )
1418 })
1419 .log_err()
1420 else {
1421 return Task::ready(anyhow::Ok(None));
1422 };
1423
1424 cx.spawn(async move |_cx| {
1425 request_task.await.map(|prediction_result| {
1426 prediction_result.map(|prediction_result| {
1427 (
1428 prediction_result,
1429 PredictionRequestedBy::Buffer(buffer.entity_id()),
1430 )
1431 })
1432 })
1433 })
1434 })
1435 }
1436
1437 pub fn refresh_prediction_from_diagnostics(
1438 &mut self,
1439 project: Entity<Project>,
1440 cx: &mut Context<Self>,
1441 ) {
1442 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1443 return;
1444 };
1445
1446 // Prefer predictions from buffer
1447 if project_state.current_prediction.is_some() {
1448 return;
1449 };
1450
1451 self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
1452 let Some((active_buffer, snapshot, cursor_point)) = this
1453 .read_with(cx, |this, cx| {
1454 let project_state = this.projects.get(&project.entity_id())?;
1455 let (buffer, position) = project_state.active_buffer(&project, cx)?;
1456 let snapshot = buffer.read(cx).snapshot();
1457
1458 if !Self::predictions_enabled_at(&snapshot, position, cx) {
1459 return None;
1460 }
1461
1462 let cursor_point = position
1463 .map(|pos| pos.to_point(&snapshot))
1464 .unwrap_or_default();
1465
1466 Some((buffer, snapshot, cursor_point))
1467 })
1468 .log_err()
1469 .flatten()
1470 else {
1471 return Task::ready(anyhow::Ok(None));
1472 };
1473
1474 cx.spawn(async move |cx| {
1475 let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1476 active_buffer,
1477 &snapshot,
1478 Default::default(),
1479 cursor_point,
1480 &project,
1481 cx,
1482 )
1483 .await?
1484 else {
1485 return anyhow::Ok(None);
1486 };
1487
1488 let Some(prediction_result) = this
1489 .update(cx, |this, cx| {
1490 this.request_prediction(
1491 &project,
1492 &jump_buffer,
1493 jump_position,
1494 PredictEditsRequestTrigger::Diagnostics,
1495 cx,
1496 )
1497 })?
1498 .await?
1499 else {
1500 return anyhow::Ok(None);
1501 };
1502
1503 this.update(cx, |this, cx| {
1504 Some((
1505 if this
1506 .get_or_init_project(&project, cx)
1507 .current_prediction
1508 .is_none()
1509 {
1510 prediction_result
1511 } else {
1512 EditPredictionResult {
1513 id: prediction_result.id,
1514 prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1515 }
1516 },
1517 PredictionRequestedBy::DiagnosticsUpdate,
1518 ))
1519 })
1520 })
1521 });
1522 }
1523
1524 fn predictions_enabled_at(
1525 snapshot: &BufferSnapshot,
1526 position: Option<language::Anchor>,
1527 cx: &App,
1528 ) -> bool {
1529 let file = snapshot.file();
1530 let all_settings = all_language_settings(file, cx);
1531 if !all_settings.show_edit_predictions(snapshot.language(), cx)
1532 || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
1533 {
1534 return false;
1535 }
1536
1537 if let Some(last_position) = position {
1538 let settings = snapshot.settings_at(last_position, cx);
1539
1540 if !settings.edit_predictions_disabled_in.is_empty()
1541 && let Some(scope) = snapshot.language_scope_at(last_position)
1542 && let Some(scope_name) = scope.override_name()
1543 && settings
1544 .edit_predictions_disabled_in
1545 .iter()
1546 .any(|s| s == scope_name)
1547 {
1548 return false;
1549 }
1550 }
1551
1552 true
1553 }
1554
1555 #[cfg(not(test))]
1556 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1557 #[cfg(test)]
1558 pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
1559
1560 fn queue_prediction_refresh(
1561 &mut self,
1562 project: Entity<Project>,
1563 throttle_entity: EntityId,
1564 cx: &mut Context<Self>,
1565 do_refresh: impl FnOnce(
1566 WeakEntity<Self>,
1567 &mut AsyncApp,
1568 )
1569 -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1570 + 'static,
1571 ) {
1572 let is_ollama = self.edit_prediction_model == EditPredictionModel::Ollama;
1573 let drop_on_cancel = is_ollama;
1574 let max_pending_predictions = if is_ollama { 1 } else { 2 };
1575 let project_state = self.get_or_init_project(&project, cx);
1576 let pending_prediction_id = project_state.next_pending_prediction_id;
1577 project_state.next_pending_prediction_id += 1;
1578 let last_request = project_state.last_prediction_refresh;
1579
1580 let task = cx.spawn(async move |this, cx| {
1581 if let Some((last_entity, last_timestamp)) = last_request
1582 && throttle_entity == last_entity
1583 && let Some(timeout) =
1584 (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
1585 {
1586 cx.background_executor().timer(timeout).await;
1587 }
1588
1589 // If this task was cancelled before the throttle timeout expired,
1590 // do not perform a request.
1591 let mut is_cancelled = true;
1592 this.update(cx, |this, cx| {
1593 let project_state = this.get_or_init_project(&project, cx);
1594 if !project_state
1595 .cancelled_predictions
1596 .remove(&pending_prediction_id)
1597 {
1598 project_state.last_prediction_refresh = Some((throttle_entity, Instant::now()));
1599 is_cancelled = false;
1600 }
1601 })
1602 .ok();
1603 if is_cancelled {
1604 return None;
1605 }
1606
1607 let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
1608 let new_prediction_id = new_prediction_result
1609 .as_ref()
1610 .map(|(prediction, _)| prediction.id.clone());
1611
1612 // When a prediction completes, remove it from the pending list, and cancel
1613 // any pending predictions that were enqueued before it.
1614 this.update(cx, |this, cx| {
1615 let project_state = this.get_or_init_project(&project, cx);
1616
1617 let is_cancelled = project_state
1618 .cancelled_predictions
1619 .remove(&pending_prediction_id);
1620
1621 let new_current_prediction = if !is_cancelled
1622 && let Some((prediction_result, requested_by)) = new_prediction_result
1623 {
1624 match prediction_result.prediction {
1625 Ok(prediction) => {
1626 let new_prediction = CurrentEditPrediction {
1627 requested_by,
1628 prediction,
1629 was_shown: false,
1630 shown_with: None,
1631 };
1632
1633 if let Some(current_prediction) =
1634 project_state.current_prediction.as_ref()
1635 {
1636 if new_prediction.should_replace_prediction(¤t_prediction, cx)
1637 {
1638 this.reject_current_prediction(
1639 EditPredictionRejectReason::Replaced,
1640 &project,
1641 edit_prediction_types::EditPredictionDismissReason::Ignored,
1642 cx,
1643 );
1644
1645 Some(new_prediction)
1646 } else {
1647 this.reject_prediction(
1648 new_prediction.prediction.id,
1649 EditPredictionRejectReason::CurrentPreferred,
1650 false,
1651 edit_prediction_types::EditPredictionDismissReason::Ignored,
1652 cx,
1653 );
1654 None
1655 }
1656 } else {
1657 Some(new_prediction)
1658 }
1659 }
1660 Err(reject_reason) => {
1661 this.reject_prediction(
1662 prediction_result.id,
1663 reject_reason,
1664 false,
1665 edit_prediction_types::EditPredictionDismissReason::Ignored,
1666 cx,
1667 );
1668 None
1669 }
1670 }
1671 } else {
1672 None
1673 };
1674
1675 let project_state = this.get_or_init_project(&project, cx);
1676
1677 if let Some(new_prediction) = new_current_prediction {
1678 project_state.current_prediction = Some(new_prediction);
1679 }
1680
1681 let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
1682 for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1683 if pending_prediction.id == pending_prediction_id {
1684 pending_predictions.remove(ix);
1685 for pending_prediction in pending_predictions.drain(0..ix) {
1686 project_state.cancel_pending_prediction(pending_prediction, cx)
1687 }
1688 break;
1689 }
1690 }
1691 this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
1692 cx.notify();
1693 })
1694 .ok();
1695
1696 new_prediction_id
1697 });
1698
1699 if project_state.pending_predictions.len() < max_pending_predictions {
1700 project_state.pending_predictions.push(PendingPrediction {
1701 id: pending_prediction_id,
1702 task,
1703 drop_on_cancel,
1704 });
1705 } else {
1706 let pending_prediction = project_state.pending_predictions.pop().unwrap();
1707 project_state.pending_predictions.push(PendingPrediction {
1708 id: pending_prediction_id,
1709 task,
1710 drop_on_cancel,
1711 });
1712 project_state.cancel_pending_prediction(pending_prediction, cx);
1713 }
1714 }
1715
1716 pub fn request_prediction(
1717 &mut self,
1718 project: &Entity<Project>,
1719 active_buffer: &Entity<Buffer>,
1720 position: language::Anchor,
1721 trigger: PredictEditsRequestTrigger,
1722 cx: &mut Context<Self>,
1723 ) -> Task<Result<Option<EditPredictionResult>>> {
1724 self.request_prediction_internal(
1725 project.clone(),
1726 active_buffer.clone(),
1727 position,
1728 trigger,
1729 cx.has_flag::<Zeta2FeatureFlag>(),
1730 cx,
1731 )
1732 }
1733
1734 fn request_prediction_internal(
1735 &mut self,
1736 project: Entity<Project>,
1737 active_buffer: Entity<Buffer>,
1738 position: language::Anchor,
1739 trigger: PredictEditsRequestTrigger,
1740 allow_jump: bool,
1741 cx: &mut Context<Self>,
1742 ) -> Task<Result<Option<EditPredictionResult>>> {
1743 const DIAGNOSTIC_LINES_RANGE: u32 = 20;
1744
1745 self.get_or_init_project(&project, cx);
1746 let project_state = self.projects.get(&project.entity_id()).unwrap();
1747 let stored_events = project_state.events_split_by_pause(cx);
1748 let has_events = !stored_events.is_empty();
1749 let events: Vec<Arc<zeta_prompt::Event>> =
1750 stored_events.into_iter().map(|e| e.event).collect();
1751 let debug_tx = project_state.debug_tx.clone();
1752
1753 let snapshot = active_buffer.read(cx).snapshot();
1754 let cursor_point = position.to_point(&snapshot);
1755 let current_offset = position.to_offset(&snapshot);
1756
1757 let mut user_actions: Vec<UserActionRecord> =
1758 project_state.user_actions.iter().cloned().collect();
1759
1760 if let Some(last_action) = user_actions.last() {
1761 if last_action.buffer_id == active_buffer.entity_id()
1762 && current_offset != last_action.offset
1763 {
1764 let timestamp_epoch_ms = SystemTime::now()
1765 .duration_since(UNIX_EPOCH)
1766 .map(|d| d.as_millis() as u64)
1767 .unwrap_or(0);
1768 user_actions.push(UserActionRecord {
1769 action_type: UserActionType::CursorMovement,
1770 buffer_id: active_buffer.entity_id(),
1771 line_number: cursor_point.row,
1772 offset: current_offset,
1773 timestamp_epoch_ms,
1774 });
1775 }
1776 }
1777 let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1778 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1779 let diagnostic_search_range =
1780 Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1781
1782 let related_files = self.context_for_project(&project, cx);
1783
1784 let inputs = EditPredictionModelInput {
1785 project: project.clone(),
1786 buffer: active_buffer.clone(),
1787 snapshot: snapshot.clone(),
1788 position,
1789 events,
1790 related_files,
1791 recent_paths: project_state.recent_paths.clone(),
1792 trigger,
1793 diagnostic_search_range: diagnostic_search_range.clone(),
1794 debug_tx,
1795 user_actions,
1796 };
1797
1798 let can_collect_example = snapshot
1799 .file()
1800 .is_some_and(|file| self.can_collect_file(&project, file, cx))
1801 && self.can_collect_events(&inputs.events, cx)
1802 && self.can_collect_related_files(&project, cx);
1803
1804 if can_collect_example && should_sample_edit_prediction_example_capture(cx) {
1805 let events_for_capture =
1806 self.edit_history_for_project_with_pause_split_last_event(&project, cx);
1807 let related_files_for_capture = inputs.related_files.clone();
1808 if let Some(example_task) = capture_example::capture_example(
1809 project.clone(),
1810 active_buffer.clone(),
1811 position,
1812 events_for_capture,
1813 related_files_for_capture,
1814 false,
1815 cx,
1816 ) {
1817 cx.spawn(async move |_this, _cx| {
1818 let example = example_task.await?;
1819 telemetry::event!("Edit Prediction Example Captured", example = example);
1820 anyhow::Ok(())
1821 })
1822 .detach_and_log_err(cx);
1823 }
1824 }
1825 let task = match self.edit_prediction_model {
1826 EditPredictionModel::Zeta1 => {
1827 if should_send_testing_zeta2_request() {
1828 let mut zeta2_inputs = inputs.clone();
1829 zeta2_inputs.trigger = PredictEditsRequestTrigger::Testing;
1830 zeta2::request_prediction_with_zeta2(
1831 self,
1832 zeta2_inputs,
1833 Default::default(),
1834 cx,
1835 )
1836 .detach();
1837 }
1838 zeta1::request_prediction_with_zeta1(self, inputs, cx)
1839 }
1840 EditPredictionModel::Zeta2 { version } => {
1841 zeta2::request_prediction_with_zeta2(self, inputs, version, cx)
1842 }
1843 EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
1844 EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
1845 EditPredictionModel::Ollama => self.ollama.request_prediction(inputs, cx),
1846 };
1847
1848 cx.spawn(async move |this, cx| {
1849 let prediction = task.await?;
1850
1851 if prediction.is_none() && allow_jump {
1852 let cursor_point = position.to_point(&snapshot);
1853 if has_events
1854 && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1855 active_buffer.clone(),
1856 &snapshot,
1857 diagnostic_search_range,
1858 cursor_point,
1859 &project,
1860 cx,
1861 )
1862 .await?
1863 {
1864 return this
1865 .update(cx, |this, cx| {
1866 this.request_prediction_internal(
1867 project,
1868 jump_buffer,
1869 jump_position,
1870 trigger,
1871 false,
1872 cx,
1873 )
1874 })?
1875 .await;
1876 }
1877
1878 return anyhow::Ok(None);
1879 }
1880
1881 Ok(prediction)
1882 })
1883 }
1884
1885 async fn next_diagnostic_location(
1886 active_buffer: Entity<Buffer>,
1887 active_buffer_snapshot: &BufferSnapshot,
1888 active_buffer_diagnostic_search_range: Range<Point>,
1889 active_buffer_cursor_point: Point,
1890 project: &Entity<Project>,
1891 cx: &mut AsyncApp,
1892 ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1893 // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
1894 let mut jump_location = active_buffer_snapshot
1895 .diagnostic_groups(None)
1896 .into_iter()
1897 .filter_map(|(_, group)| {
1898 let range = &group.entries[group.primary_ix]
1899 .range
1900 .to_point(&active_buffer_snapshot);
1901 if range.overlaps(&active_buffer_diagnostic_search_range) {
1902 None
1903 } else {
1904 Some(range.start)
1905 }
1906 })
1907 .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
1908 .map(|position| {
1909 (
1910 active_buffer.clone(),
1911 active_buffer_snapshot.anchor_before(position),
1912 )
1913 });
1914
1915 if jump_location.is_none() {
1916 let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
1917 let file = buffer.file()?;
1918
1919 Some(ProjectPath {
1920 worktree_id: file.worktree_id(cx),
1921 path: file.path().clone(),
1922 })
1923 });
1924
1925 let buffer_task = project.update(cx, |project, cx| {
1926 let (path, _, _) = project
1927 .diagnostic_summaries(false, cx)
1928 .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
1929 .max_by_key(|(path, _, _)| {
1930 // find the buffer with errors that shares most parent directories
1931 path.path
1932 .components()
1933 .zip(
1934 active_buffer_path
1935 .as_ref()
1936 .map(|p| p.path.components())
1937 .unwrap_or_default(),
1938 )
1939 .take_while(|(a, b)| a == b)
1940 .count()
1941 })?;
1942
1943 Some(project.open_buffer(path, cx))
1944 });
1945
1946 if let Some(buffer_task) = buffer_task {
1947 let closest_buffer = buffer_task.await?;
1948
1949 jump_location = closest_buffer
1950 .read_with(cx, |buffer, _cx| {
1951 buffer
1952 .buffer_diagnostics(None)
1953 .into_iter()
1954 .min_by_key(|entry| entry.diagnostic.severity)
1955 .map(|entry| entry.range.start)
1956 })
1957 .map(|position| (closest_buffer, position));
1958 }
1959 }
1960
1961 anyhow::Ok(jump_location)
1962 }
1963
1964 async fn send_raw_llm_request(
1965 request: RawCompletionRequest,
1966 client: Arc<Client>,
1967 custom_url: Option<Arc<Url>>,
1968 llm_token: LlmApiToken,
1969 app_version: Version,
1970 ) -> Result<(RawCompletionResponse, Option<EditPredictionUsage>)> {
1971 let url = if let Some(custom_url) = custom_url {
1972 custom_url.as_ref().clone()
1973 } else {
1974 client
1975 .http_client()
1976 .build_zed_llm_url("/predict_edits/raw", &[])?
1977 };
1978
1979 Self::send_api_request(
1980 |builder| {
1981 let req = builder
1982 .uri(url.as_ref())
1983 .body(serde_json::to_string(&request)?.into());
1984 Ok(req?)
1985 },
1986 client,
1987 llm_token,
1988 app_version,
1989 true,
1990 )
1991 .await
1992 }
1993
1994 pub(crate) async fn send_v3_request(
1995 input: ZetaPromptInput,
1996 prompt_version: ZetaVersion,
1997 client: Arc<Client>,
1998 llm_token: LlmApiToken,
1999 app_version: Version,
2000 trigger: PredictEditsRequestTrigger,
2001 ) -> Result<(PredictEditsV3Response, Option<EditPredictionUsage>)> {
2002 let url = client
2003 .http_client()
2004 .build_zed_llm_url("/predict_edits/v3", &[])?;
2005
2006 let request = PredictEditsV3Request {
2007 input,
2008 model: EDIT_PREDICTIONS_MODEL_ID.clone(),
2009 prompt_version,
2010 trigger,
2011 };
2012
2013 Self::send_api_request(
2014 |builder| {
2015 let req = builder
2016 .uri(url.as_ref())
2017 .body(serde_json::to_string(&request)?.into());
2018 Ok(req?)
2019 },
2020 client,
2021 llm_token,
2022 app_version,
2023 true,
2024 )
2025 .await
2026 }
2027
2028 fn handle_api_response<T>(
2029 this: &WeakEntity<Self>,
2030 response: Result<(T, Option<EditPredictionUsage>)>,
2031 cx: &mut gpui::AsyncApp,
2032 ) -> Result<T> {
2033 match response {
2034 Ok((data, usage)) => {
2035 if let Some(usage) = usage {
2036 this.update(cx, |this, cx| {
2037 this.user_store.update(cx, |user_store, cx| {
2038 user_store.update_edit_prediction_usage(usage, cx);
2039 });
2040 })
2041 .ok();
2042 }
2043 Ok(data)
2044 }
2045 Err(err) => {
2046 if err.is::<ZedUpdateRequiredError>() {
2047 cx.update(|cx| {
2048 this.update(cx, |this, _cx| {
2049 this.update_required = true;
2050 })
2051 .ok();
2052
2053 let error_message: SharedString = err.to_string().into();
2054 show_app_notification(
2055 NotificationId::unique::<ZedUpdateRequiredError>(),
2056 cx,
2057 move |cx| {
2058 cx.new(|cx| {
2059 ErrorMessagePrompt::new(error_message.clone(), cx)
2060 .with_link_button("Update Zed", "https://zed.dev/releases")
2061 })
2062 },
2063 );
2064 });
2065 }
2066 Err(err)
2067 }
2068 }
2069 }
2070
2071 async fn send_api_request<Res>(
2072 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
2073 client: Arc<Client>,
2074 llm_token: LlmApiToken,
2075 app_version: Version,
2076 require_auth: bool,
2077 ) -> Result<(Res, Option<EditPredictionUsage>)>
2078 where
2079 Res: DeserializeOwned,
2080 {
2081 let http_client = client.http_client();
2082
2083 let mut token = if let Ok(custom_token) = std::env::var("ZED_PREDICT_EDITS_TOKEN") {
2084 Some(custom_token)
2085 } else if require_auth {
2086 Some(llm_token.acquire(&client).await?)
2087 } else {
2088 llm_token.acquire(&client).await.ok()
2089 };
2090 let mut did_retry = false;
2091
2092 loop {
2093 let request_builder = http_client::Request::builder().method(Method::POST);
2094
2095 let mut request_builder = request_builder
2096 .header("Content-Type", "application/json")
2097 .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
2098
2099 // Only add Authorization header if we have a token
2100 if let Some(ref token_value) = token {
2101 request_builder =
2102 request_builder.header("Authorization", format!("Bearer {}", token_value));
2103 }
2104
2105 let request = build(request_builder)?;
2106
2107 let mut response = http_client.send(request).await?;
2108
2109 if let Some(minimum_required_version) = response
2110 .headers()
2111 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2112 .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2113 {
2114 anyhow::ensure!(
2115 app_version >= minimum_required_version,
2116 ZedUpdateRequiredError {
2117 minimum_version: minimum_required_version
2118 }
2119 );
2120 }
2121
2122 if response.status().is_success() {
2123 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2124
2125 let mut body = Vec::new();
2126 response.body_mut().read_to_end(&mut body).await?;
2127 return Ok((serde_json::from_slice(&body)?, usage));
2128 } else if !did_retry && token.is_some() && response.needs_llm_token_refresh() {
2129 did_retry = true;
2130 token = Some(llm_token.refresh(&client).await?);
2131 } else {
2132 let mut body = String::new();
2133 response.body_mut().read_to_string(&mut body).await?;
2134 anyhow::bail!(
2135 "Request failed with status: {:?}\nBody: {}",
2136 response.status(),
2137 body
2138 );
2139 }
2140 }
2141 }
2142
2143 pub fn refresh_context(
2144 &mut self,
2145 project: &Entity<Project>,
2146 buffer: &Entity<language::Buffer>,
2147 cursor_position: language::Anchor,
2148 cx: &mut Context<Self>,
2149 ) {
2150 self.get_or_init_project(project, cx)
2151 .context
2152 .update(cx, |store, cx| {
2153 store.refresh(buffer.clone(), cursor_position, cx);
2154 });
2155 }
2156
2157 #[cfg(feature = "cli-support")]
2158 pub fn set_context_for_buffer(
2159 &mut self,
2160 project: &Entity<Project>,
2161 related_files: Vec<RelatedFile>,
2162 cx: &mut Context<Self>,
2163 ) {
2164 self.get_or_init_project(project, cx)
2165 .context
2166 .update(cx, |store, cx| {
2167 store.set_related_files(related_files, cx);
2168 });
2169 }
2170
2171 #[cfg(feature = "cli-support")]
2172 pub fn set_recent_paths_for_project(
2173 &mut self,
2174 project: &Entity<Project>,
2175 paths: impl IntoIterator<Item = project::ProjectPath>,
2176 cx: &mut Context<Self>,
2177 ) {
2178 let project_state = self.get_or_init_project(project, cx);
2179 project_state.recent_paths = paths.into_iter().collect();
2180 }
2181
2182 fn is_file_open_source(
2183 &self,
2184 project: &Entity<Project>,
2185 file: &Arc<dyn File>,
2186 cx: &App,
2187 ) -> bool {
2188 if !file.is_local() || file.is_private() {
2189 return false;
2190 }
2191 let Some(project_state) = self.projects.get(&project.entity_id()) else {
2192 return false;
2193 };
2194 project_state
2195 .license_detection_watchers
2196 .get(&file.worktree_id(cx))
2197 .as_ref()
2198 .is_some_and(|watcher| watcher.is_project_open_source())
2199 }
2200
2201 fn can_collect_file(&self, project: &Entity<Project>, file: &Arc<dyn File>, cx: &App) -> bool {
2202 self.data_collection_choice.is_enabled(cx) && self.is_file_open_source(project, file, cx)
2203 }
2204
2205 fn can_collect_events(&self, events: &[Arc<zeta_prompt::Event>], cx: &App) -> bool {
2206 if !self.data_collection_choice.is_enabled(cx) {
2207 return false;
2208 }
2209 events.iter().all(|event| {
2210 matches!(
2211 event.as_ref(),
2212 zeta_prompt::Event::BufferChange {
2213 in_open_source_repo: true,
2214 ..
2215 }
2216 )
2217 })
2218 }
2219
2220 fn can_collect_related_files(&self, project: &Entity<Project>, cx: &mut App) -> bool {
2221 if !self.data_collection_choice.is_enabled(cx) {
2222 return false;
2223 }
2224
2225 let related_with_buffers = self.context_for_project_with_buffers(project, cx);
2226
2227 related_with_buffers.iter().all(|(_, buffer)| {
2228 buffer
2229 .read(cx)
2230 .file()
2231 .is_some_and(|file| self.is_file_open_source(project, &file, cx))
2232 })
2233 }
2234
2235 fn load_data_collection_choice() -> DataCollectionChoice {
2236 let choice = KEY_VALUE_STORE
2237 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2238 .log_err()
2239 .flatten();
2240
2241 match choice.as_deref() {
2242 Some("true") => DataCollectionChoice::Enabled,
2243 Some("false") => DataCollectionChoice::Disabled,
2244 Some(_) => {
2245 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2246 DataCollectionChoice::NotAnswered
2247 }
2248 None => DataCollectionChoice::NotAnswered,
2249 }
2250 }
2251
2252 fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
2253 self.data_collection_choice = self.data_collection_choice.toggle();
2254 let new_choice = self.data_collection_choice;
2255 let is_enabled = new_choice.is_enabled(cx);
2256 db::write_and_log(cx, move || {
2257 KEY_VALUE_STORE.write_kvp(
2258 ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
2259 is_enabled.to_string(),
2260 )
2261 });
2262 }
2263
2264 pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2265 self.shown_predictions.iter()
2266 }
2267
2268 pub fn shown_completions_len(&self) -> usize {
2269 self.shown_predictions.len()
2270 }
2271
2272 pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2273 self.rated_predictions.contains(id)
2274 }
2275
2276 pub fn rate_prediction(
2277 &mut self,
2278 prediction: &EditPrediction,
2279 rating: EditPredictionRating,
2280 feedback: String,
2281 cx: &mut Context<Self>,
2282 ) {
2283 self.rated_predictions.insert(prediction.id.clone());
2284 telemetry::event!(
2285 "Edit Prediction Rated",
2286 rating,
2287 inputs = prediction.inputs,
2288 output = prediction
2289 .edit_preview
2290 .as_unified_diff(prediction.snapshot.file(), &prediction.edits),
2291 feedback
2292 );
2293 self.client.telemetry().flush_events().detach();
2294 cx.notify();
2295 }
2296}
2297
2298pub(crate) fn filter_redundant_excerpts(
2299 mut related_files: Vec<RelatedFile>,
2300 cursor_path: &Path,
2301 cursor_row_range: Range<u32>,
2302) -> Vec<RelatedFile> {
2303 for file in &mut related_files {
2304 if file.path.as_ref() == cursor_path {
2305 file.excerpts.retain(|excerpt| {
2306 excerpt.row_range.start < cursor_row_range.start
2307 || excerpt.row_range.end > cursor_row_range.end
2308 });
2309 }
2310 }
2311 related_files.retain(|file| !file.excerpts.is_empty());
2312 related_files
2313}
2314
2315#[derive(Error, Debug)]
2316#[error(
2317 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2318)]
2319pub struct ZedUpdateRequiredError {
2320 minimum_version: Version,
2321}
2322
2323#[derive(Debug, Clone, Copy)]
2324pub enum DataCollectionChoice {
2325 NotAnswered,
2326 Enabled,
2327 Disabled,
2328}
2329
2330impl DataCollectionChoice {
2331 pub fn is_enabled(self, cx: &App) -> bool {
2332 if cx.is_staff() {
2333 return true;
2334 }
2335 match self {
2336 Self::Enabled => true,
2337 Self::NotAnswered | Self::Disabled => false,
2338 }
2339 }
2340
2341 #[must_use]
2342 pub fn toggle(&self) -> DataCollectionChoice {
2343 match self {
2344 Self::Enabled => Self::Disabled,
2345 Self::Disabled => Self::Enabled,
2346 Self::NotAnswered => Self::Enabled,
2347 }
2348 }
2349}
2350
2351impl From<bool> for DataCollectionChoice {
2352 fn from(value: bool) -> Self {
2353 match value {
2354 true => DataCollectionChoice::Enabled,
2355 false => DataCollectionChoice::Disabled,
2356 }
2357 }
2358}
2359
2360struct ZedPredictUpsell;
2361
2362impl Dismissable for ZedPredictUpsell {
2363 const KEY: &'static str = "dismissed-edit-predict-upsell";
2364
2365 fn dismissed() -> bool {
2366 // To make this backwards compatible with older versions of Zed, we
2367 // check if the user has seen the previous Edit Prediction Onboarding
2368 // before, by checking the data collection choice which was written to
2369 // the database once the user clicked on "Accept and Enable"
2370 if KEY_VALUE_STORE
2371 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2372 .log_err()
2373 .is_some_and(|s| s.is_some())
2374 {
2375 return true;
2376 }
2377
2378 KEY_VALUE_STORE
2379 .read_kvp(Self::KEY)
2380 .log_err()
2381 .is_some_and(|s| s.is_some())
2382 }
2383}
2384
2385pub fn should_show_upsell_modal() -> bool {
2386 !ZedPredictUpsell::dismissed()
2387}
2388
2389pub fn init(cx: &mut App) {
2390 cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2391 workspace.register_action(
2392 move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2393 ZedPredictModal::toggle(
2394 workspace,
2395 workspace.user_store().clone(),
2396 workspace.client().clone(),
2397 window,
2398 cx,
2399 )
2400 },
2401 );
2402
2403 workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2404 update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2405 settings
2406 .project
2407 .all_languages
2408 .edit_predictions
2409 .get_or_insert_default()
2410 .provider = Some(EditPredictionProvider::None)
2411 });
2412 });
2413 fn copilot_for_project(project: &Entity<Project>, cx: &mut App) -> Option<Entity<Copilot>> {
2414 EditPredictionStore::try_global(cx).and_then(|store| {
2415 store.update(cx, |this, cx| this.start_copilot_for_project(project, cx))
2416 })
2417 }
2418
2419 workspace.register_action(|workspace, _: &SignIn, window, cx| {
2420 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2421 copilot_ui::initiate_sign_in(copilot, window, cx);
2422 }
2423 });
2424 workspace.register_action(|workspace, _: &Reinstall, window, cx| {
2425 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2426 copilot_ui::reinstall_and_sign_in(copilot, window, cx);
2427 }
2428 });
2429 workspace.register_action(|workspace, _: &SignOut, window, cx| {
2430 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2431 copilot_ui::initiate_sign_out(copilot, window, cx);
2432 }
2433 });
2434 })
2435 .detach();
2436}