1use anyhow::Result;
2use arrayvec::ArrayVec;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{
5 PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
6};
7use cloud_llm_client::{
8 EditPredictionRejectReason, EditPredictionRejection,
9 MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
10 PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME,
11};
12use collections::{HashMap, HashSet};
13use copilot::{Copilot, Reinstall, SignIn, SignOut};
14use db::kvp::{Dismissable, KEY_VALUE_STORE};
15use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
16use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
17use futures::{
18 AsyncReadExt as _, FutureExt as _, StreamExt as _,
19 channel::mpsc::{self, UnboundedReceiver},
20 select_biased,
21};
22use gpui::BackgroundExecutor;
23use gpui::http_client::Url;
24use gpui::{
25 App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
26 http_client::{self, AsyncBody, Method},
27 prelude::*,
28};
29use language::language_settings::all_language_settings;
30use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
31use language::{BufferSnapshot, OffsetRangeExt};
32use language_model::{LlmApiToken, NeedsLlmTokenRefresh, RefreshLlmTokenListener};
33use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
34use release_channel::AppVersion;
35use semver::Version;
36use serde::de::DeserializeOwned;
37use settings::{EditPredictionProvider, Settings as _, update_settings_file};
38use std::collections::{VecDeque, hash_map};
39use std::env;
40use text::Edit;
41use workspace::Workspace;
42use zeta_prompt::{ZetaFormat, ZetaPromptInput};
43
44use std::mem;
45use std::ops::Range;
46use std::path::Path;
47use std::rc::Rc;
48use std::str::FromStr as _;
49use std::sync::Arc;
50use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
51use thiserror::Error;
52use util::{RangeExt as _, ResultExt as _};
53use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
54
55pub mod cursor_excerpt;
56pub mod example_spec;
57mod license_detection;
58pub mod mercury;
59pub mod ollama;
60mod onboarding_modal;
61pub mod open_ai_response;
62mod prediction;
63pub mod sweep_ai;
64
65pub mod udiff;
66
67mod capture_example;
68mod zed_edit_prediction_delegate;
69pub mod zeta1;
70pub mod zeta2;
71
72#[cfg(test)]
73mod edit_prediction_tests;
74
75use crate::license_detection::LicenseDetectionWatcher;
76use crate::mercury::Mercury;
77use crate::ollama::Ollama;
78use crate::onboarding_modal::ZedPredictModal;
79pub use crate::prediction::EditPrediction;
80pub use crate::prediction::EditPredictionId;
81use crate::prediction::EditPredictionResult;
82pub use crate::sweep_ai::SweepAi;
83pub use capture_example::capture_example;
84pub use language_model::ApiKeyState;
85pub use telemetry_events::EditPredictionRating;
86pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
87
88actions!(
89 edit_prediction,
90 [
91 /// Resets the edit prediction onboarding state.
92 ResetOnboarding,
93 /// Clears the edit prediction history.
94 ClearHistory,
95 ]
96);
97
98/// Maximum number of events to track.
99const EVENT_COUNT_MAX: usize = 6;
100const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
101const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
102const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
103const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
104
105pub struct Zeta2FeatureFlag;
106
107impl FeatureFlag for Zeta2FeatureFlag {
108 const NAME: &'static str = "zeta2";
109
110 fn enabled_for_staff() -> bool {
111 true
112 }
113}
114
115#[derive(Clone)]
116struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
117
118impl Global for EditPredictionStoreGlobal {}
119
120/// Configuration for using the raw Zeta2 endpoint.
121/// When set, the client uses the raw endpoint and constructs the prompt itself.
122/// The version is also used as the Baseten environment name (lowercased).
123#[derive(Clone)]
124pub struct Zeta2RawConfig {
125 pub model_id: Option<String>,
126 pub format: ZetaFormat,
127}
128
129pub struct EditPredictionStore {
130 client: Arc<Client>,
131 user_store: Entity<UserStore>,
132 llm_token: LlmApiToken,
133 _llm_token_subscription: Subscription,
134 projects: HashMap<EntityId, ProjectState>,
135 update_required: bool,
136 edit_prediction_model: EditPredictionModel,
137 zeta2_raw_config: Option<Zeta2RawConfig>,
138 pub sweep_ai: SweepAi,
139 pub mercury: Mercury,
140 pub ollama: Ollama,
141 data_collection_choice: DataCollectionChoice,
142 reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
143 shown_predictions: VecDeque<EditPrediction>,
144 rated_predictions: HashSet<EditPredictionId>,
145}
146
147#[derive(Copy, Clone, Default, PartialEq, Eq)]
148pub enum EditPredictionModel {
149 #[default]
150 Zeta1,
151 Zeta2,
152 Sweep,
153 Mercury,
154 Ollama,
155}
156
157#[derive(Clone)]
158pub struct EditPredictionModelInput {
159 project: Entity<Project>,
160 buffer: Entity<Buffer>,
161 snapshot: BufferSnapshot,
162 position: Anchor,
163 events: Vec<Arc<zeta_prompt::Event>>,
164 related_files: Vec<RelatedFile>,
165 recent_paths: VecDeque<ProjectPath>,
166 trigger: PredictEditsRequestTrigger,
167 diagnostic_search_range: Range<Point>,
168 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
169 pub user_actions: Vec<UserActionRecord>,
170}
171
172#[derive(Debug)]
173pub enum DebugEvent {
174 ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
175 ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
176 EditPredictionStarted(EditPredictionStartedDebugEvent),
177 EditPredictionFinished(EditPredictionFinishedDebugEvent),
178}
179
180#[derive(Debug)]
181pub struct ContextRetrievalStartedDebugEvent {
182 pub project_entity_id: EntityId,
183 pub timestamp: Instant,
184 pub search_prompt: String,
185}
186
187#[derive(Debug)]
188pub struct ContextRetrievalFinishedDebugEvent {
189 pub project_entity_id: EntityId,
190 pub timestamp: Instant,
191 pub metadata: Vec<(&'static str, SharedString)>,
192}
193
194#[derive(Debug)]
195pub struct EditPredictionStartedDebugEvent {
196 pub buffer: WeakEntity<Buffer>,
197 pub position: Anchor,
198 pub prompt: Option<String>,
199}
200
201#[derive(Debug)]
202pub struct EditPredictionFinishedDebugEvent {
203 pub buffer: WeakEntity<Buffer>,
204 pub position: Anchor,
205 pub model_output: Option<String>,
206}
207
208const USER_ACTION_HISTORY_SIZE: usize = 16;
209
210#[derive(Clone, Debug)]
211pub struct UserActionRecord {
212 pub action_type: UserActionType,
213 pub buffer_id: EntityId,
214 pub line_number: u32,
215 pub offset: usize,
216 pub timestamp_epoch_ms: u64,
217}
218
219#[derive(Clone, Copy, Debug, PartialEq, Eq)]
220pub enum UserActionType {
221 InsertChar,
222 InsertSelection,
223 DeleteChar,
224 DeleteSelection,
225 CursorMovement,
226}
227
228/// An event with associated metadata for reconstructing buffer state.
229#[derive(Clone)]
230pub struct StoredEvent {
231 pub event: Arc<zeta_prompt::Event>,
232 pub old_snapshot: TextBufferSnapshot,
233 pub edit_range: Range<Anchor>,
234}
235
236impl StoredEvent {
237 fn can_merge(
238 &self,
239 next_old_event: &&&StoredEvent,
240 new_snapshot: &TextBufferSnapshot,
241 last_edit_range: &Range<Anchor>,
242 ) -> bool {
243 // Events must be for the same buffer
244 if self.old_snapshot.remote_id() != next_old_event.old_snapshot.remote_id() {
245 return false;
246 }
247
248 let a_is_predicted = matches!(
249 self.event.as_ref(),
250 zeta_prompt::Event::BufferChange {
251 predicted: true,
252 ..
253 }
254 );
255 let b_is_predicted = matches!(
256 next_old_event.event.as_ref(),
257 zeta_prompt::Event::BufferChange {
258 predicted: true,
259 ..
260 }
261 );
262
263 // If events come from the same source (both predicted or both manual) then
264 // we would have coalesced them already.
265 if a_is_predicted == b_is_predicted {
266 return false;
267 }
268
269 let left_range = self.edit_range.to_point(new_snapshot);
270 let right_range = next_old_event.edit_range.to_point(new_snapshot);
271 let latest_range = last_edit_range.to_point(&new_snapshot);
272
273 // Events near to the latest edit are not merged if their sources differ.
274 if lines_between_ranges(&left_range, &latest_range)
275 .min(lines_between_ranges(&right_range, &latest_range))
276 <= CHANGE_GROUPING_LINE_SPAN
277 {
278 return false;
279 }
280
281 // Events that are distant from each other are not merged.
282 if lines_between_ranges(&left_range, &right_range) > CHANGE_GROUPING_LINE_SPAN {
283 return false;
284 }
285
286 true
287 }
288}
289
290fn lines_between_ranges(left: &Range<Point>, right: &Range<Point>) -> u32 {
291 if left.start > right.end {
292 return left.start.row - right.end.row;
293 }
294 if right.start > left.end {
295 return right.start.row - left.end.row;
296 }
297 0
298}
299
300struct ProjectState {
301 events: VecDeque<StoredEvent>,
302 last_event: Option<LastEvent>,
303 recent_paths: VecDeque<ProjectPath>,
304 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
305 current_prediction: Option<CurrentEditPrediction>,
306 next_pending_prediction_id: usize,
307 pending_predictions: ArrayVec<PendingPrediction, 2>,
308 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
309 last_prediction_refresh: Option<(EntityId, Instant)>,
310 cancelled_predictions: HashSet<usize>,
311 context: Entity<RelatedExcerptStore>,
312 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
313 user_actions: VecDeque<UserActionRecord>,
314 _subscriptions: [gpui::Subscription; 2],
315 copilot: Option<Entity<Copilot>>,
316}
317
318impl ProjectState {
319 fn record_user_action(&mut self, action: UserActionRecord) {
320 if self.user_actions.len() >= USER_ACTION_HISTORY_SIZE {
321 self.user_actions.pop_front();
322 }
323 self.user_actions.push_back(action);
324 }
325
326 pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
327 self.events
328 .iter()
329 .cloned()
330 .chain(self.last_event.as_ref().iter().flat_map(|event| {
331 let (one, two) = event.split_by_pause();
332 let one = one.finalize(&self.license_detection_watchers, cx);
333 let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
334 one.into_iter().chain(two)
335 }))
336 .collect()
337 }
338
339 fn cancel_pending_prediction(
340 &mut self,
341 pending_prediction: PendingPrediction,
342 cx: &mut Context<EditPredictionStore>,
343 ) {
344 self.cancelled_predictions.insert(pending_prediction.id);
345
346 if pending_prediction.drop_on_cancel {
347 drop(pending_prediction.task);
348 } else {
349 cx.spawn(async move |this, cx| {
350 let Some(prediction_id) = pending_prediction.task.await else {
351 return;
352 };
353
354 this.update(cx, |this, cx| {
355 this.reject_prediction(
356 prediction_id,
357 EditPredictionRejectReason::Canceled,
358 false,
359 cx,
360 );
361 })
362 .ok();
363 })
364 .detach()
365 }
366 }
367
368 fn active_buffer(
369 &self,
370 project: &Entity<Project>,
371 cx: &App,
372 ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
373 let project = project.read(cx);
374 let active_path = project.path_for_entry(project.active_entry()?, cx)?;
375 let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
376 let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
377 Some((active_buffer, registered_buffer.last_position))
378 }
379}
380
381#[derive(Debug, Clone)]
382struct CurrentEditPrediction {
383 pub requested_by: PredictionRequestedBy,
384 pub prediction: EditPrediction,
385 pub was_shown: bool,
386 pub shown_with: Option<edit_prediction_types::SuggestionDisplayType>,
387}
388
389impl CurrentEditPrediction {
390 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
391 let Some(new_edits) = self
392 .prediction
393 .interpolate(&self.prediction.buffer.read(cx))
394 else {
395 return false;
396 };
397
398 if self.prediction.buffer != old_prediction.prediction.buffer {
399 return true;
400 }
401
402 let Some(old_edits) = old_prediction
403 .prediction
404 .interpolate(&old_prediction.prediction.buffer.read(cx))
405 else {
406 return true;
407 };
408
409 let requested_by_buffer_id = self.requested_by.buffer_id();
410
411 // This reduces the occurrence of UI thrash from replacing edits
412 //
413 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
414 if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
415 && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
416 && old_edits.len() == 1
417 && new_edits.len() == 1
418 {
419 let (old_range, old_text) = &old_edits[0];
420 let (new_range, new_text) = &new_edits[0];
421 new_range == old_range && new_text.starts_with(old_text.as_ref())
422 } else {
423 true
424 }
425 }
426}
427
428#[derive(Debug, Clone)]
429enum PredictionRequestedBy {
430 DiagnosticsUpdate,
431 Buffer(EntityId),
432}
433
434impl PredictionRequestedBy {
435 pub fn buffer_id(&self) -> Option<EntityId> {
436 match self {
437 PredictionRequestedBy::DiagnosticsUpdate => None,
438 PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
439 }
440 }
441}
442
443#[derive(Debug)]
444struct PendingPrediction {
445 id: usize,
446 task: Task<Option<EditPredictionId>>,
447 /// If true, the task is dropped immediately on cancel (cancelling the HTTP request).
448 /// If false, the task is awaited to completion so rejection can be reported.
449 drop_on_cancel: bool,
450}
451
452/// A prediction from the perspective of a buffer.
453#[derive(Debug)]
454enum BufferEditPrediction<'a> {
455 Local { prediction: &'a EditPrediction },
456 Jump { prediction: &'a EditPrediction },
457}
458
459#[cfg(test)]
460impl std::ops::Deref for BufferEditPrediction<'_> {
461 type Target = EditPrediction;
462
463 fn deref(&self) -> &Self::Target {
464 match self {
465 BufferEditPrediction::Local { prediction } => prediction,
466 BufferEditPrediction::Jump { prediction } => prediction,
467 }
468 }
469}
470
471struct RegisteredBuffer {
472 file: Option<Arc<dyn File>>,
473 snapshot: TextBufferSnapshot,
474 last_position: Option<Anchor>,
475 _subscriptions: [gpui::Subscription; 2],
476}
477
478#[derive(Clone)]
479struct LastEvent {
480 old_snapshot: TextBufferSnapshot,
481 new_snapshot: TextBufferSnapshot,
482 old_file: Option<Arc<dyn File>>,
483 new_file: Option<Arc<dyn File>>,
484 edit_range: Option<Range<Anchor>>,
485 predicted: bool,
486 snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
487 last_edit_time: Option<Instant>,
488}
489
490impl LastEvent {
491 pub fn finalize(
492 &self,
493 license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
494 cx: &App,
495 ) -> Option<StoredEvent> {
496 let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
497 let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
498
499 let in_open_source_repo =
500 [self.new_file.as_ref(), self.old_file.as_ref()]
501 .iter()
502 .all(|file| {
503 file.is_some_and(|file| {
504 license_detection_watchers
505 .get(&file.worktree_id(cx))
506 .is_some_and(|watcher| watcher.is_project_open_source())
507 })
508 });
509
510 let (diff, edit_range) =
511 compute_diff_between_snapshots(&self.old_snapshot, &self.new_snapshot)?;
512
513 if path == old_path && diff.is_empty() {
514 None
515 } else {
516 Some(StoredEvent {
517 event: Arc::new(zeta_prompt::Event::BufferChange {
518 old_path,
519 path,
520 diff,
521 in_open_source_repo,
522 predicted: self.predicted,
523 }),
524 edit_range: self.new_snapshot.anchor_before(edit_range.start)
525 ..self.new_snapshot.anchor_before(edit_range.end),
526 old_snapshot: self.old_snapshot.clone(),
527 })
528 }
529 }
530
531 pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
532 let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
533 return (self.clone(), None);
534 };
535
536 let before = LastEvent {
537 old_snapshot: self.old_snapshot.clone(),
538 new_snapshot: boundary_snapshot.clone(),
539 old_file: self.old_file.clone(),
540 new_file: self.new_file.clone(),
541 edit_range: None,
542 predicted: self.predicted,
543 snapshot_after_last_editing_pause: None,
544 last_edit_time: self.last_edit_time,
545 };
546
547 let after = LastEvent {
548 old_snapshot: boundary_snapshot.clone(),
549 new_snapshot: self.new_snapshot.clone(),
550 old_file: self.old_file.clone(),
551 new_file: self.new_file.clone(),
552 edit_range: None,
553 predicted: self.predicted,
554 snapshot_after_last_editing_pause: None,
555 last_edit_time: self.last_edit_time,
556 };
557
558 (before, Some(after))
559 }
560}
561
562pub(crate) fn compute_diff_between_snapshots(
563 old_snapshot: &TextBufferSnapshot,
564 new_snapshot: &TextBufferSnapshot,
565) -> Option<(String, Range<Point>)> {
566 let edits: Vec<Edit<usize>> = new_snapshot
567 .edits_since::<usize>(&old_snapshot.version)
568 .collect();
569
570 let (first_edit, last_edit) = edits.first().zip(edits.last())?;
571
572 let old_start_point = old_snapshot.offset_to_point(first_edit.old.start);
573 let old_end_point = old_snapshot.offset_to_point(last_edit.old.end);
574 let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
575 let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
576
577 const CONTEXT_LINES: u32 = 3;
578
579 let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
580 let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
581 let old_context_end_row =
582 (old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
583 let new_context_end_row =
584 (new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
585
586 let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
587 let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
588 let old_end_line_offset = old_snapshot
589 .point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
590 let new_end_line_offset = new_snapshot
591 .point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
592 let old_edit_range = old_start_line_offset..old_end_line_offset;
593 let new_edit_range = new_start_line_offset..new_end_line_offset;
594
595 let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
596 let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
597
598 let diff = language::unified_diff_with_offsets(
599 &old_region_text,
600 &new_region_text,
601 old_context_start_row,
602 new_context_start_row,
603 );
604
605 Some((diff, new_start_point..new_end_point))
606}
607
608fn buffer_path_with_id_fallback(
609 file: Option<&Arc<dyn File>>,
610 snapshot: &TextBufferSnapshot,
611 cx: &App,
612) -> Arc<Path> {
613 if let Some(file) = file {
614 file.full_path(cx).into()
615 } else {
616 Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
617 }
618}
619
620impl EditPredictionStore {
621 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
622 cx.try_global::<EditPredictionStoreGlobal>()
623 .map(|global| global.0.clone())
624 }
625
626 pub fn global(
627 client: &Arc<Client>,
628 user_store: &Entity<UserStore>,
629 cx: &mut App,
630 ) -> Entity<Self> {
631 cx.try_global::<EditPredictionStoreGlobal>()
632 .map(|global| global.0.clone())
633 .unwrap_or_else(|| {
634 let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
635 cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
636 ep_store
637 })
638 }
639
640 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
641 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
642 let data_collection_choice = Self::load_data_collection_choice();
643
644 let llm_token = LlmApiToken::default();
645
646 let (reject_tx, reject_rx) = mpsc::unbounded();
647 cx.background_spawn({
648 let client = client.clone();
649 let llm_token = llm_token.clone();
650 let app_version = AppVersion::global(cx);
651 let background_executor = cx.background_executor().clone();
652 async move {
653 Self::handle_rejected_predictions(
654 reject_rx,
655 client,
656 llm_token,
657 app_version,
658 background_executor,
659 )
660 .await
661 }
662 })
663 .detach();
664
665 let this = Self {
666 projects: HashMap::default(),
667 client,
668 user_store,
669 llm_token,
670 _llm_token_subscription: cx.subscribe(
671 &refresh_llm_token_listener,
672 |this, _listener, _event, cx| {
673 let client = this.client.clone();
674 let llm_token = this.llm_token.clone();
675 cx.spawn(async move |_this, _cx| {
676 llm_token.refresh(&client).await?;
677 anyhow::Ok(())
678 })
679 .detach_and_log_err(cx);
680 },
681 ),
682 update_required: false,
683 edit_prediction_model: EditPredictionModel::Zeta2,
684 zeta2_raw_config: Self::zeta2_raw_config_from_env(),
685 sweep_ai: SweepAi::new(cx),
686 mercury: Mercury::new(cx),
687 ollama: Ollama::new(),
688
689 data_collection_choice,
690 reject_predictions_tx: reject_tx,
691 rated_predictions: Default::default(),
692 shown_predictions: Default::default(),
693 };
694
695 this
696 }
697
698 fn zeta2_raw_config_from_env() -> Option<Zeta2RawConfig> {
699 let version_str = env::var("ZED_ZETA_FORMAT").ok()?;
700 let format = ZetaFormat::parse(&version_str).ok()?;
701 let model_id = env::var("ZED_ZETA_MODEL").ok();
702 Some(Zeta2RawConfig { model_id, format })
703 }
704
705 pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
706 self.edit_prediction_model = model;
707 }
708
709 pub fn set_zeta2_raw_config(&mut self, config: Zeta2RawConfig) {
710 self.zeta2_raw_config = Some(config);
711 }
712
713 pub fn zeta2_raw_config(&self) -> Option<&Zeta2RawConfig> {
714 self.zeta2_raw_config.as_ref()
715 }
716
717 pub fn icons(&self) -> edit_prediction_types::EditPredictionIconSet {
718 use ui::IconName;
719 match self.edit_prediction_model {
720 EditPredictionModel::Sweep => {
721 edit_prediction_types::EditPredictionIconSet::new(IconName::SweepAi)
722 .with_disabled(IconName::SweepAiDisabled)
723 .with_up(IconName::SweepAiUp)
724 .with_down(IconName::SweepAiDown)
725 .with_error(IconName::SweepAiError)
726 }
727 EditPredictionModel::Mercury => {
728 edit_prediction_types::EditPredictionIconSet::new(IconName::Inception)
729 }
730 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
731 edit_prediction_types::EditPredictionIconSet::new(IconName::ZedPredict)
732 .with_disabled(IconName::ZedPredictDisabled)
733 .with_up(IconName::ZedPredictUp)
734 .with_down(IconName::ZedPredictDown)
735 .with_error(IconName::ZedPredictError)
736 }
737 EditPredictionModel::Ollama => {
738 edit_prediction_types::EditPredictionIconSet::new(IconName::AiOllama)
739 }
740 }
741 }
742
743 pub fn has_sweep_api_token(&self, cx: &App) -> bool {
744 self.sweep_ai.api_token.read(cx).has_key()
745 }
746
747 pub fn has_mercury_api_token(&self, cx: &App) -> bool {
748 self.mercury.api_token.read(cx).has_key()
749 }
750
751 pub fn clear_history(&mut self) {
752 for project_state in self.projects.values_mut() {
753 project_state.events.clear();
754 project_state.last_event.take();
755 }
756 }
757
758 pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
759 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
760 project_state.events.clear();
761 project_state.last_event.take();
762 }
763 }
764
765 pub fn edit_history_for_project(
766 &self,
767 project: &Entity<Project>,
768 cx: &App,
769 ) -> Vec<StoredEvent> {
770 self.projects
771 .get(&project.entity_id())
772 .map(|project_state| project_state.events(cx))
773 .unwrap_or_default()
774 }
775
776 pub fn context_for_project<'a>(
777 &'a self,
778 project: &Entity<Project>,
779 cx: &'a mut App,
780 ) -> Vec<RelatedFile> {
781 self.projects
782 .get(&project.entity_id())
783 .map(|project_state| {
784 project_state.context.update(cx, |context, cx| {
785 context
786 .related_files_with_buffers(cx)
787 .map(|(mut related_file, buffer)| {
788 related_file.in_open_source_repo = buffer
789 .read(cx)
790 .file()
791 .map_or(false, |file| self.is_file_open_source(&project, file, cx));
792 related_file
793 })
794 .collect()
795 })
796 })
797 .unwrap_or_default()
798 }
799
800 pub fn copilot_for_project(&self, project: &Entity<Project>) -> Option<Entity<Copilot>> {
801 self.projects
802 .get(&project.entity_id())
803 .and_then(|project| project.copilot.clone())
804 }
805
806 pub fn start_copilot_for_project(
807 &mut self,
808 project: &Entity<Project>,
809 cx: &mut Context<Self>,
810 ) -> Option<Entity<Copilot>> {
811 if DisableAiSettings::get(None, cx).disable_ai {
812 return None;
813 }
814 let state = self.get_or_init_project(project, cx);
815
816 if state.copilot.is_some() {
817 return state.copilot.clone();
818 }
819 let _project = project.clone();
820 let project = project.read(cx);
821
822 let node = project.node_runtime().cloned();
823 if let Some(node) = node {
824 let next_id = project.languages().next_language_server_id();
825 let fs = project.fs().clone();
826
827 let copilot = cx.new(|cx| Copilot::new(Some(_project), next_id, fs, node, cx));
828 state.copilot = Some(copilot.clone());
829 Some(copilot)
830 } else {
831 None
832 }
833 }
834
835 pub fn context_for_project_with_buffers<'a>(
836 &'a self,
837 project: &Entity<Project>,
838 cx: &'a mut App,
839 ) -> Vec<(RelatedFile, Entity<Buffer>)> {
840 self.projects
841 .get(&project.entity_id())
842 .map(|project| {
843 project.context.update(cx, |context, cx| {
844 context.related_files_with_buffers(cx).collect()
845 })
846 })
847 .unwrap_or_default()
848 }
849
850 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
851 if matches!(self.edit_prediction_model, EditPredictionModel::Zeta2) {
852 self.user_store.read(cx).edit_prediction_usage()
853 } else {
854 None
855 }
856 }
857
858 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
859 self.get_or_init_project(project, cx);
860 }
861
862 pub fn register_buffer(
863 &mut self,
864 buffer: &Entity<Buffer>,
865 project: &Entity<Project>,
866 cx: &mut Context<Self>,
867 ) {
868 let project_state = self.get_or_init_project(project, cx);
869 Self::register_buffer_impl(project_state, buffer, project, cx);
870 }
871
872 fn get_or_init_project(
873 &mut self,
874 project: &Entity<Project>,
875 cx: &mut Context<Self>,
876 ) -> &mut ProjectState {
877 let entity_id = project.entity_id();
878 self.projects
879 .entry(entity_id)
880 .or_insert_with(|| ProjectState {
881 context: {
882 let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
883 cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
884 this.handle_excerpt_store_event(entity_id, event);
885 })
886 .detach();
887 related_excerpt_store
888 },
889 events: VecDeque::new(),
890 last_event: None,
891 recent_paths: VecDeque::new(),
892 debug_tx: None,
893 registered_buffers: HashMap::default(),
894 current_prediction: None,
895 cancelled_predictions: HashSet::default(),
896 pending_predictions: ArrayVec::new(),
897 next_pending_prediction_id: 0,
898 last_prediction_refresh: None,
899 license_detection_watchers: HashMap::default(),
900 user_actions: VecDeque::with_capacity(USER_ACTION_HISTORY_SIZE),
901 _subscriptions: [
902 cx.subscribe(&project, Self::handle_project_event),
903 cx.observe_release(&project, move |this, _, cx| {
904 this.projects.remove(&entity_id);
905 cx.notify();
906 }),
907 ],
908 copilot: None,
909 })
910 }
911
912 pub fn remove_project(&mut self, project: &Entity<Project>) {
913 self.projects.remove(&project.entity_id());
914 }
915
916 fn handle_excerpt_store_event(
917 &mut self,
918 project_entity_id: EntityId,
919 event: &RelatedExcerptStoreEvent,
920 ) {
921 if let Some(project_state) = self.projects.get(&project_entity_id) {
922 if let Some(debug_tx) = project_state.debug_tx.clone() {
923 match event {
924 RelatedExcerptStoreEvent::StartedRefresh => {
925 debug_tx
926 .unbounded_send(DebugEvent::ContextRetrievalStarted(
927 ContextRetrievalStartedDebugEvent {
928 project_entity_id: project_entity_id,
929 timestamp: Instant::now(),
930 search_prompt: String::new(),
931 },
932 ))
933 .ok();
934 }
935 RelatedExcerptStoreEvent::FinishedRefresh {
936 cache_hit_count,
937 cache_miss_count,
938 mean_definition_latency,
939 max_definition_latency,
940 } => {
941 debug_tx
942 .unbounded_send(DebugEvent::ContextRetrievalFinished(
943 ContextRetrievalFinishedDebugEvent {
944 project_entity_id: project_entity_id,
945 timestamp: Instant::now(),
946 metadata: vec![
947 (
948 "Cache Hits",
949 format!(
950 "{}/{}",
951 cache_hit_count,
952 cache_hit_count + cache_miss_count
953 )
954 .into(),
955 ),
956 (
957 "Max LSP Time",
958 format!("{} ms", max_definition_latency.as_millis())
959 .into(),
960 ),
961 (
962 "Mean LSP Time",
963 format!("{} ms", mean_definition_latency.as_millis())
964 .into(),
965 ),
966 ],
967 },
968 ))
969 .ok();
970 }
971 }
972 }
973 }
974 }
975
976 pub fn debug_info(
977 &mut self,
978 project: &Entity<Project>,
979 cx: &mut Context<Self>,
980 ) -> mpsc::UnboundedReceiver<DebugEvent> {
981 let project_state = self.get_or_init_project(project, cx);
982 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
983 project_state.debug_tx = Some(debug_watch_tx);
984 debug_watch_rx
985 }
986
987 fn handle_project_event(
988 &mut self,
989 project: Entity<Project>,
990 event: &project::Event,
991 cx: &mut Context<Self>,
992 ) {
993 // TODO [zeta2] init with recent paths
994 match event {
995 project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
996 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
997 return;
998 };
999 let path = project.read(cx).path_for_entry(*active_entry_id, cx);
1000 if let Some(path) = path {
1001 if let Some(ix) = project_state
1002 .recent_paths
1003 .iter()
1004 .position(|probe| probe == &path)
1005 {
1006 project_state.recent_paths.remove(ix);
1007 }
1008 project_state.recent_paths.push_front(path);
1009 }
1010 }
1011 project::Event::DiagnosticsUpdated { .. } => {
1012 if cx.has_flag::<Zeta2FeatureFlag>() {
1013 self.refresh_prediction_from_diagnostics(project, cx);
1014 }
1015 }
1016 _ => (),
1017 }
1018 }
1019
1020 fn register_buffer_impl<'a>(
1021 project_state: &'a mut ProjectState,
1022 buffer: &Entity<Buffer>,
1023 project: &Entity<Project>,
1024 cx: &mut Context<Self>,
1025 ) -> &'a mut RegisteredBuffer {
1026 let buffer_id = buffer.entity_id();
1027
1028 if let Some(file) = buffer.read(cx).file() {
1029 let worktree_id = file.worktree_id(cx);
1030 if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
1031 project_state
1032 .license_detection_watchers
1033 .entry(worktree_id)
1034 .or_insert_with(|| {
1035 let project_entity_id = project.entity_id();
1036 cx.observe_release(&worktree, move |this, _worktree, _cx| {
1037 let Some(project_state) = this.projects.get_mut(&project_entity_id)
1038 else {
1039 return;
1040 };
1041 project_state
1042 .license_detection_watchers
1043 .remove(&worktree_id);
1044 })
1045 .detach();
1046 Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
1047 });
1048 }
1049 }
1050
1051 match project_state.registered_buffers.entry(buffer_id) {
1052 hash_map::Entry::Occupied(entry) => entry.into_mut(),
1053 hash_map::Entry::Vacant(entry) => {
1054 let buf = buffer.read(cx);
1055 let snapshot = buf.text_snapshot();
1056 let file = buf.file().cloned();
1057 let project_entity_id = project.entity_id();
1058 entry.insert(RegisteredBuffer {
1059 snapshot,
1060 file,
1061 last_position: None,
1062 _subscriptions: [
1063 cx.subscribe(buffer, {
1064 let project = project.downgrade();
1065 move |this, buffer, event, cx| {
1066 if let language::BufferEvent::Edited = event
1067 && let Some(project) = project.upgrade()
1068 {
1069 this.report_changes_for_buffer(&buffer, &project, false, cx);
1070 }
1071 }
1072 }),
1073 cx.observe_release(buffer, move |this, _buffer, _cx| {
1074 let Some(project_state) = this.projects.get_mut(&project_entity_id)
1075 else {
1076 return;
1077 };
1078 project_state.registered_buffers.remove(&buffer_id);
1079 }),
1080 ],
1081 })
1082 }
1083 }
1084 }
1085
1086 fn report_changes_for_buffer(
1087 &mut self,
1088 buffer: &Entity<Buffer>,
1089 project: &Entity<Project>,
1090 is_predicted: bool,
1091 cx: &mut Context<Self>,
1092 ) {
1093 let project_state = self.get_or_init_project(project, cx);
1094 let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
1095
1096 let buf = buffer.read(cx);
1097 let new_file = buf.file().cloned();
1098 let new_snapshot = buf.text_snapshot();
1099 if new_snapshot.version == registered_buffer.snapshot.version {
1100 return;
1101 }
1102
1103 let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
1104 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1105 let mut num_edits = 0usize;
1106 let mut total_deleted = 0usize;
1107 let mut total_inserted = 0usize;
1108 let mut edit_range: Option<Range<Anchor>> = None;
1109 let mut last_offset: Option<usize> = None;
1110
1111 for (edit, anchor_range) in
1112 new_snapshot.anchored_edits_since::<usize>(&old_snapshot.version)
1113 {
1114 num_edits += 1;
1115 total_deleted += edit.old.len();
1116 total_inserted += edit.new.len();
1117 edit_range = Some(match edit_range {
1118 None => anchor_range,
1119 Some(acc) => acc.start..anchor_range.end,
1120 });
1121 last_offset = Some(edit.new.end);
1122 }
1123
1124 let Some(edit_range) = edit_range else {
1125 return;
1126 };
1127
1128 let action_type = match (total_deleted, total_inserted, num_edits) {
1129 (0, ins, n) if ins == n => UserActionType::InsertChar,
1130 (0, _, _) => UserActionType::InsertSelection,
1131 (del, 0, n) if del == n => UserActionType::DeleteChar,
1132 (_, 0, _) => UserActionType::DeleteSelection,
1133 (_, ins, n) if ins == n => UserActionType::InsertChar,
1134 (_, _, _) => UserActionType::InsertSelection,
1135 };
1136
1137 if let Some(offset) = last_offset {
1138 let point = new_snapshot.offset_to_point(offset);
1139 let timestamp_epoch_ms = SystemTime::now()
1140 .duration_since(UNIX_EPOCH)
1141 .map(|d| d.as_millis() as u64)
1142 .unwrap_or(0);
1143 project_state.record_user_action(UserActionRecord {
1144 action_type,
1145 buffer_id: buffer.entity_id(),
1146 line_number: point.row,
1147 offset,
1148 timestamp_epoch_ms,
1149 });
1150 }
1151
1152 let events = &mut project_state.events;
1153
1154 let now = cx.background_executor().now();
1155 if let Some(last_event) = project_state.last_event.as_mut() {
1156 let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
1157 == last_event.new_snapshot.remote_id()
1158 && old_snapshot.version == last_event.new_snapshot.version;
1159
1160 let prediction_source_changed = is_predicted != last_event.predicted;
1161
1162 let should_coalesce = is_next_snapshot_of_same_buffer
1163 && !prediction_source_changed
1164 && last_event
1165 .edit_range
1166 .as_ref()
1167 .is_some_and(|last_edit_range| {
1168 lines_between_ranges(
1169 &edit_range.to_point(&new_snapshot),
1170 &last_edit_range.to_point(&new_snapshot),
1171 ) <= CHANGE_GROUPING_LINE_SPAN
1172 });
1173
1174 if should_coalesce {
1175 let pause_elapsed = last_event
1176 .last_edit_time
1177 .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
1178 .unwrap_or(false);
1179 if pause_elapsed {
1180 last_event.snapshot_after_last_editing_pause =
1181 Some(last_event.new_snapshot.clone());
1182 }
1183
1184 last_event.edit_range = Some(edit_range);
1185 last_event.new_snapshot = new_snapshot;
1186 last_event.last_edit_time = Some(now);
1187 return;
1188 }
1189 }
1190
1191 if let Some(event) = project_state.last_event.take() {
1192 if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1193 if events.len() + 1 >= EVENT_COUNT_MAX {
1194 events.pop_front();
1195 }
1196 events.push_back(event);
1197 }
1198 }
1199
1200 merge_trailing_events_if_needed(events, &old_snapshot, &new_snapshot, &edit_range);
1201
1202 project_state.last_event = Some(LastEvent {
1203 old_file,
1204 new_file,
1205 old_snapshot,
1206 new_snapshot,
1207 edit_range: Some(edit_range),
1208 predicted: is_predicted,
1209 snapshot_after_last_editing_pause: None,
1210 last_edit_time: Some(now),
1211 });
1212 }
1213
1214 fn prediction_at(
1215 &mut self,
1216 buffer: &Entity<Buffer>,
1217 position: Option<language::Anchor>,
1218 project: &Entity<Project>,
1219 cx: &App,
1220 ) -> Option<BufferEditPrediction<'_>> {
1221 let project_state = self.projects.get_mut(&project.entity_id())?;
1222 if let Some(position) = position
1223 && let Some(buffer) = project_state
1224 .registered_buffers
1225 .get_mut(&buffer.entity_id())
1226 {
1227 buffer.last_position = Some(position);
1228 }
1229
1230 let CurrentEditPrediction {
1231 requested_by,
1232 prediction,
1233 ..
1234 } = project_state.current_prediction.as_ref()?;
1235
1236 if prediction.targets_buffer(buffer.read(cx)) {
1237 Some(BufferEditPrediction::Local { prediction })
1238 } else {
1239 let show_jump = match requested_by {
1240 PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1241 requested_by_buffer_id == &buffer.entity_id()
1242 }
1243 PredictionRequestedBy::DiagnosticsUpdate => true,
1244 };
1245
1246 if show_jump {
1247 Some(BufferEditPrediction::Jump { prediction })
1248 } else {
1249 None
1250 }
1251 }
1252 }
1253
1254 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1255 let Some(current_prediction) = self
1256 .projects
1257 .get_mut(&project.entity_id())
1258 .and_then(|project_state| project_state.current_prediction.take())
1259 else {
1260 return;
1261 };
1262
1263 self.report_changes_for_buffer(¤t_prediction.prediction.buffer, project, true, cx);
1264
1265 // can't hold &mut project_state ref across report_changes_for_buffer_call
1266 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1267 return;
1268 };
1269
1270 for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1271 project_state.cancel_pending_prediction(pending_prediction, cx);
1272 }
1273
1274 match self.edit_prediction_model {
1275 EditPredictionModel::Sweep => {
1276 sweep_ai::edit_prediction_accepted(self, current_prediction, cx)
1277 }
1278 EditPredictionModel::Mercury => {
1279 mercury::edit_prediction_accepted(
1280 current_prediction.prediction.id,
1281 self.client.http_client(),
1282 cx,
1283 );
1284 }
1285 EditPredictionModel::Ollama => {}
1286 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
1287 zeta2::edit_prediction_accepted(self, current_prediction, cx)
1288 }
1289 }
1290 }
1291
1292 async fn handle_rejected_predictions(
1293 rx: UnboundedReceiver<EditPredictionRejection>,
1294 client: Arc<Client>,
1295 llm_token: LlmApiToken,
1296 app_version: Version,
1297 background_executor: BackgroundExecutor,
1298 ) {
1299 let mut rx = std::pin::pin!(rx.peekable());
1300 let mut batched = Vec::new();
1301
1302 while let Some(rejection) = rx.next().await {
1303 batched.push(rejection);
1304
1305 if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1306 select_biased! {
1307 next = rx.as_mut().peek().fuse() => {
1308 if next.is_some() {
1309 continue;
1310 }
1311 }
1312 () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1313 }
1314 }
1315
1316 let url = client
1317 .http_client()
1318 .build_zed_llm_url("/predict_edits/reject", &[])
1319 .unwrap();
1320
1321 let flush_count = batched
1322 .len()
1323 // in case items have accumulated after failure
1324 .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1325 let start = batched.len() - flush_count;
1326
1327 let body = RejectEditPredictionsBodyRef {
1328 rejections: &batched[start..],
1329 };
1330
1331 let result = Self::send_api_request::<()>(
1332 |builder| {
1333 let req = builder
1334 .uri(url.as_ref())
1335 .body(serde_json::to_string(&body)?.into());
1336 anyhow::Ok(req?)
1337 },
1338 client.clone(),
1339 llm_token.clone(),
1340 app_version.clone(),
1341 true,
1342 )
1343 .await;
1344
1345 if result.log_err().is_some() {
1346 batched.drain(start..);
1347 }
1348 }
1349 }
1350
1351 fn reject_current_prediction(
1352 &mut self,
1353 reason: EditPredictionRejectReason,
1354 project: &Entity<Project>,
1355 cx: &App,
1356 ) {
1357 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1358 project_state.pending_predictions.clear();
1359 if let Some(prediction) = project_state.current_prediction.take() {
1360 self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown, cx);
1361 }
1362 };
1363 }
1364
1365 fn did_show_current_prediction(
1366 &mut self,
1367 project: &Entity<Project>,
1368 display_type: edit_prediction_types::SuggestionDisplayType,
1369 cx: &mut Context<Self>,
1370 ) {
1371 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1372 return;
1373 };
1374
1375 let Some(current_prediction) = project_state.current_prediction.as_mut() else {
1376 return;
1377 };
1378
1379 let is_jump = display_type == edit_prediction_types::SuggestionDisplayType::Jump;
1380 let previous_shown_with = current_prediction.shown_with;
1381
1382 if previous_shown_with.is_none() || !is_jump {
1383 current_prediction.shown_with = Some(display_type);
1384 }
1385
1386 let is_first_non_jump_show = !current_prediction.was_shown && !is_jump;
1387
1388 if is_first_non_jump_show {
1389 current_prediction.was_shown = true;
1390 }
1391
1392 let display_type_changed = previous_shown_with != Some(display_type);
1393
1394 if self.edit_prediction_model == EditPredictionModel::Sweep && display_type_changed {
1395 sweep_ai::edit_prediction_shown(
1396 &self.sweep_ai,
1397 self.client.clone(),
1398 ¤t_prediction.prediction,
1399 display_type,
1400 cx,
1401 );
1402 }
1403
1404 if is_first_non_jump_show {
1405 self.shown_predictions
1406 .push_front(current_prediction.prediction.clone());
1407 if self.shown_predictions.len() > 50 {
1408 let completion = self.shown_predictions.pop_back().unwrap();
1409 self.rated_predictions.remove(&completion.id);
1410 }
1411 }
1412 }
1413
1414 fn reject_prediction(
1415 &mut self,
1416 prediction_id: EditPredictionId,
1417 reason: EditPredictionRejectReason,
1418 was_shown: bool,
1419 cx: &App,
1420 ) {
1421 match self.edit_prediction_model {
1422 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
1423 self.reject_predictions_tx
1424 .unbounded_send(EditPredictionRejection {
1425 request_id: prediction_id.to_string(),
1426 reason,
1427 was_shown,
1428 })
1429 .log_err();
1430 }
1431 EditPredictionModel::Sweep | EditPredictionModel::Ollama => {}
1432 EditPredictionModel::Mercury => {
1433 mercury::edit_prediction_rejected(
1434 prediction_id,
1435 was_shown,
1436 reason,
1437 self.client.http_client(),
1438 cx,
1439 );
1440 }
1441 }
1442 }
1443
1444 fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1445 self.projects
1446 .get(&project.entity_id())
1447 .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1448 }
1449
1450 pub fn refresh_prediction_from_buffer(
1451 &mut self,
1452 project: Entity<Project>,
1453 buffer: Entity<Buffer>,
1454 position: language::Anchor,
1455 cx: &mut Context<Self>,
1456 ) {
1457 self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
1458 let Some(request_task) = this
1459 .update(cx, |this, cx| {
1460 this.request_prediction(
1461 &project,
1462 &buffer,
1463 position,
1464 PredictEditsRequestTrigger::Other,
1465 cx,
1466 )
1467 })
1468 .log_err()
1469 else {
1470 return Task::ready(anyhow::Ok(None));
1471 };
1472
1473 cx.spawn(async move |_cx| {
1474 request_task.await.map(|prediction_result| {
1475 prediction_result.map(|prediction_result| {
1476 (
1477 prediction_result,
1478 PredictionRequestedBy::Buffer(buffer.entity_id()),
1479 )
1480 })
1481 })
1482 })
1483 })
1484 }
1485
1486 pub fn refresh_prediction_from_diagnostics(
1487 &mut self,
1488 project: Entity<Project>,
1489 cx: &mut Context<Self>,
1490 ) {
1491 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1492 return;
1493 };
1494
1495 // Prefer predictions from buffer
1496 if project_state.current_prediction.is_some() {
1497 return;
1498 };
1499
1500 self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
1501 let Some((active_buffer, snapshot, cursor_point)) = this
1502 .read_with(cx, |this, cx| {
1503 let project_state = this.projects.get(&project.entity_id())?;
1504 let (buffer, position) = project_state.active_buffer(&project, cx)?;
1505 let snapshot = buffer.read(cx).snapshot();
1506
1507 if !Self::predictions_enabled_at(&snapshot, position, cx) {
1508 return None;
1509 }
1510
1511 let cursor_point = position
1512 .map(|pos| pos.to_point(&snapshot))
1513 .unwrap_or_default();
1514
1515 Some((buffer, snapshot, cursor_point))
1516 })
1517 .log_err()
1518 .flatten()
1519 else {
1520 return Task::ready(anyhow::Ok(None));
1521 };
1522
1523 cx.spawn(async move |cx| {
1524 let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1525 active_buffer,
1526 &snapshot,
1527 Default::default(),
1528 cursor_point,
1529 &project,
1530 cx,
1531 )
1532 .await?
1533 else {
1534 return anyhow::Ok(None);
1535 };
1536
1537 let Some(prediction_result) = this
1538 .update(cx, |this, cx| {
1539 this.request_prediction(
1540 &project,
1541 &jump_buffer,
1542 jump_position,
1543 PredictEditsRequestTrigger::Diagnostics,
1544 cx,
1545 )
1546 })?
1547 .await?
1548 else {
1549 return anyhow::Ok(None);
1550 };
1551
1552 this.update(cx, |this, cx| {
1553 Some((
1554 if this
1555 .get_or_init_project(&project, cx)
1556 .current_prediction
1557 .is_none()
1558 {
1559 prediction_result
1560 } else {
1561 EditPredictionResult {
1562 id: prediction_result.id,
1563 prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1564 }
1565 },
1566 PredictionRequestedBy::DiagnosticsUpdate,
1567 ))
1568 })
1569 })
1570 });
1571 }
1572
1573 fn predictions_enabled_at(
1574 snapshot: &BufferSnapshot,
1575 position: Option<language::Anchor>,
1576 cx: &App,
1577 ) -> bool {
1578 let file = snapshot.file();
1579 let all_settings = all_language_settings(file, cx);
1580 if !all_settings.show_edit_predictions(snapshot.language(), cx)
1581 || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
1582 {
1583 return false;
1584 }
1585
1586 if let Some(last_position) = position {
1587 let settings = snapshot.settings_at(last_position, cx);
1588
1589 if !settings.edit_predictions_disabled_in.is_empty()
1590 && let Some(scope) = snapshot.language_scope_at(last_position)
1591 && let Some(scope_name) = scope.override_name()
1592 && settings
1593 .edit_predictions_disabled_in
1594 .iter()
1595 .any(|s| s == scope_name)
1596 {
1597 return false;
1598 }
1599 }
1600
1601 true
1602 }
1603
1604 #[cfg(not(test))]
1605 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1606 #[cfg(test)]
1607 pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
1608
1609 fn queue_prediction_refresh(
1610 &mut self,
1611 project: Entity<Project>,
1612 throttle_entity: EntityId,
1613 cx: &mut Context<Self>,
1614 do_refresh: impl FnOnce(
1615 WeakEntity<Self>,
1616 &mut AsyncApp,
1617 )
1618 -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1619 + 'static,
1620 ) {
1621 let is_ollama = self.edit_prediction_model == EditPredictionModel::Ollama;
1622 let drop_on_cancel = is_ollama;
1623 let max_pending_predictions = if is_ollama { 1 } else { 2 };
1624 let project_state = self.get_or_init_project(&project, cx);
1625 let pending_prediction_id = project_state.next_pending_prediction_id;
1626 project_state.next_pending_prediction_id += 1;
1627 let last_request = project_state.last_prediction_refresh;
1628
1629 let task = cx.spawn(async move |this, cx| {
1630 if let Some((last_entity, last_timestamp)) = last_request
1631 && throttle_entity == last_entity
1632 && let Some(timeout) =
1633 (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
1634 {
1635 cx.background_executor().timer(timeout).await;
1636 }
1637
1638 // If this task was cancelled before the throttle timeout expired,
1639 // do not perform a request.
1640 let mut is_cancelled = true;
1641 this.update(cx, |this, cx| {
1642 let project_state = this.get_or_init_project(&project, cx);
1643 if !project_state
1644 .cancelled_predictions
1645 .remove(&pending_prediction_id)
1646 {
1647 project_state.last_prediction_refresh = Some((throttle_entity, Instant::now()));
1648 is_cancelled = false;
1649 }
1650 })
1651 .ok();
1652 if is_cancelled {
1653 return None;
1654 }
1655
1656 let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
1657 let new_prediction_id = new_prediction_result
1658 .as_ref()
1659 .map(|(prediction, _)| prediction.id.clone());
1660
1661 // When a prediction completes, remove it from the pending list, and cancel
1662 // any pending predictions that were enqueued before it.
1663 this.update(cx, |this, cx| {
1664 let project_state = this.get_or_init_project(&project, cx);
1665
1666 let is_cancelled = project_state
1667 .cancelled_predictions
1668 .remove(&pending_prediction_id);
1669
1670 let new_current_prediction = if !is_cancelled
1671 && let Some((prediction_result, requested_by)) = new_prediction_result
1672 {
1673 match prediction_result.prediction {
1674 Ok(prediction) => {
1675 let new_prediction = CurrentEditPrediction {
1676 requested_by,
1677 prediction,
1678 was_shown: false,
1679 shown_with: None,
1680 };
1681
1682 if let Some(current_prediction) =
1683 project_state.current_prediction.as_ref()
1684 {
1685 if new_prediction.should_replace_prediction(¤t_prediction, cx)
1686 {
1687 this.reject_current_prediction(
1688 EditPredictionRejectReason::Replaced,
1689 &project,
1690 cx,
1691 );
1692
1693 Some(new_prediction)
1694 } else {
1695 this.reject_prediction(
1696 new_prediction.prediction.id,
1697 EditPredictionRejectReason::CurrentPreferred,
1698 false,
1699 cx,
1700 );
1701 None
1702 }
1703 } else {
1704 Some(new_prediction)
1705 }
1706 }
1707 Err(reject_reason) => {
1708 this.reject_prediction(prediction_result.id, reject_reason, false, cx);
1709 None
1710 }
1711 }
1712 } else {
1713 None
1714 };
1715
1716 let project_state = this.get_or_init_project(&project, cx);
1717
1718 if let Some(new_prediction) = new_current_prediction {
1719 project_state.current_prediction = Some(new_prediction);
1720 }
1721
1722 let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
1723 for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1724 if pending_prediction.id == pending_prediction_id {
1725 pending_predictions.remove(ix);
1726 for pending_prediction in pending_predictions.drain(0..ix) {
1727 project_state.cancel_pending_prediction(pending_prediction, cx)
1728 }
1729 break;
1730 }
1731 }
1732 this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
1733 cx.notify();
1734 })
1735 .ok();
1736
1737 new_prediction_id
1738 });
1739
1740 if project_state.pending_predictions.len() < max_pending_predictions {
1741 project_state.pending_predictions.push(PendingPrediction {
1742 id: pending_prediction_id,
1743 task,
1744 drop_on_cancel,
1745 });
1746 } else {
1747 let pending_prediction = project_state.pending_predictions.pop().unwrap();
1748 project_state.pending_predictions.push(PendingPrediction {
1749 id: pending_prediction_id,
1750 task,
1751 drop_on_cancel,
1752 });
1753 project_state.cancel_pending_prediction(pending_prediction, cx);
1754 }
1755 }
1756
1757 pub fn request_prediction(
1758 &mut self,
1759 project: &Entity<Project>,
1760 active_buffer: &Entity<Buffer>,
1761 position: language::Anchor,
1762 trigger: PredictEditsRequestTrigger,
1763 cx: &mut Context<Self>,
1764 ) -> Task<Result<Option<EditPredictionResult>>> {
1765 self.request_prediction_internal(
1766 project.clone(),
1767 active_buffer.clone(),
1768 position,
1769 trigger,
1770 cx.has_flag::<Zeta2FeatureFlag>(),
1771 cx,
1772 )
1773 }
1774
1775 fn request_prediction_internal(
1776 &mut self,
1777 project: Entity<Project>,
1778 active_buffer: Entity<Buffer>,
1779 position: language::Anchor,
1780 trigger: PredictEditsRequestTrigger,
1781 allow_jump: bool,
1782 cx: &mut Context<Self>,
1783 ) -> Task<Result<Option<EditPredictionResult>>> {
1784 const DIAGNOSTIC_LINES_RANGE: u32 = 20;
1785
1786 self.get_or_init_project(&project, cx);
1787 let project_state = self.projects.get(&project.entity_id()).unwrap();
1788 let stored_events = project_state.events(cx);
1789 let has_events = !stored_events.is_empty();
1790 let events: Vec<Arc<zeta_prompt::Event>> =
1791 stored_events.into_iter().map(|e| e.event).collect();
1792 let debug_tx = project_state.debug_tx.clone();
1793
1794 let snapshot = active_buffer.read(cx).snapshot();
1795 let cursor_point = position.to_point(&snapshot);
1796 let current_offset = position.to_offset(&snapshot);
1797
1798 let mut user_actions: Vec<UserActionRecord> =
1799 project_state.user_actions.iter().cloned().collect();
1800
1801 if let Some(last_action) = user_actions.last() {
1802 if last_action.buffer_id == active_buffer.entity_id()
1803 && current_offset != last_action.offset
1804 {
1805 let timestamp_epoch_ms = SystemTime::now()
1806 .duration_since(UNIX_EPOCH)
1807 .map(|d| d.as_millis() as u64)
1808 .unwrap_or(0);
1809 user_actions.push(UserActionRecord {
1810 action_type: UserActionType::CursorMovement,
1811 buffer_id: active_buffer.entity_id(),
1812 line_number: cursor_point.row,
1813 offset: current_offset,
1814 timestamp_epoch_ms,
1815 });
1816 }
1817 }
1818 let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1819 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1820 let diagnostic_search_range =
1821 Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1822
1823 let related_files = self.context_for_project(&project, cx);
1824
1825 let inputs = EditPredictionModelInput {
1826 project: project.clone(),
1827 buffer: active_buffer.clone(),
1828 snapshot: snapshot.clone(),
1829 position,
1830 events,
1831 related_files,
1832 recent_paths: project_state.recent_paths.clone(),
1833 trigger,
1834 diagnostic_search_range: diagnostic_search_range.clone(),
1835 debug_tx,
1836 user_actions,
1837 };
1838
1839 let task = match &self.edit_prediction_model {
1840 EditPredictionModel::Zeta1 => zeta2::request_prediction_with_zeta2(
1841 self,
1842 inputs,
1843 Some(zeta_prompt::EditPredictionModelKind::Zeta1),
1844 cx,
1845 ),
1846 EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2(
1847 self,
1848 inputs,
1849 Some(zeta_prompt::EditPredictionModelKind::Zeta2),
1850 cx,
1851 ),
1852 EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
1853 EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
1854 EditPredictionModel::Ollama => self.ollama.request_prediction(inputs, cx),
1855 };
1856
1857 cx.spawn(async move |this, cx| {
1858 let prediction = task.await?;
1859
1860 if prediction.is_none() && allow_jump {
1861 let cursor_point = position.to_point(&snapshot);
1862 if has_events
1863 && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1864 active_buffer.clone(),
1865 &snapshot,
1866 diagnostic_search_range,
1867 cursor_point,
1868 &project,
1869 cx,
1870 )
1871 .await?
1872 {
1873 return this
1874 .update(cx, |this, cx| {
1875 this.request_prediction_internal(
1876 project,
1877 jump_buffer,
1878 jump_position,
1879 trigger,
1880 false,
1881 cx,
1882 )
1883 })?
1884 .await;
1885 }
1886
1887 return anyhow::Ok(None);
1888 }
1889
1890 Ok(prediction)
1891 })
1892 }
1893
1894 async fn next_diagnostic_location(
1895 active_buffer: Entity<Buffer>,
1896 active_buffer_snapshot: &BufferSnapshot,
1897 active_buffer_diagnostic_search_range: Range<Point>,
1898 active_buffer_cursor_point: Point,
1899 project: &Entity<Project>,
1900 cx: &mut AsyncApp,
1901 ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1902 // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
1903 let mut jump_location = active_buffer_snapshot
1904 .diagnostic_groups(None)
1905 .into_iter()
1906 .filter_map(|(_, group)| {
1907 let range = &group.entries[group.primary_ix]
1908 .range
1909 .to_point(&active_buffer_snapshot);
1910 if range.overlaps(&active_buffer_diagnostic_search_range) {
1911 None
1912 } else {
1913 Some(range.start)
1914 }
1915 })
1916 .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
1917 .map(|position| {
1918 (
1919 active_buffer.clone(),
1920 active_buffer_snapshot.anchor_before(position),
1921 )
1922 });
1923
1924 if jump_location.is_none() {
1925 let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
1926 let file = buffer.file()?;
1927
1928 Some(ProjectPath {
1929 worktree_id: file.worktree_id(cx),
1930 path: file.path().clone(),
1931 })
1932 });
1933
1934 let buffer_task = project.update(cx, |project, cx| {
1935 let (path, _, _) = project
1936 .diagnostic_summaries(false, cx)
1937 .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
1938 .max_by_key(|(path, _, _)| {
1939 // find the buffer with errors that shares most parent directories
1940 path.path
1941 .components()
1942 .zip(
1943 active_buffer_path
1944 .as_ref()
1945 .map(|p| p.path.components())
1946 .unwrap_or_default(),
1947 )
1948 .take_while(|(a, b)| a == b)
1949 .count()
1950 })?;
1951
1952 Some(project.open_buffer(path, cx))
1953 });
1954
1955 if let Some(buffer_task) = buffer_task {
1956 let closest_buffer = buffer_task.await?;
1957
1958 jump_location = closest_buffer
1959 .read_with(cx, |buffer, _cx| {
1960 buffer
1961 .buffer_diagnostics(None)
1962 .into_iter()
1963 .min_by_key(|entry| entry.diagnostic.severity)
1964 .map(|entry| entry.range.start)
1965 })
1966 .map(|position| (closest_buffer, position));
1967 }
1968 }
1969
1970 anyhow::Ok(jump_location)
1971 }
1972
1973 async fn send_raw_llm_request(
1974 request: RawCompletionRequest,
1975 client: Arc<Client>,
1976 custom_url: Option<Arc<Url>>,
1977 llm_token: LlmApiToken,
1978 app_version: Version,
1979 ) -> Result<(RawCompletionResponse, Option<EditPredictionUsage>)> {
1980 let url = if let Some(custom_url) = custom_url {
1981 custom_url.as_ref().clone()
1982 } else {
1983 client
1984 .http_client()
1985 .build_zed_llm_url("/predict_edits/raw", &[])?
1986 };
1987
1988 Self::send_api_request(
1989 |builder| {
1990 let req = builder
1991 .uri(url.as_ref())
1992 .body(serde_json::to_string(&request)?.into());
1993 Ok(req?)
1994 },
1995 client,
1996 llm_token,
1997 app_version,
1998 true,
1999 )
2000 .await
2001 }
2002
2003 pub(crate) async fn send_v3_request(
2004 input: ZetaPromptInput,
2005 client: Arc<Client>,
2006 llm_token: LlmApiToken,
2007 app_version: Version,
2008 trigger: PredictEditsRequestTrigger,
2009 ) -> Result<(PredictEditsV3Response, Option<EditPredictionUsage>)> {
2010 let url = client
2011 .http_client()
2012 .build_zed_llm_url("/predict_edits/v3", &[])?;
2013
2014 let request = PredictEditsV3Request { input, trigger };
2015
2016 let json_bytes = serde_json::to_vec(&request)?;
2017 let compressed = zstd::encode_all(&json_bytes[..], 3)?;
2018
2019 Self::send_api_request(
2020 |builder| {
2021 let req = builder
2022 .uri(url.as_ref())
2023 .header("Content-Encoding", "zstd")
2024 .body(compressed.clone().into());
2025 Ok(req?)
2026 },
2027 client,
2028 llm_token,
2029 app_version,
2030 true,
2031 )
2032 .await
2033 }
2034
2035 fn handle_api_response<T>(
2036 this: &WeakEntity<Self>,
2037 response: Result<(T, Option<EditPredictionUsage>)>,
2038 cx: &mut gpui::AsyncApp,
2039 ) -> Result<T> {
2040 match response {
2041 Ok((data, usage)) => {
2042 if let Some(usage) = usage {
2043 this.update(cx, |this, cx| {
2044 this.user_store.update(cx, |user_store, cx| {
2045 user_store.update_edit_prediction_usage(usage, cx);
2046 });
2047 })
2048 .ok();
2049 }
2050 Ok(data)
2051 }
2052 Err(err) => {
2053 if err.is::<ZedUpdateRequiredError>() {
2054 cx.update(|cx| {
2055 this.update(cx, |this, _cx| {
2056 this.update_required = true;
2057 })
2058 .ok();
2059
2060 let error_message: SharedString = err.to_string().into();
2061 show_app_notification(
2062 NotificationId::unique::<ZedUpdateRequiredError>(),
2063 cx,
2064 move |cx| {
2065 cx.new(|cx| {
2066 ErrorMessagePrompt::new(error_message.clone(), cx)
2067 .with_link_button("Update Zed", "https://zed.dev/releases")
2068 })
2069 },
2070 );
2071 });
2072 }
2073 Err(err)
2074 }
2075 }
2076 }
2077
2078 async fn send_api_request<Res>(
2079 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
2080 client: Arc<Client>,
2081 llm_token: LlmApiToken,
2082 app_version: Version,
2083 require_auth: bool,
2084 ) -> Result<(Res, Option<EditPredictionUsage>)>
2085 where
2086 Res: DeserializeOwned,
2087 {
2088 let http_client = client.http_client();
2089
2090 let mut token = if let Ok(custom_token) = std::env::var("ZED_PREDICT_EDITS_TOKEN") {
2091 Some(custom_token)
2092 } else if require_auth {
2093 Some(llm_token.acquire(&client).await?)
2094 } else {
2095 llm_token.acquire(&client).await.ok()
2096 };
2097 let mut did_retry = false;
2098
2099 loop {
2100 let request_builder = http_client::Request::builder().method(Method::POST);
2101
2102 let mut request_builder = request_builder
2103 .header("Content-Type", "application/json")
2104 .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
2105
2106 // Only add Authorization header if we have a token
2107 if let Some(ref token_value) = token {
2108 request_builder =
2109 request_builder.header("Authorization", format!("Bearer {}", token_value));
2110 }
2111
2112 let request = build(request_builder)?;
2113
2114 let mut response = http_client.send(request).await?;
2115
2116 if let Some(minimum_required_version) = response
2117 .headers()
2118 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2119 .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2120 {
2121 anyhow::ensure!(
2122 app_version >= minimum_required_version,
2123 ZedUpdateRequiredError {
2124 minimum_version: minimum_required_version
2125 }
2126 );
2127 }
2128
2129 if response.status().is_success() {
2130 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2131
2132 let mut body = Vec::new();
2133 response.body_mut().read_to_end(&mut body).await?;
2134 return Ok((serde_json::from_slice(&body)?, usage));
2135 } else if !did_retry && token.is_some() && response.needs_llm_token_refresh() {
2136 did_retry = true;
2137 token = Some(llm_token.refresh(&client).await?);
2138 } else {
2139 let mut body = String::new();
2140 response.body_mut().read_to_string(&mut body).await?;
2141 anyhow::bail!(
2142 "Request failed with status: {:?}\nBody: {}",
2143 response.status(),
2144 body
2145 );
2146 }
2147 }
2148 }
2149
2150 pub fn refresh_context(
2151 &mut self,
2152 project: &Entity<Project>,
2153 buffer: &Entity<language::Buffer>,
2154 cursor_position: language::Anchor,
2155 cx: &mut Context<Self>,
2156 ) {
2157 self.get_or_init_project(project, cx)
2158 .context
2159 .update(cx, |store, cx| {
2160 store.refresh(buffer.clone(), cursor_position, cx);
2161 });
2162 }
2163
2164 #[cfg(feature = "cli-support")]
2165 pub fn set_context_for_buffer(
2166 &mut self,
2167 project: &Entity<Project>,
2168 related_files: Vec<RelatedFile>,
2169 cx: &mut Context<Self>,
2170 ) {
2171 self.get_or_init_project(project, cx)
2172 .context
2173 .update(cx, |store, cx| {
2174 store.set_related_files(related_files, cx);
2175 });
2176 }
2177
2178 #[cfg(feature = "cli-support")]
2179 pub fn set_recent_paths_for_project(
2180 &mut self,
2181 project: &Entity<Project>,
2182 paths: impl IntoIterator<Item = project::ProjectPath>,
2183 cx: &mut Context<Self>,
2184 ) {
2185 let project_state = self.get_or_init_project(project, cx);
2186 project_state.recent_paths = paths.into_iter().collect();
2187 }
2188
2189 fn is_file_open_source(
2190 &self,
2191 project: &Entity<Project>,
2192 file: &Arc<dyn File>,
2193 cx: &App,
2194 ) -> bool {
2195 if !file.is_local() || file.is_private() {
2196 return false;
2197 }
2198 let Some(project_state) = self.projects.get(&project.entity_id()) else {
2199 return false;
2200 };
2201 project_state
2202 .license_detection_watchers
2203 .get(&file.worktree_id(cx))
2204 .as_ref()
2205 .is_some_and(|watcher| watcher.is_project_open_source())
2206 }
2207
2208 fn load_data_collection_choice() -> DataCollectionChoice {
2209 let choice = KEY_VALUE_STORE
2210 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2211 .log_err()
2212 .flatten();
2213
2214 match choice.as_deref() {
2215 Some("true") => DataCollectionChoice::Enabled,
2216 Some("false") => DataCollectionChoice::Disabled,
2217 Some(_) => {
2218 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2219 DataCollectionChoice::NotAnswered
2220 }
2221 None => DataCollectionChoice::NotAnswered,
2222 }
2223 }
2224
2225 fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
2226 self.data_collection_choice = self.data_collection_choice.toggle();
2227 let new_choice = self.data_collection_choice;
2228 let is_enabled = new_choice.is_enabled(cx);
2229 db::write_and_log(cx, move || {
2230 KEY_VALUE_STORE.write_kvp(
2231 ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
2232 is_enabled.to_string(),
2233 )
2234 });
2235 }
2236
2237 pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2238 self.shown_predictions.iter()
2239 }
2240
2241 pub fn shown_completions_len(&self) -> usize {
2242 self.shown_predictions.len()
2243 }
2244
2245 pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2246 self.rated_predictions.contains(id)
2247 }
2248
2249 pub fn rate_prediction(
2250 &mut self,
2251 prediction: &EditPrediction,
2252 rating: EditPredictionRating,
2253 feedback: String,
2254 cx: &mut Context<Self>,
2255 ) {
2256 self.rated_predictions.insert(prediction.id.clone());
2257 telemetry::event!(
2258 "Edit Prediction Rated",
2259 request_id = prediction.id.to_string(),
2260 rating,
2261 inputs = prediction.inputs,
2262 output = prediction
2263 .edit_preview
2264 .as_unified_diff(prediction.snapshot.file(), &prediction.edits),
2265 feedback
2266 );
2267 self.client.telemetry().flush_events().detach();
2268 cx.notify();
2269 }
2270}
2271
2272fn merge_trailing_events_if_needed(
2273 events: &mut VecDeque<StoredEvent>,
2274 end_snapshot: &TextBufferSnapshot,
2275 latest_snapshot: &TextBufferSnapshot,
2276 latest_edit_range: &Range<Anchor>,
2277) {
2278 let mut next_old_event = None;
2279 let mut mergeable_count = 0;
2280 for old_event in events.iter().rev() {
2281 if let Some(next_old_event) = &next_old_event
2282 && !old_event.can_merge(&next_old_event, latest_snapshot, latest_edit_range)
2283 {
2284 break;
2285 }
2286 mergeable_count += 1;
2287 next_old_event = Some(old_event);
2288 }
2289
2290 if mergeable_count <= 1 {
2291 return;
2292 }
2293
2294 let mut events_to_merge = events.range(events.len() - mergeable_count..).peekable();
2295 let oldest_event = events_to_merge.peek().unwrap();
2296 let oldest_snapshot = oldest_event.old_snapshot.clone();
2297
2298 if let Some((diff, edited_range)) =
2299 compute_diff_between_snapshots(&oldest_snapshot, end_snapshot)
2300 {
2301 let merged_event = match oldest_event.event.as_ref() {
2302 zeta_prompt::Event::BufferChange {
2303 old_path,
2304 path,
2305 in_open_source_repo,
2306 ..
2307 } => StoredEvent {
2308 event: Arc::new(zeta_prompt::Event::BufferChange {
2309 old_path: old_path.clone(),
2310 path: path.clone(),
2311 diff,
2312 in_open_source_repo: *in_open_source_repo,
2313 predicted: events_to_merge.all(|e| {
2314 matches!(
2315 e.event.as_ref(),
2316 zeta_prompt::Event::BufferChange {
2317 predicted: true,
2318 ..
2319 }
2320 )
2321 }),
2322 }),
2323 old_snapshot: oldest_snapshot.clone(),
2324 edit_range: end_snapshot.anchor_before(edited_range.start)
2325 ..end_snapshot.anchor_before(edited_range.end),
2326 },
2327 };
2328 events.truncate(events.len() - mergeable_count);
2329 events.push_back(merged_event);
2330 }
2331}
2332
2333pub(crate) fn filter_redundant_excerpts(
2334 mut related_files: Vec<RelatedFile>,
2335 cursor_path: &Path,
2336 cursor_row_range: Range<u32>,
2337) -> Vec<RelatedFile> {
2338 for file in &mut related_files {
2339 if file.path.as_ref() == cursor_path {
2340 file.excerpts.retain(|excerpt| {
2341 excerpt.row_range.start < cursor_row_range.start
2342 || excerpt.row_range.end > cursor_row_range.end
2343 });
2344 }
2345 }
2346 related_files.retain(|file| !file.excerpts.is_empty());
2347 related_files
2348}
2349
2350#[derive(Error, Debug)]
2351#[error(
2352 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2353)]
2354pub struct ZedUpdateRequiredError {
2355 minimum_version: Version,
2356}
2357
2358#[derive(Debug, Clone, Copy)]
2359pub enum DataCollectionChoice {
2360 NotAnswered,
2361 Enabled,
2362 Disabled,
2363}
2364
2365impl DataCollectionChoice {
2366 pub fn is_enabled(self, cx: &App) -> bool {
2367 if cx.is_staff() {
2368 return true;
2369 }
2370 match self {
2371 Self::Enabled => true,
2372 Self::NotAnswered | Self::Disabled => false,
2373 }
2374 }
2375
2376 #[must_use]
2377 pub fn toggle(&self) -> DataCollectionChoice {
2378 match self {
2379 Self::Enabled => Self::Disabled,
2380 Self::Disabled => Self::Enabled,
2381 Self::NotAnswered => Self::Enabled,
2382 }
2383 }
2384}
2385
2386impl From<bool> for DataCollectionChoice {
2387 fn from(value: bool) -> Self {
2388 match value {
2389 true => DataCollectionChoice::Enabled,
2390 false => DataCollectionChoice::Disabled,
2391 }
2392 }
2393}
2394
2395struct ZedPredictUpsell;
2396
2397impl Dismissable for ZedPredictUpsell {
2398 const KEY: &'static str = "dismissed-edit-predict-upsell";
2399
2400 fn dismissed() -> bool {
2401 // To make this backwards compatible with older versions of Zed, we
2402 // check if the user has seen the previous Edit Prediction Onboarding
2403 // before, by checking the data collection choice which was written to
2404 // the database once the user clicked on "Accept and Enable"
2405 if KEY_VALUE_STORE
2406 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2407 .log_err()
2408 .is_some_and(|s| s.is_some())
2409 {
2410 return true;
2411 }
2412
2413 KEY_VALUE_STORE
2414 .read_kvp(Self::KEY)
2415 .log_err()
2416 .is_some_and(|s| s.is_some())
2417 }
2418}
2419
2420pub fn should_show_upsell_modal() -> bool {
2421 !ZedPredictUpsell::dismissed()
2422}
2423
2424pub fn init(cx: &mut App) {
2425 cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2426 workspace.register_action(
2427 move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2428 ZedPredictModal::toggle(
2429 workspace,
2430 workspace.user_store().clone(),
2431 workspace.client().clone(),
2432 window,
2433 cx,
2434 )
2435 },
2436 );
2437
2438 workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2439 update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2440 settings
2441 .project
2442 .all_languages
2443 .edit_predictions
2444 .get_or_insert_default()
2445 .provider = Some(EditPredictionProvider::None)
2446 });
2447 });
2448 fn copilot_for_project(project: &Entity<Project>, cx: &mut App) -> Option<Entity<Copilot>> {
2449 EditPredictionStore::try_global(cx).and_then(|store| {
2450 store.update(cx, |this, cx| this.start_copilot_for_project(project, cx))
2451 })
2452 }
2453
2454 workspace.register_action(|workspace, _: &SignIn, window, cx| {
2455 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2456 copilot_ui::initiate_sign_in(copilot, window, cx);
2457 }
2458 });
2459 workspace.register_action(|workspace, _: &Reinstall, window, cx| {
2460 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2461 copilot_ui::reinstall_and_sign_in(copilot, window, cx);
2462 }
2463 });
2464 workspace.register_action(|workspace, _: &SignOut, window, cx| {
2465 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2466 copilot_ui::initiate_sign_out(copilot, window, cx);
2467 }
2468 });
2469 })
2470 .detach();
2471}