1use anyhow::Result;
2use arrayvec::ArrayVec;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{
5 PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
6};
7use cloud_llm_client::{
8 EditPredictionRejectReason, EditPredictionRejection,
9 MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
10 PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME,
11};
12use collections::{HashMap, HashSet};
13use copilot::{Copilot, Reinstall, SignIn, SignOut};
14use db::kvp::{Dismissable, KEY_VALUE_STORE};
15use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
16use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
17use futures::{
18 AsyncReadExt as _, FutureExt as _, StreamExt as _,
19 channel::mpsc::{self, UnboundedReceiver},
20 select_biased,
21};
22use gpui::BackgroundExecutor;
23use gpui::http_client::Url;
24use gpui::{
25 App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
26 http_client::{self, AsyncBody, Method},
27 prelude::*,
28};
29use language::language_settings::all_language_settings;
30use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
31use language::{BufferSnapshot, OffsetRangeExt};
32use language_model::{LlmApiToken, NeedsLlmTokenRefresh, RefreshLlmTokenListener};
33use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
34use release_channel::AppVersion;
35use semver::Version;
36use serde::de::DeserializeOwned;
37use settings::{EditPredictionProvider, Settings as _, update_settings_file};
38use std::collections::{VecDeque, hash_map};
39use std::env;
40use text::Edit;
41use workspace::Workspace;
42use zeta_prompt::{ZetaFormat, ZetaPromptInput};
43
44use std::mem;
45use std::ops::Range;
46use std::path::Path;
47use std::rc::Rc;
48use std::str::FromStr as _;
49use std::sync::Arc;
50use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
51use thiserror::Error;
52use util::{RangeExt as _, ResultExt as _};
53use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
54
55pub mod cursor_excerpt;
56pub mod example_spec;
57mod license_detection;
58pub mod mercury;
59pub mod ollama;
60mod onboarding_modal;
61pub mod open_ai_response;
62mod prediction;
63pub mod sweep_ai;
64
65pub mod udiff;
66
67mod capture_example;
68mod zed_edit_prediction_delegate;
69pub mod zeta1;
70pub mod zeta2;
71
72#[cfg(test)]
73mod edit_prediction_tests;
74
75use crate::capture_example::should_send_testing_zeta2_request;
76use crate::license_detection::LicenseDetectionWatcher;
77use crate::mercury::Mercury;
78use crate::ollama::Ollama;
79use crate::onboarding_modal::ZedPredictModal;
80pub use crate::prediction::EditPrediction;
81pub use crate::prediction::EditPredictionId;
82use crate::prediction::EditPredictionResult;
83pub use crate::sweep_ai::SweepAi;
84pub use capture_example::capture_example;
85pub use language_model::ApiKeyState;
86pub use telemetry_events::EditPredictionRating;
87pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
88
89actions!(
90 edit_prediction,
91 [
92 /// Resets the edit prediction onboarding state.
93 ResetOnboarding,
94 /// Clears the edit prediction history.
95 ClearHistory,
96 ]
97);
98
99/// Maximum number of events to track.
100const EVENT_COUNT_MAX: usize = 6;
101const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
102const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
103const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
104const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
105
106pub struct Zeta2FeatureFlag;
107
108impl FeatureFlag for Zeta2FeatureFlag {
109 const NAME: &'static str = "zeta2";
110
111 fn enabled_for_staff() -> bool {
112 true
113 }
114}
115
116#[derive(Clone)]
117struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
118
119impl Global for EditPredictionStoreGlobal {}
120
121/// Configuration for using the raw Zeta2 endpoint.
122/// When set, the client uses the raw endpoint and constructs the prompt itself.
123/// The version is also used as the Baseten environment name (lowercased).
124#[derive(Clone)]
125pub struct Zeta2RawConfig {
126 pub model_id: Option<String>,
127 pub format: ZetaFormat,
128}
129
130pub struct EditPredictionStore {
131 client: Arc<Client>,
132 user_store: Entity<UserStore>,
133 llm_token: LlmApiToken,
134 _llm_token_subscription: Subscription,
135 projects: HashMap<EntityId, ProjectState>,
136 update_required: bool,
137 edit_prediction_model: EditPredictionModel,
138 zeta2_raw_config: Option<Zeta2RawConfig>,
139 pub sweep_ai: SweepAi,
140 pub mercury: Mercury,
141 pub ollama: Ollama,
142 data_collection_choice: DataCollectionChoice,
143 reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
144 shown_predictions: VecDeque<EditPrediction>,
145 rated_predictions: HashSet<EditPredictionId>,
146}
147
148#[derive(Copy, Clone, Default, PartialEq, Eq)]
149pub enum EditPredictionModel {
150 #[default]
151 Zeta1,
152 Zeta2,
153 Sweep,
154 Mercury,
155 Ollama,
156}
157
158#[derive(Clone)]
159pub struct EditPredictionModelInput {
160 project: Entity<Project>,
161 buffer: Entity<Buffer>,
162 snapshot: BufferSnapshot,
163 position: Anchor,
164 events: Vec<Arc<zeta_prompt::Event>>,
165 related_files: Vec<RelatedFile>,
166 recent_paths: VecDeque<ProjectPath>,
167 trigger: PredictEditsRequestTrigger,
168 diagnostic_search_range: Range<Point>,
169 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
170 pub user_actions: Vec<UserActionRecord>,
171}
172
173#[derive(Debug)]
174pub enum DebugEvent {
175 ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
176 ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
177 EditPredictionStarted(EditPredictionStartedDebugEvent),
178 EditPredictionFinished(EditPredictionFinishedDebugEvent),
179}
180
181#[derive(Debug)]
182pub struct ContextRetrievalStartedDebugEvent {
183 pub project_entity_id: EntityId,
184 pub timestamp: Instant,
185 pub search_prompt: String,
186}
187
188#[derive(Debug)]
189pub struct ContextRetrievalFinishedDebugEvent {
190 pub project_entity_id: EntityId,
191 pub timestamp: Instant,
192 pub metadata: Vec<(&'static str, SharedString)>,
193}
194
195#[derive(Debug)]
196pub struct EditPredictionStartedDebugEvent {
197 pub buffer: WeakEntity<Buffer>,
198 pub position: Anchor,
199 pub prompt: Option<String>,
200}
201
202#[derive(Debug)]
203pub struct EditPredictionFinishedDebugEvent {
204 pub buffer: WeakEntity<Buffer>,
205 pub position: Anchor,
206 pub model_output: Option<String>,
207}
208
209const USER_ACTION_HISTORY_SIZE: usize = 16;
210
211#[derive(Clone, Debug)]
212pub struct UserActionRecord {
213 pub action_type: UserActionType,
214 pub buffer_id: EntityId,
215 pub line_number: u32,
216 pub offset: usize,
217 pub timestamp_epoch_ms: u64,
218}
219
220#[derive(Clone, Copy, Debug, PartialEq, Eq)]
221pub enum UserActionType {
222 InsertChar,
223 InsertSelection,
224 DeleteChar,
225 DeleteSelection,
226 CursorMovement,
227}
228
229/// An event with associated metadata for reconstructing buffer state.
230#[derive(Clone)]
231pub struct StoredEvent {
232 pub event: Arc<zeta_prompt::Event>,
233 pub old_snapshot: TextBufferSnapshot,
234}
235
236struct ProjectState {
237 events: VecDeque<StoredEvent>,
238 last_event: Option<LastEvent>,
239 recent_paths: VecDeque<ProjectPath>,
240 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
241 current_prediction: Option<CurrentEditPrediction>,
242 next_pending_prediction_id: usize,
243 pending_predictions: ArrayVec<PendingPrediction, 2>,
244 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
245 last_prediction_refresh: Option<(EntityId, Instant)>,
246 cancelled_predictions: HashSet<usize>,
247 context: Entity<RelatedExcerptStore>,
248 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
249 user_actions: VecDeque<UserActionRecord>,
250 _subscriptions: [gpui::Subscription; 2],
251 copilot: Option<Entity<Copilot>>,
252}
253
254impl ProjectState {
255 fn record_user_action(&mut self, action: UserActionRecord) {
256 if self.user_actions.len() >= USER_ACTION_HISTORY_SIZE {
257 self.user_actions.pop_front();
258 }
259 self.user_actions.push_back(action);
260 }
261
262 pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
263 self.events
264 .iter()
265 .cloned()
266 .chain(
267 self.last_event
268 .as_ref()
269 .and_then(|event| event.finalize(&self.license_detection_watchers, cx)),
270 )
271 .collect()
272 }
273
274 pub fn events_split_by_pause(&self, cx: &App) -> Vec<StoredEvent> {
275 self.events
276 .iter()
277 .cloned()
278 .chain(self.last_event.as_ref().iter().flat_map(|event| {
279 let (one, two) = event.split_by_pause();
280 let one = one.finalize(&self.license_detection_watchers, cx);
281 let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
282 one.into_iter().chain(two)
283 }))
284 .collect()
285 }
286
287 fn cancel_pending_prediction(
288 &mut self,
289 pending_prediction: PendingPrediction,
290 cx: &mut Context<EditPredictionStore>,
291 ) {
292 self.cancelled_predictions.insert(pending_prediction.id);
293
294 if pending_prediction.drop_on_cancel {
295 drop(pending_prediction.task);
296 } else {
297 cx.spawn(async move |this, cx| {
298 let Some(prediction_id) = pending_prediction.task.await else {
299 return;
300 };
301
302 this.update(cx, |this, cx| {
303 this.reject_prediction(
304 prediction_id,
305 EditPredictionRejectReason::Canceled,
306 false,
307 cx,
308 );
309 })
310 .ok();
311 })
312 .detach()
313 }
314 }
315
316 fn active_buffer(
317 &self,
318 project: &Entity<Project>,
319 cx: &App,
320 ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
321 let project = project.read(cx);
322 let active_path = project.path_for_entry(project.active_entry()?, cx)?;
323 let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
324 let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
325 Some((active_buffer, registered_buffer.last_position))
326 }
327}
328
329#[derive(Debug, Clone)]
330struct CurrentEditPrediction {
331 pub requested_by: PredictionRequestedBy,
332 pub prediction: EditPrediction,
333 pub was_shown: bool,
334 pub shown_with: Option<edit_prediction_types::SuggestionDisplayType>,
335}
336
337impl CurrentEditPrediction {
338 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
339 let Some(new_edits) = self
340 .prediction
341 .interpolate(&self.prediction.buffer.read(cx))
342 else {
343 return false;
344 };
345
346 if self.prediction.buffer != old_prediction.prediction.buffer {
347 return true;
348 }
349
350 let Some(old_edits) = old_prediction
351 .prediction
352 .interpolate(&old_prediction.prediction.buffer.read(cx))
353 else {
354 return true;
355 };
356
357 let requested_by_buffer_id = self.requested_by.buffer_id();
358
359 // This reduces the occurrence of UI thrash from replacing edits
360 //
361 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
362 if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
363 && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
364 && old_edits.len() == 1
365 && new_edits.len() == 1
366 {
367 let (old_range, old_text) = &old_edits[0];
368 let (new_range, new_text) = &new_edits[0];
369 new_range == old_range && new_text.starts_with(old_text.as_ref())
370 } else {
371 true
372 }
373 }
374}
375
376#[derive(Debug, Clone)]
377enum PredictionRequestedBy {
378 DiagnosticsUpdate,
379 Buffer(EntityId),
380}
381
382impl PredictionRequestedBy {
383 pub fn buffer_id(&self) -> Option<EntityId> {
384 match self {
385 PredictionRequestedBy::DiagnosticsUpdate => None,
386 PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
387 }
388 }
389}
390
391#[derive(Debug)]
392struct PendingPrediction {
393 id: usize,
394 task: Task<Option<EditPredictionId>>,
395 /// If true, the task is dropped immediately on cancel (cancelling the HTTP request).
396 /// If false, the task is awaited to completion so rejection can be reported.
397 drop_on_cancel: bool,
398}
399
400/// A prediction from the perspective of a buffer.
401#[derive(Debug)]
402enum BufferEditPrediction<'a> {
403 Local { prediction: &'a EditPrediction },
404 Jump { prediction: &'a EditPrediction },
405}
406
407#[cfg(test)]
408impl std::ops::Deref for BufferEditPrediction<'_> {
409 type Target = EditPrediction;
410
411 fn deref(&self) -> &Self::Target {
412 match self {
413 BufferEditPrediction::Local { prediction } => prediction,
414 BufferEditPrediction::Jump { prediction } => prediction,
415 }
416 }
417}
418
419struct RegisteredBuffer {
420 file: Option<Arc<dyn File>>,
421 snapshot: TextBufferSnapshot,
422 last_position: Option<Anchor>,
423 _subscriptions: [gpui::Subscription; 2],
424}
425
426#[derive(Clone)]
427struct LastEvent {
428 old_snapshot: TextBufferSnapshot,
429 new_snapshot: TextBufferSnapshot,
430 old_file: Option<Arc<dyn File>>,
431 new_file: Option<Arc<dyn File>>,
432 edit_range: Option<Range<Anchor>>,
433 snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
434 last_edit_time: Option<Instant>,
435}
436
437impl LastEvent {
438 pub fn finalize(
439 &self,
440 license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
441 cx: &App,
442 ) -> Option<StoredEvent> {
443 let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
444 let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
445
446 let in_open_source_repo =
447 [self.new_file.as_ref(), self.old_file.as_ref()]
448 .iter()
449 .all(|file| {
450 file.is_some_and(|file| {
451 license_detection_watchers
452 .get(&file.worktree_id(cx))
453 .is_some_and(|watcher| watcher.is_project_open_source())
454 })
455 });
456
457 let diff = compute_diff_between_snapshots(&self.old_snapshot, &self.new_snapshot)?;
458
459 if path == old_path && diff.is_empty() {
460 None
461 } else {
462 Some(StoredEvent {
463 event: Arc::new(zeta_prompt::Event::BufferChange {
464 old_path,
465 path,
466 diff,
467 in_open_source_repo,
468 // TODO: Actually detect if this edit was predicted or not
469 predicted: false,
470 }),
471 old_snapshot: self.old_snapshot.clone(),
472 })
473 }
474 }
475
476 pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
477 let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
478 return (self.clone(), None);
479 };
480
481 let before = LastEvent {
482 old_snapshot: self.old_snapshot.clone(),
483 new_snapshot: boundary_snapshot.clone(),
484 old_file: self.old_file.clone(),
485 new_file: self.new_file.clone(),
486 edit_range: None,
487 snapshot_after_last_editing_pause: None,
488 last_edit_time: self.last_edit_time,
489 };
490
491 let after = LastEvent {
492 old_snapshot: boundary_snapshot.clone(),
493 new_snapshot: self.new_snapshot.clone(),
494 old_file: self.old_file.clone(),
495 new_file: self.new_file.clone(),
496 edit_range: None,
497 snapshot_after_last_editing_pause: None,
498 last_edit_time: self.last_edit_time,
499 };
500
501 (before, Some(after))
502 }
503}
504
505pub(crate) fn compute_diff_between_snapshots(
506 old_snapshot: &TextBufferSnapshot,
507 new_snapshot: &TextBufferSnapshot,
508) -> Option<String> {
509 let edits: Vec<Edit<usize>> = new_snapshot
510 .edits_since::<usize>(&old_snapshot.version)
511 .collect();
512
513 let (first_edit, last_edit) = edits.first().zip(edits.last())?;
514
515 let old_start_point = old_snapshot.offset_to_point(first_edit.old.start);
516 let old_end_point = old_snapshot.offset_to_point(last_edit.old.end);
517 let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
518 let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
519
520 const CONTEXT_LINES: u32 = 3;
521
522 let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
523 let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
524 let old_context_end_row =
525 (old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
526 let new_context_end_row =
527 (new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
528
529 let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
530 let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
531 let old_end_line_offset = old_snapshot
532 .point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
533 let new_end_line_offset = new_snapshot
534 .point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
535 let old_edit_range = old_start_line_offset..old_end_line_offset;
536 let new_edit_range = new_start_line_offset..new_end_line_offset;
537
538 let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
539 let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
540
541 let diff = language::unified_diff_with_offsets(
542 &old_region_text,
543 &new_region_text,
544 old_context_start_row,
545 new_context_start_row,
546 );
547
548 Some(diff)
549}
550
551fn buffer_path_with_id_fallback(
552 file: Option<&Arc<dyn File>>,
553 snapshot: &TextBufferSnapshot,
554 cx: &App,
555) -> Arc<Path> {
556 if let Some(file) = file {
557 file.full_path(cx).into()
558 } else {
559 Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
560 }
561}
562
563impl EditPredictionStore {
564 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
565 cx.try_global::<EditPredictionStoreGlobal>()
566 .map(|global| global.0.clone())
567 }
568
569 pub fn global(
570 client: &Arc<Client>,
571 user_store: &Entity<UserStore>,
572 cx: &mut App,
573 ) -> Entity<Self> {
574 cx.try_global::<EditPredictionStoreGlobal>()
575 .map(|global| global.0.clone())
576 .unwrap_or_else(|| {
577 let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
578 cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
579 ep_store
580 })
581 }
582
583 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
584 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
585 let data_collection_choice = Self::load_data_collection_choice();
586
587 let llm_token = LlmApiToken::default();
588
589 let (reject_tx, reject_rx) = mpsc::unbounded();
590 cx.background_spawn({
591 let client = client.clone();
592 let llm_token = llm_token.clone();
593 let app_version = AppVersion::global(cx);
594 let background_executor = cx.background_executor().clone();
595 async move {
596 Self::handle_rejected_predictions(
597 reject_rx,
598 client,
599 llm_token,
600 app_version,
601 background_executor,
602 )
603 .await
604 }
605 })
606 .detach();
607
608 let this = Self {
609 projects: HashMap::default(),
610 client,
611 user_store,
612 llm_token,
613 _llm_token_subscription: cx.subscribe(
614 &refresh_llm_token_listener,
615 |this, _listener, _event, cx| {
616 let client = this.client.clone();
617 let llm_token = this.llm_token.clone();
618 cx.spawn(async move |_this, _cx| {
619 llm_token.refresh(&client).await?;
620 anyhow::Ok(())
621 })
622 .detach_and_log_err(cx);
623 },
624 ),
625 update_required: false,
626 edit_prediction_model: EditPredictionModel::Zeta2,
627 zeta2_raw_config: Self::zeta2_raw_config_from_env(),
628 sweep_ai: SweepAi::new(cx),
629 mercury: Mercury::new(cx),
630 ollama: Ollama::new(),
631
632 data_collection_choice,
633 reject_predictions_tx: reject_tx,
634 rated_predictions: Default::default(),
635 shown_predictions: Default::default(),
636 };
637
638 this
639 }
640
641 fn zeta2_raw_config_from_env() -> Option<Zeta2RawConfig> {
642 let version_str = env::var("ZED_ZETA_FORMAT").ok()?;
643 let format = ZetaFormat::parse(&version_str).ok()?;
644 let model_id = env::var("ZED_ZETA_MODEL").ok();
645 Some(Zeta2RawConfig { model_id, format })
646 }
647
648 pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
649 self.edit_prediction_model = model;
650 }
651
652 pub fn set_zeta2_raw_config(&mut self, config: Zeta2RawConfig) {
653 self.zeta2_raw_config = Some(config);
654 }
655
656 pub fn zeta2_raw_config(&self) -> Option<&Zeta2RawConfig> {
657 self.zeta2_raw_config.as_ref()
658 }
659
660 pub fn icons(&self) -> edit_prediction_types::EditPredictionIconSet {
661 use ui::IconName;
662 match self.edit_prediction_model {
663 EditPredictionModel::Sweep => {
664 edit_prediction_types::EditPredictionIconSet::new(IconName::SweepAi)
665 .with_disabled(IconName::SweepAiDisabled)
666 .with_up(IconName::SweepAiUp)
667 .with_down(IconName::SweepAiDown)
668 .with_error(IconName::SweepAiError)
669 }
670 EditPredictionModel::Mercury => {
671 edit_prediction_types::EditPredictionIconSet::new(IconName::Inception)
672 }
673 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
674 edit_prediction_types::EditPredictionIconSet::new(IconName::ZedPredict)
675 .with_disabled(IconName::ZedPredictDisabled)
676 .with_up(IconName::ZedPredictUp)
677 .with_down(IconName::ZedPredictDown)
678 .with_error(IconName::ZedPredictError)
679 }
680 EditPredictionModel::Ollama => {
681 edit_prediction_types::EditPredictionIconSet::new(IconName::AiOllama)
682 }
683 }
684 }
685
686 pub fn has_sweep_api_token(&self, cx: &App) -> bool {
687 self.sweep_ai.api_token.read(cx).has_key()
688 }
689
690 pub fn has_mercury_api_token(&self, cx: &App) -> bool {
691 self.mercury.api_token.read(cx).has_key()
692 }
693
694 pub fn clear_history(&mut self) {
695 for project_state in self.projects.values_mut() {
696 project_state.events.clear();
697 project_state.last_event.take();
698 }
699 }
700
701 pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
702 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
703 project_state.events.clear();
704 project_state.last_event.take();
705 }
706 }
707
708 pub fn edit_history_for_project(
709 &self,
710 project: &Entity<Project>,
711 cx: &App,
712 ) -> Vec<StoredEvent> {
713 self.projects
714 .get(&project.entity_id())
715 .map(|project_state| project_state.events(cx))
716 .unwrap_or_default()
717 }
718
719 pub fn edit_history_for_project_with_pause_split_last_event(
720 &self,
721 project: &Entity<Project>,
722 cx: &App,
723 ) -> Vec<StoredEvent> {
724 self.projects
725 .get(&project.entity_id())
726 .map(|project_state| project_state.events_split_by_pause(cx))
727 .unwrap_or_default()
728 }
729
730 pub fn context_for_project<'a>(
731 &'a self,
732 project: &Entity<Project>,
733 cx: &'a mut App,
734 ) -> Vec<RelatedFile> {
735 self.projects
736 .get(&project.entity_id())
737 .map(|project| {
738 project
739 .context
740 .update(cx, |context, cx| context.related_files(cx))
741 })
742 .unwrap_or_default()
743 }
744
745 pub fn copilot_for_project(&self, project: &Entity<Project>) -> Option<Entity<Copilot>> {
746 self.projects
747 .get(&project.entity_id())
748 .and_then(|project| project.copilot.clone())
749 }
750
751 pub fn start_copilot_for_project(
752 &mut self,
753 project: &Entity<Project>,
754 cx: &mut Context<Self>,
755 ) -> Option<Entity<Copilot>> {
756 if DisableAiSettings::get(None, cx).disable_ai {
757 return None;
758 }
759 let state = self.get_or_init_project(project, cx);
760
761 if state.copilot.is_some() {
762 return state.copilot.clone();
763 }
764 let _project = project.clone();
765 let project = project.read(cx);
766
767 let node = project.node_runtime().cloned();
768 if let Some(node) = node {
769 let next_id = project.languages().next_language_server_id();
770 let fs = project.fs().clone();
771
772 let copilot = cx.new(|cx| Copilot::new(Some(_project), next_id, fs, node, cx));
773 state.copilot = Some(copilot.clone());
774 Some(copilot)
775 } else {
776 None
777 }
778 }
779
780 pub fn context_for_project_with_buffers<'a>(
781 &'a self,
782 project: &Entity<Project>,
783 cx: &'a mut App,
784 ) -> Vec<(RelatedFile, Entity<Buffer>)> {
785 self.projects
786 .get(&project.entity_id())
787 .map(|project| {
788 project
789 .context
790 .update(cx, |context, cx| context.related_files_with_buffers(cx))
791 })
792 .unwrap_or_default()
793 }
794
795 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
796 if matches!(self.edit_prediction_model, EditPredictionModel::Zeta2) {
797 self.user_store.read(cx).edit_prediction_usage()
798 } else {
799 None
800 }
801 }
802
803 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
804 self.get_or_init_project(project, cx);
805 }
806
807 pub fn register_buffer(
808 &mut self,
809 buffer: &Entity<Buffer>,
810 project: &Entity<Project>,
811 cx: &mut Context<Self>,
812 ) {
813 let project_state = self.get_or_init_project(project, cx);
814 Self::register_buffer_impl(project_state, buffer, project, cx);
815 }
816
817 fn get_or_init_project(
818 &mut self,
819 project: &Entity<Project>,
820 cx: &mut Context<Self>,
821 ) -> &mut ProjectState {
822 let entity_id = project.entity_id();
823 self.projects
824 .entry(entity_id)
825 .or_insert_with(|| ProjectState {
826 context: {
827 let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
828 cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
829 this.handle_excerpt_store_event(entity_id, event);
830 })
831 .detach();
832 related_excerpt_store
833 },
834 events: VecDeque::new(),
835 last_event: None,
836 recent_paths: VecDeque::new(),
837 debug_tx: None,
838 registered_buffers: HashMap::default(),
839 current_prediction: None,
840 cancelled_predictions: HashSet::default(),
841 pending_predictions: ArrayVec::new(),
842 next_pending_prediction_id: 0,
843 last_prediction_refresh: None,
844 license_detection_watchers: HashMap::default(),
845 user_actions: VecDeque::with_capacity(USER_ACTION_HISTORY_SIZE),
846 _subscriptions: [
847 cx.subscribe(&project, Self::handle_project_event),
848 cx.observe_release(&project, move |this, _, cx| {
849 this.projects.remove(&entity_id);
850 cx.notify();
851 }),
852 ],
853 copilot: None,
854 })
855 }
856
857 pub fn remove_project(&mut self, project: &Entity<Project>) {
858 self.projects.remove(&project.entity_id());
859 }
860
861 fn handle_excerpt_store_event(
862 &mut self,
863 project_entity_id: EntityId,
864 event: &RelatedExcerptStoreEvent,
865 ) {
866 if let Some(project_state) = self.projects.get(&project_entity_id) {
867 if let Some(debug_tx) = project_state.debug_tx.clone() {
868 match event {
869 RelatedExcerptStoreEvent::StartedRefresh => {
870 debug_tx
871 .unbounded_send(DebugEvent::ContextRetrievalStarted(
872 ContextRetrievalStartedDebugEvent {
873 project_entity_id: project_entity_id,
874 timestamp: Instant::now(),
875 search_prompt: String::new(),
876 },
877 ))
878 .ok();
879 }
880 RelatedExcerptStoreEvent::FinishedRefresh {
881 cache_hit_count,
882 cache_miss_count,
883 mean_definition_latency,
884 max_definition_latency,
885 } => {
886 debug_tx
887 .unbounded_send(DebugEvent::ContextRetrievalFinished(
888 ContextRetrievalFinishedDebugEvent {
889 project_entity_id: project_entity_id,
890 timestamp: Instant::now(),
891 metadata: vec![
892 (
893 "Cache Hits",
894 format!(
895 "{}/{}",
896 cache_hit_count,
897 cache_hit_count + cache_miss_count
898 )
899 .into(),
900 ),
901 (
902 "Max LSP Time",
903 format!("{} ms", max_definition_latency.as_millis())
904 .into(),
905 ),
906 (
907 "Mean LSP Time",
908 format!("{} ms", mean_definition_latency.as_millis())
909 .into(),
910 ),
911 ],
912 },
913 ))
914 .ok();
915 }
916 }
917 }
918 }
919 }
920
921 pub fn debug_info(
922 &mut self,
923 project: &Entity<Project>,
924 cx: &mut Context<Self>,
925 ) -> mpsc::UnboundedReceiver<DebugEvent> {
926 let project_state = self.get_or_init_project(project, cx);
927 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
928 project_state.debug_tx = Some(debug_watch_tx);
929 debug_watch_rx
930 }
931
932 fn handle_project_event(
933 &mut self,
934 project: Entity<Project>,
935 event: &project::Event,
936 cx: &mut Context<Self>,
937 ) {
938 // TODO [zeta2] init with recent paths
939 match event {
940 project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
941 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
942 return;
943 };
944 let path = project.read(cx).path_for_entry(*active_entry_id, cx);
945 if let Some(path) = path {
946 if let Some(ix) = project_state
947 .recent_paths
948 .iter()
949 .position(|probe| probe == &path)
950 {
951 project_state.recent_paths.remove(ix);
952 }
953 project_state.recent_paths.push_front(path);
954 }
955 }
956 project::Event::DiagnosticsUpdated { .. } => {
957 if cx.has_flag::<Zeta2FeatureFlag>() {
958 self.refresh_prediction_from_diagnostics(project, cx);
959 }
960 }
961 _ => (),
962 }
963 }
964
965 fn register_buffer_impl<'a>(
966 project_state: &'a mut ProjectState,
967 buffer: &Entity<Buffer>,
968 project: &Entity<Project>,
969 cx: &mut Context<Self>,
970 ) -> &'a mut RegisteredBuffer {
971 let buffer_id = buffer.entity_id();
972
973 if let Some(file) = buffer.read(cx).file() {
974 let worktree_id = file.worktree_id(cx);
975 if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
976 project_state
977 .license_detection_watchers
978 .entry(worktree_id)
979 .or_insert_with(|| {
980 let project_entity_id = project.entity_id();
981 cx.observe_release(&worktree, move |this, _worktree, _cx| {
982 let Some(project_state) = this.projects.get_mut(&project_entity_id)
983 else {
984 return;
985 };
986 project_state
987 .license_detection_watchers
988 .remove(&worktree_id);
989 })
990 .detach();
991 Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
992 });
993 }
994 }
995
996 match project_state.registered_buffers.entry(buffer_id) {
997 hash_map::Entry::Occupied(entry) => entry.into_mut(),
998 hash_map::Entry::Vacant(entry) => {
999 let buf = buffer.read(cx);
1000 let snapshot = buf.text_snapshot();
1001 let file = buf.file().cloned();
1002 let project_entity_id = project.entity_id();
1003 entry.insert(RegisteredBuffer {
1004 snapshot,
1005 file,
1006 last_position: None,
1007 _subscriptions: [
1008 cx.subscribe(buffer, {
1009 let project = project.downgrade();
1010 move |this, buffer, event, cx| {
1011 if let language::BufferEvent::Edited = event
1012 && let Some(project) = project.upgrade()
1013 {
1014 this.report_changes_for_buffer(&buffer, &project, cx);
1015 }
1016 }
1017 }),
1018 cx.observe_release(buffer, move |this, _buffer, _cx| {
1019 let Some(project_state) = this.projects.get_mut(&project_entity_id)
1020 else {
1021 return;
1022 };
1023 project_state.registered_buffers.remove(&buffer_id);
1024 }),
1025 ],
1026 })
1027 }
1028 }
1029 }
1030
1031 fn report_changes_for_buffer(
1032 &mut self,
1033 buffer: &Entity<Buffer>,
1034 project: &Entity<Project>,
1035 cx: &mut Context<Self>,
1036 ) {
1037 let project_state = self.get_or_init_project(project, cx);
1038 let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
1039
1040 let buf = buffer.read(cx);
1041 let new_file = buf.file().cloned();
1042 let new_snapshot = buf.text_snapshot();
1043 if new_snapshot.version == registered_buffer.snapshot.version {
1044 return;
1045 }
1046
1047 let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
1048 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1049 let mut num_edits = 0usize;
1050 let mut total_deleted = 0usize;
1051 let mut total_inserted = 0usize;
1052 let mut edit_range: Option<Range<Anchor>> = None;
1053 let mut last_offset: Option<usize> = None;
1054
1055 for (edit, anchor_range) in
1056 new_snapshot.anchored_edits_since::<usize>(&old_snapshot.version)
1057 {
1058 num_edits += 1;
1059 total_deleted += edit.old.len();
1060 total_inserted += edit.new.len();
1061 edit_range = Some(match edit_range {
1062 None => anchor_range,
1063 Some(acc) => acc.start..anchor_range.end,
1064 });
1065 last_offset = Some(edit.new.end);
1066 }
1067
1068 if num_edits > 0 {
1069 let action_type = match (total_deleted, total_inserted, num_edits) {
1070 (0, ins, n) if ins == n => UserActionType::InsertChar,
1071 (0, _, _) => UserActionType::InsertSelection,
1072 (del, 0, n) if del == n => UserActionType::DeleteChar,
1073 (_, 0, _) => UserActionType::DeleteSelection,
1074 (_, ins, n) if ins == n => UserActionType::InsertChar,
1075 (_, _, _) => UserActionType::InsertSelection,
1076 };
1077
1078 if let Some(offset) = last_offset {
1079 let point = new_snapshot.offset_to_point(offset);
1080 let timestamp_epoch_ms = SystemTime::now()
1081 .duration_since(UNIX_EPOCH)
1082 .map(|d| d.as_millis() as u64)
1083 .unwrap_or(0);
1084 project_state.record_user_action(UserActionRecord {
1085 action_type,
1086 buffer_id: buffer.entity_id(),
1087 line_number: point.row,
1088 offset,
1089 timestamp_epoch_ms,
1090 });
1091 }
1092 }
1093
1094 let events = &mut project_state.events;
1095
1096 let now = cx.background_executor().now();
1097 if let Some(last_event) = project_state.last_event.as_mut() {
1098 let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
1099 == last_event.new_snapshot.remote_id()
1100 && old_snapshot.version == last_event.new_snapshot.version;
1101
1102 let should_coalesce = is_next_snapshot_of_same_buffer
1103 && edit_range
1104 .as_ref()
1105 .zip(last_event.edit_range.as_ref())
1106 .is_some_and(|(a, b)| {
1107 let a = a.to_point(&new_snapshot);
1108 let b = b.to_point(&new_snapshot);
1109 if a.start > b.end {
1110 a.start.row.abs_diff(b.end.row) <= CHANGE_GROUPING_LINE_SPAN
1111 } else if b.start > a.end {
1112 b.start.row.abs_diff(a.end.row) <= CHANGE_GROUPING_LINE_SPAN
1113 } else {
1114 true
1115 }
1116 });
1117
1118 if should_coalesce {
1119 let pause_elapsed = last_event
1120 .last_edit_time
1121 .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
1122 .unwrap_or(false);
1123 if pause_elapsed {
1124 last_event.snapshot_after_last_editing_pause =
1125 Some(last_event.new_snapshot.clone());
1126 }
1127
1128 last_event.edit_range = edit_range;
1129 last_event.new_snapshot = new_snapshot;
1130 last_event.last_edit_time = Some(now);
1131 return;
1132 }
1133 }
1134
1135 if let Some(event) = project_state.last_event.take() {
1136 if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1137 if events.len() + 1 >= EVENT_COUNT_MAX {
1138 events.pop_front();
1139 }
1140 events.push_back(event);
1141 }
1142 }
1143
1144 project_state.last_event = Some(LastEvent {
1145 old_file,
1146 new_file,
1147 old_snapshot,
1148 new_snapshot,
1149 edit_range,
1150 snapshot_after_last_editing_pause: None,
1151 last_edit_time: Some(now),
1152 });
1153 }
1154
1155 fn prediction_at(
1156 &mut self,
1157 buffer: &Entity<Buffer>,
1158 position: Option<language::Anchor>,
1159 project: &Entity<Project>,
1160 cx: &App,
1161 ) -> Option<BufferEditPrediction<'_>> {
1162 let project_state = self.projects.get_mut(&project.entity_id())?;
1163 if let Some(position) = position
1164 && let Some(buffer) = project_state
1165 .registered_buffers
1166 .get_mut(&buffer.entity_id())
1167 {
1168 buffer.last_position = Some(position);
1169 }
1170
1171 let CurrentEditPrediction {
1172 requested_by,
1173 prediction,
1174 ..
1175 } = project_state.current_prediction.as_ref()?;
1176
1177 if prediction.targets_buffer(buffer.read(cx)) {
1178 Some(BufferEditPrediction::Local { prediction })
1179 } else {
1180 let show_jump = match requested_by {
1181 PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1182 requested_by_buffer_id == &buffer.entity_id()
1183 }
1184 PredictionRequestedBy::DiagnosticsUpdate => true,
1185 };
1186
1187 if show_jump {
1188 Some(BufferEditPrediction::Jump { prediction })
1189 } else {
1190 None
1191 }
1192 }
1193 }
1194
1195 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1196 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1197 return;
1198 };
1199
1200 let Some(current_prediction) = project_state.current_prediction.take() else {
1201 return;
1202 };
1203
1204 for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1205 project_state.cancel_pending_prediction(pending_prediction, cx);
1206 }
1207
1208 match self.edit_prediction_model {
1209 EditPredictionModel::Sweep => {
1210 sweep_ai::edit_prediction_accepted(self, current_prediction, cx)
1211 }
1212 EditPredictionModel::Mercury => {
1213 mercury::edit_prediction_accepted(
1214 current_prediction.prediction.id,
1215 self.client.http_client(),
1216 cx,
1217 );
1218 }
1219 EditPredictionModel::Ollama => {}
1220 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
1221 zeta2::edit_prediction_accepted(self, current_prediction, cx)
1222 }
1223 }
1224 }
1225
1226 async fn handle_rejected_predictions(
1227 rx: UnboundedReceiver<EditPredictionRejection>,
1228 client: Arc<Client>,
1229 llm_token: LlmApiToken,
1230 app_version: Version,
1231 background_executor: BackgroundExecutor,
1232 ) {
1233 let mut rx = std::pin::pin!(rx.peekable());
1234 let mut batched = Vec::new();
1235
1236 while let Some(rejection) = rx.next().await {
1237 batched.push(rejection);
1238
1239 if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1240 select_biased! {
1241 next = rx.as_mut().peek().fuse() => {
1242 if next.is_some() {
1243 continue;
1244 }
1245 }
1246 () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1247 }
1248 }
1249
1250 let url = client
1251 .http_client()
1252 .build_zed_llm_url("/predict_edits/reject", &[])
1253 .unwrap();
1254
1255 let flush_count = batched
1256 .len()
1257 // in case items have accumulated after failure
1258 .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1259 let start = batched.len() - flush_count;
1260
1261 let body = RejectEditPredictionsBodyRef {
1262 rejections: &batched[start..],
1263 };
1264
1265 let result = Self::send_api_request::<()>(
1266 |builder| {
1267 let req = builder
1268 .uri(url.as_ref())
1269 .body(serde_json::to_string(&body)?.into());
1270 anyhow::Ok(req?)
1271 },
1272 client.clone(),
1273 llm_token.clone(),
1274 app_version.clone(),
1275 true,
1276 )
1277 .await;
1278
1279 if result.log_err().is_some() {
1280 batched.drain(start..);
1281 }
1282 }
1283 }
1284
1285 fn reject_current_prediction(
1286 &mut self,
1287 reason: EditPredictionRejectReason,
1288 project: &Entity<Project>,
1289 cx: &App,
1290 ) {
1291 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1292 project_state.pending_predictions.clear();
1293 if let Some(prediction) = project_state.current_prediction.take() {
1294 self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown, cx);
1295 }
1296 };
1297 }
1298
1299 fn did_show_current_prediction(
1300 &mut self,
1301 project: &Entity<Project>,
1302 display_type: edit_prediction_types::SuggestionDisplayType,
1303 cx: &mut Context<Self>,
1304 ) {
1305 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1306 return;
1307 };
1308
1309 let Some(current_prediction) = project_state.current_prediction.as_mut() else {
1310 return;
1311 };
1312
1313 let is_jump = display_type == edit_prediction_types::SuggestionDisplayType::Jump;
1314 let previous_shown_with = current_prediction.shown_with;
1315
1316 if previous_shown_with.is_none() || !is_jump {
1317 current_prediction.shown_with = Some(display_type);
1318 }
1319
1320 let is_first_non_jump_show = !current_prediction.was_shown && !is_jump;
1321
1322 if is_first_non_jump_show {
1323 current_prediction.was_shown = true;
1324 }
1325
1326 let display_type_changed = previous_shown_with != Some(display_type);
1327
1328 if self.edit_prediction_model == EditPredictionModel::Sweep && display_type_changed {
1329 sweep_ai::edit_prediction_shown(
1330 &self.sweep_ai,
1331 self.client.clone(),
1332 ¤t_prediction.prediction,
1333 display_type,
1334 cx,
1335 );
1336 }
1337
1338 if is_first_non_jump_show {
1339 self.shown_predictions
1340 .push_front(current_prediction.prediction.clone());
1341 if self.shown_predictions.len() > 50 {
1342 let completion = self.shown_predictions.pop_back().unwrap();
1343 self.rated_predictions.remove(&completion.id);
1344 }
1345 }
1346 }
1347
1348 fn reject_prediction(
1349 &mut self,
1350 prediction_id: EditPredictionId,
1351 reason: EditPredictionRejectReason,
1352 was_shown: bool,
1353 cx: &App,
1354 ) {
1355 match self.edit_prediction_model {
1356 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
1357 self.reject_predictions_tx
1358 .unbounded_send(EditPredictionRejection {
1359 request_id: prediction_id.to_string(),
1360 reason,
1361 was_shown,
1362 })
1363 .log_err();
1364 }
1365 EditPredictionModel::Sweep | EditPredictionModel::Ollama => {}
1366 EditPredictionModel::Mercury => {
1367 mercury::edit_prediction_rejected(
1368 prediction_id,
1369 was_shown,
1370 reason,
1371 self.client.http_client(),
1372 cx,
1373 );
1374 }
1375 }
1376 }
1377
1378 fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1379 self.projects
1380 .get(&project.entity_id())
1381 .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1382 }
1383
1384 pub fn refresh_prediction_from_buffer(
1385 &mut self,
1386 project: Entity<Project>,
1387 buffer: Entity<Buffer>,
1388 position: language::Anchor,
1389 cx: &mut Context<Self>,
1390 ) {
1391 self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
1392 let Some(request_task) = this
1393 .update(cx, |this, cx| {
1394 this.request_prediction(
1395 &project,
1396 &buffer,
1397 position,
1398 PredictEditsRequestTrigger::Other,
1399 cx,
1400 )
1401 })
1402 .log_err()
1403 else {
1404 return Task::ready(anyhow::Ok(None));
1405 };
1406
1407 cx.spawn(async move |_cx| {
1408 request_task.await.map(|prediction_result| {
1409 prediction_result.map(|prediction_result| {
1410 (
1411 prediction_result,
1412 PredictionRequestedBy::Buffer(buffer.entity_id()),
1413 )
1414 })
1415 })
1416 })
1417 })
1418 }
1419
1420 pub fn refresh_prediction_from_diagnostics(
1421 &mut self,
1422 project: Entity<Project>,
1423 cx: &mut Context<Self>,
1424 ) {
1425 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1426 return;
1427 };
1428
1429 // Prefer predictions from buffer
1430 if project_state.current_prediction.is_some() {
1431 return;
1432 };
1433
1434 self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
1435 let Some((active_buffer, snapshot, cursor_point)) = this
1436 .read_with(cx, |this, cx| {
1437 let project_state = this.projects.get(&project.entity_id())?;
1438 let (buffer, position) = project_state.active_buffer(&project, cx)?;
1439 let snapshot = buffer.read(cx).snapshot();
1440
1441 if !Self::predictions_enabled_at(&snapshot, position, cx) {
1442 return None;
1443 }
1444
1445 let cursor_point = position
1446 .map(|pos| pos.to_point(&snapshot))
1447 .unwrap_or_default();
1448
1449 Some((buffer, snapshot, cursor_point))
1450 })
1451 .log_err()
1452 .flatten()
1453 else {
1454 return Task::ready(anyhow::Ok(None));
1455 };
1456
1457 cx.spawn(async move |cx| {
1458 let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1459 active_buffer,
1460 &snapshot,
1461 Default::default(),
1462 cursor_point,
1463 &project,
1464 cx,
1465 )
1466 .await?
1467 else {
1468 return anyhow::Ok(None);
1469 };
1470
1471 let Some(prediction_result) = this
1472 .update(cx, |this, cx| {
1473 this.request_prediction(
1474 &project,
1475 &jump_buffer,
1476 jump_position,
1477 PredictEditsRequestTrigger::Diagnostics,
1478 cx,
1479 )
1480 })?
1481 .await?
1482 else {
1483 return anyhow::Ok(None);
1484 };
1485
1486 this.update(cx, |this, cx| {
1487 Some((
1488 if this
1489 .get_or_init_project(&project, cx)
1490 .current_prediction
1491 .is_none()
1492 {
1493 prediction_result
1494 } else {
1495 EditPredictionResult {
1496 id: prediction_result.id,
1497 prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1498 }
1499 },
1500 PredictionRequestedBy::DiagnosticsUpdate,
1501 ))
1502 })
1503 })
1504 });
1505 }
1506
1507 fn predictions_enabled_at(
1508 snapshot: &BufferSnapshot,
1509 position: Option<language::Anchor>,
1510 cx: &App,
1511 ) -> bool {
1512 let file = snapshot.file();
1513 let all_settings = all_language_settings(file, cx);
1514 if !all_settings.show_edit_predictions(snapshot.language(), cx)
1515 || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
1516 {
1517 return false;
1518 }
1519
1520 if let Some(last_position) = position {
1521 let settings = snapshot.settings_at(last_position, cx);
1522
1523 if !settings.edit_predictions_disabled_in.is_empty()
1524 && let Some(scope) = snapshot.language_scope_at(last_position)
1525 && let Some(scope_name) = scope.override_name()
1526 && settings
1527 .edit_predictions_disabled_in
1528 .iter()
1529 .any(|s| s == scope_name)
1530 {
1531 return false;
1532 }
1533 }
1534
1535 true
1536 }
1537
1538 #[cfg(not(test))]
1539 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1540 #[cfg(test)]
1541 pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
1542
1543 fn queue_prediction_refresh(
1544 &mut self,
1545 project: Entity<Project>,
1546 throttle_entity: EntityId,
1547 cx: &mut Context<Self>,
1548 do_refresh: impl FnOnce(
1549 WeakEntity<Self>,
1550 &mut AsyncApp,
1551 )
1552 -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1553 + 'static,
1554 ) {
1555 let is_ollama = self.edit_prediction_model == EditPredictionModel::Ollama;
1556 let drop_on_cancel = is_ollama;
1557 let max_pending_predictions = if is_ollama { 1 } else { 2 };
1558 let project_state = self.get_or_init_project(&project, cx);
1559 let pending_prediction_id = project_state.next_pending_prediction_id;
1560 project_state.next_pending_prediction_id += 1;
1561 let last_request = project_state.last_prediction_refresh;
1562
1563 let task = cx.spawn(async move |this, cx| {
1564 if let Some((last_entity, last_timestamp)) = last_request
1565 && throttle_entity == last_entity
1566 && let Some(timeout) =
1567 (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
1568 {
1569 cx.background_executor().timer(timeout).await;
1570 }
1571
1572 // If this task was cancelled before the throttle timeout expired,
1573 // do not perform a request.
1574 let mut is_cancelled = true;
1575 this.update(cx, |this, cx| {
1576 let project_state = this.get_or_init_project(&project, cx);
1577 if !project_state
1578 .cancelled_predictions
1579 .remove(&pending_prediction_id)
1580 {
1581 project_state.last_prediction_refresh = Some((throttle_entity, Instant::now()));
1582 is_cancelled = false;
1583 }
1584 })
1585 .ok();
1586 if is_cancelled {
1587 return None;
1588 }
1589
1590 let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
1591 let new_prediction_id = new_prediction_result
1592 .as_ref()
1593 .map(|(prediction, _)| prediction.id.clone());
1594
1595 // When a prediction completes, remove it from the pending list, and cancel
1596 // any pending predictions that were enqueued before it.
1597 this.update(cx, |this, cx| {
1598 let project_state = this.get_or_init_project(&project, cx);
1599
1600 let is_cancelled = project_state
1601 .cancelled_predictions
1602 .remove(&pending_prediction_id);
1603
1604 let new_current_prediction = if !is_cancelled
1605 && let Some((prediction_result, requested_by)) = new_prediction_result
1606 {
1607 match prediction_result.prediction {
1608 Ok(prediction) => {
1609 let new_prediction = CurrentEditPrediction {
1610 requested_by,
1611 prediction,
1612 was_shown: false,
1613 shown_with: None,
1614 };
1615
1616 if let Some(current_prediction) =
1617 project_state.current_prediction.as_ref()
1618 {
1619 if new_prediction.should_replace_prediction(¤t_prediction, cx)
1620 {
1621 this.reject_current_prediction(
1622 EditPredictionRejectReason::Replaced,
1623 &project,
1624 cx,
1625 );
1626
1627 Some(new_prediction)
1628 } else {
1629 this.reject_prediction(
1630 new_prediction.prediction.id,
1631 EditPredictionRejectReason::CurrentPreferred,
1632 false,
1633 cx,
1634 );
1635 None
1636 }
1637 } else {
1638 Some(new_prediction)
1639 }
1640 }
1641 Err(reject_reason) => {
1642 this.reject_prediction(prediction_result.id, reject_reason, false, cx);
1643 None
1644 }
1645 }
1646 } else {
1647 None
1648 };
1649
1650 let project_state = this.get_or_init_project(&project, cx);
1651
1652 if let Some(new_prediction) = new_current_prediction {
1653 project_state.current_prediction = Some(new_prediction);
1654 }
1655
1656 let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
1657 for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1658 if pending_prediction.id == pending_prediction_id {
1659 pending_predictions.remove(ix);
1660 for pending_prediction in pending_predictions.drain(0..ix) {
1661 project_state.cancel_pending_prediction(pending_prediction, cx)
1662 }
1663 break;
1664 }
1665 }
1666 this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
1667 cx.notify();
1668 })
1669 .ok();
1670
1671 new_prediction_id
1672 });
1673
1674 if project_state.pending_predictions.len() < max_pending_predictions {
1675 project_state.pending_predictions.push(PendingPrediction {
1676 id: pending_prediction_id,
1677 task,
1678 drop_on_cancel,
1679 });
1680 } else {
1681 let pending_prediction = project_state.pending_predictions.pop().unwrap();
1682 project_state.pending_predictions.push(PendingPrediction {
1683 id: pending_prediction_id,
1684 task,
1685 drop_on_cancel,
1686 });
1687 project_state.cancel_pending_prediction(pending_prediction, cx);
1688 }
1689 }
1690
1691 pub fn request_prediction(
1692 &mut self,
1693 project: &Entity<Project>,
1694 active_buffer: &Entity<Buffer>,
1695 position: language::Anchor,
1696 trigger: PredictEditsRequestTrigger,
1697 cx: &mut Context<Self>,
1698 ) -> Task<Result<Option<EditPredictionResult>>> {
1699 self.request_prediction_internal(
1700 project.clone(),
1701 active_buffer.clone(),
1702 position,
1703 trigger,
1704 cx.has_flag::<Zeta2FeatureFlag>(),
1705 cx,
1706 )
1707 }
1708
1709 fn request_prediction_internal(
1710 &mut self,
1711 project: Entity<Project>,
1712 active_buffer: Entity<Buffer>,
1713 position: language::Anchor,
1714 trigger: PredictEditsRequestTrigger,
1715 allow_jump: bool,
1716 cx: &mut Context<Self>,
1717 ) -> Task<Result<Option<EditPredictionResult>>> {
1718 const DIAGNOSTIC_LINES_RANGE: u32 = 20;
1719
1720 self.get_or_init_project(&project, cx);
1721 let project_state = self.projects.get(&project.entity_id()).unwrap();
1722 let stored_events = project_state.events_split_by_pause(cx);
1723 let has_events = !stored_events.is_empty();
1724 let events: Vec<Arc<zeta_prompt::Event>> =
1725 stored_events.into_iter().map(|e| e.event).collect();
1726 let debug_tx = project_state.debug_tx.clone();
1727
1728 let snapshot = active_buffer.read(cx).snapshot();
1729 let cursor_point = position.to_point(&snapshot);
1730 let current_offset = position.to_offset(&snapshot);
1731
1732 let mut user_actions: Vec<UserActionRecord> =
1733 project_state.user_actions.iter().cloned().collect();
1734
1735 if let Some(last_action) = user_actions.last() {
1736 if last_action.buffer_id == active_buffer.entity_id()
1737 && current_offset != last_action.offset
1738 {
1739 let timestamp_epoch_ms = SystemTime::now()
1740 .duration_since(UNIX_EPOCH)
1741 .map(|d| d.as_millis() as u64)
1742 .unwrap_or(0);
1743 user_actions.push(UserActionRecord {
1744 action_type: UserActionType::CursorMovement,
1745 buffer_id: active_buffer.entity_id(),
1746 line_number: cursor_point.row,
1747 offset: current_offset,
1748 timestamp_epoch_ms,
1749 });
1750 }
1751 }
1752 let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1753 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1754 let diagnostic_search_range =
1755 Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1756
1757 let related_files = self.context_for_project(&project, cx);
1758
1759 let inputs = EditPredictionModelInput {
1760 project: project.clone(),
1761 buffer: active_buffer.clone(),
1762 snapshot: snapshot.clone(),
1763 position,
1764 events,
1765 related_files,
1766 recent_paths: project_state.recent_paths.clone(),
1767 trigger,
1768 diagnostic_search_range: diagnostic_search_range.clone(),
1769 debug_tx,
1770 user_actions,
1771 };
1772
1773 let task = match &self.edit_prediction_model {
1774 EditPredictionModel::Zeta1 => {
1775 if should_send_testing_zeta2_request() {
1776 let mut zeta2_inputs = inputs.clone();
1777 zeta2_inputs.trigger = PredictEditsRequestTrigger::Testing;
1778 zeta2::request_prediction_with_zeta2(self, zeta2_inputs, cx).detach();
1779 }
1780 zeta1::request_prediction_with_zeta1(self, inputs, cx)
1781 }
1782 EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2(self, inputs, cx),
1783 EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
1784 EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
1785 EditPredictionModel::Ollama => self.ollama.request_prediction(inputs, cx),
1786 };
1787
1788 cx.spawn(async move |this, cx| {
1789 let prediction = task.await?;
1790
1791 if prediction.is_none() && allow_jump {
1792 let cursor_point = position.to_point(&snapshot);
1793 if has_events
1794 && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1795 active_buffer.clone(),
1796 &snapshot,
1797 diagnostic_search_range,
1798 cursor_point,
1799 &project,
1800 cx,
1801 )
1802 .await?
1803 {
1804 return this
1805 .update(cx, |this, cx| {
1806 this.request_prediction_internal(
1807 project,
1808 jump_buffer,
1809 jump_position,
1810 trigger,
1811 false,
1812 cx,
1813 )
1814 })?
1815 .await;
1816 }
1817
1818 return anyhow::Ok(None);
1819 }
1820
1821 Ok(prediction)
1822 })
1823 }
1824
1825 async fn next_diagnostic_location(
1826 active_buffer: Entity<Buffer>,
1827 active_buffer_snapshot: &BufferSnapshot,
1828 active_buffer_diagnostic_search_range: Range<Point>,
1829 active_buffer_cursor_point: Point,
1830 project: &Entity<Project>,
1831 cx: &mut AsyncApp,
1832 ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1833 // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
1834 let mut jump_location = active_buffer_snapshot
1835 .diagnostic_groups(None)
1836 .into_iter()
1837 .filter_map(|(_, group)| {
1838 let range = &group.entries[group.primary_ix]
1839 .range
1840 .to_point(&active_buffer_snapshot);
1841 if range.overlaps(&active_buffer_diagnostic_search_range) {
1842 None
1843 } else {
1844 Some(range.start)
1845 }
1846 })
1847 .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
1848 .map(|position| {
1849 (
1850 active_buffer.clone(),
1851 active_buffer_snapshot.anchor_before(position),
1852 )
1853 });
1854
1855 if jump_location.is_none() {
1856 let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
1857 let file = buffer.file()?;
1858
1859 Some(ProjectPath {
1860 worktree_id: file.worktree_id(cx),
1861 path: file.path().clone(),
1862 })
1863 });
1864
1865 let buffer_task = project.update(cx, |project, cx| {
1866 let (path, _, _) = project
1867 .diagnostic_summaries(false, cx)
1868 .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
1869 .max_by_key(|(path, _, _)| {
1870 // find the buffer with errors that shares most parent directories
1871 path.path
1872 .components()
1873 .zip(
1874 active_buffer_path
1875 .as_ref()
1876 .map(|p| p.path.components())
1877 .unwrap_or_default(),
1878 )
1879 .take_while(|(a, b)| a == b)
1880 .count()
1881 })?;
1882
1883 Some(project.open_buffer(path, cx))
1884 });
1885
1886 if let Some(buffer_task) = buffer_task {
1887 let closest_buffer = buffer_task.await?;
1888
1889 jump_location = closest_buffer
1890 .read_with(cx, |buffer, _cx| {
1891 buffer
1892 .buffer_diagnostics(None)
1893 .into_iter()
1894 .min_by_key(|entry| entry.diagnostic.severity)
1895 .map(|entry| entry.range.start)
1896 })
1897 .map(|position| (closest_buffer, position));
1898 }
1899 }
1900
1901 anyhow::Ok(jump_location)
1902 }
1903
1904 async fn send_raw_llm_request(
1905 request: RawCompletionRequest,
1906 client: Arc<Client>,
1907 custom_url: Option<Arc<Url>>,
1908 llm_token: LlmApiToken,
1909 app_version: Version,
1910 ) -> Result<(RawCompletionResponse, Option<EditPredictionUsage>)> {
1911 let url = if let Some(custom_url) = custom_url {
1912 custom_url.as_ref().clone()
1913 } else {
1914 client
1915 .http_client()
1916 .build_zed_llm_url("/predict_edits/raw", &[])?
1917 };
1918
1919 Self::send_api_request(
1920 |builder| {
1921 let req = builder
1922 .uri(url.as_ref())
1923 .body(serde_json::to_string(&request)?.into());
1924 Ok(req?)
1925 },
1926 client,
1927 llm_token,
1928 app_version,
1929 true,
1930 )
1931 .await
1932 }
1933
1934 pub(crate) async fn send_v3_request(
1935 input: ZetaPromptInput,
1936 client: Arc<Client>,
1937 llm_token: LlmApiToken,
1938 app_version: Version,
1939 trigger: PredictEditsRequestTrigger,
1940 ) -> Result<(PredictEditsV3Response, Option<EditPredictionUsage>)> {
1941 let url = client
1942 .http_client()
1943 .build_zed_llm_url("/predict_edits/v3", &[])?;
1944
1945 let request = PredictEditsV3Request { input, trigger };
1946
1947 let json_bytes = serde_json::to_vec(&request)?;
1948 let compressed = zstd::encode_all(&json_bytes[..], 3)?;
1949
1950 Self::send_api_request(
1951 |builder| {
1952 let req = builder
1953 .uri(url.as_ref())
1954 .header("Content-Encoding", "zstd")
1955 .body(compressed.clone().into());
1956 Ok(req?)
1957 },
1958 client,
1959 llm_token,
1960 app_version,
1961 true,
1962 )
1963 .await
1964 }
1965
1966 fn handle_api_response<T>(
1967 this: &WeakEntity<Self>,
1968 response: Result<(T, Option<EditPredictionUsage>)>,
1969 cx: &mut gpui::AsyncApp,
1970 ) -> Result<T> {
1971 match response {
1972 Ok((data, usage)) => {
1973 if let Some(usage) = usage {
1974 this.update(cx, |this, cx| {
1975 this.user_store.update(cx, |user_store, cx| {
1976 user_store.update_edit_prediction_usage(usage, cx);
1977 });
1978 })
1979 .ok();
1980 }
1981 Ok(data)
1982 }
1983 Err(err) => {
1984 if err.is::<ZedUpdateRequiredError>() {
1985 cx.update(|cx| {
1986 this.update(cx, |this, _cx| {
1987 this.update_required = true;
1988 })
1989 .ok();
1990
1991 let error_message: SharedString = err.to_string().into();
1992 show_app_notification(
1993 NotificationId::unique::<ZedUpdateRequiredError>(),
1994 cx,
1995 move |cx| {
1996 cx.new(|cx| {
1997 ErrorMessagePrompt::new(error_message.clone(), cx)
1998 .with_link_button("Update Zed", "https://zed.dev/releases")
1999 })
2000 },
2001 );
2002 });
2003 }
2004 Err(err)
2005 }
2006 }
2007 }
2008
2009 async fn send_api_request<Res>(
2010 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
2011 client: Arc<Client>,
2012 llm_token: LlmApiToken,
2013 app_version: Version,
2014 require_auth: bool,
2015 ) -> Result<(Res, Option<EditPredictionUsage>)>
2016 where
2017 Res: DeserializeOwned,
2018 {
2019 let http_client = client.http_client();
2020
2021 let mut token = if let Ok(custom_token) = std::env::var("ZED_PREDICT_EDITS_TOKEN") {
2022 Some(custom_token)
2023 } else if require_auth {
2024 Some(llm_token.acquire(&client).await?)
2025 } else {
2026 llm_token.acquire(&client).await.ok()
2027 };
2028 let mut did_retry = false;
2029
2030 loop {
2031 let request_builder = http_client::Request::builder().method(Method::POST);
2032
2033 let mut request_builder = request_builder
2034 .header("Content-Type", "application/json")
2035 .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
2036
2037 // Only add Authorization header if we have a token
2038 if let Some(ref token_value) = token {
2039 request_builder =
2040 request_builder.header("Authorization", format!("Bearer {}", token_value));
2041 }
2042
2043 let request = build(request_builder)?;
2044
2045 let mut response = http_client.send(request).await?;
2046
2047 if let Some(minimum_required_version) = response
2048 .headers()
2049 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2050 .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2051 {
2052 anyhow::ensure!(
2053 app_version >= minimum_required_version,
2054 ZedUpdateRequiredError {
2055 minimum_version: minimum_required_version
2056 }
2057 );
2058 }
2059
2060 if response.status().is_success() {
2061 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2062
2063 let mut body = Vec::new();
2064 response.body_mut().read_to_end(&mut body).await?;
2065 return Ok((serde_json::from_slice(&body)?, usage));
2066 } else if !did_retry && token.is_some() && response.needs_llm_token_refresh() {
2067 did_retry = true;
2068 token = Some(llm_token.refresh(&client).await?);
2069 } else {
2070 let mut body = String::new();
2071 response.body_mut().read_to_string(&mut body).await?;
2072 anyhow::bail!(
2073 "Request failed with status: {:?}\nBody: {}",
2074 response.status(),
2075 body
2076 );
2077 }
2078 }
2079 }
2080
2081 pub fn refresh_context(
2082 &mut self,
2083 project: &Entity<Project>,
2084 buffer: &Entity<language::Buffer>,
2085 cursor_position: language::Anchor,
2086 cx: &mut Context<Self>,
2087 ) {
2088 self.get_or_init_project(project, cx)
2089 .context
2090 .update(cx, |store, cx| {
2091 store.refresh(buffer.clone(), cursor_position, cx);
2092 });
2093 }
2094
2095 #[cfg(feature = "cli-support")]
2096 pub fn set_context_for_buffer(
2097 &mut self,
2098 project: &Entity<Project>,
2099 related_files: Vec<RelatedFile>,
2100 cx: &mut Context<Self>,
2101 ) {
2102 self.get_or_init_project(project, cx)
2103 .context
2104 .update(cx, |store, cx| {
2105 store.set_related_files(related_files, cx);
2106 });
2107 }
2108
2109 #[cfg(feature = "cli-support")]
2110 pub fn set_recent_paths_for_project(
2111 &mut self,
2112 project: &Entity<Project>,
2113 paths: impl IntoIterator<Item = project::ProjectPath>,
2114 cx: &mut Context<Self>,
2115 ) {
2116 let project_state = self.get_or_init_project(project, cx);
2117 project_state.recent_paths = paths.into_iter().collect();
2118 }
2119
2120 fn is_file_open_source(
2121 &self,
2122 project: &Entity<Project>,
2123 file: &Arc<dyn File>,
2124 cx: &App,
2125 ) -> bool {
2126 if !file.is_local() || file.is_private() {
2127 return false;
2128 }
2129 let Some(project_state) = self.projects.get(&project.entity_id()) else {
2130 return false;
2131 };
2132 project_state
2133 .license_detection_watchers
2134 .get(&file.worktree_id(cx))
2135 .as_ref()
2136 .is_some_and(|watcher| watcher.is_project_open_source())
2137 }
2138
2139 fn can_collect_file(&self, project: &Entity<Project>, file: &Arc<dyn File>, cx: &App) -> bool {
2140 self.data_collection_choice.is_enabled(cx) && self.is_file_open_source(project, file, cx)
2141 }
2142
2143 fn can_collect_events(&self, events: &[Arc<zeta_prompt::Event>], cx: &App) -> bool {
2144 if !self.data_collection_choice.is_enabled(cx) {
2145 return false;
2146 }
2147 events.iter().all(|event| {
2148 matches!(
2149 event.as_ref(),
2150 zeta_prompt::Event::BufferChange {
2151 in_open_source_repo: true,
2152 ..
2153 }
2154 )
2155 })
2156 }
2157
2158 fn load_data_collection_choice() -> DataCollectionChoice {
2159 let choice = KEY_VALUE_STORE
2160 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2161 .log_err()
2162 .flatten();
2163
2164 match choice.as_deref() {
2165 Some("true") => DataCollectionChoice::Enabled,
2166 Some("false") => DataCollectionChoice::Disabled,
2167 Some(_) => {
2168 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2169 DataCollectionChoice::NotAnswered
2170 }
2171 None => DataCollectionChoice::NotAnswered,
2172 }
2173 }
2174
2175 fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
2176 self.data_collection_choice = self.data_collection_choice.toggle();
2177 let new_choice = self.data_collection_choice;
2178 let is_enabled = new_choice.is_enabled(cx);
2179 db::write_and_log(cx, move || {
2180 KEY_VALUE_STORE.write_kvp(
2181 ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
2182 is_enabled.to_string(),
2183 )
2184 });
2185 }
2186
2187 pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2188 self.shown_predictions.iter()
2189 }
2190
2191 pub fn shown_completions_len(&self) -> usize {
2192 self.shown_predictions.len()
2193 }
2194
2195 pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2196 self.rated_predictions.contains(id)
2197 }
2198
2199 pub fn rate_prediction(
2200 &mut self,
2201 prediction: &EditPrediction,
2202 rating: EditPredictionRating,
2203 feedback: String,
2204 cx: &mut Context<Self>,
2205 ) {
2206 self.rated_predictions.insert(prediction.id.clone());
2207 telemetry::event!(
2208 "Edit Prediction Rated",
2209 request_id = prediction.id.to_string(),
2210 rating,
2211 inputs = prediction.inputs,
2212 output = prediction
2213 .edit_preview
2214 .as_unified_diff(prediction.snapshot.file(), &prediction.edits),
2215 feedback
2216 );
2217 self.client.telemetry().flush_events().detach();
2218 cx.notify();
2219 }
2220}
2221
2222pub(crate) fn filter_redundant_excerpts(
2223 mut related_files: Vec<RelatedFile>,
2224 cursor_path: &Path,
2225 cursor_row_range: Range<u32>,
2226) -> Vec<RelatedFile> {
2227 for file in &mut related_files {
2228 if file.path.as_ref() == cursor_path {
2229 file.excerpts.retain(|excerpt| {
2230 excerpt.row_range.start < cursor_row_range.start
2231 || excerpt.row_range.end > cursor_row_range.end
2232 });
2233 }
2234 }
2235 related_files.retain(|file| !file.excerpts.is_empty());
2236 related_files
2237}
2238
2239#[derive(Error, Debug)]
2240#[error(
2241 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2242)]
2243pub struct ZedUpdateRequiredError {
2244 minimum_version: Version,
2245}
2246
2247#[derive(Debug, Clone, Copy)]
2248pub enum DataCollectionChoice {
2249 NotAnswered,
2250 Enabled,
2251 Disabled,
2252}
2253
2254impl DataCollectionChoice {
2255 pub fn is_enabled(self, cx: &App) -> bool {
2256 if cx.is_staff() {
2257 return true;
2258 }
2259 match self {
2260 Self::Enabled => true,
2261 Self::NotAnswered | Self::Disabled => false,
2262 }
2263 }
2264
2265 #[must_use]
2266 pub fn toggle(&self) -> DataCollectionChoice {
2267 match self {
2268 Self::Enabled => Self::Disabled,
2269 Self::Disabled => Self::Enabled,
2270 Self::NotAnswered => Self::Enabled,
2271 }
2272 }
2273}
2274
2275impl From<bool> for DataCollectionChoice {
2276 fn from(value: bool) -> Self {
2277 match value {
2278 true => DataCollectionChoice::Enabled,
2279 false => DataCollectionChoice::Disabled,
2280 }
2281 }
2282}
2283
2284struct ZedPredictUpsell;
2285
2286impl Dismissable for ZedPredictUpsell {
2287 const KEY: &'static str = "dismissed-edit-predict-upsell";
2288
2289 fn dismissed() -> bool {
2290 // To make this backwards compatible with older versions of Zed, we
2291 // check if the user has seen the previous Edit Prediction Onboarding
2292 // before, by checking the data collection choice which was written to
2293 // the database once the user clicked on "Accept and Enable"
2294 if KEY_VALUE_STORE
2295 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2296 .log_err()
2297 .is_some_and(|s| s.is_some())
2298 {
2299 return true;
2300 }
2301
2302 KEY_VALUE_STORE
2303 .read_kvp(Self::KEY)
2304 .log_err()
2305 .is_some_and(|s| s.is_some())
2306 }
2307}
2308
2309pub fn should_show_upsell_modal() -> bool {
2310 !ZedPredictUpsell::dismissed()
2311}
2312
2313pub fn init(cx: &mut App) {
2314 cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2315 workspace.register_action(
2316 move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2317 ZedPredictModal::toggle(
2318 workspace,
2319 workspace.user_store().clone(),
2320 workspace.client().clone(),
2321 window,
2322 cx,
2323 )
2324 },
2325 );
2326
2327 workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2328 update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2329 settings
2330 .project
2331 .all_languages
2332 .edit_predictions
2333 .get_or_insert_default()
2334 .provider = Some(EditPredictionProvider::None)
2335 });
2336 });
2337 fn copilot_for_project(project: &Entity<Project>, cx: &mut App) -> Option<Entity<Copilot>> {
2338 EditPredictionStore::try_global(cx).and_then(|store| {
2339 store.update(cx, |this, cx| this.start_copilot_for_project(project, cx))
2340 })
2341 }
2342
2343 workspace.register_action(|workspace, _: &SignIn, window, cx| {
2344 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2345 copilot_ui::initiate_sign_in(copilot, window, cx);
2346 }
2347 });
2348 workspace.register_action(|workspace, _: &Reinstall, window, cx| {
2349 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2350 copilot_ui::reinstall_and_sign_in(copilot, window, cx);
2351 }
2352 });
2353 workspace.register_action(|workspace, _: &SignOut, window, cx| {
2354 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2355 copilot_ui::initiate_sign_out(copilot, window, cx);
2356 }
2357 });
2358 })
2359 .detach();
2360}