1use anyhow::Result;
2use arrayvec::ArrayVec;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{
5 PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
6};
7use cloud_llm_client::{
8 EditPredictionRejectReason, EditPredictionRejection,
9 MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
10 PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME,
11};
12use collections::{HashMap, HashSet};
13use copilot::{Copilot, Reinstall, SignIn, SignOut};
14use db::kvp::{Dismissable, KEY_VALUE_STORE};
15use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
16use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
17use futures::{
18 AsyncReadExt as _, FutureExt as _, StreamExt as _,
19 channel::mpsc::{self, UnboundedReceiver},
20 select_biased,
21};
22use gpui::BackgroundExecutor;
23use gpui::http_client::Url;
24use gpui::{
25 App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
26 http_client::{self, AsyncBody, Method},
27 prelude::*,
28};
29use language::language_settings::all_language_settings;
30use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
31use language::{BufferSnapshot, OffsetRangeExt};
32use language_model::{LlmApiToken, NeedsLlmTokenRefresh, RefreshLlmTokenListener};
33use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
34use release_channel::AppVersion;
35use semver::Version;
36use serde::de::DeserializeOwned;
37use settings::{EditPredictionProvider, Settings as _, update_settings_file};
38use std::collections::{VecDeque, hash_map};
39use std::env;
40use text::Edit;
41use workspace::Workspace;
42use zeta_prompt::{ZetaFormat, ZetaPromptInput};
43
44use std::mem;
45use std::ops::Range;
46use std::path::Path;
47use std::rc::Rc;
48use std::str::FromStr as _;
49use std::sync::Arc;
50use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
51use thiserror::Error;
52use util::{RangeExt as _, ResultExt as _};
53use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
54
55pub mod cursor_excerpt;
56pub mod example_spec;
57mod license_detection;
58pub mod mercury;
59pub mod ollama;
60mod onboarding_modal;
61pub mod open_ai_response;
62mod prediction;
63pub mod sweep_ai;
64
65pub mod udiff;
66
67mod capture_example;
68mod zed_edit_prediction_delegate;
69pub mod zeta1;
70pub mod zeta2;
71
72#[cfg(test)]
73mod edit_prediction_tests;
74
75use crate::capture_example::should_send_testing_zeta2_request;
76use crate::license_detection::LicenseDetectionWatcher;
77use crate::mercury::Mercury;
78use crate::ollama::Ollama;
79use crate::onboarding_modal::ZedPredictModal;
80pub use crate::prediction::EditPrediction;
81pub use crate::prediction::EditPredictionId;
82use crate::prediction::EditPredictionResult;
83pub use crate::sweep_ai::SweepAi;
84pub use capture_example::capture_example;
85pub use language_model::ApiKeyState;
86pub use telemetry_events::EditPredictionRating;
87pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
88
89actions!(
90 edit_prediction,
91 [
92 /// Resets the edit prediction onboarding state.
93 ResetOnboarding,
94 /// Clears the edit prediction history.
95 ClearHistory,
96 ]
97);
98
99/// Maximum number of events to track.
100const EVENT_COUNT_MAX: usize = 6;
101const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
102const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
103const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
104const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
105
106pub struct Zeta2FeatureFlag;
107
108impl FeatureFlag for Zeta2FeatureFlag {
109 const NAME: &'static str = "zeta2";
110
111 fn enabled_for_staff() -> bool {
112 true
113 }
114}
115
116#[derive(Clone)]
117struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
118
119impl Global for EditPredictionStoreGlobal {}
120
121/// Configuration for using the raw Zeta2 endpoint.
122/// When set, the client uses the raw endpoint and constructs the prompt itself.
123/// The version is also used as the Baseten environment name (lowercased).
124#[derive(Clone)]
125pub struct Zeta2RawConfig {
126 pub model_id: Option<String>,
127 pub format: ZetaFormat,
128}
129
130pub struct EditPredictionStore {
131 client: Arc<Client>,
132 user_store: Entity<UserStore>,
133 llm_token: LlmApiToken,
134 _llm_token_subscription: Subscription,
135 projects: HashMap<EntityId, ProjectState>,
136 update_required: bool,
137 edit_prediction_model: EditPredictionModel,
138 zeta2_raw_config: Option<Zeta2RawConfig>,
139 pub sweep_ai: SweepAi,
140 pub mercury: Mercury,
141 pub ollama: Ollama,
142 data_collection_choice: DataCollectionChoice,
143 reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
144 shown_predictions: VecDeque<EditPrediction>,
145 rated_predictions: HashSet<EditPredictionId>,
146}
147
148#[derive(Copy, Clone, Default, PartialEq, Eq)]
149pub enum EditPredictionModel {
150 #[default]
151 Zeta1,
152 Zeta2,
153 Sweep,
154 Mercury,
155 Ollama,
156}
157
158#[derive(Clone)]
159pub struct EditPredictionModelInput {
160 project: Entity<Project>,
161 buffer: Entity<Buffer>,
162 snapshot: BufferSnapshot,
163 position: Anchor,
164 events: Vec<Arc<zeta_prompt::Event>>,
165 related_files: Vec<RelatedFile>,
166 recent_paths: VecDeque<ProjectPath>,
167 trigger: PredictEditsRequestTrigger,
168 diagnostic_search_range: Range<Point>,
169 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
170 pub user_actions: Vec<UserActionRecord>,
171}
172
173#[derive(Debug)]
174pub enum DebugEvent {
175 ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
176 ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
177 EditPredictionStarted(EditPredictionStartedDebugEvent),
178 EditPredictionFinished(EditPredictionFinishedDebugEvent),
179}
180
181#[derive(Debug)]
182pub struct ContextRetrievalStartedDebugEvent {
183 pub project_entity_id: EntityId,
184 pub timestamp: Instant,
185 pub search_prompt: String,
186}
187
188#[derive(Debug)]
189pub struct ContextRetrievalFinishedDebugEvent {
190 pub project_entity_id: EntityId,
191 pub timestamp: Instant,
192 pub metadata: Vec<(&'static str, SharedString)>,
193}
194
195#[derive(Debug)]
196pub struct EditPredictionStartedDebugEvent {
197 pub buffer: WeakEntity<Buffer>,
198 pub position: Anchor,
199 pub prompt: Option<String>,
200}
201
202#[derive(Debug)]
203pub struct EditPredictionFinishedDebugEvent {
204 pub buffer: WeakEntity<Buffer>,
205 pub position: Anchor,
206 pub model_output: Option<String>,
207}
208
209const USER_ACTION_HISTORY_SIZE: usize = 16;
210
211#[derive(Clone, Debug)]
212pub struct UserActionRecord {
213 pub action_type: UserActionType,
214 pub buffer_id: EntityId,
215 pub line_number: u32,
216 pub offset: usize,
217 pub timestamp_epoch_ms: u64,
218}
219
220#[derive(Clone, Copy, Debug, PartialEq, Eq)]
221pub enum UserActionType {
222 InsertChar,
223 InsertSelection,
224 DeleteChar,
225 DeleteSelection,
226 CursorMovement,
227}
228
229/// An event with associated metadata for reconstructing buffer state.
230#[derive(Clone)]
231pub struct StoredEvent {
232 pub event: Arc<zeta_prompt::Event>,
233 pub old_snapshot: TextBufferSnapshot,
234 pub edit_range: Range<Anchor>,
235}
236
237impl StoredEvent {
238 fn can_merge(
239 &self,
240 next_old_event: &&&StoredEvent,
241 new_snapshot: &TextBufferSnapshot,
242 last_edit_range: &Range<Anchor>,
243 ) -> bool {
244 // Events must be for the same buffer
245 if self.old_snapshot.remote_id() != next_old_event.old_snapshot.remote_id() {
246 return false;
247 }
248 if self.old_snapshot.remote_id() != new_snapshot.remote_id() {
249 return false;
250 }
251
252 let a_is_predicted = matches!(
253 self.event.as_ref(),
254 zeta_prompt::Event::BufferChange {
255 predicted: true,
256 ..
257 }
258 );
259 let b_is_predicted = matches!(
260 next_old_event.event.as_ref(),
261 zeta_prompt::Event::BufferChange {
262 predicted: true,
263 ..
264 }
265 );
266
267 // If events come from the same source (both predicted or both manual) then
268 // we would have coalesced them already.
269 if a_is_predicted == b_is_predicted {
270 return false;
271 }
272
273 let left_range = self.edit_range.to_point(new_snapshot);
274 let right_range = next_old_event.edit_range.to_point(new_snapshot);
275 let latest_range = last_edit_range.to_point(&new_snapshot);
276
277 // Events near to the latest edit are not merged if their sources differ.
278 if lines_between_ranges(&left_range, &latest_range)
279 .min(lines_between_ranges(&right_range, &latest_range))
280 <= CHANGE_GROUPING_LINE_SPAN
281 {
282 return false;
283 }
284
285 // Events that are distant from each other are not merged.
286 if lines_between_ranges(&left_range, &right_range) > CHANGE_GROUPING_LINE_SPAN {
287 return false;
288 }
289
290 true
291 }
292}
293
294fn lines_between_ranges(left: &Range<Point>, right: &Range<Point>) -> u32 {
295 if left.start > right.end {
296 return left.start.row - right.end.row;
297 }
298 if right.start > left.end {
299 return right.start.row - left.end.row;
300 }
301 0
302}
303
304struct ProjectState {
305 events: VecDeque<StoredEvent>,
306 last_event: Option<LastEvent>,
307 recent_paths: VecDeque<ProjectPath>,
308 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
309 current_prediction: Option<CurrentEditPrediction>,
310 next_pending_prediction_id: usize,
311 pending_predictions: ArrayVec<PendingPrediction, 2>,
312 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
313 last_prediction_refresh: Option<(EntityId, Instant)>,
314 cancelled_predictions: HashSet<usize>,
315 context: Entity<RelatedExcerptStore>,
316 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
317 user_actions: VecDeque<UserActionRecord>,
318 _subscriptions: [gpui::Subscription; 2],
319 copilot: Option<Entity<Copilot>>,
320}
321
322impl ProjectState {
323 fn record_user_action(&mut self, action: UserActionRecord) {
324 if self.user_actions.len() >= USER_ACTION_HISTORY_SIZE {
325 self.user_actions.pop_front();
326 }
327 self.user_actions.push_back(action);
328 }
329
330 pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
331 self.events
332 .iter()
333 .cloned()
334 .chain(self.last_event.as_ref().iter().flat_map(|event| {
335 let (one, two) = event.split_by_pause();
336 let one = one.finalize(&self.license_detection_watchers, cx);
337 let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
338 one.into_iter().chain(two)
339 }))
340 .collect()
341 }
342
343 fn cancel_pending_prediction(
344 &mut self,
345 pending_prediction: PendingPrediction,
346 cx: &mut Context<EditPredictionStore>,
347 ) {
348 self.cancelled_predictions.insert(pending_prediction.id);
349
350 if pending_prediction.drop_on_cancel {
351 drop(pending_prediction.task);
352 } else {
353 cx.spawn(async move |this, cx| {
354 let Some(prediction_id) = pending_prediction.task.await else {
355 return;
356 };
357
358 this.update(cx, |this, cx| {
359 this.reject_prediction(
360 prediction_id,
361 EditPredictionRejectReason::Canceled,
362 false,
363 cx,
364 );
365 })
366 .ok();
367 })
368 .detach()
369 }
370 }
371
372 fn active_buffer(
373 &self,
374 project: &Entity<Project>,
375 cx: &App,
376 ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
377 let project = project.read(cx);
378 let active_path = project.path_for_entry(project.active_entry()?, cx)?;
379 let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
380 let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
381 Some((active_buffer, registered_buffer.last_position))
382 }
383}
384
385#[derive(Debug, Clone)]
386struct CurrentEditPrediction {
387 pub requested_by: PredictionRequestedBy,
388 pub prediction: EditPrediction,
389 pub was_shown: bool,
390 pub shown_with: Option<edit_prediction_types::SuggestionDisplayType>,
391}
392
393impl CurrentEditPrediction {
394 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
395 let Some(new_edits) = self
396 .prediction
397 .interpolate(&self.prediction.buffer.read(cx))
398 else {
399 return false;
400 };
401
402 if self.prediction.buffer != old_prediction.prediction.buffer {
403 return true;
404 }
405
406 let Some(old_edits) = old_prediction
407 .prediction
408 .interpolate(&old_prediction.prediction.buffer.read(cx))
409 else {
410 return true;
411 };
412
413 let requested_by_buffer_id = self.requested_by.buffer_id();
414
415 // This reduces the occurrence of UI thrash from replacing edits
416 //
417 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
418 if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
419 && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
420 && old_edits.len() == 1
421 && new_edits.len() == 1
422 {
423 let (old_range, old_text) = &old_edits[0];
424 let (new_range, new_text) = &new_edits[0];
425 new_range == old_range && new_text.starts_with(old_text.as_ref())
426 } else {
427 true
428 }
429 }
430}
431
432#[derive(Debug, Clone)]
433enum PredictionRequestedBy {
434 DiagnosticsUpdate,
435 Buffer(EntityId),
436}
437
438impl PredictionRequestedBy {
439 pub fn buffer_id(&self) -> Option<EntityId> {
440 match self {
441 PredictionRequestedBy::DiagnosticsUpdate => None,
442 PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
443 }
444 }
445}
446
447#[derive(Debug)]
448struct PendingPrediction {
449 id: usize,
450 task: Task<Option<EditPredictionId>>,
451 /// If true, the task is dropped immediately on cancel (cancelling the HTTP request).
452 /// If false, the task is awaited to completion so rejection can be reported.
453 drop_on_cancel: bool,
454}
455
456/// A prediction from the perspective of a buffer.
457#[derive(Debug)]
458enum BufferEditPrediction<'a> {
459 Local { prediction: &'a EditPrediction },
460 Jump { prediction: &'a EditPrediction },
461}
462
463#[cfg(test)]
464impl std::ops::Deref for BufferEditPrediction<'_> {
465 type Target = EditPrediction;
466
467 fn deref(&self) -> &Self::Target {
468 match self {
469 BufferEditPrediction::Local { prediction } => prediction,
470 BufferEditPrediction::Jump { prediction } => prediction,
471 }
472 }
473}
474
475struct RegisteredBuffer {
476 file: Option<Arc<dyn File>>,
477 snapshot: TextBufferSnapshot,
478 last_position: Option<Anchor>,
479 _subscriptions: [gpui::Subscription; 2],
480}
481
482#[derive(Clone)]
483struct LastEvent {
484 old_snapshot: TextBufferSnapshot,
485 new_snapshot: TextBufferSnapshot,
486 old_file: Option<Arc<dyn File>>,
487 new_file: Option<Arc<dyn File>>,
488 edit_range: Option<Range<Anchor>>,
489 predicted: bool,
490 snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
491 last_edit_time: Option<Instant>,
492}
493
494impl LastEvent {
495 pub fn finalize(
496 &self,
497 license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
498 cx: &App,
499 ) -> Option<StoredEvent> {
500 let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
501 let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
502
503 let in_open_source_repo =
504 [self.new_file.as_ref(), self.old_file.as_ref()]
505 .iter()
506 .all(|file| {
507 file.is_some_and(|file| {
508 license_detection_watchers
509 .get(&file.worktree_id(cx))
510 .is_some_and(|watcher| watcher.is_project_open_source())
511 })
512 });
513
514 let (diff, edit_range) =
515 compute_diff_between_snapshots(&self.old_snapshot, &self.new_snapshot)?;
516
517 if path == old_path && diff.is_empty() {
518 None
519 } else {
520 Some(StoredEvent {
521 event: Arc::new(zeta_prompt::Event::BufferChange {
522 old_path,
523 path,
524 diff,
525 in_open_source_repo,
526 predicted: self.predicted,
527 }),
528 edit_range: self.new_snapshot.anchor_before(edit_range.start)
529 ..self.new_snapshot.anchor_before(edit_range.end),
530 old_snapshot: self.old_snapshot.clone(),
531 })
532 }
533 }
534
535 pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
536 let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
537 return (self.clone(), None);
538 };
539
540 let before = LastEvent {
541 old_snapshot: self.old_snapshot.clone(),
542 new_snapshot: boundary_snapshot.clone(),
543 old_file: self.old_file.clone(),
544 new_file: self.new_file.clone(),
545 edit_range: None,
546 predicted: self.predicted,
547 snapshot_after_last_editing_pause: None,
548 last_edit_time: self.last_edit_time,
549 };
550
551 let after = LastEvent {
552 old_snapshot: boundary_snapshot.clone(),
553 new_snapshot: self.new_snapshot.clone(),
554 old_file: self.old_file.clone(),
555 new_file: self.new_file.clone(),
556 edit_range: None,
557 predicted: self.predicted,
558 snapshot_after_last_editing_pause: None,
559 last_edit_time: self.last_edit_time,
560 };
561
562 (before, Some(after))
563 }
564}
565
566pub(crate) fn compute_diff_between_snapshots(
567 old_snapshot: &TextBufferSnapshot,
568 new_snapshot: &TextBufferSnapshot,
569) -> Option<(String, Range<Point>)> {
570 let edits: Vec<Edit<usize>> = new_snapshot
571 .edits_since::<usize>(&old_snapshot.version)
572 .collect();
573
574 let (first_edit, last_edit) = edits.first().zip(edits.last())?;
575
576 let old_start_point = old_snapshot.offset_to_point(first_edit.old.start);
577 let old_end_point = old_snapshot.offset_to_point(last_edit.old.end);
578 let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
579 let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
580
581 const CONTEXT_LINES: u32 = 3;
582
583 let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
584 let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
585 let old_context_end_row =
586 (old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
587 let new_context_end_row =
588 (new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
589
590 let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
591 let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
592 let old_end_line_offset = old_snapshot
593 .point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
594 let new_end_line_offset = new_snapshot
595 .point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
596 let old_edit_range = old_start_line_offset..old_end_line_offset;
597 let new_edit_range = new_start_line_offset..new_end_line_offset;
598
599 let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
600 let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
601
602 let diff = language::unified_diff_with_offsets(
603 &old_region_text,
604 &new_region_text,
605 old_context_start_row,
606 new_context_start_row,
607 );
608
609 Some((diff, new_start_point..new_end_point))
610}
611
612fn buffer_path_with_id_fallback(
613 file: Option<&Arc<dyn File>>,
614 snapshot: &TextBufferSnapshot,
615 cx: &App,
616) -> Arc<Path> {
617 if let Some(file) = file {
618 file.full_path(cx).into()
619 } else {
620 Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
621 }
622}
623
624impl EditPredictionStore {
625 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
626 cx.try_global::<EditPredictionStoreGlobal>()
627 .map(|global| global.0.clone())
628 }
629
630 pub fn global(
631 client: &Arc<Client>,
632 user_store: &Entity<UserStore>,
633 cx: &mut App,
634 ) -> Entity<Self> {
635 cx.try_global::<EditPredictionStoreGlobal>()
636 .map(|global| global.0.clone())
637 .unwrap_or_else(|| {
638 let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
639 cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
640 ep_store
641 })
642 }
643
644 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
645 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
646 let data_collection_choice = Self::load_data_collection_choice();
647
648 let llm_token = LlmApiToken::default();
649
650 let (reject_tx, reject_rx) = mpsc::unbounded();
651 cx.background_spawn({
652 let client = client.clone();
653 let llm_token = llm_token.clone();
654 let app_version = AppVersion::global(cx);
655 let background_executor = cx.background_executor().clone();
656 async move {
657 Self::handle_rejected_predictions(
658 reject_rx,
659 client,
660 llm_token,
661 app_version,
662 background_executor,
663 )
664 .await
665 }
666 })
667 .detach();
668
669 let this = Self {
670 projects: HashMap::default(),
671 client,
672 user_store,
673 llm_token,
674 _llm_token_subscription: cx.subscribe(
675 &refresh_llm_token_listener,
676 |this, _listener, _event, cx| {
677 let client = this.client.clone();
678 let llm_token = this.llm_token.clone();
679 cx.spawn(async move |_this, _cx| {
680 llm_token.refresh(&client).await?;
681 anyhow::Ok(())
682 })
683 .detach_and_log_err(cx);
684 },
685 ),
686 update_required: false,
687 edit_prediction_model: EditPredictionModel::Zeta2,
688 zeta2_raw_config: Self::zeta2_raw_config_from_env(),
689 sweep_ai: SweepAi::new(cx),
690 mercury: Mercury::new(cx),
691 ollama: Ollama::new(),
692
693 data_collection_choice,
694 reject_predictions_tx: reject_tx,
695 rated_predictions: Default::default(),
696 shown_predictions: Default::default(),
697 };
698
699 this
700 }
701
702 fn zeta2_raw_config_from_env() -> Option<Zeta2RawConfig> {
703 let version_str = env::var("ZED_ZETA_FORMAT").ok()?;
704 let format = ZetaFormat::parse(&version_str).ok()?;
705 let model_id = env::var("ZED_ZETA_MODEL").ok();
706 Some(Zeta2RawConfig { model_id, format })
707 }
708
709 pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
710 self.edit_prediction_model = model;
711 }
712
713 pub fn set_zeta2_raw_config(&mut self, config: Zeta2RawConfig) {
714 self.zeta2_raw_config = Some(config);
715 }
716
717 pub fn zeta2_raw_config(&self) -> Option<&Zeta2RawConfig> {
718 self.zeta2_raw_config.as_ref()
719 }
720
721 pub fn icons(&self) -> edit_prediction_types::EditPredictionIconSet {
722 use ui::IconName;
723 match self.edit_prediction_model {
724 EditPredictionModel::Sweep => {
725 edit_prediction_types::EditPredictionIconSet::new(IconName::SweepAi)
726 .with_disabled(IconName::SweepAiDisabled)
727 .with_up(IconName::SweepAiUp)
728 .with_down(IconName::SweepAiDown)
729 .with_error(IconName::SweepAiError)
730 }
731 EditPredictionModel::Mercury => {
732 edit_prediction_types::EditPredictionIconSet::new(IconName::Inception)
733 }
734 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
735 edit_prediction_types::EditPredictionIconSet::new(IconName::ZedPredict)
736 .with_disabled(IconName::ZedPredictDisabled)
737 .with_up(IconName::ZedPredictUp)
738 .with_down(IconName::ZedPredictDown)
739 .with_error(IconName::ZedPredictError)
740 }
741 EditPredictionModel::Ollama => {
742 edit_prediction_types::EditPredictionIconSet::new(IconName::AiOllama)
743 }
744 }
745 }
746
747 pub fn has_sweep_api_token(&self, cx: &App) -> bool {
748 self.sweep_ai.api_token.read(cx).has_key()
749 }
750
751 pub fn has_mercury_api_token(&self, cx: &App) -> bool {
752 self.mercury.api_token.read(cx).has_key()
753 }
754
755 pub fn clear_history(&mut self) {
756 for project_state in self.projects.values_mut() {
757 project_state.events.clear();
758 project_state.last_event.take();
759 }
760 }
761
762 pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
763 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
764 project_state.events.clear();
765 project_state.last_event.take();
766 }
767 }
768
769 pub fn edit_history_for_project(
770 &self,
771 project: &Entity<Project>,
772 cx: &App,
773 ) -> Vec<StoredEvent> {
774 self.projects
775 .get(&project.entity_id())
776 .map(|project_state| project_state.events(cx))
777 .unwrap_or_default()
778 }
779
780 pub fn context_for_project<'a>(
781 &'a self,
782 project: &Entity<Project>,
783 cx: &'a mut App,
784 ) -> Vec<RelatedFile> {
785 self.projects
786 .get(&project.entity_id())
787 .map(|project| {
788 project
789 .context
790 .update(cx, |context, cx| context.related_files(cx))
791 })
792 .unwrap_or_default()
793 }
794
795 pub fn copilot_for_project(&self, project: &Entity<Project>) -> Option<Entity<Copilot>> {
796 self.projects
797 .get(&project.entity_id())
798 .and_then(|project| project.copilot.clone())
799 }
800
801 pub fn start_copilot_for_project(
802 &mut self,
803 project: &Entity<Project>,
804 cx: &mut Context<Self>,
805 ) -> Option<Entity<Copilot>> {
806 if DisableAiSettings::get(None, cx).disable_ai {
807 return None;
808 }
809 let state = self.get_or_init_project(project, cx);
810
811 if state.copilot.is_some() {
812 return state.copilot.clone();
813 }
814 let _project = project.clone();
815 let project = project.read(cx);
816
817 let node = project.node_runtime().cloned();
818 if let Some(node) = node {
819 let next_id = project.languages().next_language_server_id();
820 let fs = project.fs().clone();
821
822 let copilot = cx.new(|cx| Copilot::new(Some(_project), next_id, fs, node, cx));
823 state.copilot = Some(copilot.clone());
824 Some(copilot)
825 } else {
826 None
827 }
828 }
829
830 pub fn context_for_project_with_buffers<'a>(
831 &'a self,
832 project: &Entity<Project>,
833 cx: &'a mut App,
834 ) -> Vec<(RelatedFile, Entity<Buffer>)> {
835 self.projects
836 .get(&project.entity_id())
837 .map(|project| {
838 project
839 .context
840 .update(cx, |context, cx| context.related_files_with_buffers(cx))
841 })
842 .unwrap_or_default()
843 }
844
845 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
846 if matches!(self.edit_prediction_model, EditPredictionModel::Zeta2) {
847 self.user_store.read(cx).edit_prediction_usage()
848 } else {
849 None
850 }
851 }
852
853 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
854 self.get_or_init_project(project, cx);
855 }
856
857 pub fn register_buffer(
858 &mut self,
859 buffer: &Entity<Buffer>,
860 project: &Entity<Project>,
861 cx: &mut Context<Self>,
862 ) {
863 let project_state = self.get_or_init_project(project, cx);
864 Self::register_buffer_impl(project_state, buffer, project, cx);
865 }
866
867 fn get_or_init_project(
868 &mut self,
869 project: &Entity<Project>,
870 cx: &mut Context<Self>,
871 ) -> &mut ProjectState {
872 let entity_id = project.entity_id();
873 self.projects
874 .entry(entity_id)
875 .or_insert_with(|| ProjectState {
876 context: {
877 let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
878 cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
879 this.handle_excerpt_store_event(entity_id, event);
880 })
881 .detach();
882 related_excerpt_store
883 },
884 events: VecDeque::new(),
885 last_event: None,
886 recent_paths: VecDeque::new(),
887 debug_tx: None,
888 registered_buffers: HashMap::default(),
889 current_prediction: None,
890 cancelled_predictions: HashSet::default(),
891 pending_predictions: ArrayVec::new(),
892 next_pending_prediction_id: 0,
893 last_prediction_refresh: None,
894 license_detection_watchers: HashMap::default(),
895 user_actions: VecDeque::with_capacity(USER_ACTION_HISTORY_SIZE),
896 _subscriptions: [
897 cx.subscribe(&project, Self::handle_project_event),
898 cx.observe_release(&project, move |this, _, cx| {
899 this.projects.remove(&entity_id);
900 cx.notify();
901 }),
902 ],
903 copilot: None,
904 })
905 }
906
907 pub fn remove_project(&mut self, project: &Entity<Project>) {
908 self.projects.remove(&project.entity_id());
909 }
910
911 fn handle_excerpt_store_event(
912 &mut self,
913 project_entity_id: EntityId,
914 event: &RelatedExcerptStoreEvent,
915 ) {
916 if let Some(project_state) = self.projects.get(&project_entity_id) {
917 if let Some(debug_tx) = project_state.debug_tx.clone() {
918 match event {
919 RelatedExcerptStoreEvent::StartedRefresh => {
920 debug_tx
921 .unbounded_send(DebugEvent::ContextRetrievalStarted(
922 ContextRetrievalStartedDebugEvent {
923 project_entity_id: project_entity_id,
924 timestamp: Instant::now(),
925 search_prompt: String::new(),
926 },
927 ))
928 .ok();
929 }
930 RelatedExcerptStoreEvent::FinishedRefresh {
931 cache_hit_count,
932 cache_miss_count,
933 mean_definition_latency,
934 max_definition_latency,
935 } => {
936 debug_tx
937 .unbounded_send(DebugEvent::ContextRetrievalFinished(
938 ContextRetrievalFinishedDebugEvent {
939 project_entity_id: project_entity_id,
940 timestamp: Instant::now(),
941 metadata: vec![
942 (
943 "Cache Hits",
944 format!(
945 "{}/{}",
946 cache_hit_count,
947 cache_hit_count + cache_miss_count
948 )
949 .into(),
950 ),
951 (
952 "Max LSP Time",
953 format!("{} ms", max_definition_latency.as_millis())
954 .into(),
955 ),
956 (
957 "Mean LSP Time",
958 format!("{} ms", mean_definition_latency.as_millis())
959 .into(),
960 ),
961 ],
962 },
963 ))
964 .ok();
965 }
966 }
967 }
968 }
969 }
970
971 pub fn debug_info(
972 &mut self,
973 project: &Entity<Project>,
974 cx: &mut Context<Self>,
975 ) -> mpsc::UnboundedReceiver<DebugEvent> {
976 let project_state = self.get_or_init_project(project, cx);
977 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
978 project_state.debug_tx = Some(debug_watch_tx);
979 debug_watch_rx
980 }
981
982 fn handle_project_event(
983 &mut self,
984 project: Entity<Project>,
985 event: &project::Event,
986 cx: &mut Context<Self>,
987 ) {
988 // TODO [zeta2] init with recent paths
989 match event {
990 project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
991 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
992 return;
993 };
994 let path = project.read(cx).path_for_entry(*active_entry_id, cx);
995 if let Some(path) = path {
996 if let Some(ix) = project_state
997 .recent_paths
998 .iter()
999 .position(|probe| probe == &path)
1000 {
1001 project_state.recent_paths.remove(ix);
1002 }
1003 project_state.recent_paths.push_front(path);
1004 }
1005 }
1006 project::Event::DiagnosticsUpdated { .. } => {
1007 if cx.has_flag::<Zeta2FeatureFlag>() {
1008 self.refresh_prediction_from_diagnostics(project, cx);
1009 }
1010 }
1011 _ => (),
1012 }
1013 }
1014
1015 fn register_buffer_impl<'a>(
1016 project_state: &'a mut ProjectState,
1017 buffer: &Entity<Buffer>,
1018 project: &Entity<Project>,
1019 cx: &mut Context<Self>,
1020 ) -> &'a mut RegisteredBuffer {
1021 let buffer_id = buffer.entity_id();
1022
1023 if let Some(file) = buffer.read(cx).file() {
1024 let worktree_id = file.worktree_id(cx);
1025 if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
1026 project_state
1027 .license_detection_watchers
1028 .entry(worktree_id)
1029 .or_insert_with(|| {
1030 let project_entity_id = project.entity_id();
1031 cx.observe_release(&worktree, move |this, _worktree, _cx| {
1032 let Some(project_state) = this.projects.get_mut(&project_entity_id)
1033 else {
1034 return;
1035 };
1036 project_state
1037 .license_detection_watchers
1038 .remove(&worktree_id);
1039 })
1040 .detach();
1041 Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
1042 });
1043 }
1044 }
1045
1046 match project_state.registered_buffers.entry(buffer_id) {
1047 hash_map::Entry::Occupied(entry) => entry.into_mut(),
1048 hash_map::Entry::Vacant(entry) => {
1049 let buf = buffer.read(cx);
1050 let snapshot = buf.text_snapshot();
1051 let file = buf.file().cloned();
1052 let project_entity_id = project.entity_id();
1053 entry.insert(RegisteredBuffer {
1054 snapshot,
1055 file,
1056 last_position: None,
1057 _subscriptions: [
1058 cx.subscribe(buffer, {
1059 let project = project.downgrade();
1060 move |this, buffer, event, cx| {
1061 if let language::BufferEvent::Edited = event
1062 && let Some(project) = project.upgrade()
1063 {
1064 this.report_changes_for_buffer(&buffer, &project, false, cx);
1065 }
1066 }
1067 }),
1068 cx.observe_release(buffer, move |this, _buffer, _cx| {
1069 let Some(project_state) = this.projects.get_mut(&project_entity_id)
1070 else {
1071 return;
1072 };
1073 project_state.registered_buffers.remove(&buffer_id);
1074 }),
1075 ],
1076 })
1077 }
1078 }
1079 }
1080
1081 fn report_changes_for_buffer(
1082 &mut self,
1083 buffer: &Entity<Buffer>,
1084 project: &Entity<Project>,
1085 is_predicted: bool,
1086 cx: &mut Context<Self>,
1087 ) {
1088 let project_state = self.get_or_init_project(project, cx);
1089 let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
1090
1091 let buf = buffer.read(cx);
1092 let new_file = buf.file().cloned();
1093 let new_snapshot = buf.text_snapshot();
1094 if new_snapshot.version == registered_buffer.snapshot.version {
1095 return;
1096 }
1097
1098 let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
1099 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1100 let mut num_edits = 0usize;
1101 let mut total_deleted = 0usize;
1102 let mut total_inserted = 0usize;
1103 let mut edit_range: Option<Range<Anchor>> = None;
1104 let mut last_offset: Option<usize> = None;
1105
1106 for (edit, anchor_range) in
1107 new_snapshot.anchored_edits_since::<usize>(&old_snapshot.version)
1108 {
1109 num_edits += 1;
1110 total_deleted += edit.old.len();
1111 total_inserted += edit.new.len();
1112 edit_range = Some(match edit_range {
1113 None => anchor_range,
1114 Some(acc) => acc.start..anchor_range.end,
1115 });
1116 last_offset = Some(edit.new.end);
1117 }
1118
1119 let Some(edit_range) = edit_range else {
1120 return;
1121 };
1122
1123 let action_type = match (total_deleted, total_inserted, num_edits) {
1124 (0, ins, n) if ins == n => UserActionType::InsertChar,
1125 (0, _, _) => UserActionType::InsertSelection,
1126 (del, 0, n) if del == n => UserActionType::DeleteChar,
1127 (_, 0, _) => UserActionType::DeleteSelection,
1128 (_, ins, n) if ins == n => UserActionType::InsertChar,
1129 (_, _, _) => UserActionType::InsertSelection,
1130 };
1131
1132 if let Some(offset) = last_offset {
1133 let point = new_snapshot.offset_to_point(offset);
1134 let timestamp_epoch_ms = SystemTime::now()
1135 .duration_since(UNIX_EPOCH)
1136 .map(|d| d.as_millis() as u64)
1137 .unwrap_or(0);
1138 project_state.record_user_action(UserActionRecord {
1139 action_type,
1140 buffer_id: buffer.entity_id(),
1141 line_number: point.row,
1142 offset,
1143 timestamp_epoch_ms,
1144 });
1145 }
1146
1147 let events = &mut project_state.events;
1148
1149 let now = cx.background_executor().now();
1150 if let Some(last_event) = project_state.last_event.as_mut() {
1151 let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
1152 == last_event.new_snapshot.remote_id()
1153 && old_snapshot.version == last_event.new_snapshot.version;
1154
1155 let prediction_source_changed = is_predicted != last_event.predicted;
1156
1157 let should_coalesce = is_next_snapshot_of_same_buffer
1158 && !prediction_source_changed
1159 && last_event
1160 .edit_range
1161 .as_ref()
1162 .is_some_and(|last_edit_range| {
1163 lines_between_ranges(
1164 &edit_range.to_point(&new_snapshot),
1165 &last_edit_range.to_point(&new_snapshot),
1166 ) <= CHANGE_GROUPING_LINE_SPAN
1167 });
1168
1169 if should_coalesce {
1170 let pause_elapsed = last_event
1171 .last_edit_time
1172 .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
1173 .unwrap_or(false);
1174 if pause_elapsed {
1175 last_event.snapshot_after_last_editing_pause =
1176 Some(last_event.new_snapshot.clone());
1177 }
1178
1179 last_event.edit_range = Some(edit_range);
1180 last_event.new_snapshot = new_snapshot;
1181 last_event.last_edit_time = Some(now);
1182 return;
1183 }
1184 }
1185
1186 if let Some(event) = project_state.last_event.take() {
1187 if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1188 if events.len() + 1 >= EVENT_COUNT_MAX {
1189 events.pop_front();
1190 }
1191 events.push_back(event);
1192 }
1193 }
1194
1195 merge_trailing_events_if_needed(events, &old_snapshot, &new_snapshot, &edit_range);
1196
1197 project_state.last_event = Some(LastEvent {
1198 old_file,
1199 new_file,
1200 old_snapshot,
1201 new_snapshot,
1202 edit_range: Some(edit_range),
1203 predicted: is_predicted,
1204 snapshot_after_last_editing_pause: None,
1205 last_edit_time: Some(now),
1206 });
1207 }
1208
1209 fn prediction_at(
1210 &mut self,
1211 buffer: &Entity<Buffer>,
1212 position: Option<language::Anchor>,
1213 project: &Entity<Project>,
1214 cx: &App,
1215 ) -> Option<BufferEditPrediction<'_>> {
1216 let project_state = self.projects.get_mut(&project.entity_id())?;
1217 if let Some(position) = position
1218 && let Some(buffer) = project_state
1219 .registered_buffers
1220 .get_mut(&buffer.entity_id())
1221 {
1222 buffer.last_position = Some(position);
1223 }
1224
1225 let CurrentEditPrediction {
1226 requested_by,
1227 prediction,
1228 ..
1229 } = project_state.current_prediction.as_ref()?;
1230
1231 if prediction.targets_buffer(buffer.read(cx)) {
1232 Some(BufferEditPrediction::Local { prediction })
1233 } else {
1234 let show_jump = match requested_by {
1235 PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1236 requested_by_buffer_id == &buffer.entity_id()
1237 }
1238 PredictionRequestedBy::DiagnosticsUpdate => true,
1239 };
1240
1241 if show_jump {
1242 Some(BufferEditPrediction::Jump { prediction })
1243 } else {
1244 None
1245 }
1246 }
1247 }
1248
1249 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1250 let Some(current_prediction) = self
1251 .projects
1252 .get_mut(&project.entity_id())
1253 .and_then(|project_state| project_state.current_prediction.take())
1254 else {
1255 return;
1256 };
1257
1258 self.report_changes_for_buffer(¤t_prediction.prediction.buffer, project, true, cx);
1259
1260 // can't hold &mut project_state ref across report_changes_for_buffer_call
1261 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1262 return;
1263 };
1264
1265 for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1266 project_state.cancel_pending_prediction(pending_prediction, cx);
1267 }
1268
1269 match self.edit_prediction_model {
1270 EditPredictionModel::Sweep => {
1271 sweep_ai::edit_prediction_accepted(self, current_prediction, cx)
1272 }
1273 EditPredictionModel::Mercury => {
1274 mercury::edit_prediction_accepted(
1275 current_prediction.prediction.id,
1276 self.client.http_client(),
1277 cx,
1278 );
1279 }
1280 EditPredictionModel::Ollama => {}
1281 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
1282 zeta2::edit_prediction_accepted(self, current_prediction, cx)
1283 }
1284 }
1285 }
1286
1287 async fn handle_rejected_predictions(
1288 rx: UnboundedReceiver<EditPredictionRejection>,
1289 client: Arc<Client>,
1290 llm_token: LlmApiToken,
1291 app_version: Version,
1292 background_executor: BackgroundExecutor,
1293 ) {
1294 let mut rx = std::pin::pin!(rx.peekable());
1295 let mut batched = Vec::new();
1296
1297 while let Some(rejection) = rx.next().await {
1298 batched.push(rejection);
1299
1300 if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1301 select_biased! {
1302 next = rx.as_mut().peek().fuse() => {
1303 if next.is_some() {
1304 continue;
1305 }
1306 }
1307 () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1308 }
1309 }
1310
1311 let url = client
1312 .http_client()
1313 .build_zed_llm_url("/predict_edits/reject", &[])
1314 .unwrap();
1315
1316 let flush_count = batched
1317 .len()
1318 // in case items have accumulated after failure
1319 .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1320 let start = batched.len() - flush_count;
1321
1322 let body = RejectEditPredictionsBodyRef {
1323 rejections: &batched[start..],
1324 };
1325
1326 let result = Self::send_api_request::<()>(
1327 |builder| {
1328 let req = builder
1329 .uri(url.as_ref())
1330 .body(serde_json::to_string(&body)?.into());
1331 anyhow::Ok(req?)
1332 },
1333 client.clone(),
1334 llm_token.clone(),
1335 app_version.clone(),
1336 true,
1337 )
1338 .await;
1339
1340 if result.log_err().is_some() {
1341 batched.drain(start..);
1342 }
1343 }
1344 }
1345
1346 fn reject_current_prediction(
1347 &mut self,
1348 reason: EditPredictionRejectReason,
1349 project: &Entity<Project>,
1350 cx: &App,
1351 ) {
1352 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1353 project_state.pending_predictions.clear();
1354 if let Some(prediction) = project_state.current_prediction.take() {
1355 self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown, cx);
1356 }
1357 };
1358 }
1359
1360 fn did_show_current_prediction(
1361 &mut self,
1362 project: &Entity<Project>,
1363 display_type: edit_prediction_types::SuggestionDisplayType,
1364 cx: &mut Context<Self>,
1365 ) {
1366 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1367 return;
1368 };
1369
1370 let Some(current_prediction) = project_state.current_prediction.as_mut() else {
1371 return;
1372 };
1373
1374 let is_jump = display_type == edit_prediction_types::SuggestionDisplayType::Jump;
1375 let previous_shown_with = current_prediction.shown_with;
1376
1377 if previous_shown_with.is_none() || !is_jump {
1378 current_prediction.shown_with = Some(display_type);
1379 }
1380
1381 let is_first_non_jump_show = !current_prediction.was_shown && !is_jump;
1382
1383 if is_first_non_jump_show {
1384 current_prediction.was_shown = true;
1385 }
1386
1387 let display_type_changed = previous_shown_with != Some(display_type);
1388
1389 if self.edit_prediction_model == EditPredictionModel::Sweep && display_type_changed {
1390 sweep_ai::edit_prediction_shown(
1391 &self.sweep_ai,
1392 self.client.clone(),
1393 ¤t_prediction.prediction,
1394 display_type,
1395 cx,
1396 );
1397 }
1398
1399 if is_first_non_jump_show {
1400 self.shown_predictions
1401 .push_front(current_prediction.prediction.clone());
1402 if self.shown_predictions.len() > 50 {
1403 let completion = self.shown_predictions.pop_back().unwrap();
1404 self.rated_predictions.remove(&completion.id);
1405 }
1406 }
1407 }
1408
1409 fn reject_prediction(
1410 &mut self,
1411 prediction_id: EditPredictionId,
1412 reason: EditPredictionRejectReason,
1413 was_shown: bool,
1414 cx: &App,
1415 ) {
1416 match self.edit_prediction_model {
1417 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
1418 self.reject_predictions_tx
1419 .unbounded_send(EditPredictionRejection {
1420 request_id: prediction_id.to_string(),
1421 reason,
1422 was_shown,
1423 })
1424 .log_err();
1425 }
1426 EditPredictionModel::Sweep | EditPredictionModel::Ollama => {}
1427 EditPredictionModel::Mercury => {
1428 mercury::edit_prediction_rejected(
1429 prediction_id,
1430 was_shown,
1431 reason,
1432 self.client.http_client(),
1433 cx,
1434 );
1435 }
1436 }
1437 }
1438
1439 fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1440 self.projects
1441 .get(&project.entity_id())
1442 .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1443 }
1444
1445 pub fn refresh_prediction_from_buffer(
1446 &mut self,
1447 project: Entity<Project>,
1448 buffer: Entity<Buffer>,
1449 position: language::Anchor,
1450 cx: &mut Context<Self>,
1451 ) {
1452 self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
1453 let Some(request_task) = this
1454 .update(cx, |this, cx| {
1455 this.request_prediction(
1456 &project,
1457 &buffer,
1458 position,
1459 PredictEditsRequestTrigger::Other,
1460 cx,
1461 )
1462 })
1463 .log_err()
1464 else {
1465 return Task::ready(anyhow::Ok(None));
1466 };
1467
1468 cx.spawn(async move |_cx| {
1469 request_task.await.map(|prediction_result| {
1470 prediction_result.map(|prediction_result| {
1471 (
1472 prediction_result,
1473 PredictionRequestedBy::Buffer(buffer.entity_id()),
1474 )
1475 })
1476 })
1477 })
1478 })
1479 }
1480
1481 pub fn refresh_prediction_from_diagnostics(
1482 &mut self,
1483 project: Entity<Project>,
1484 cx: &mut Context<Self>,
1485 ) {
1486 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1487 return;
1488 };
1489
1490 // Prefer predictions from buffer
1491 if project_state.current_prediction.is_some() {
1492 return;
1493 };
1494
1495 self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
1496 let Some((active_buffer, snapshot, cursor_point)) = this
1497 .read_with(cx, |this, cx| {
1498 let project_state = this.projects.get(&project.entity_id())?;
1499 let (buffer, position) = project_state.active_buffer(&project, cx)?;
1500 let snapshot = buffer.read(cx).snapshot();
1501
1502 if !Self::predictions_enabled_at(&snapshot, position, cx) {
1503 return None;
1504 }
1505
1506 let cursor_point = position
1507 .map(|pos| pos.to_point(&snapshot))
1508 .unwrap_or_default();
1509
1510 Some((buffer, snapshot, cursor_point))
1511 })
1512 .log_err()
1513 .flatten()
1514 else {
1515 return Task::ready(anyhow::Ok(None));
1516 };
1517
1518 cx.spawn(async move |cx| {
1519 let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1520 active_buffer,
1521 &snapshot,
1522 Default::default(),
1523 cursor_point,
1524 &project,
1525 cx,
1526 )
1527 .await?
1528 else {
1529 return anyhow::Ok(None);
1530 };
1531
1532 let Some(prediction_result) = this
1533 .update(cx, |this, cx| {
1534 this.request_prediction(
1535 &project,
1536 &jump_buffer,
1537 jump_position,
1538 PredictEditsRequestTrigger::Diagnostics,
1539 cx,
1540 )
1541 })?
1542 .await?
1543 else {
1544 return anyhow::Ok(None);
1545 };
1546
1547 this.update(cx, |this, cx| {
1548 Some((
1549 if this
1550 .get_or_init_project(&project, cx)
1551 .current_prediction
1552 .is_none()
1553 {
1554 prediction_result
1555 } else {
1556 EditPredictionResult {
1557 id: prediction_result.id,
1558 prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1559 }
1560 },
1561 PredictionRequestedBy::DiagnosticsUpdate,
1562 ))
1563 })
1564 })
1565 });
1566 }
1567
1568 fn predictions_enabled_at(
1569 snapshot: &BufferSnapshot,
1570 position: Option<language::Anchor>,
1571 cx: &App,
1572 ) -> bool {
1573 let file = snapshot.file();
1574 let all_settings = all_language_settings(file, cx);
1575 if !all_settings.show_edit_predictions(snapshot.language(), cx)
1576 || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
1577 {
1578 return false;
1579 }
1580
1581 if let Some(last_position) = position {
1582 let settings = snapshot.settings_at(last_position, cx);
1583
1584 if !settings.edit_predictions_disabled_in.is_empty()
1585 && let Some(scope) = snapshot.language_scope_at(last_position)
1586 && let Some(scope_name) = scope.override_name()
1587 && settings
1588 .edit_predictions_disabled_in
1589 .iter()
1590 .any(|s| s == scope_name)
1591 {
1592 return false;
1593 }
1594 }
1595
1596 true
1597 }
1598
1599 #[cfg(not(test))]
1600 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1601 #[cfg(test)]
1602 pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
1603
1604 fn queue_prediction_refresh(
1605 &mut self,
1606 project: Entity<Project>,
1607 throttle_entity: EntityId,
1608 cx: &mut Context<Self>,
1609 do_refresh: impl FnOnce(
1610 WeakEntity<Self>,
1611 &mut AsyncApp,
1612 )
1613 -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1614 + 'static,
1615 ) {
1616 let is_ollama = self.edit_prediction_model == EditPredictionModel::Ollama;
1617 let drop_on_cancel = is_ollama;
1618 let max_pending_predictions = if is_ollama { 1 } else { 2 };
1619 let project_state = self.get_or_init_project(&project, cx);
1620 let pending_prediction_id = project_state.next_pending_prediction_id;
1621 project_state.next_pending_prediction_id += 1;
1622 let last_request = project_state.last_prediction_refresh;
1623
1624 let task = cx.spawn(async move |this, cx| {
1625 if let Some((last_entity, last_timestamp)) = last_request
1626 && throttle_entity == last_entity
1627 && let Some(timeout) =
1628 (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
1629 {
1630 cx.background_executor().timer(timeout).await;
1631 }
1632
1633 // If this task was cancelled before the throttle timeout expired,
1634 // do not perform a request.
1635 let mut is_cancelled = true;
1636 this.update(cx, |this, cx| {
1637 let project_state = this.get_or_init_project(&project, cx);
1638 if !project_state
1639 .cancelled_predictions
1640 .remove(&pending_prediction_id)
1641 {
1642 project_state.last_prediction_refresh = Some((throttle_entity, Instant::now()));
1643 is_cancelled = false;
1644 }
1645 })
1646 .ok();
1647 if is_cancelled {
1648 return None;
1649 }
1650
1651 let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
1652 let new_prediction_id = new_prediction_result
1653 .as_ref()
1654 .map(|(prediction, _)| prediction.id.clone());
1655
1656 // When a prediction completes, remove it from the pending list, and cancel
1657 // any pending predictions that were enqueued before it.
1658 this.update(cx, |this, cx| {
1659 let project_state = this.get_or_init_project(&project, cx);
1660
1661 let is_cancelled = project_state
1662 .cancelled_predictions
1663 .remove(&pending_prediction_id);
1664
1665 let new_current_prediction = if !is_cancelled
1666 && let Some((prediction_result, requested_by)) = new_prediction_result
1667 {
1668 match prediction_result.prediction {
1669 Ok(prediction) => {
1670 let new_prediction = CurrentEditPrediction {
1671 requested_by,
1672 prediction,
1673 was_shown: false,
1674 shown_with: None,
1675 };
1676
1677 if let Some(current_prediction) =
1678 project_state.current_prediction.as_ref()
1679 {
1680 if new_prediction.should_replace_prediction(¤t_prediction, cx)
1681 {
1682 this.reject_current_prediction(
1683 EditPredictionRejectReason::Replaced,
1684 &project,
1685 cx,
1686 );
1687
1688 Some(new_prediction)
1689 } else {
1690 this.reject_prediction(
1691 new_prediction.prediction.id,
1692 EditPredictionRejectReason::CurrentPreferred,
1693 false,
1694 cx,
1695 );
1696 None
1697 }
1698 } else {
1699 Some(new_prediction)
1700 }
1701 }
1702 Err(reject_reason) => {
1703 this.reject_prediction(prediction_result.id, reject_reason, false, cx);
1704 None
1705 }
1706 }
1707 } else {
1708 None
1709 };
1710
1711 let project_state = this.get_or_init_project(&project, cx);
1712
1713 if let Some(new_prediction) = new_current_prediction {
1714 project_state.current_prediction = Some(new_prediction);
1715 }
1716
1717 let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
1718 for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1719 if pending_prediction.id == pending_prediction_id {
1720 pending_predictions.remove(ix);
1721 for pending_prediction in pending_predictions.drain(0..ix) {
1722 project_state.cancel_pending_prediction(pending_prediction, cx)
1723 }
1724 break;
1725 }
1726 }
1727 this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
1728 cx.notify();
1729 })
1730 .ok();
1731
1732 new_prediction_id
1733 });
1734
1735 if project_state.pending_predictions.len() < max_pending_predictions {
1736 project_state.pending_predictions.push(PendingPrediction {
1737 id: pending_prediction_id,
1738 task,
1739 drop_on_cancel,
1740 });
1741 } else {
1742 let pending_prediction = project_state.pending_predictions.pop().unwrap();
1743 project_state.pending_predictions.push(PendingPrediction {
1744 id: pending_prediction_id,
1745 task,
1746 drop_on_cancel,
1747 });
1748 project_state.cancel_pending_prediction(pending_prediction, cx);
1749 }
1750 }
1751
1752 pub fn request_prediction(
1753 &mut self,
1754 project: &Entity<Project>,
1755 active_buffer: &Entity<Buffer>,
1756 position: language::Anchor,
1757 trigger: PredictEditsRequestTrigger,
1758 cx: &mut Context<Self>,
1759 ) -> Task<Result<Option<EditPredictionResult>>> {
1760 self.request_prediction_internal(
1761 project.clone(),
1762 active_buffer.clone(),
1763 position,
1764 trigger,
1765 cx.has_flag::<Zeta2FeatureFlag>(),
1766 cx,
1767 )
1768 }
1769
1770 fn request_prediction_internal(
1771 &mut self,
1772 project: Entity<Project>,
1773 active_buffer: Entity<Buffer>,
1774 position: language::Anchor,
1775 trigger: PredictEditsRequestTrigger,
1776 allow_jump: bool,
1777 cx: &mut Context<Self>,
1778 ) -> Task<Result<Option<EditPredictionResult>>> {
1779 const DIAGNOSTIC_LINES_RANGE: u32 = 20;
1780
1781 self.get_or_init_project(&project, cx);
1782 let project_state = self.projects.get(&project.entity_id()).unwrap();
1783 let stored_events = project_state.events(cx);
1784 let has_events = !stored_events.is_empty();
1785 let events: Vec<Arc<zeta_prompt::Event>> =
1786 stored_events.into_iter().map(|e| e.event).collect();
1787 let debug_tx = project_state.debug_tx.clone();
1788
1789 let snapshot = active_buffer.read(cx).snapshot();
1790 let cursor_point = position.to_point(&snapshot);
1791 let current_offset = position.to_offset(&snapshot);
1792
1793 let mut user_actions: Vec<UserActionRecord> =
1794 project_state.user_actions.iter().cloned().collect();
1795
1796 if let Some(last_action) = user_actions.last() {
1797 if last_action.buffer_id == active_buffer.entity_id()
1798 && current_offset != last_action.offset
1799 {
1800 let timestamp_epoch_ms = SystemTime::now()
1801 .duration_since(UNIX_EPOCH)
1802 .map(|d| d.as_millis() as u64)
1803 .unwrap_or(0);
1804 user_actions.push(UserActionRecord {
1805 action_type: UserActionType::CursorMovement,
1806 buffer_id: active_buffer.entity_id(),
1807 line_number: cursor_point.row,
1808 offset: current_offset,
1809 timestamp_epoch_ms,
1810 });
1811 }
1812 }
1813 let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1814 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1815 let diagnostic_search_range =
1816 Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1817
1818 let related_files = self.context_for_project(&project, cx);
1819
1820 let inputs = EditPredictionModelInput {
1821 project: project.clone(),
1822 buffer: active_buffer.clone(),
1823 snapshot: snapshot.clone(),
1824 position,
1825 events,
1826 related_files,
1827 recent_paths: project_state.recent_paths.clone(),
1828 trigger,
1829 diagnostic_search_range: diagnostic_search_range.clone(),
1830 debug_tx,
1831 user_actions,
1832 };
1833
1834 let task = match &self.edit_prediction_model {
1835 EditPredictionModel::Zeta1 => {
1836 if should_send_testing_zeta2_request() {
1837 let mut zeta2_inputs = inputs.clone();
1838 zeta2_inputs.trigger = PredictEditsRequestTrigger::Testing;
1839 zeta2::request_prediction_with_zeta2(self, zeta2_inputs, cx).detach();
1840 }
1841 zeta1::request_prediction_with_zeta1(self, inputs, cx)
1842 }
1843 EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2(self, inputs, cx),
1844 EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
1845 EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
1846 EditPredictionModel::Ollama => self.ollama.request_prediction(inputs, cx),
1847 };
1848
1849 cx.spawn(async move |this, cx| {
1850 let prediction = task.await?;
1851
1852 if prediction.is_none() && allow_jump {
1853 let cursor_point = position.to_point(&snapshot);
1854 if has_events
1855 && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1856 active_buffer.clone(),
1857 &snapshot,
1858 diagnostic_search_range,
1859 cursor_point,
1860 &project,
1861 cx,
1862 )
1863 .await?
1864 {
1865 return this
1866 .update(cx, |this, cx| {
1867 this.request_prediction_internal(
1868 project,
1869 jump_buffer,
1870 jump_position,
1871 trigger,
1872 false,
1873 cx,
1874 )
1875 })?
1876 .await;
1877 }
1878
1879 return anyhow::Ok(None);
1880 }
1881
1882 Ok(prediction)
1883 })
1884 }
1885
1886 async fn next_diagnostic_location(
1887 active_buffer: Entity<Buffer>,
1888 active_buffer_snapshot: &BufferSnapshot,
1889 active_buffer_diagnostic_search_range: Range<Point>,
1890 active_buffer_cursor_point: Point,
1891 project: &Entity<Project>,
1892 cx: &mut AsyncApp,
1893 ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1894 // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
1895 let mut jump_location = active_buffer_snapshot
1896 .diagnostic_groups(None)
1897 .into_iter()
1898 .filter_map(|(_, group)| {
1899 let range = &group.entries[group.primary_ix]
1900 .range
1901 .to_point(&active_buffer_snapshot);
1902 if range.overlaps(&active_buffer_diagnostic_search_range) {
1903 None
1904 } else {
1905 Some(range.start)
1906 }
1907 })
1908 .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
1909 .map(|position| {
1910 (
1911 active_buffer.clone(),
1912 active_buffer_snapshot.anchor_before(position),
1913 )
1914 });
1915
1916 if jump_location.is_none() {
1917 let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
1918 let file = buffer.file()?;
1919
1920 Some(ProjectPath {
1921 worktree_id: file.worktree_id(cx),
1922 path: file.path().clone(),
1923 })
1924 });
1925
1926 let buffer_task = project.update(cx, |project, cx| {
1927 let (path, _, _) = project
1928 .diagnostic_summaries(false, cx)
1929 .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
1930 .max_by_key(|(path, _, _)| {
1931 // find the buffer with errors that shares most parent directories
1932 path.path
1933 .components()
1934 .zip(
1935 active_buffer_path
1936 .as_ref()
1937 .map(|p| p.path.components())
1938 .unwrap_or_default(),
1939 )
1940 .take_while(|(a, b)| a == b)
1941 .count()
1942 })?;
1943
1944 Some(project.open_buffer(path, cx))
1945 });
1946
1947 if let Some(buffer_task) = buffer_task {
1948 let closest_buffer = buffer_task.await?;
1949
1950 jump_location = closest_buffer
1951 .read_with(cx, |buffer, _cx| {
1952 buffer
1953 .buffer_diagnostics(None)
1954 .into_iter()
1955 .min_by_key(|entry| entry.diagnostic.severity)
1956 .map(|entry| entry.range.start)
1957 })
1958 .map(|position| (closest_buffer, position));
1959 }
1960 }
1961
1962 anyhow::Ok(jump_location)
1963 }
1964
1965 async fn send_raw_llm_request(
1966 request: RawCompletionRequest,
1967 client: Arc<Client>,
1968 custom_url: Option<Arc<Url>>,
1969 llm_token: LlmApiToken,
1970 app_version: Version,
1971 ) -> Result<(RawCompletionResponse, Option<EditPredictionUsage>)> {
1972 let url = if let Some(custom_url) = custom_url {
1973 custom_url.as_ref().clone()
1974 } else {
1975 client
1976 .http_client()
1977 .build_zed_llm_url("/predict_edits/raw", &[])?
1978 };
1979
1980 Self::send_api_request(
1981 |builder| {
1982 let req = builder
1983 .uri(url.as_ref())
1984 .body(serde_json::to_string(&request)?.into());
1985 Ok(req?)
1986 },
1987 client,
1988 llm_token,
1989 app_version,
1990 true,
1991 )
1992 .await
1993 }
1994
1995 pub(crate) async fn send_v3_request(
1996 input: ZetaPromptInput,
1997 client: Arc<Client>,
1998 llm_token: LlmApiToken,
1999 app_version: Version,
2000 trigger: PredictEditsRequestTrigger,
2001 ) -> Result<(PredictEditsV3Response, Option<EditPredictionUsage>)> {
2002 let url = client
2003 .http_client()
2004 .build_zed_llm_url("/predict_edits/v3", &[])?;
2005
2006 let request = PredictEditsV3Request { input, trigger };
2007
2008 let json_bytes = serde_json::to_vec(&request)?;
2009 let compressed = zstd::encode_all(&json_bytes[..], 3)?;
2010
2011 Self::send_api_request(
2012 |builder| {
2013 let req = builder
2014 .uri(url.as_ref())
2015 .header("Content-Encoding", "zstd")
2016 .body(compressed.clone().into());
2017 Ok(req?)
2018 },
2019 client,
2020 llm_token,
2021 app_version,
2022 true,
2023 )
2024 .await
2025 }
2026
2027 fn handle_api_response<T>(
2028 this: &WeakEntity<Self>,
2029 response: Result<(T, Option<EditPredictionUsage>)>,
2030 cx: &mut gpui::AsyncApp,
2031 ) -> Result<T> {
2032 match response {
2033 Ok((data, usage)) => {
2034 if let Some(usage) = usage {
2035 this.update(cx, |this, cx| {
2036 this.user_store.update(cx, |user_store, cx| {
2037 user_store.update_edit_prediction_usage(usage, cx);
2038 });
2039 })
2040 .ok();
2041 }
2042 Ok(data)
2043 }
2044 Err(err) => {
2045 if err.is::<ZedUpdateRequiredError>() {
2046 cx.update(|cx| {
2047 this.update(cx, |this, _cx| {
2048 this.update_required = true;
2049 })
2050 .ok();
2051
2052 let error_message: SharedString = err.to_string().into();
2053 show_app_notification(
2054 NotificationId::unique::<ZedUpdateRequiredError>(),
2055 cx,
2056 move |cx| {
2057 cx.new(|cx| {
2058 ErrorMessagePrompt::new(error_message.clone(), cx)
2059 .with_link_button("Update Zed", "https://zed.dev/releases")
2060 })
2061 },
2062 );
2063 });
2064 }
2065 Err(err)
2066 }
2067 }
2068 }
2069
2070 async fn send_api_request<Res>(
2071 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
2072 client: Arc<Client>,
2073 llm_token: LlmApiToken,
2074 app_version: Version,
2075 require_auth: bool,
2076 ) -> Result<(Res, Option<EditPredictionUsage>)>
2077 where
2078 Res: DeserializeOwned,
2079 {
2080 let http_client = client.http_client();
2081
2082 let mut token = if let Ok(custom_token) = std::env::var("ZED_PREDICT_EDITS_TOKEN") {
2083 Some(custom_token)
2084 } else if require_auth {
2085 Some(llm_token.acquire(&client).await?)
2086 } else {
2087 llm_token.acquire(&client).await.ok()
2088 };
2089 let mut did_retry = false;
2090
2091 loop {
2092 let request_builder = http_client::Request::builder().method(Method::POST);
2093
2094 let mut request_builder = request_builder
2095 .header("Content-Type", "application/json")
2096 .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
2097
2098 // Only add Authorization header if we have a token
2099 if let Some(ref token_value) = token {
2100 request_builder =
2101 request_builder.header("Authorization", format!("Bearer {}", token_value));
2102 }
2103
2104 let request = build(request_builder)?;
2105
2106 let mut response = http_client.send(request).await?;
2107
2108 if let Some(minimum_required_version) = response
2109 .headers()
2110 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2111 .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2112 {
2113 anyhow::ensure!(
2114 app_version >= minimum_required_version,
2115 ZedUpdateRequiredError {
2116 minimum_version: minimum_required_version
2117 }
2118 );
2119 }
2120
2121 if response.status().is_success() {
2122 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2123
2124 let mut body = Vec::new();
2125 response.body_mut().read_to_end(&mut body).await?;
2126 return Ok((serde_json::from_slice(&body)?, usage));
2127 } else if !did_retry && token.is_some() && response.needs_llm_token_refresh() {
2128 did_retry = true;
2129 token = Some(llm_token.refresh(&client).await?);
2130 } else {
2131 let mut body = String::new();
2132 response.body_mut().read_to_string(&mut body).await?;
2133 anyhow::bail!(
2134 "Request failed with status: {:?}\nBody: {}",
2135 response.status(),
2136 body
2137 );
2138 }
2139 }
2140 }
2141
2142 pub fn refresh_context(
2143 &mut self,
2144 project: &Entity<Project>,
2145 buffer: &Entity<language::Buffer>,
2146 cursor_position: language::Anchor,
2147 cx: &mut Context<Self>,
2148 ) {
2149 self.get_or_init_project(project, cx)
2150 .context
2151 .update(cx, |store, cx| {
2152 store.refresh(buffer.clone(), cursor_position, cx);
2153 });
2154 }
2155
2156 #[cfg(feature = "cli-support")]
2157 pub fn set_context_for_buffer(
2158 &mut self,
2159 project: &Entity<Project>,
2160 related_files: Vec<RelatedFile>,
2161 cx: &mut Context<Self>,
2162 ) {
2163 self.get_or_init_project(project, cx)
2164 .context
2165 .update(cx, |store, cx| {
2166 store.set_related_files(related_files, cx);
2167 });
2168 }
2169
2170 #[cfg(feature = "cli-support")]
2171 pub fn set_recent_paths_for_project(
2172 &mut self,
2173 project: &Entity<Project>,
2174 paths: impl IntoIterator<Item = project::ProjectPath>,
2175 cx: &mut Context<Self>,
2176 ) {
2177 let project_state = self.get_or_init_project(project, cx);
2178 project_state.recent_paths = paths.into_iter().collect();
2179 }
2180
2181 fn is_file_open_source(
2182 &self,
2183 project: &Entity<Project>,
2184 file: &Arc<dyn File>,
2185 cx: &App,
2186 ) -> bool {
2187 if !file.is_local() || file.is_private() {
2188 return false;
2189 }
2190 let Some(project_state) = self.projects.get(&project.entity_id()) else {
2191 return false;
2192 };
2193 project_state
2194 .license_detection_watchers
2195 .get(&file.worktree_id(cx))
2196 .as_ref()
2197 .is_some_and(|watcher| watcher.is_project_open_source())
2198 }
2199
2200 fn can_collect_file(&self, project: &Entity<Project>, file: &Arc<dyn File>, cx: &App) -> bool {
2201 self.data_collection_choice.is_enabled(cx) && self.is_file_open_source(project, file, cx)
2202 }
2203
2204 fn can_collect_events(&self, events: &[Arc<zeta_prompt::Event>], cx: &App) -> bool {
2205 if !self.data_collection_choice.is_enabled(cx) {
2206 return false;
2207 }
2208 events.iter().all(|event| {
2209 matches!(
2210 event.as_ref(),
2211 zeta_prompt::Event::BufferChange {
2212 in_open_source_repo: true,
2213 ..
2214 }
2215 )
2216 })
2217 }
2218
2219 fn load_data_collection_choice() -> DataCollectionChoice {
2220 let choice = KEY_VALUE_STORE
2221 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2222 .log_err()
2223 .flatten();
2224
2225 match choice.as_deref() {
2226 Some("true") => DataCollectionChoice::Enabled,
2227 Some("false") => DataCollectionChoice::Disabled,
2228 Some(_) => {
2229 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2230 DataCollectionChoice::NotAnswered
2231 }
2232 None => DataCollectionChoice::NotAnswered,
2233 }
2234 }
2235
2236 fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
2237 self.data_collection_choice = self.data_collection_choice.toggle();
2238 let new_choice = self.data_collection_choice;
2239 let is_enabled = new_choice.is_enabled(cx);
2240 db::write_and_log(cx, move || {
2241 KEY_VALUE_STORE.write_kvp(
2242 ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
2243 is_enabled.to_string(),
2244 )
2245 });
2246 }
2247
2248 pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2249 self.shown_predictions.iter()
2250 }
2251
2252 pub fn shown_completions_len(&self) -> usize {
2253 self.shown_predictions.len()
2254 }
2255
2256 pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2257 self.rated_predictions.contains(id)
2258 }
2259
2260 pub fn rate_prediction(
2261 &mut self,
2262 prediction: &EditPrediction,
2263 rating: EditPredictionRating,
2264 feedback: String,
2265 cx: &mut Context<Self>,
2266 ) {
2267 self.rated_predictions.insert(prediction.id.clone());
2268 telemetry::event!(
2269 "Edit Prediction Rated",
2270 request_id = prediction.id.to_string(),
2271 rating,
2272 inputs = prediction.inputs,
2273 output = prediction
2274 .edit_preview
2275 .as_unified_diff(prediction.snapshot.file(), &prediction.edits),
2276 feedback
2277 );
2278 self.client.telemetry().flush_events().detach();
2279 cx.notify();
2280 }
2281}
2282
2283fn merge_trailing_events_if_needed(
2284 events: &mut VecDeque<StoredEvent>,
2285 end_snapshot: &TextBufferSnapshot,
2286 latest_snapshot: &TextBufferSnapshot,
2287 latest_edit_range: &Range<Anchor>,
2288) {
2289 if let Some(last_event) = events.back() {
2290 if last_event.old_snapshot.remote_id() != latest_snapshot.remote_id() {
2291 return;
2292 }
2293 }
2294
2295 let mut next_old_event = None;
2296 let mut mergeable_count = 0;
2297 for old_event in events.iter().rev() {
2298 if let Some(next_old_event) = &next_old_event
2299 && !old_event.can_merge(&next_old_event, latest_snapshot, latest_edit_range)
2300 {
2301 break;
2302 }
2303 mergeable_count += 1;
2304 next_old_event = Some(old_event);
2305 }
2306
2307 if mergeable_count <= 1 {
2308 return;
2309 }
2310
2311 let mut events_to_merge = events.range(events.len() - mergeable_count..).peekable();
2312 let oldest_event = events_to_merge.peek().unwrap();
2313 let oldest_snapshot = oldest_event.old_snapshot.clone();
2314
2315 if let Some((diff, edited_range)) =
2316 compute_diff_between_snapshots(&oldest_snapshot, end_snapshot)
2317 {
2318 let merged_event = match oldest_event.event.as_ref() {
2319 zeta_prompt::Event::BufferChange {
2320 old_path,
2321 path,
2322 in_open_source_repo,
2323 ..
2324 } => StoredEvent {
2325 event: Arc::new(zeta_prompt::Event::BufferChange {
2326 old_path: old_path.clone(),
2327 path: path.clone(),
2328 diff,
2329 in_open_source_repo: *in_open_source_repo,
2330 predicted: events_to_merge.all(|e| {
2331 matches!(
2332 e.event.as_ref(),
2333 zeta_prompt::Event::BufferChange {
2334 predicted: true,
2335 ..
2336 }
2337 )
2338 }),
2339 }),
2340 old_snapshot: oldest_snapshot.clone(),
2341 edit_range: end_snapshot.anchor_before(edited_range.start)
2342 ..end_snapshot.anchor_before(edited_range.end),
2343 },
2344 };
2345 events.truncate(events.len() - mergeable_count);
2346 events.push_back(merged_event);
2347 }
2348}
2349
2350pub(crate) fn filter_redundant_excerpts(
2351 mut related_files: Vec<RelatedFile>,
2352 cursor_path: &Path,
2353 cursor_row_range: Range<u32>,
2354) -> Vec<RelatedFile> {
2355 for file in &mut related_files {
2356 if file.path.as_ref() == cursor_path {
2357 file.excerpts.retain(|excerpt| {
2358 excerpt.row_range.start < cursor_row_range.start
2359 || excerpt.row_range.end > cursor_row_range.end
2360 });
2361 }
2362 }
2363 related_files.retain(|file| !file.excerpts.is_empty());
2364 related_files
2365}
2366
2367#[derive(Error, Debug)]
2368#[error(
2369 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2370)]
2371pub struct ZedUpdateRequiredError {
2372 minimum_version: Version,
2373}
2374
2375#[derive(Debug, Clone, Copy)]
2376pub enum DataCollectionChoice {
2377 NotAnswered,
2378 Enabled,
2379 Disabled,
2380}
2381
2382impl DataCollectionChoice {
2383 pub fn is_enabled(self, cx: &App) -> bool {
2384 if cx.is_staff() {
2385 return true;
2386 }
2387 match self {
2388 Self::Enabled => true,
2389 Self::NotAnswered | Self::Disabled => false,
2390 }
2391 }
2392
2393 #[must_use]
2394 pub fn toggle(&self) -> DataCollectionChoice {
2395 match self {
2396 Self::Enabled => Self::Disabled,
2397 Self::Disabled => Self::Enabled,
2398 Self::NotAnswered => Self::Enabled,
2399 }
2400 }
2401}
2402
2403impl From<bool> for DataCollectionChoice {
2404 fn from(value: bool) -> Self {
2405 match value {
2406 true => DataCollectionChoice::Enabled,
2407 false => DataCollectionChoice::Disabled,
2408 }
2409 }
2410}
2411
2412struct ZedPredictUpsell;
2413
2414impl Dismissable for ZedPredictUpsell {
2415 const KEY: &'static str = "dismissed-edit-predict-upsell";
2416
2417 fn dismissed() -> bool {
2418 // To make this backwards compatible with older versions of Zed, we
2419 // check if the user has seen the previous Edit Prediction Onboarding
2420 // before, by checking the data collection choice which was written to
2421 // the database once the user clicked on "Accept and Enable"
2422 if KEY_VALUE_STORE
2423 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2424 .log_err()
2425 .is_some_and(|s| s.is_some())
2426 {
2427 return true;
2428 }
2429
2430 KEY_VALUE_STORE
2431 .read_kvp(Self::KEY)
2432 .log_err()
2433 .is_some_and(|s| s.is_some())
2434 }
2435}
2436
2437pub fn should_show_upsell_modal() -> bool {
2438 !ZedPredictUpsell::dismissed()
2439}
2440
2441pub fn init(cx: &mut App) {
2442 cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2443 workspace.register_action(
2444 move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2445 ZedPredictModal::toggle(
2446 workspace,
2447 workspace.user_store().clone(),
2448 workspace.client().clone(),
2449 window,
2450 cx,
2451 )
2452 },
2453 );
2454
2455 workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2456 update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2457 settings
2458 .project
2459 .all_languages
2460 .edit_predictions
2461 .get_or_insert_default()
2462 .provider = Some(EditPredictionProvider::None)
2463 });
2464 });
2465 fn copilot_for_project(project: &Entity<Project>, cx: &mut App) -> Option<Entity<Copilot>> {
2466 EditPredictionStore::try_global(cx).and_then(|store| {
2467 store.update(cx, |this, cx| this.start_copilot_for_project(project, cx))
2468 })
2469 }
2470
2471 workspace.register_action(|workspace, _: &SignIn, window, cx| {
2472 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2473 copilot_ui::initiate_sign_in(copilot, window, cx);
2474 }
2475 });
2476 workspace.register_action(|workspace, _: &Reinstall, window, cx| {
2477 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2478 copilot_ui::reinstall_and_sign_in(copilot, window, cx);
2479 }
2480 });
2481 workspace.register_action(|workspace, _: &SignOut, window, cx| {
2482 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2483 copilot_ui::initiate_sign_out(copilot, window, cx);
2484 }
2485 });
2486 })
2487 .detach();
2488}