1use anyhow::Result;
2use arrayvec::ArrayVec;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{
5 PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
6};
7use cloud_llm_client::{
8 EditPredictionRejectReason, EditPredictionRejection,
9 MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
10 PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME,
11};
12use collections::{HashMap, HashSet};
13use copilot::{Copilot, Reinstall, SignIn, SignOut};
14use db::kvp::{Dismissable, KEY_VALUE_STORE};
15use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
16use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
17use futures::{
18 AsyncReadExt as _, FutureExt as _, StreamExt as _,
19 channel::mpsc::{self, UnboundedReceiver},
20 select_biased,
21};
22use gpui::BackgroundExecutor;
23use gpui::http_client::Url;
24use gpui::{
25 App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
26 http_client::{self, AsyncBody, Method},
27 prelude::*,
28};
29use language::language_settings::all_language_settings;
30use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
31use language::{BufferSnapshot, OffsetRangeExt};
32use language_model::{LlmApiToken, NeedsLlmTokenRefresh, RefreshLlmTokenListener};
33use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
34use release_channel::AppVersion;
35use semver::Version;
36use serde::de::DeserializeOwned;
37use settings::{EditPredictionProvider, Settings as _, update_settings_file};
38use std::collections::{VecDeque, hash_map};
39use std::env;
40use text::Edit;
41use workspace::Workspace;
42use zeta_prompt::{ZetaFormat, ZetaPromptInput};
43
44use std::mem;
45use std::ops::Range;
46use std::path::Path;
47use std::rc::Rc;
48use std::str::FromStr as _;
49use std::sync::Arc;
50use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
51use thiserror::Error;
52use util::{RangeExt as _, ResultExt as _};
53use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
54
55pub mod cursor_excerpt;
56pub mod example_spec;
57mod license_detection;
58pub mod mercury;
59pub mod ollama;
60mod onboarding_modal;
61pub mod open_ai_response;
62mod prediction;
63pub mod sweep_ai;
64
65pub mod udiff;
66
67mod capture_example;
68mod zed_edit_prediction_delegate;
69pub mod zeta1;
70pub mod zeta2;
71
72#[cfg(test)]
73mod edit_prediction_tests;
74
75use crate::license_detection::LicenseDetectionWatcher;
76use crate::mercury::Mercury;
77use crate::ollama::Ollama;
78use crate::onboarding_modal::ZedPredictModal;
79pub use crate::prediction::EditPrediction;
80pub use crate::prediction::EditPredictionId;
81use crate::prediction::EditPredictionResult;
82pub use crate::sweep_ai::SweepAi;
83pub use capture_example::capture_example;
84pub use language_model::ApiKeyState;
85pub use telemetry_events::EditPredictionRating;
86pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
87
88actions!(
89 edit_prediction,
90 [
91 /// Resets the edit prediction onboarding state.
92 ResetOnboarding,
93 /// Clears the edit prediction history.
94 ClearHistory,
95 ]
96);
97
98/// Maximum number of events to track.
99const EVENT_COUNT_MAX: usize = 6;
100const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
101const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
102const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
103const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
104
105pub struct Zeta2FeatureFlag;
106
107impl FeatureFlag for Zeta2FeatureFlag {
108 const NAME: &'static str = "zeta2";
109
110 fn enabled_for_staff() -> bool {
111 true
112 }
113}
114
115#[derive(Clone)]
116struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
117
118impl Global for EditPredictionStoreGlobal {}
119
120/// Configuration for using the raw Zeta2 endpoint.
121/// When set, the client uses the raw endpoint and constructs the prompt itself.
122/// The version is also used as the Baseten environment name (lowercased).
123#[derive(Clone)]
124pub struct Zeta2RawConfig {
125 pub model_id: Option<String>,
126 pub format: ZetaFormat,
127}
128
129pub struct EditPredictionStore {
130 client: Arc<Client>,
131 user_store: Entity<UserStore>,
132 llm_token: LlmApiToken,
133 _llm_token_subscription: Subscription,
134 projects: HashMap<EntityId, ProjectState>,
135 update_required: bool,
136 edit_prediction_model: EditPredictionModel,
137 zeta2_raw_config: Option<Zeta2RawConfig>,
138 pub sweep_ai: SweepAi,
139 pub mercury: Mercury,
140 pub ollama: Ollama,
141 data_collection_choice: DataCollectionChoice,
142 reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
143 shown_predictions: VecDeque<EditPrediction>,
144 rated_predictions: HashSet<EditPredictionId>,
145}
146
147#[derive(Copy, Clone, Default, PartialEq, Eq)]
148pub enum EditPredictionModel {
149 #[default]
150 Zeta1,
151 Zeta2,
152 Sweep,
153 Mercury,
154 Ollama,
155}
156
157#[derive(Clone)]
158pub struct EditPredictionModelInput {
159 project: Entity<Project>,
160 buffer: Entity<Buffer>,
161 snapshot: BufferSnapshot,
162 position: Anchor,
163 events: Vec<Arc<zeta_prompt::Event>>,
164 related_files: Vec<RelatedFile>,
165 recent_paths: VecDeque<ProjectPath>,
166 trigger: PredictEditsRequestTrigger,
167 diagnostic_search_range: Range<Point>,
168 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
169 pub user_actions: Vec<UserActionRecord>,
170}
171
172#[derive(Debug)]
173pub enum DebugEvent {
174 ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
175 ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
176 EditPredictionStarted(EditPredictionStartedDebugEvent),
177 EditPredictionFinished(EditPredictionFinishedDebugEvent),
178}
179
180#[derive(Debug)]
181pub struct ContextRetrievalStartedDebugEvent {
182 pub project_entity_id: EntityId,
183 pub timestamp: Instant,
184 pub search_prompt: String,
185}
186
187#[derive(Debug)]
188pub struct ContextRetrievalFinishedDebugEvent {
189 pub project_entity_id: EntityId,
190 pub timestamp: Instant,
191 pub metadata: Vec<(&'static str, SharedString)>,
192}
193
194#[derive(Debug)]
195pub struct EditPredictionStartedDebugEvent {
196 pub buffer: WeakEntity<Buffer>,
197 pub position: Anchor,
198 pub prompt: Option<String>,
199}
200
201#[derive(Debug)]
202pub struct EditPredictionFinishedDebugEvent {
203 pub buffer: WeakEntity<Buffer>,
204 pub position: Anchor,
205 pub model_output: Option<String>,
206}
207
208const USER_ACTION_HISTORY_SIZE: usize = 16;
209
210#[derive(Clone, Debug)]
211pub struct UserActionRecord {
212 pub action_type: UserActionType,
213 pub buffer_id: EntityId,
214 pub line_number: u32,
215 pub offset: usize,
216 pub timestamp_epoch_ms: u64,
217}
218
219#[derive(Clone, Copy, Debug, PartialEq, Eq)]
220pub enum UserActionType {
221 InsertChar,
222 InsertSelection,
223 DeleteChar,
224 DeleteSelection,
225 CursorMovement,
226}
227
228/// An event with associated metadata for reconstructing buffer state.
229#[derive(Clone)]
230pub struct StoredEvent {
231 pub event: Arc<zeta_prompt::Event>,
232 pub old_snapshot: TextBufferSnapshot,
233 pub edit_range: Range<Anchor>,
234}
235
236impl StoredEvent {
237 fn can_merge(
238 &self,
239 next_old_event: &&&StoredEvent,
240 new_snapshot: &TextBufferSnapshot,
241 last_edit_range: &Range<Anchor>,
242 ) -> bool {
243 // Events must be for the same buffer
244 if self.old_snapshot.remote_id() != next_old_event.old_snapshot.remote_id() {
245 return false;
246 }
247 if self.old_snapshot.remote_id() != new_snapshot.remote_id() {
248 return false;
249 }
250
251 let a_is_predicted = matches!(
252 self.event.as_ref(),
253 zeta_prompt::Event::BufferChange {
254 predicted: true,
255 ..
256 }
257 );
258 let b_is_predicted = matches!(
259 next_old_event.event.as_ref(),
260 zeta_prompt::Event::BufferChange {
261 predicted: true,
262 ..
263 }
264 );
265
266 // If events come from the same source (both predicted or both manual) then
267 // we would have coalesced them already.
268 if a_is_predicted == b_is_predicted {
269 return false;
270 }
271
272 let left_range = self.edit_range.to_point(new_snapshot);
273 let right_range = next_old_event.edit_range.to_point(new_snapshot);
274 let latest_range = last_edit_range.to_point(&new_snapshot);
275
276 // Events near to the latest edit are not merged if their sources differ.
277 if lines_between_ranges(&left_range, &latest_range)
278 .min(lines_between_ranges(&right_range, &latest_range))
279 <= CHANGE_GROUPING_LINE_SPAN
280 {
281 return false;
282 }
283
284 // Events that are distant from each other are not merged.
285 if lines_between_ranges(&left_range, &right_range) > CHANGE_GROUPING_LINE_SPAN {
286 return false;
287 }
288
289 true
290 }
291}
292
293fn lines_between_ranges(left: &Range<Point>, right: &Range<Point>) -> u32 {
294 if left.start > right.end {
295 return left.start.row - right.end.row;
296 }
297 if right.start > left.end {
298 return right.start.row - left.end.row;
299 }
300 0
301}
302
303struct ProjectState {
304 events: VecDeque<StoredEvent>,
305 last_event: Option<LastEvent>,
306 recent_paths: VecDeque<ProjectPath>,
307 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
308 current_prediction: Option<CurrentEditPrediction>,
309 next_pending_prediction_id: usize,
310 pending_predictions: ArrayVec<PendingPrediction, 2>,
311 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
312 last_prediction_refresh: Option<(EntityId, Instant)>,
313 cancelled_predictions: HashSet<usize>,
314 context: Entity<RelatedExcerptStore>,
315 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
316 user_actions: VecDeque<UserActionRecord>,
317 _subscriptions: [gpui::Subscription; 2],
318 copilot: Option<Entity<Copilot>>,
319}
320
321impl ProjectState {
322 fn record_user_action(&mut self, action: UserActionRecord) {
323 if self.user_actions.len() >= USER_ACTION_HISTORY_SIZE {
324 self.user_actions.pop_front();
325 }
326 self.user_actions.push_back(action);
327 }
328
329 pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
330 self.events
331 .iter()
332 .cloned()
333 .chain(self.last_event.as_ref().iter().flat_map(|event| {
334 let (one, two) = event.split_by_pause();
335 let one = one.finalize(&self.license_detection_watchers, cx);
336 let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
337 one.into_iter().chain(two)
338 }))
339 .collect()
340 }
341
342 fn cancel_pending_prediction(
343 &mut self,
344 pending_prediction: PendingPrediction,
345 cx: &mut Context<EditPredictionStore>,
346 ) {
347 self.cancelled_predictions.insert(pending_prediction.id);
348
349 if pending_prediction.drop_on_cancel {
350 drop(pending_prediction.task);
351 } else {
352 cx.spawn(async move |this, cx| {
353 let Some(prediction_id) = pending_prediction.task.await else {
354 return;
355 };
356
357 this.update(cx, |this, cx| {
358 this.reject_prediction(
359 prediction_id,
360 EditPredictionRejectReason::Canceled,
361 false,
362 cx,
363 );
364 })
365 .ok();
366 })
367 .detach()
368 }
369 }
370
371 fn active_buffer(
372 &self,
373 project: &Entity<Project>,
374 cx: &App,
375 ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
376 let project = project.read(cx);
377 let active_path = project.path_for_entry(project.active_entry()?, cx)?;
378 let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
379 let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
380 Some((active_buffer, registered_buffer.last_position))
381 }
382}
383
384#[derive(Debug, Clone)]
385struct CurrentEditPrediction {
386 pub requested_by: PredictionRequestedBy,
387 pub prediction: EditPrediction,
388 pub was_shown: bool,
389 pub shown_with: Option<edit_prediction_types::SuggestionDisplayType>,
390}
391
392impl CurrentEditPrediction {
393 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
394 let Some(new_edits) = self
395 .prediction
396 .interpolate(&self.prediction.buffer.read(cx))
397 else {
398 return false;
399 };
400
401 if self.prediction.buffer != old_prediction.prediction.buffer {
402 return true;
403 }
404
405 let Some(old_edits) = old_prediction
406 .prediction
407 .interpolate(&old_prediction.prediction.buffer.read(cx))
408 else {
409 return true;
410 };
411
412 let requested_by_buffer_id = self.requested_by.buffer_id();
413
414 // This reduces the occurrence of UI thrash from replacing edits
415 //
416 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
417 if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
418 && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
419 && old_edits.len() == 1
420 && new_edits.len() == 1
421 {
422 let (old_range, old_text) = &old_edits[0];
423 let (new_range, new_text) = &new_edits[0];
424 new_range == old_range && new_text.starts_with(old_text.as_ref())
425 } else {
426 true
427 }
428 }
429}
430
431#[derive(Debug, Clone)]
432enum PredictionRequestedBy {
433 DiagnosticsUpdate,
434 Buffer(EntityId),
435}
436
437impl PredictionRequestedBy {
438 pub fn buffer_id(&self) -> Option<EntityId> {
439 match self {
440 PredictionRequestedBy::DiagnosticsUpdate => None,
441 PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
442 }
443 }
444}
445
446#[derive(Debug)]
447struct PendingPrediction {
448 id: usize,
449 task: Task<Option<EditPredictionId>>,
450 /// If true, the task is dropped immediately on cancel (cancelling the HTTP request).
451 /// If false, the task is awaited to completion so rejection can be reported.
452 drop_on_cancel: bool,
453}
454
455/// A prediction from the perspective of a buffer.
456#[derive(Debug)]
457enum BufferEditPrediction<'a> {
458 Local { prediction: &'a EditPrediction },
459 Jump { prediction: &'a EditPrediction },
460}
461
462#[cfg(test)]
463impl std::ops::Deref for BufferEditPrediction<'_> {
464 type Target = EditPrediction;
465
466 fn deref(&self) -> &Self::Target {
467 match self {
468 BufferEditPrediction::Local { prediction } => prediction,
469 BufferEditPrediction::Jump { prediction } => prediction,
470 }
471 }
472}
473
474struct RegisteredBuffer {
475 file: Option<Arc<dyn File>>,
476 snapshot: TextBufferSnapshot,
477 last_position: Option<Anchor>,
478 _subscriptions: [gpui::Subscription; 2],
479}
480
481#[derive(Clone)]
482struct LastEvent {
483 old_snapshot: TextBufferSnapshot,
484 new_snapshot: TextBufferSnapshot,
485 old_file: Option<Arc<dyn File>>,
486 new_file: Option<Arc<dyn File>>,
487 edit_range: Option<Range<Anchor>>,
488 predicted: bool,
489 snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
490 last_edit_time: Option<Instant>,
491}
492
493impl LastEvent {
494 pub fn finalize(
495 &self,
496 license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
497 cx: &App,
498 ) -> Option<StoredEvent> {
499 let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
500 let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
501
502 let in_open_source_repo =
503 [self.new_file.as_ref(), self.old_file.as_ref()]
504 .iter()
505 .all(|file| {
506 file.is_some_and(|file| {
507 license_detection_watchers
508 .get(&file.worktree_id(cx))
509 .is_some_and(|watcher| watcher.is_project_open_source())
510 })
511 });
512
513 let (diff, edit_range) =
514 compute_diff_between_snapshots(&self.old_snapshot, &self.new_snapshot)?;
515
516 if path == old_path && diff.is_empty() {
517 None
518 } else {
519 Some(StoredEvent {
520 event: Arc::new(zeta_prompt::Event::BufferChange {
521 old_path,
522 path,
523 diff,
524 in_open_source_repo,
525 predicted: self.predicted,
526 }),
527 edit_range: self.new_snapshot.anchor_before(edit_range.start)
528 ..self.new_snapshot.anchor_before(edit_range.end),
529 old_snapshot: self.old_snapshot.clone(),
530 })
531 }
532 }
533
534 pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
535 let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
536 return (self.clone(), None);
537 };
538
539 let before = LastEvent {
540 old_snapshot: self.old_snapshot.clone(),
541 new_snapshot: boundary_snapshot.clone(),
542 old_file: self.old_file.clone(),
543 new_file: self.new_file.clone(),
544 edit_range: None,
545 predicted: self.predicted,
546 snapshot_after_last_editing_pause: None,
547 last_edit_time: self.last_edit_time,
548 };
549
550 let after = LastEvent {
551 old_snapshot: boundary_snapshot.clone(),
552 new_snapshot: self.new_snapshot.clone(),
553 old_file: self.old_file.clone(),
554 new_file: self.new_file.clone(),
555 edit_range: None,
556 predicted: self.predicted,
557 snapshot_after_last_editing_pause: None,
558 last_edit_time: self.last_edit_time,
559 };
560
561 (before, Some(after))
562 }
563}
564
565pub(crate) fn compute_diff_between_snapshots(
566 old_snapshot: &TextBufferSnapshot,
567 new_snapshot: &TextBufferSnapshot,
568) -> Option<(String, Range<Point>)> {
569 let edits: Vec<Edit<usize>> = new_snapshot
570 .edits_since::<usize>(&old_snapshot.version)
571 .collect();
572
573 let (first_edit, last_edit) = edits.first().zip(edits.last())?;
574
575 let old_start_point = old_snapshot.offset_to_point(first_edit.old.start);
576 let old_end_point = old_snapshot.offset_to_point(last_edit.old.end);
577 let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
578 let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
579
580 const CONTEXT_LINES: u32 = 3;
581
582 let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
583 let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
584 let old_context_end_row =
585 (old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
586 let new_context_end_row =
587 (new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
588
589 let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
590 let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
591 let old_end_line_offset = old_snapshot
592 .point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
593 let new_end_line_offset = new_snapshot
594 .point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
595 let old_edit_range = old_start_line_offset..old_end_line_offset;
596 let new_edit_range = new_start_line_offset..new_end_line_offset;
597
598 let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
599 let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
600
601 let diff = language::unified_diff_with_offsets(
602 &old_region_text,
603 &new_region_text,
604 old_context_start_row,
605 new_context_start_row,
606 );
607
608 Some((diff, new_start_point..new_end_point))
609}
610
611fn buffer_path_with_id_fallback(
612 file: Option<&Arc<dyn File>>,
613 snapshot: &TextBufferSnapshot,
614 cx: &App,
615) -> Arc<Path> {
616 if let Some(file) = file {
617 file.full_path(cx).into()
618 } else {
619 Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
620 }
621}
622
623impl EditPredictionStore {
624 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
625 cx.try_global::<EditPredictionStoreGlobal>()
626 .map(|global| global.0.clone())
627 }
628
629 pub fn global(
630 client: &Arc<Client>,
631 user_store: &Entity<UserStore>,
632 cx: &mut App,
633 ) -> Entity<Self> {
634 cx.try_global::<EditPredictionStoreGlobal>()
635 .map(|global| global.0.clone())
636 .unwrap_or_else(|| {
637 let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
638 cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
639 ep_store
640 })
641 }
642
643 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
644 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
645 let data_collection_choice = Self::load_data_collection_choice();
646
647 let llm_token = LlmApiToken::default();
648
649 let (reject_tx, reject_rx) = mpsc::unbounded();
650 cx.background_spawn({
651 let client = client.clone();
652 let llm_token = llm_token.clone();
653 let app_version = AppVersion::global(cx);
654 let background_executor = cx.background_executor().clone();
655 async move {
656 Self::handle_rejected_predictions(
657 reject_rx,
658 client,
659 llm_token,
660 app_version,
661 background_executor,
662 )
663 .await
664 }
665 })
666 .detach();
667
668 let this = Self {
669 projects: HashMap::default(),
670 client,
671 user_store,
672 llm_token,
673 _llm_token_subscription: cx.subscribe(
674 &refresh_llm_token_listener,
675 |this, _listener, _event, cx| {
676 let client = this.client.clone();
677 let llm_token = this.llm_token.clone();
678 cx.spawn(async move |_this, _cx| {
679 llm_token.refresh(&client).await?;
680 anyhow::Ok(())
681 })
682 .detach_and_log_err(cx);
683 },
684 ),
685 update_required: false,
686 edit_prediction_model: EditPredictionModel::Zeta2,
687 zeta2_raw_config: Self::zeta2_raw_config_from_env(),
688 sweep_ai: SweepAi::new(cx),
689 mercury: Mercury::new(cx),
690 ollama: Ollama::new(),
691
692 data_collection_choice,
693 reject_predictions_tx: reject_tx,
694 rated_predictions: Default::default(),
695 shown_predictions: Default::default(),
696 };
697
698 this
699 }
700
701 fn zeta2_raw_config_from_env() -> Option<Zeta2RawConfig> {
702 let version_str = env::var("ZED_ZETA_FORMAT").ok()?;
703 let format = ZetaFormat::parse(&version_str).ok()?;
704 let model_id = env::var("ZED_ZETA_MODEL").ok();
705 Some(Zeta2RawConfig { model_id, format })
706 }
707
708 pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
709 self.edit_prediction_model = model;
710 }
711
712 pub fn set_zeta2_raw_config(&mut self, config: Zeta2RawConfig) {
713 self.zeta2_raw_config = Some(config);
714 }
715
716 pub fn zeta2_raw_config(&self) -> Option<&Zeta2RawConfig> {
717 self.zeta2_raw_config.as_ref()
718 }
719
720 pub fn icons(&self) -> edit_prediction_types::EditPredictionIconSet {
721 use ui::IconName;
722 match self.edit_prediction_model {
723 EditPredictionModel::Sweep => {
724 edit_prediction_types::EditPredictionIconSet::new(IconName::SweepAi)
725 .with_disabled(IconName::SweepAiDisabled)
726 .with_up(IconName::SweepAiUp)
727 .with_down(IconName::SweepAiDown)
728 .with_error(IconName::SweepAiError)
729 }
730 EditPredictionModel::Mercury => {
731 edit_prediction_types::EditPredictionIconSet::new(IconName::Inception)
732 }
733 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
734 edit_prediction_types::EditPredictionIconSet::new(IconName::ZedPredict)
735 .with_disabled(IconName::ZedPredictDisabled)
736 .with_up(IconName::ZedPredictUp)
737 .with_down(IconName::ZedPredictDown)
738 .with_error(IconName::ZedPredictError)
739 }
740 EditPredictionModel::Ollama => {
741 edit_prediction_types::EditPredictionIconSet::new(IconName::AiOllama)
742 }
743 }
744 }
745
746 pub fn has_sweep_api_token(&self, cx: &App) -> bool {
747 self.sweep_ai.api_token.read(cx).has_key()
748 }
749
750 pub fn has_mercury_api_token(&self, cx: &App) -> bool {
751 self.mercury.api_token.read(cx).has_key()
752 }
753
754 pub fn clear_history(&mut self) {
755 for project_state in self.projects.values_mut() {
756 project_state.events.clear();
757 project_state.last_event.take();
758 }
759 }
760
761 pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
762 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
763 project_state.events.clear();
764 project_state.last_event.take();
765 }
766 }
767
768 pub fn edit_history_for_project(
769 &self,
770 project: &Entity<Project>,
771 cx: &App,
772 ) -> Vec<StoredEvent> {
773 self.projects
774 .get(&project.entity_id())
775 .map(|project_state| project_state.events(cx))
776 .unwrap_or_default()
777 }
778
779 pub fn context_for_project<'a>(
780 &'a self,
781 project: &Entity<Project>,
782 cx: &'a mut App,
783 ) -> Vec<RelatedFile> {
784 self.projects
785 .get(&project.entity_id())
786 .map(|project_state| {
787 project_state.context.update(cx, |context, cx| {
788 context
789 .related_files_with_buffers(cx)
790 .map(|(mut related_file, buffer)| {
791 related_file.in_open_source_repo = buffer
792 .read(cx)
793 .file()
794 .map_or(false, |file| self.is_file_open_source(&project, file, cx));
795 related_file
796 })
797 .collect()
798 })
799 })
800 .unwrap_or_default()
801 }
802
803 pub fn copilot_for_project(&self, project: &Entity<Project>) -> Option<Entity<Copilot>> {
804 self.projects
805 .get(&project.entity_id())
806 .and_then(|project| project.copilot.clone())
807 }
808
809 pub fn start_copilot_for_project(
810 &mut self,
811 project: &Entity<Project>,
812 cx: &mut Context<Self>,
813 ) -> Option<Entity<Copilot>> {
814 if DisableAiSettings::get(None, cx).disable_ai {
815 return None;
816 }
817 let state = self.get_or_init_project(project, cx);
818
819 if state.copilot.is_some() {
820 return state.copilot.clone();
821 }
822 let _project = project.clone();
823 let project = project.read(cx);
824
825 let node = project.node_runtime().cloned();
826 if let Some(node) = node {
827 let next_id = project.languages().next_language_server_id();
828 let fs = project.fs().clone();
829
830 let copilot = cx.new(|cx| Copilot::new(Some(_project), next_id, fs, node, cx));
831 state.copilot = Some(copilot.clone());
832 Some(copilot)
833 } else {
834 None
835 }
836 }
837
838 pub fn context_for_project_with_buffers<'a>(
839 &'a self,
840 project: &Entity<Project>,
841 cx: &'a mut App,
842 ) -> Vec<(RelatedFile, Entity<Buffer>)> {
843 self.projects
844 .get(&project.entity_id())
845 .map(|project| {
846 project.context.update(cx, |context, cx| {
847 context.related_files_with_buffers(cx).collect()
848 })
849 })
850 .unwrap_or_default()
851 }
852
853 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
854 if matches!(self.edit_prediction_model, EditPredictionModel::Zeta2) {
855 self.user_store.read(cx).edit_prediction_usage()
856 } else {
857 None
858 }
859 }
860
861 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
862 self.get_or_init_project(project, cx);
863 }
864
865 pub fn register_buffer(
866 &mut self,
867 buffer: &Entity<Buffer>,
868 project: &Entity<Project>,
869 cx: &mut Context<Self>,
870 ) {
871 let project_state = self.get_or_init_project(project, cx);
872 Self::register_buffer_impl(project_state, buffer, project, cx);
873 }
874
875 fn get_or_init_project(
876 &mut self,
877 project: &Entity<Project>,
878 cx: &mut Context<Self>,
879 ) -> &mut ProjectState {
880 let entity_id = project.entity_id();
881 self.projects
882 .entry(entity_id)
883 .or_insert_with(|| ProjectState {
884 context: {
885 let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
886 cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
887 this.handle_excerpt_store_event(entity_id, event);
888 })
889 .detach();
890 related_excerpt_store
891 },
892 events: VecDeque::new(),
893 last_event: None,
894 recent_paths: VecDeque::new(),
895 debug_tx: None,
896 registered_buffers: HashMap::default(),
897 current_prediction: None,
898 cancelled_predictions: HashSet::default(),
899 pending_predictions: ArrayVec::new(),
900 next_pending_prediction_id: 0,
901 last_prediction_refresh: None,
902 license_detection_watchers: HashMap::default(),
903 user_actions: VecDeque::with_capacity(USER_ACTION_HISTORY_SIZE),
904 _subscriptions: [
905 cx.subscribe(&project, Self::handle_project_event),
906 cx.observe_release(&project, move |this, _, cx| {
907 this.projects.remove(&entity_id);
908 cx.notify();
909 }),
910 ],
911 copilot: None,
912 })
913 }
914
915 pub fn remove_project(&mut self, project: &Entity<Project>) {
916 self.projects.remove(&project.entity_id());
917 }
918
919 fn handle_excerpt_store_event(
920 &mut self,
921 project_entity_id: EntityId,
922 event: &RelatedExcerptStoreEvent,
923 ) {
924 if let Some(project_state) = self.projects.get(&project_entity_id) {
925 if let Some(debug_tx) = project_state.debug_tx.clone() {
926 match event {
927 RelatedExcerptStoreEvent::StartedRefresh => {
928 debug_tx
929 .unbounded_send(DebugEvent::ContextRetrievalStarted(
930 ContextRetrievalStartedDebugEvent {
931 project_entity_id: project_entity_id,
932 timestamp: Instant::now(),
933 search_prompt: String::new(),
934 },
935 ))
936 .ok();
937 }
938 RelatedExcerptStoreEvent::FinishedRefresh {
939 cache_hit_count,
940 cache_miss_count,
941 mean_definition_latency,
942 max_definition_latency,
943 } => {
944 debug_tx
945 .unbounded_send(DebugEvent::ContextRetrievalFinished(
946 ContextRetrievalFinishedDebugEvent {
947 project_entity_id: project_entity_id,
948 timestamp: Instant::now(),
949 metadata: vec![
950 (
951 "Cache Hits",
952 format!(
953 "{}/{}",
954 cache_hit_count,
955 cache_hit_count + cache_miss_count
956 )
957 .into(),
958 ),
959 (
960 "Max LSP Time",
961 format!("{} ms", max_definition_latency.as_millis())
962 .into(),
963 ),
964 (
965 "Mean LSP Time",
966 format!("{} ms", mean_definition_latency.as_millis())
967 .into(),
968 ),
969 ],
970 },
971 ))
972 .ok();
973 }
974 }
975 }
976 }
977 }
978
979 pub fn debug_info(
980 &mut self,
981 project: &Entity<Project>,
982 cx: &mut Context<Self>,
983 ) -> mpsc::UnboundedReceiver<DebugEvent> {
984 let project_state = self.get_or_init_project(project, cx);
985 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
986 project_state.debug_tx = Some(debug_watch_tx);
987 debug_watch_rx
988 }
989
990 fn handle_project_event(
991 &mut self,
992 project: Entity<Project>,
993 event: &project::Event,
994 cx: &mut Context<Self>,
995 ) {
996 // TODO [zeta2] init with recent paths
997 match event {
998 project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
999 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1000 return;
1001 };
1002 let path = project.read(cx).path_for_entry(*active_entry_id, cx);
1003 if let Some(path) = path {
1004 if let Some(ix) = project_state
1005 .recent_paths
1006 .iter()
1007 .position(|probe| probe == &path)
1008 {
1009 project_state.recent_paths.remove(ix);
1010 }
1011 project_state.recent_paths.push_front(path);
1012 }
1013 }
1014 project::Event::DiagnosticsUpdated { .. } => {
1015 if cx.has_flag::<Zeta2FeatureFlag>() {
1016 self.refresh_prediction_from_diagnostics(project, cx);
1017 }
1018 }
1019 _ => (),
1020 }
1021 }
1022
1023 fn register_buffer_impl<'a>(
1024 project_state: &'a mut ProjectState,
1025 buffer: &Entity<Buffer>,
1026 project: &Entity<Project>,
1027 cx: &mut Context<Self>,
1028 ) -> &'a mut RegisteredBuffer {
1029 let buffer_id = buffer.entity_id();
1030
1031 if let Some(file) = buffer.read(cx).file() {
1032 let worktree_id = file.worktree_id(cx);
1033 if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
1034 project_state
1035 .license_detection_watchers
1036 .entry(worktree_id)
1037 .or_insert_with(|| {
1038 let project_entity_id = project.entity_id();
1039 cx.observe_release(&worktree, move |this, _worktree, _cx| {
1040 let Some(project_state) = this.projects.get_mut(&project_entity_id)
1041 else {
1042 return;
1043 };
1044 project_state
1045 .license_detection_watchers
1046 .remove(&worktree_id);
1047 })
1048 .detach();
1049 Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
1050 });
1051 }
1052 }
1053
1054 match project_state.registered_buffers.entry(buffer_id) {
1055 hash_map::Entry::Occupied(entry) => entry.into_mut(),
1056 hash_map::Entry::Vacant(entry) => {
1057 let buf = buffer.read(cx);
1058 let snapshot = buf.text_snapshot();
1059 let file = buf.file().cloned();
1060 let project_entity_id = project.entity_id();
1061 entry.insert(RegisteredBuffer {
1062 snapshot,
1063 file,
1064 last_position: None,
1065 _subscriptions: [
1066 cx.subscribe(buffer, {
1067 let project = project.downgrade();
1068 move |this, buffer, event, cx| {
1069 if let language::BufferEvent::Edited = event
1070 && let Some(project) = project.upgrade()
1071 {
1072 this.report_changes_for_buffer(&buffer, &project, false, cx);
1073 }
1074 }
1075 }),
1076 cx.observe_release(buffer, move |this, _buffer, _cx| {
1077 let Some(project_state) = this.projects.get_mut(&project_entity_id)
1078 else {
1079 return;
1080 };
1081 project_state.registered_buffers.remove(&buffer_id);
1082 }),
1083 ],
1084 })
1085 }
1086 }
1087 }
1088
1089 fn report_changes_for_buffer(
1090 &mut self,
1091 buffer: &Entity<Buffer>,
1092 project: &Entity<Project>,
1093 is_predicted: bool,
1094 cx: &mut Context<Self>,
1095 ) {
1096 let project_state = self.get_or_init_project(project, cx);
1097 let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
1098
1099 let buf = buffer.read(cx);
1100 let new_file = buf.file().cloned();
1101 let new_snapshot = buf.text_snapshot();
1102 if new_snapshot.version == registered_buffer.snapshot.version {
1103 return;
1104 }
1105
1106 let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
1107 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1108 let mut num_edits = 0usize;
1109 let mut total_deleted = 0usize;
1110 let mut total_inserted = 0usize;
1111 let mut edit_range: Option<Range<Anchor>> = None;
1112 let mut last_offset: Option<usize> = None;
1113
1114 for (edit, anchor_range) in
1115 new_snapshot.anchored_edits_since::<usize>(&old_snapshot.version)
1116 {
1117 num_edits += 1;
1118 total_deleted += edit.old.len();
1119 total_inserted += edit.new.len();
1120 edit_range = Some(match edit_range {
1121 None => anchor_range,
1122 Some(acc) => acc.start..anchor_range.end,
1123 });
1124 last_offset = Some(edit.new.end);
1125 }
1126
1127 let Some(edit_range) = edit_range else {
1128 return;
1129 };
1130
1131 let action_type = match (total_deleted, total_inserted, num_edits) {
1132 (0, ins, n) if ins == n => UserActionType::InsertChar,
1133 (0, _, _) => UserActionType::InsertSelection,
1134 (del, 0, n) if del == n => UserActionType::DeleteChar,
1135 (_, 0, _) => UserActionType::DeleteSelection,
1136 (_, ins, n) if ins == n => UserActionType::InsertChar,
1137 (_, _, _) => UserActionType::InsertSelection,
1138 };
1139
1140 if let Some(offset) = last_offset {
1141 let point = new_snapshot.offset_to_point(offset);
1142 let timestamp_epoch_ms = SystemTime::now()
1143 .duration_since(UNIX_EPOCH)
1144 .map(|d| d.as_millis() as u64)
1145 .unwrap_or(0);
1146 project_state.record_user_action(UserActionRecord {
1147 action_type,
1148 buffer_id: buffer.entity_id(),
1149 line_number: point.row,
1150 offset,
1151 timestamp_epoch_ms,
1152 });
1153 }
1154
1155 let events = &mut project_state.events;
1156
1157 let now = cx.background_executor().now();
1158 if let Some(last_event) = project_state.last_event.as_mut() {
1159 let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
1160 == last_event.new_snapshot.remote_id()
1161 && old_snapshot.version == last_event.new_snapshot.version;
1162
1163 let prediction_source_changed = is_predicted != last_event.predicted;
1164
1165 let should_coalesce = is_next_snapshot_of_same_buffer
1166 && !prediction_source_changed
1167 && last_event
1168 .edit_range
1169 .as_ref()
1170 .is_some_and(|last_edit_range| {
1171 lines_between_ranges(
1172 &edit_range.to_point(&new_snapshot),
1173 &last_edit_range.to_point(&new_snapshot),
1174 ) <= CHANGE_GROUPING_LINE_SPAN
1175 });
1176
1177 if should_coalesce {
1178 let pause_elapsed = last_event
1179 .last_edit_time
1180 .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
1181 .unwrap_or(false);
1182 if pause_elapsed {
1183 last_event.snapshot_after_last_editing_pause =
1184 Some(last_event.new_snapshot.clone());
1185 }
1186
1187 last_event.edit_range = Some(edit_range);
1188 last_event.new_snapshot = new_snapshot;
1189 last_event.last_edit_time = Some(now);
1190 return;
1191 }
1192 }
1193
1194 if let Some(event) = project_state.last_event.take() {
1195 if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1196 if events.len() + 1 >= EVENT_COUNT_MAX {
1197 events.pop_front();
1198 }
1199 events.push_back(event);
1200 }
1201 }
1202
1203 merge_trailing_events_if_needed(events, &old_snapshot, &new_snapshot, &edit_range);
1204
1205 project_state.last_event = Some(LastEvent {
1206 old_file,
1207 new_file,
1208 old_snapshot,
1209 new_snapshot,
1210 edit_range: Some(edit_range),
1211 predicted: is_predicted,
1212 snapshot_after_last_editing_pause: None,
1213 last_edit_time: Some(now),
1214 });
1215 }
1216
1217 fn prediction_at(
1218 &mut self,
1219 buffer: &Entity<Buffer>,
1220 position: Option<language::Anchor>,
1221 project: &Entity<Project>,
1222 cx: &App,
1223 ) -> Option<BufferEditPrediction<'_>> {
1224 let project_state = self.projects.get_mut(&project.entity_id())?;
1225 if let Some(position) = position
1226 && let Some(buffer) = project_state
1227 .registered_buffers
1228 .get_mut(&buffer.entity_id())
1229 {
1230 buffer.last_position = Some(position);
1231 }
1232
1233 let CurrentEditPrediction {
1234 requested_by,
1235 prediction,
1236 ..
1237 } = project_state.current_prediction.as_ref()?;
1238
1239 if prediction.targets_buffer(buffer.read(cx)) {
1240 Some(BufferEditPrediction::Local { prediction })
1241 } else {
1242 let show_jump = match requested_by {
1243 PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1244 requested_by_buffer_id == &buffer.entity_id()
1245 }
1246 PredictionRequestedBy::DiagnosticsUpdate => true,
1247 };
1248
1249 if show_jump {
1250 Some(BufferEditPrediction::Jump { prediction })
1251 } else {
1252 None
1253 }
1254 }
1255 }
1256
1257 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1258 let Some(current_prediction) = self
1259 .projects
1260 .get_mut(&project.entity_id())
1261 .and_then(|project_state| project_state.current_prediction.take())
1262 else {
1263 return;
1264 };
1265
1266 self.report_changes_for_buffer(¤t_prediction.prediction.buffer, project, true, cx);
1267
1268 // can't hold &mut project_state ref across report_changes_for_buffer_call
1269 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1270 return;
1271 };
1272
1273 for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1274 project_state.cancel_pending_prediction(pending_prediction, cx);
1275 }
1276
1277 match self.edit_prediction_model {
1278 EditPredictionModel::Sweep => {
1279 sweep_ai::edit_prediction_accepted(self, current_prediction, cx)
1280 }
1281 EditPredictionModel::Mercury => {
1282 mercury::edit_prediction_accepted(
1283 current_prediction.prediction.id,
1284 self.client.http_client(),
1285 cx,
1286 );
1287 }
1288 EditPredictionModel::Ollama => {}
1289 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
1290 zeta2::edit_prediction_accepted(self, current_prediction, cx)
1291 }
1292 }
1293 }
1294
1295 async fn handle_rejected_predictions(
1296 rx: UnboundedReceiver<EditPredictionRejection>,
1297 client: Arc<Client>,
1298 llm_token: LlmApiToken,
1299 app_version: Version,
1300 background_executor: BackgroundExecutor,
1301 ) {
1302 let mut rx = std::pin::pin!(rx.peekable());
1303 let mut batched = Vec::new();
1304
1305 while let Some(rejection) = rx.next().await {
1306 batched.push(rejection);
1307
1308 if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1309 select_biased! {
1310 next = rx.as_mut().peek().fuse() => {
1311 if next.is_some() {
1312 continue;
1313 }
1314 }
1315 () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1316 }
1317 }
1318
1319 let url = client
1320 .http_client()
1321 .build_zed_llm_url("/predict_edits/reject", &[])
1322 .unwrap();
1323
1324 let flush_count = batched
1325 .len()
1326 // in case items have accumulated after failure
1327 .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1328 let start = batched.len() - flush_count;
1329
1330 let body = RejectEditPredictionsBodyRef {
1331 rejections: &batched[start..],
1332 };
1333
1334 let result = Self::send_api_request::<()>(
1335 |builder| {
1336 let req = builder
1337 .uri(url.as_ref())
1338 .body(serde_json::to_string(&body)?.into());
1339 anyhow::Ok(req?)
1340 },
1341 client.clone(),
1342 llm_token.clone(),
1343 app_version.clone(),
1344 true,
1345 )
1346 .await;
1347
1348 if result.log_err().is_some() {
1349 batched.drain(start..);
1350 }
1351 }
1352 }
1353
1354 fn reject_current_prediction(
1355 &mut self,
1356 reason: EditPredictionRejectReason,
1357 project: &Entity<Project>,
1358 cx: &App,
1359 ) {
1360 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1361 project_state.pending_predictions.clear();
1362 if let Some(prediction) = project_state.current_prediction.take() {
1363 self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown, cx);
1364 }
1365 };
1366 }
1367
1368 fn did_show_current_prediction(
1369 &mut self,
1370 project: &Entity<Project>,
1371 display_type: edit_prediction_types::SuggestionDisplayType,
1372 cx: &mut Context<Self>,
1373 ) {
1374 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1375 return;
1376 };
1377
1378 let Some(current_prediction) = project_state.current_prediction.as_mut() else {
1379 return;
1380 };
1381
1382 let is_jump = display_type == edit_prediction_types::SuggestionDisplayType::Jump;
1383 let previous_shown_with = current_prediction.shown_with;
1384
1385 if previous_shown_with.is_none() || !is_jump {
1386 current_prediction.shown_with = Some(display_type);
1387 }
1388
1389 let is_first_non_jump_show = !current_prediction.was_shown && !is_jump;
1390
1391 if is_first_non_jump_show {
1392 current_prediction.was_shown = true;
1393 }
1394
1395 let display_type_changed = previous_shown_with != Some(display_type);
1396
1397 if self.edit_prediction_model == EditPredictionModel::Sweep && display_type_changed {
1398 sweep_ai::edit_prediction_shown(
1399 &self.sweep_ai,
1400 self.client.clone(),
1401 ¤t_prediction.prediction,
1402 display_type,
1403 cx,
1404 );
1405 }
1406
1407 if is_first_non_jump_show {
1408 self.shown_predictions
1409 .push_front(current_prediction.prediction.clone());
1410 if self.shown_predictions.len() > 50 {
1411 let completion = self.shown_predictions.pop_back().unwrap();
1412 self.rated_predictions.remove(&completion.id);
1413 }
1414 }
1415 }
1416
1417 fn reject_prediction(
1418 &mut self,
1419 prediction_id: EditPredictionId,
1420 reason: EditPredictionRejectReason,
1421 was_shown: bool,
1422 cx: &App,
1423 ) {
1424 match self.edit_prediction_model {
1425 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
1426 self.reject_predictions_tx
1427 .unbounded_send(EditPredictionRejection {
1428 request_id: prediction_id.to_string(),
1429 reason,
1430 was_shown,
1431 })
1432 .log_err();
1433 }
1434 EditPredictionModel::Sweep | EditPredictionModel::Ollama => {}
1435 EditPredictionModel::Mercury => {
1436 mercury::edit_prediction_rejected(
1437 prediction_id,
1438 was_shown,
1439 reason,
1440 self.client.http_client(),
1441 cx,
1442 );
1443 }
1444 }
1445 }
1446
1447 fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1448 self.projects
1449 .get(&project.entity_id())
1450 .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1451 }
1452
1453 pub fn refresh_prediction_from_buffer(
1454 &mut self,
1455 project: Entity<Project>,
1456 buffer: Entity<Buffer>,
1457 position: language::Anchor,
1458 cx: &mut Context<Self>,
1459 ) {
1460 self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
1461 let Some(request_task) = this
1462 .update(cx, |this, cx| {
1463 this.request_prediction(
1464 &project,
1465 &buffer,
1466 position,
1467 PredictEditsRequestTrigger::Other,
1468 cx,
1469 )
1470 })
1471 .log_err()
1472 else {
1473 return Task::ready(anyhow::Ok(None));
1474 };
1475
1476 cx.spawn(async move |_cx| {
1477 request_task.await.map(|prediction_result| {
1478 prediction_result.map(|prediction_result| {
1479 (
1480 prediction_result,
1481 PredictionRequestedBy::Buffer(buffer.entity_id()),
1482 )
1483 })
1484 })
1485 })
1486 })
1487 }
1488
1489 pub fn refresh_prediction_from_diagnostics(
1490 &mut self,
1491 project: Entity<Project>,
1492 cx: &mut Context<Self>,
1493 ) {
1494 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1495 return;
1496 };
1497
1498 // Prefer predictions from buffer
1499 if project_state.current_prediction.is_some() {
1500 return;
1501 };
1502
1503 self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
1504 let Some((active_buffer, snapshot, cursor_point)) = this
1505 .read_with(cx, |this, cx| {
1506 let project_state = this.projects.get(&project.entity_id())?;
1507 let (buffer, position) = project_state.active_buffer(&project, cx)?;
1508 let snapshot = buffer.read(cx).snapshot();
1509
1510 if !Self::predictions_enabled_at(&snapshot, position, cx) {
1511 return None;
1512 }
1513
1514 let cursor_point = position
1515 .map(|pos| pos.to_point(&snapshot))
1516 .unwrap_or_default();
1517
1518 Some((buffer, snapshot, cursor_point))
1519 })
1520 .log_err()
1521 .flatten()
1522 else {
1523 return Task::ready(anyhow::Ok(None));
1524 };
1525
1526 cx.spawn(async move |cx| {
1527 let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1528 active_buffer,
1529 &snapshot,
1530 Default::default(),
1531 cursor_point,
1532 &project,
1533 cx,
1534 )
1535 .await?
1536 else {
1537 return anyhow::Ok(None);
1538 };
1539
1540 let Some(prediction_result) = this
1541 .update(cx, |this, cx| {
1542 this.request_prediction(
1543 &project,
1544 &jump_buffer,
1545 jump_position,
1546 PredictEditsRequestTrigger::Diagnostics,
1547 cx,
1548 )
1549 })?
1550 .await?
1551 else {
1552 return anyhow::Ok(None);
1553 };
1554
1555 this.update(cx, |this, cx| {
1556 Some((
1557 if this
1558 .get_or_init_project(&project, cx)
1559 .current_prediction
1560 .is_none()
1561 {
1562 prediction_result
1563 } else {
1564 EditPredictionResult {
1565 id: prediction_result.id,
1566 prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1567 }
1568 },
1569 PredictionRequestedBy::DiagnosticsUpdate,
1570 ))
1571 })
1572 })
1573 });
1574 }
1575
1576 fn predictions_enabled_at(
1577 snapshot: &BufferSnapshot,
1578 position: Option<language::Anchor>,
1579 cx: &App,
1580 ) -> bool {
1581 let file = snapshot.file();
1582 let all_settings = all_language_settings(file, cx);
1583 if !all_settings.show_edit_predictions(snapshot.language(), cx)
1584 || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
1585 {
1586 return false;
1587 }
1588
1589 if let Some(last_position) = position {
1590 let settings = snapshot.settings_at(last_position, cx);
1591
1592 if !settings.edit_predictions_disabled_in.is_empty()
1593 && let Some(scope) = snapshot.language_scope_at(last_position)
1594 && let Some(scope_name) = scope.override_name()
1595 && settings
1596 .edit_predictions_disabled_in
1597 .iter()
1598 .any(|s| s == scope_name)
1599 {
1600 return false;
1601 }
1602 }
1603
1604 true
1605 }
1606
1607 #[cfg(not(test))]
1608 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1609 #[cfg(test)]
1610 pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
1611
1612 fn queue_prediction_refresh(
1613 &mut self,
1614 project: Entity<Project>,
1615 throttle_entity: EntityId,
1616 cx: &mut Context<Self>,
1617 do_refresh: impl FnOnce(
1618 WeakEntity<Self>,
1619 &mut AsyncApp,
1620 )
1621 -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1622 + 'static,
1623 ) {
1624 let is_ollama = self.edit_prediction_model == EditPredictionModel::Ollama;
1625 let drop_on_cancel = is_ollama;
1626 let max_pending_predictions = if is_ollama { 1 } else { 2 };
1627 let project_state = self.get_or_init_project(&project, cx);
1628 let pending_prediction_id = project_state.next_pending_prediction_id;
1629 project_state.next_pending_prediction_id += 1;
1630 let last_request = project_state.last_prediction_refresh;
1631
1632 let task = cx.spawn(async move |this, cx| {
1633 if let Some((last_entity, last_timestamp)) = last_request
1634 && throttle_entity == last_entity
1635 && let Some(timeout) =
1636 (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
1637 {
1638 cx.background_executor().timer(timeout).await;
1639 }
1640
1641 // If this task was cancelled before the throttle timeout expired,
1642 // do not perform a request.
1643 let mut is_cancelled = true;
1644 this.update(cx, |this, cx| {
1645 let project_state = this.get_or_init_project(&project, cx);
1646 if !project_state
1647 .cancelled_predictions
1648 .remove(&pending_prediction_id)
1649 {
1650 project_state.last_prediction_refresh = Some((throttle_entity, Instant::now()));
1651 is_cancelled = false;
1652 }
1653 })
1654 .ok();
1655 if is_cancelled {
1656 return None;
1657 }
1658
1659 let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
1660 let new_prediction_id = new_prediction_result
1661 .as_ref()
1662 .map(|(prediction, _)| prediction.id.clone());
1663
1664 // When a prediction completes, remove it from the pending list, and cancel
1665 // any pending predictions that were enqueued before it.
1666 this.update(cx, |this, cx| {
1667 let project_state = this.get_or_init_project(&project, cx);
1668
1669 let is_cancelled = project_state
1670 .cancelled_predictions
1671 .remove(&pending_prediction_id);
1672
1673 let new_current_prediction = if !is_cancelled
1674 && let Some((prediction_result, requested_by)) = new_prediction_result
1675 {
1676 match prediction_result.prediction {
1677 Ok(prediction) => {
1678 let new_prediction = CurrentEditPrediction {
1679 requested_by,
1680 prediction,
1681 was_shown: false,
1682 shown_with: None,
1683 };
1684
1685 if let Some(current_prediction) =
1686 project_state.current_prediction.as_ref()
1687 {
1688 if new_prediction.should_replace_prediction(¤t_prediction, cx)
1689 {
1690 this.reject_current_prediction(
1691 EditPredictionRejectReason::Replaced,
1692 &project,
1693 cx,
1694 );
1695
1696 Some(new_prediction)
1697 } else {
1698 this.reject_prediction(
1699 new_prediction.prediction.id,
1700 EditPredictionRejectReason::CurrentPreferred,
1701 false,
1702 cx,
1703 );
1704 None
1705 }
1706 } else {
1707 Some(new_prediction)
1708 }
1709 }
1710 Err(reject_reason) => {
1711 this.reject_prediction(prediction_result.id, reject_reason, false, cx);
1712 None
1713 }
1714 }
1715 } else {
1716 None
1717 };
1718
1719 let project_state = this.get_or_init_project(&project, cx);
1720
1721 if let Some(new_prediction) = new_current_prediction {
1722 project_state.current_prediction = Some(new_prediction);
1723 }
1724
1725 let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
1726 for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1727 if pending_prediction.id == pending_prediction_id {
1728 pending_predictions.remove(ix);
1729 for pending_prediction in pending_predictions.drain(0..ix) {
1730 project_state.cancel_pending_prediction(pending_prediction, cx)
1731 }
1732 break;
1733 }
1734 }
1735 this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
1736 cx.notify();
1737 })
1738 .ok();
1739
1740 new_prediction_id
1741 });
1742
1743 if project_state.pending_predictions.len() < max_pending_predictions {
1744 project_state.pending_predictions.push(PendingPrediction {
1745 id: pending_prediction_id,
1746 task,
1747 drop_on_cancel,
1748 });
1749 } else {
1750 let pending_prediction = project_state.pending_predictions.pop().unwrap();
1751 project_state.pending_predictions.push(PendingPrediction {
1752 id: pending_prediction_id,
1753 task,
1754 drop_on_cancel,
1755 });
1756 project_state.cancel_pending_prediction(pending_prediction, cx);
1757 }
1758 }
1759
1760 pub fn request_prediction(
1761 &mut self,
1762 project: &Entity<Project>,
1763 active_buffer: &Entity<Buffer>,
1764 position: language::Anchor,
1765 trigger: PredictEditsRequestTrigger,
1766 cx: &mut Context<Self>,
1767 ) -> Task<Result<Option<EditPredictionResult>>> {
1768 self.request_prediction_internal(
1769 project.clone(),
1770 active_buffer.clone(),
1771 position,
1772 trigger,
1773 cx.has_flag::<Zeta2FeatureFlag>(),
1774 cx,
1775 )
1776 }
1777
1778 fn request_prediction_internal(
1779 &mut self,
1780 project: Entity<Project>,
1781 active_buffer: Entity<Buffer>,
1782 position: language::Anchor,
1783 trigger: PredictEditsRequestTrigger,
1784 allow_jump: bool,
1785 cx: &mut Context<Self>,
1786 ) -> Task<Result<Option<EditPredictionResult>>> {
1787 const DIAGNOSTIC_LINES_RANGE: u32 = 20;
1788
1789 self.get_or_init_project(&project, cx);
1790 let project_state = self.projects.get(&project.entity_id()).unwrap();
1791 let stored_events = project_state.events(cx);
1792 let has_events = !stored_events.is_empty();
1793 let events: Vec<Arc<zeta_prompt::Event>> =
1794 stored_events.into_iter().map(|e| e.event).collect();
1795 let debug_tx = project_state.debug_tx.clone();
1796
1797 let snapshot = active_buffer.read(cx).snapshot();
1798 let cursor_point = position.to_point(&snapshot);
1799 let current_offset = position.to_offset(&snapshot);
1800
1801 let mut user_actions: Vec<UserActionRecord> =
1802 project_state.user_actions.iter().cloned().collect();
1803
1804 if let Some(last_action) = user_actions.last() {
1805 if last_action.buffer_id == active_buffer.entity_id()
1806 && current_offset != last_action.offset
1807 {
1808 let timestamp_epoch_ms = SystemTime::now()
1809 .duration_since(UNIX_EPOCH)
1810 .map(|d| d.as_millis() as u64)
1811 .unwrap_or(0);
1812 user_actions.push(UserActionRecord {
1813 action_type: UserActionType::CursorMovement,
1814 buffer_id: active_buffer.entity_id(),
1815 line_number: cursor_point.row,
1816 offset: current_offset,
1817 timestamp_epoch_ms,
1818 });
1819 }
1820 }
1821 let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1822 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1823 let diagnostic_search_range =
1824 Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1825
1826 let related_files = self.context_for_project(&project, cx);
1827
1828 let inputs = EditPredictionModelInput {
1829 project: project.clone(),
1830 buffer: active_buffer.clone(),
1831 snapshot: snapshot.clone(),
1832 position,
1833 events,
1834 related_files,
1835 recent_paths: project_state.recent_paths.clone(),
1836 trigger,
1837 diagnostic_search_range: diagnostic_search_range.clone(),
1838 debug_tx,
1839 user_actions,
1840 };
1841
1842 let task = match &self.edit_prediction_model {
1843 EditPredictionModel::Zeta1 => zeta2::request_prediction_with_zeta2(
1844 self,
1845 inputs,
1846 Some(zeta_prompt::EditPredictionModelKind::Zeta1),
1847 cx,
1848 ),
1849 EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2(
1850 self,
1851 inputs,
1852 Some(zeta_prompt::EditPredictionModelKind::Zeta2),
1853 cx,
1854 ),
1855 EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
1856 EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
1857 EditPredictionModel::Ollama => self.ollama.request_prediction(inputs, cx),
1858 };
1859
1860 cx.spawn(async move |this, cx| {
1861 let prediction = task.await?;
1862
1863 if prediction.is_none() && allow_jump {
1864 let cursor_point = position.to_point(&snapshot);
1865 if has_events
1866 && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1867 active_buffer.clone(),
1868 &snapshot,
1869 diagnostic_search_range,
1870 cursor_point,
1871 &project,
1872 cx,
1873 )
1874 .await?
1875 {
1876 return this
1877 .update(cx, |this, cx| {
1878 this.request_prediction_internal(
1879 project,
1880 jump_buffer,
1881 jump_position,
1882 trigger,
1883 false,
1884 cx,
1885 )
1886 })?
1887 .await;
1888 }
1889
1890 return anyhow::Ok(None);
1891 }
1892
1893 Ok(prediction)
1894 })
1895 }
1896
1897 async fn next_diagnostic_location(
1898 active_buffer: Entity<Buffer>,
1899 active_buffer_snapshot: &BufferSnapshot,
1900 active_buffer_diagnostic_search_range: Range<Point>,
1901 active_buffer_cursor_point: Point,
1902 project: &Entity<Project>,
1903 cx: &mut AsyncApp,
1904 ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1905 // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
1906 let mut jump_location = active_buffer_snapshot
1907 .diagnostic_groups(None)
1908 .into_iter()
1909 .filter_map(|(_, group)| {
1910 let range = &group.entries[group.primary_ix]
1911 .range
1912 .to_point(&active_buffer_snapshot);
1913 if range.overlaps(&active_buffer_diagnostic_search_range) {
1914 None
1915 } else {
1916 Some(range.start)
1917 }
1918 })
1919 .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
1920 .map(|position| {
1921 (
1922 active_buffer.clone(),
1923 active_buffer_snapshot.anchor_before(position),
1924 )
1925 });
1926
1927 if jump_location.is_none() {
1928 let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
1929 let file = buffer.file()?;
1930
1931 Some(ProjectPath {
1932 worktree_id: file.worktree_id(cx),
1933 path: file.path().clone(),
1934 })
1935 });
1936
1937 let buffer_task = project.update(cx, |project, cx| {
1938 let (path, _, _) = project
1939 .diagnostic_summaries(false, cx)
1940 .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
1941 .max_by_key(|(path, _, _)| {
1942 // find the buffer with errors that shares most parent directories
1943 path.path
1944 .components()
1945 .zip(
1946 active_buffer_path
1947 .as_ref()
1948 .map(|p| p.path.components())
1949 .unwrap_or_default(),
1950 )
1951 .take_while(|(a, b)| a == b)
1952 .count()
1953 })?;
1954
1955 Some(project.open_buffer(path, cx))
1956 });
1957
1958 if let Some(buffer_task) = buffer_task {
1959 let closest_buffer = buffer_task.await?;
1960
1961 jump_location = closest_buffer
1962 .read_with(cx, |buffer, _cx| {
1963 buffer
1964 .buffer_diagnostics(None)
1965 .into_iter()
1966 .min_by_key(|entry| entry.diagnostic.severity)
1967 .map(|entry| entry.range.start)
1968 })
1969 .map(|position| (closest_buffer, position));
1970 }
1971 }
1972
1973 anyhow::Ok(jump_location)
1974 }
1975
1976 async fn send_raw_llm_request(
1977 request: RawCompletionRequest,
1978 client: Arc<Client>,
1979 custom_url: Option<Arc<Url>>,
1980 llm_token: LlmApiToken,
1981 app_version: Version,
1982 ) -> Result<(RawCompletionResponse, Option<EditPredictionUsage>)> {
1983 let url = if let Some(custom_url) = custom_url {
1984 custom_url.as_ref().clone()
1985 } else {
1986 client
1987 .http_client()
1988 .build_zed_llm_url("/predict_edits/raw", &[])?
1989 };
1990
1991 Self::send_api_request(
1992 |builder| {
1993 let req = builder
1994 .uri(url.as_ref())
1995 .body(serde_json::to_string(&request)?.into());
1996 Ok(req?)
1997 },
1998 client,
1999 llm_token,
2000 app_version,
2001 true,
2002 )
2003 .await
2004 }
2005
2006 pub(crate) async fn send_v3_request(
2007 input: ZetaPromptInput,
2008 client: Arc<Client>,
2009 llm_token: LlmApiToken,
2010 app_version: Version,
2011 trigger: PredictEditsRequestTrigger,
2012 ) -> Result<(PredictEditsV3Response, Option<EditPredictionUsage>)> {
2013 let url = client
2014 .http_client()
2015 .build_zed_llm_url("/predict_edits/v3", &[])?;
2016
2017 let request = PredictEditsV3Request { input, trigger };
2018
2019 let json_bytes = serde_json::to_vec(&request)?;
2020 let compressed = zstd::encode_all(&json_bytes[..], 3)?;
2021
2022 Self::send_api_request(
2023 |builder| {
2024 let req = builder
2025 .uri(url.as_ref())
2026 .header("Content-Encoding", "zstd")
2027 .body(compressed.clone().into());
2028 Ok(req?)
2029 },
2030 client,
2031 llm_token,
2032 app_version,
2033 true,
2034 )
2035 .await
2036 }
2037
2038 fn handle_api_response<T>(
2039 this: &WeakEntity<Self>,
2040 response: Result<(T, Option<EditPredictionUsage>)>,
2041 cx: &mut gpui::AsyncApp,
2042 ) -> Result<T> {
2043 match response {
2044 Ok((data, usage)) => {
2045 if let Some(usage) = usage {
2046 this.update(cx, |this, cx| {
2047 this.user_store.update(cx, |user_store, cx| {
2048 user_store.update_edit_prediction_usage(usage, cx);
2049 });
2050 })
2051 .ok();
2052 }
2053 Ok(data)
2054 }
2055 Err(err) => {
2056 if err.is::<ZedUpdateRequiredError>() {
2057 cx.update(|cx| {
2058 this.update(cx, |this, _cx| {
2059 this.update_required = true;
2060 })
2061 .ok();
2062
2063 let error_message: SharedString = err.to_string().into();
2064 show_app_notification(
2065 NotificationId::unique::<ZedUpdateRequiredError>(),
2066 cx,
2067 move |cx| {
2068 cx.new(|cx| {
2069 ErrorMessagePrompt::new(error_message.clone(), cx)
2070 .with_link_button("Update Zed", "https://zed.dev/releases")
2071 })
2072 },
2073 );
2074 });
2075 }
2076 Err(err)
2077 }
2078 }
2079 }
2080
2081 async fn send_api_request<Res>(
2082 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
2083 client: Arc<Client>,
2084 llm_token: LlmApiToken,
2085 app_version: Version,
2086 require_auth: bool,
2087 ) -> Result<(Res, Option<EditPredictionUsage>)>
2088 where
2089 Res: DeserializeOwned,
2090 {
2091 let http_client = client.http_client();
2092
2093 let mut token = if let Ok(custom_token) = std::env::var("ZED_PREDICT_EDITS_TOKEN") {
2094 Some(custom_token)
2095 } else if require_auth {
2096 Some(llm_token.acquire(&client).await?)
2097 } else {
2098 llm_token.acquire(&client).await.ok()
2099 };
2100 let mut did_retry = false;
2101
2102 loop {
2103 let request_builder = http_client::Request::builder().method(Method::POST);
2104
2105 let mut request_builder = request_builder
2106 .header("Content-Type", "application/json")
2107 .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
2108
2109 // Only add Authorization header if we have a token
2110 if let Some(ref token_value) = token {
2111 request_builder =
2112 request_builder.header("Authorization", format!("Bearer {}", token_value));
2113 }
2114
2115 let request = build(request_builder)?;
2116
2117 let mut response = http_client.send(request).await?;
2118
2119 if let Some(minimum_required_version) = response
2120 .headers()
2121 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2122 .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2123 {
2124 anyhow::ensure!(
2125 app_version >= minimum_required_version,
2126 ZedUpdateRequiredError {
2127 minimum_version: minimum_required_version
2128 }
2129 );
2130 }
2131
2132 if response.status().is_success() {
2133 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2134
2135 let mut body = Vec::new();
2136 response.body_mut().read_to_end(&mut body).await?;
2137 return Ok((serde_json::from_slice(&body)?, usage));
2138 } else if !did_retry && token.is_some() && response.needs_llm_token_refresh() {
2139 did_retry = true;
2140 token = Some(llm_token.refresh(&client).await?);
2141 } else {
2142 let mut body = String::new();
2143 response.body_mut().read_to_string(&mut body).await?;
2144 anyhow::bail!(
2145 "Request failed with status: {:?}\nBody: {}",
2146 response.status(),
2147 body
2148 );
2149 }
2150 }
2151 }
2152
2153 pub fn refresh_context(
2154 &mut self,
2155 project: &Entity<Project>,
2156 buffer: &Entity<language::Buffer>,
2157 cursor_position: language::Anchor,
2158 cx: &mut Context<Self>,
2159 ) {
2160 self.get_or_init_project(project, cx)
2161 .context
2162 .update(cx, |store, cx| {
2163 store.refresh(buffer.clone(), cursor_position, cx);
2164 });
2165 }
2166
2167 #[cfg(feature = "cli-support")]
2168 pub fn set_context_for_buffer(
2169 &mut self,
2170 project: &Entity<Project>,
2171 related_files: Vec<RelatedFile>,
2172 cx: &mut Context<Self>,
2173 ) {
2174 self.get_or_init_project(project, cx)
2175 .context
2176 .update(cx, |store, cx| {
2177 store.set_related_files(related_files, cx);
2178 });
2179 }
2180
2181 #[cfg(feature = "cli-support")]
2182 pub fn set_recent_paths_for_project(
2183 &mut self,
2184 project: &Entity<Project>,
2185 paths: impl IntoIterator<Item = project::ProjectPath>,
2186 cx: &mut Context<Self>,
2187 ) {
2188 let project_state = self.get_or_init_project(project, cx);
2189 project_state.recent_paths = paths.into_iter().collect();
2190 }
2191
2192 fn is_file_open_source(
2193 &self,
2194 project: &Entity<Project>,
2195 file: &Arc<dyn File>,
2196 cx: &App,
2197 ) -> bool {
2198 if !file.is_local() || file.is_private() {
2199 return false;
2200 }
2201 let Some(project_state) = self.projects.get(&project.entity_id()) else {
2202 return false;
2203 };
2204 project_state
2205 .license_detection_watchers
2206 .get(&file.worktree_id(cx))
2207 .as_ref()
2208 .is_some_and(|watcher| watcher.is_project_open_source())
2209 }
2210
2211 pub(crate) fn is_data_collection_enabled(&self, cx: &App) -> bool {
2212 self.data_collection_choice.is_enabled(cx)
2213 }
2214
2215 fn load_data_collection_choice() -> DataCollectionChoice {
2216 let choice = KEY_VALUE_STORE
2217 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2218 .log_err()
2219 .flatten();
2220
2221 match choice.as_deref() {
2222 Some("true") => DataCollectionChoice::Enabled,
2223 Some("false") => DataCollectionChoice::Disabled,
2224 Some(_) => {
2225 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2226 DataCollectionChoice::NotAnswered
2227 }
2228 None => DataCollectionChoice::NotAnswered,
2229 }
2230 }
2231
2232 fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
2233 self.data_collection_choice = self.data_collection_choice.toggle();
2234 let new_choice = self.data_collection_choice;
2235 let is_enabled = new_choice.is_enabled(cx);
2236 db::write_and_log(cx, move || {
2237 KEY_VALUE_STORE.write_kvp(
2238 ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
2239 is_enabled.to_string(),
2240 )
2241 });
2242 }
2243
2244 pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2245 self.shown_predictions.iter()
2246 }
2247
2248 pub fn shown_completions_len(&self) -> usize {
2249 self.shown_predictions.len()
2250 }
2251
2252 pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2253 self.rated_predictions.contains(id)
2254 }
2255
2256 pub fn rate_prediction(
2257 &mut self,
2258 prediction: &EditPrediction,
2259 rating: EditPredictionRating,
2260 feedback: String,
2261 cx: &mut Context<Self>,
2262 ) {
2263 self.rated_predictions.insert(prediction.id.clone());
2264 telemetry::event!(
2265 "Edit Prediction Rated",
2266 request_id = prediction.id.to_string(),
2267 rating,
2268 inputs = prediction.inputs,
2269 output = prediction
2270 .edit_preview
2271 .as_unified_diff(prediction.snapshot.file(), &prediction.edits),
2272 feedback
2273 );
2274 self.client.telemetry().flush_events().detach();
2275 cx.notify();
2276 }
2277}
2278
2279fn merge_trailing_events_if_needed(
2280 events: &mut VecDeque<StoredEvent>,
2281 end_snapshot: &TextBufferSnapshot,
2282 latest_snapshot: &TextBufferSnapshot,
2283 latest_edit_range: &Range<Anchor>,
2284) {
2285 if let Some(last_event) = events.back() {
2286 if last_event.old_snapshot.remote_id() != latest_snapshot.remote_id() {
2287 return;
2288 }
2289 }
2290
2291 let mut next_old_event = None;
2292 let mut mergeable_count = 0;
2293 for old_event in events.iter().rev() {
2294 if let Some(next_old_event) = &next_old_event
2295 && !old_event.can_merge(&next_old_event, latest_snapshot, latest_edit_range)
2296 {
2297 break;
2298 }
2299 mergeable_count += 1;
2300 next_old_event = Some(old_event);
2301 }
2302
2303 if mergeable_count <= 1 {
2304 return;
2305 }
2306
2307 let mut events_to_merge = events.range(events.len() - mergeable_count..).peekable();
2308 let oldest_event = events_to_merge.peek().unwrap();
2309 let oldest_snapshot = oldest_event.old_snapshot.clone();
2310
2311 if let Some((diff, edited_range)) =
2312 compute_diff_between_snapshots(&oldest_snapshot, end_snapshot)
2313 {
2314 let merged_event = match oldest_event.event.as_ref() {
2315 zeta_prompt::Event::BufferChange {
2316 old_path,
2317 path,
2318 in_open_source_repo,
2319 ..
2320 } => StoredEvent {
2321 event: Arc::new(zeta_prompt::Event::BufferChange {
2322 old_path: old_path.clone(),
2323 path: path.clone(),
2324 diff,
2325 in_open_source_repo: *in_open_source_repo,
2326 predicted: events_to_merge.all(|e| {
2327 matches!(
2328 e.event.as_ref(),
2329 zeta_prompt::Event::BufferChange {
2330 predicted: true,
2331 ..
2332 }
2333 )
2334 }),
2335 }),
2336 old_snapshot: oldest_snapshot.clone(),
2337 edit_range: end_snapshot.anchor_before(edited_range.start)
2338 ..end_snapshot.anchor_before(edited_range.end),
2339 },
2340 };
2341 events.truncate(events.len() - mergeable_count);
2342 events.push_back(merged_event);
2343 }
2344}
2345
2346pub(crate) fn filter_redundant_excerpts(
2347 mut related_files: Vec<RelatedFile>,
2348 cursor_path: &Path,
2349 cursor_row_range: Range<u32>,
2350) -> Vec<RelatedFile> {
2351 for file in &mut related_files {
2352 if file.path.as_ref() == cursor_path {
2353 file.excerpts.retain(|excerpt| {
2354 excerpt.row_range.start < cursor_row_range.start
2355 || excerpt.row_range.end > cursor_row_range.end
2356 });
2357 }
2358 }
2359 related_files.retain(|file| !file.excerpts.is_empty());
2360 related_files
2361}
2362
2363#[derive(Error, Debug)]
2364#[error(
2365 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2366)]
2367pub struct ZedUpdateRequiredError {
2368 minimum_version: Version,
2369}
2370
2371#[derive(Debug, Clone, Copy)]
2372pub enum DataCollectionChoice {
2373 NotAnswered,
2374 Enabled,
2375 Disabled,
2376}
2377
2378impl DataCollectionChoice {
2379 pub fn is_enabled(self, cx: &App) -> bool {
2380 if cx.is_staff() {
2381 return true;
2382 }
2383 match self {
2384 Self::Enabled => true,
2385 Self::NotAnswered | Self::Disabled => false,
2386 }
2387 }
2388
2389 #[must_use]
2390 pub fn toggle(&self) -> DataCollectionChoice {
2391 match self {
2392 Self::Enabled => Self::Disabled,
2393 Self::Disabled => Self::Enabled,
2394 Self::NotAnswered => Self::Enabled,
2395 }
2396 }
2397}
2398
2399impl From<bool> for DataCollectionChoice {
2400 fn from(value: bool) -> Self {
2401 match value {
2402 true => DataCollectionChoice::Enabled,
2403 false => DataCollectionChoice::Disabled,
2404 }
2405 }
2406}
2407
2408struct ZedPredictUpsell;
2409
2410impl Dismissable for ZedPredictUpsell {
2411 const KEY: &'static str = "dismissed-edit-predict-upsell";
2412
2413 fn dismissed() -> bool {
2414 // To make this backwards compatible with older versions of Zed, we
2415 // check if the user has seen the previous Edit Prediction Onboarding
2416 // before, by checking the data collection choice which was written to
2417 // the database once the user clicked on "Accept and Enable"
2418 if KEY_VALUE_STORE
2419 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2420 .log_err()
2421 .is_some_and(|s| s.is_some())
2422 {
2423 return true;
2424 }
2425
2426 KEY_VALUE_STORE
2427 .read_kvp(Self::KEY)
2428 .log_err()
2429 .is_some_and(|s| s.is_some())
2430 }
2431}
2432
2433pub fn should_show_upsell_modal() -> bool {
2434 !ZedPredictUpsell::dismissed()
2435}
2436
2437pub fn init(cx: &mut App) {
2438 cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2439 workspace.register_action(
2440 move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2441 ZedPredictModal::toggle(
2442 workspace,
2443 workspace.user_store().clone(),
2444 workspace.client().clone(),
2445 window,
2446 cx,
2447 )
2448 },
2449 );
2450
2451 workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2452 update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2453 settings
2454 .project
2455 .all_languages
2456 .edit_predictions
2457 .get_or_insert_default()
2458 .provider = Some(EditPredictionProvider::None)
2459 });
2460 });
2461 fn copilot_for_project(project: &Entity<Project>, cx: &mut App) -> Option<Entity<Copilot>> {
2462 EditPredictionStore::try_global(cx).and_then(|store| {
2463 store.update(cx, |this, cx| this.start_copilot_for_project(project, cx))
2464 })
2465 }
2466
2467 workspace.register_action(|workspace, _: &SignIn, window, cx| {
2468 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2469 copilot_ui::initiate_sign_in(copilot, window, cx);
2470 }
2471 });
2472 workspace.register_action(|workspace, _: &Reinstall, window, cx| {
2473 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2474 copilot_ui::reinstall_and_sign_in(copilot, window, cx);
2475 }
2476 });
2477 workspace.register_action(|workspace, _: &SignOut, window, cx| {
2478 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2479 copilot_ui::initiate_sign_out(copilot, window, cx);
2480 }
2481 });
2482 })
2483 .detach();
2484}