1use anyhow::Result;
2use arrayvec::ArrayVec;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_api_types::SubmitEditPredictionFeedbackBody;
5use cloud_llm_client::predict_edits_v3::{
6 PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
7};
8use cloud_llm_client::{
9 EditPredictionRejectReason, EditPredictionRejection,
10 MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
11 PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME,
12};
13use collections::{HashMap, HashSet};
14use copilot::{Copilot, Reinstall, SignIn, SignOut};
15use db::kvp::{Dismissable, KEY_VALUE_STORE};
16use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
17use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
18use futures::{
19 AsyncReadExt as _, FutureExt as _, StreamExt as _,
20 channel::mpsc::{self, UnboundedReceiver},
21 select_biased,
22};
23use gpui::BackgroundExecutor;
24use gpui::http_client::Url;
25use gpui::{
26 App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
27 http_client::{self, AsyncBody, Method},
28 prelude::*,
29};
30use language::language_settings::all_language_settings;
31use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
32use language::{BufferSnapshot, OffsetRangeExt};
33use language_model::{LlmApiToken, NeedsLlmTokenRefresh, RefreshLlmTokenListener};
34use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
35use release_channel::AppVersion;
36use semver::Version;
37use serde::de::DeserializeOwned;
38use settings::{EditPredictionProvider, Settings as _, update_settings_file};
39use std::collections::{VecDeque, hash_map};
40use std::env;
41use text::Edit;
42use workspace::Workspace;
43use zeta_prompt::{ZetaFormat, ZetaPromptInput};
44
45use std::mem;
46use std::ops::Range;
47use std::path::Path;
48use std::rc::Rc;
49use std::str::FromStr as _;
50use std::sync::Arc;
51use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
52use thiserror::Error;
53use util::{RangeExt as _, ResultExt as _};
54use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
55
56pub mod cursor_excerpt;
57pub mod example_spec;
58mod license_detection;
59pub mod mercury;
60pub mod ollama;
61mod onboarding_modal;
62pub mod open_ai_response;
63mod prediction;
64pub mod sweep_ai;
65
66pub mod udiff;
67
68mod capture_example;
69mod zed_edit_prediction_delegate;
70pub mod zeta1;
71pub mod zeta2;
72
73#[cfg(test)]
74mod edit_prediction_tests;
75
76use crate::license_detection::LicenseDetectionWatcher;
77use crate::mercury::Mercury;
78use crate::ollama::Ollama;
79use crate::onboarding_modal::ZedPredictModal;
80pub use crate::prediction::EditPrediction;
81pub use crate::prediction::EditPredictionId;
82use crate::prediction::EditPredictionResult;
83pub use crate::sweep_ai::SweepAi;
84pub use capture_example::capture_example;
85pub use language_model::ApiKeyState;
86pub use telemetry_events::EditPredictionRating;
87pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
88
89actions!(
90 edit_prediction,
91 [
92 /// Resets the edit prediction onboarding state.
93 ResetOnboarding,
94 /// Clears the edit prediction history.
95 ClearHistory,
96 ]
97);
98
99/// Maximum number of events to track.
100const EVENT_COUNT_MAX: usize = 6;
101const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
102const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
103const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
104const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
105
106pub struct Zeta2FeatureFlag;
107
108impl FeatureFlag for Zeta2FeatureFlag {
109 const NAME: &'static str = "zeta2";
110
111 fn enabled_for_staff() -> bool {
112 true
113 }
114}
115
116#[derive(Clone)]
117struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
118
119impl Global for EditPredictionStoreGlobal {}
120
121/// Configuration for using the raw Zeta2 endpoint.
122/// When set, the client uses the raw endpoint and constructs the prompt itself.
123/// The version is also used as the Baseten environment name (lowercased).
124#[derive(Clone)]
125pub struct Zeta2RawConfig {
126 pub model_id: Option<String>,
127 pub format: ZetaFormat,
128}
129
130pub struct EditPredictionStore {
131 client: Arc<Client>,
132 user_store: Entity<UserStore>,
133 llm_token: LlmApiToken,
134 _llm_token_subscription: Subscription,
135 projects: HashMap<EntityId, ProjectState>,
136 update_required: bool,
137 edit_prediction_model: EditPredictionModel,
138 zeta2_raw_config: Option<Zeta2RawConfig>,
139 pub sweep_ai: SweepAi,
140 pub mercury: Mercury,
141 pub ollama: Ollama,
142 data_collection_choice: DataCollectionChoice,
143 reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
144 shown_predictions: VecDeque<EditPrediction>,
145 rated_predictions: HashSet<EditPredictionId>,
146}
147
148#[derive(Copy, Clone, Default, PartialEq, Eq)]
149pub enum EditPredictionModel {
150 #[default]
151 Zeta1,
152 Zeta2,
153 Sweep,
154 Mercury,
155 Ollama,
156}
157
158#[derive(Clone)]
159pub struct EditPredictionModelInput {
160 project: Entity<Project>,
161 buffer: Entity<Buffer>,
162 snapshot: BufferSnapshot,
163 position: Anchor,
164 events: Vec<Arc<zeta_prompt::Event>>,
165 related_files: Vec<RelatedFile>,
166 recent_paths: VecDeque<ProjectPath>,
167 trigger: PredictEditsRequestTrigger,
168 diagnostic_search_range: Range<Point>,
169 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
170 pub user_actions: Vec<UserActionRecord>,
171}
172
173#[derive(Debug)]
174pub enum DebugEvent {
175 ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
176 ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
177 EditPredictionStarted(EditPredictionStartedDebugEvent),
178 EditPredictionFinished(EditPredictionFinishedDebugEvent),
179}
180
181#[derive(Debug)]
182pub struct ContextRetrievalStartedDebugEvent {
183 pub project_entity_id: EntityId,
184 pub timestamp: Instant,
185 pub search_prompt: String,
186}
187
188#[derive(Debug)]
189pub struct ContextRetrievalFinishedDebugEvent {
190 pub project_entity_id: EntityId,
191 pub timestamp: Instant,
192 pub metadata: Vec<(&'static str, SharedString)>,
193}
194
195#[derive(Debug)]
196pub struct EditPredictionStartedDebugEvent {
197 pub buffer: WeakEntity<Buffer>,
198 pub position: Anchor,
199 pub prompt: Option<String>,
200}
201
202#[derive(Debug)]
203pub struct EditPredictionFinishedDebugEvent {
204 pub buffer: WeakEntity<Buffer>,
205 pub position: Anchor,
206 pub model_output: Option<String>,
207}
208
209const USER_ACTION_HISTORY_SIZE: usize = 16;
210
211#[derive(Clone, Debug)]
212pub struct UserActionRecord {
213 pub action_type: UserActionType,
214 pub buffer_id: EntityId,
215 pub line_number: u32,
216 pub offset: usize,
217 pub timestamp_epoch_ms: u64,
218}
219
220#[derive(Clone, Copy, Debug, PartialEq, Eq)]
221pub enum UserActionType {
222 InsertChar,
223 InsertSelection,
224 DeleteChar,
225 DeleteSelection,
226 CursorMovement,
227}
228
229/// An event with associated metadata for reconstructing buffer state.
230#[derive(Clone)]
231pub struct StoredEvent {
232 pub event: Arc<zeta_prompt::Event>,
233 pub old_snapshot: TextBufferSnapshot,
234 pub edit_range: Range<Anchor>,
235}
236
237impl StoredEvent {
238 fn can_merge(
239 &self,
240 next_old_event: &&&StoredEvent,
241 new_snapshot: &TextBufferSnapshot,
242 last_edit_range: &Range<Anchor>,
243 ) -> bool {
244 // Events must be for the same buffer
245 if self.old_snapshot.remote_id() != next_old_event.old_snapshot.remote_id() {
246 return false;
247 }
248 if self.old_snapshot.remote_id() != new_snapshot.remote_id() {
249 return false;
250 }
251
252 let a_is_predicted = matches!(
253 self.event.as_ref(),
254 zeta_prompt::Event::BufferChange {
255 predicted: true,
256 ..
257 }
258 );
259 let b_is_predicted = matches!(
260 next_old_event.event.as_ref(),
261 zeta_prompt::Event::BufferChange {
262 predicted: true,
263 ..
264 }
265 );
266
267 // If events come from the same source (both predicted or both manual) then
268 // we would have coalesced them already.
269 if a_is_predicted == b_is_predicted {
270 return false;
271 }
272
273 let left_range = self.edit_range.to_point(new_snapshot);
274 let right_range = next_old_event.edit_range.to_point(new_snapshot);
275 let latest_range = last_edit_range.to_point(&new_snapshot);
276
277 // Events near to the latest edit are not merged if their sources differ.
278 if lines_between_ranges(&left_range, &latest_range)
279 .min(lines_between_ranges(&right_range, &latest_range))
280 <= CHANGE_GROUPING_LINE_SPAN
281 {
282 return false;
283 }
284
285 // Events that are distant from each other are not merged.
286 if lines_between_ranges(&left_range, &right_range) > CHANGE_GROUPING_LINE_SPAN {
287 return false;
288 }
289
290 true
291 }
292}
293
294fn lines_between_ranges(left: &Range<Point>, right: &Range<Point>) -> u32 {
295 if left.start > right.end {
296 return left.start.row - right.end.row;
297 }
298 if right.start > left.end {
299 return right.start.row - left.end.row;
300 }
301 0
302}
303
304struct ProjectState {
305 events: VecDeque<StoredEvent>,
306 last_event: Option<LastEvent>,
307 recent_paths: VecDeque<ProjectPath>,
308 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
309 current_prediction: Option<CurrentEditPrediction>,
310 next_pending_prediction_id: usize,
311 pending_predictions: ArrayVec<PendingPrediction, 2>,
312 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
313 last_prediction_refresh: Option<(EntityId, Instant)>,
314 cancelled_predictions: HashSet<usize>,
315 context: Entity<RelatedExcerptStore>,
316 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
317 user_actions: VecDeque<UserActionRecord>,
318 _subscriptions: [gpui::Subscription; 2],
319 copilot: Option<Entity<Copilot>>,
320}
321
322impl ProjectState {
323 fn record_user_action(&mut self, action: UserActionRecord) {
324 if self.user_actions.len() >= USER_ACTION_HISTORY_SIZE {
325 self.user_actions.pop_front();
326 }
327 self.user_actions.push_back(action);
328 }
329
330 pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
331 self.events
332 .iter()
333 .cloned()
334 .chain(self.last_event.as_ref().iter().flat_map(|event| {
335 let (one, two) = event.split_by_pause();
336 let one = one.finalize(&self.license_detection_watchers, cx);
337 let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
338 one.into_iter().chain(two)
339 }))
340 .collect()
341 }
342
343 fn cancel_pending_prediction(
344 &mut self,
345 pending_prediction: PendingPrediction,
346 cx: &mut Context<EditPredictionStore>,
347 ) {
348 self.cancelled_predictions.insert(pending_prediction.id);
349
350 if pending_prediction.drop_on_cancel {
351 drop(pending_prediction.task);
352 } else {
353 cx.spawn(async move |this, cx| {
354 let Some(prediction_id) = pending_prediction.task.await else {
355 return;
356 };
357
358 this.update(cx, |this, cx| {
359 this.reject_prediction(
360 prediction_id,
361 EditPredictionRejectReason::Canceled,
362 false,
363 cx,
364 );
365 })
366 .ok();
367 })
368 .detach()
369 }
370 }
371
372 fn active_buffer(
373 &self,
374 project: &Entity<Project>,
375 cx: &App,
376 ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
377 let project = project.read(cx);
378 let active_path = project.path_for_entry(project.active_entry()?, cx)?;
379 let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
380 let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
381 Some((active_buffer, registered_buffer.last_position))
382 }
383}
384
385#[derive(Debug, Clone)]
386struct CurrentEditPrediction {
387 pub requested_by: PredictionRequestedBy,
388 pub prediction: EditPrediction,
389 pub was_shown: bool,
390 pub shown_with: Option<edit_prediction_types::SuggestionDisplayType>,
391}
392
393impl CurrentEditPrediction {
394 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
395 let Some(new_edits) = self
396 .prediction
397 .interpolate(&self.prediction.buffer.read(cx))
398 else {
399 return false;
400 };
401
402 if self.prediction.buffer != old_prediction.prediction.buffer {
403 return true;
404 }
405
406 let Some(old_edits) = old_prediction
407 .prediction
408 .interpolate(&old_prediction.prediction.buffer.read(cx))
409 else {
410 return true;
411 };
412
413 let requested_by_buffer_id = self.requested_by.buffer_id();
414
415 // This reduces the occurrence of UI thrash from replacing edits
416 //
417 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
418 if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
419 && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
420 && old_edits.len() == 1
421 && new_edits.len() == 1
422 {
423 let (old_range, old_text) = &old_edits[0];
424 let (new_range, new_text) = &new_edits[0];
425 new_range == old_range && new_text.starts_with(old_text.as_ref())
426 } else {
427 true
428 }
429 }
430}
431
432#[derive(Debug, Clone)]
433enum PredictionRequestedBy {
434 DiagnosticsUpdate,
435 Buffer(EntityId),
436}
437
438impl PredictionRequestedBy {
439 pub fn buffer_id(&self) -> Option<EntityId> {
440 match self {
441 PredictionRequestedBy::DiagnosticsUpdate => None,
442 PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
443 }
444 }
445}
446
447#[derive(Debug)]
448struct PendingPrediction {
449 id: usize,
450 task: Task<Option<EditPredictionId>>,
451 /// If true, the task is dropped immediately on cancel (cancelling the HTTP request).
452 /// If false, the task is awaited to completion so rejection can be reported.
453 drop_on_cancel: bool,
454}
455
456/// A prediction from the perspective of a buffer.
457#[derive(Debug)]
458enum BufferEditPrediction<'a> {
459 Local { prediction: &'a EditPrediction },
460 Jump { prediction: &'a EditPrediction },
461}
462
463#[cfg(test)]
464impl std::ops::Deref for BufferEditPrediction<'_> {
465 type Target = EditPrediction;
466
467 fn deref(&self) -> &Self::Target {
468 match self {
469 BufferEditPrediction::Local { prediction } => prediction,
470 BufferEditPrediction::Jump { prediction } => prediction,
471 }
472 }
473}
474
475struct RegisteredBuffer {
476 file: Option<Arc<dyn File>>,
477 snapshot: TextBufferSnapshot,
478 last_position: Option<Anchor>,
479 _subscriptions: [gpui::Subscription; 2],
480}
481
482#[derive(Clone)]
483struct LastEvent {
484 old_snapshot: TextBufferSnapshot,
485 new_snapshot: TextBufferSnapshot,
486 old_file: Option<Arc<dyn File>>,
487 new_file: Option<Arc<dyn File>>,
488 edit_range: Option<Range<Anchor>>,
489 predicted: bool,
490 snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
491 last_edit_time: Option<Instant>,
492}
493
494impl LastEvent {
495 pub fn finalize(
496 &self,
497 license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
498 cx: &App,
499 ) -> Option<StoredEvent> {
500 let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
501 let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
502
503 let in_open_source_repo =
504 [self.new_file.as_ref(), self.old_file.as_ref()]
505 .iter()
506 .all(|file| {
507 file.is_some_and(|file| {
508 license_detection_watchers
509 .get(&file.worktree_id(cx))
510 .is_some_and(|watcher| watcher.is_project_open_source())
511 })
512 });
513
514 let (diff, edit_range) =
515 compute_diff_between_snapshots(&self.old_snapshot, &self.new_snapshot)?;
516
517 if path == old_path && diff.is_empty() {
518 None
519 } else {
520 Some(StoredEvent {
521 event: Arc::new(zeta_prompt::Event::BufferChange {
522 old_path,
523 path,
524 diff,
525 in_open_source_repo,
526 predicted: self.predicted,
527 }),
528 edit_range: self.new_snapshot.anchor_before(edit_range.start)
529 ..self.new_snapshot.anchor_before(edit_range.end),
530 old_snapshot: self.old_snapshot.clone(),
531 })
532 }
533 }
534
535 pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
536 let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
537 return (self.clone(), None);
538 };
539
540 let before = LastEvent {
541 old_snapshot: self.old_snapshot.clone(),
542 new_snapshot: boundary_snapshot.clone(),
543 old_file: self.old_file.clone(),
544 new_file: self.new_file.clone(),
545 edit_range: None,
546 predicted: self.predicted,
547 snapshot_after_last_editing_pause: None,
548 last_edit_time: self.last_edit_time,
549 };
550
551 let after = LastEvent {
552 old_snapshot: boundary_snapshot.clone(),
553 new_snapshot: self.new_snapshot.clone(),
554 old_file: self.old_file.clone(),
555 new_file: self.new_file.clone(),
556 edit_range: None,
557 predicted: self.predicted,
558 snapshot_after_last_editing_pause: None,
559 last_edit_time: self.last_edit_time,
560 };
561
562 (before, Some(after))
563 }
564}
565
566pub(crate) fn compute_diff_between_snapshots(
567 old_snapshot: &TextBufferSnapshot,
568 new_snapshot: &TextBufferSnapshot,
569) -> Option<(String, Range<Point>)> {
570 let edits: Vec<Edit<usize>> = new_snapshot
571 .edits_since::<usize>(&old_snapshot.version)
572 .collect();
573
574 let (first_edit, last_edit) = edits.first().zip(edits.last())?;
575
576 let old_start_point = old_snapshot.offset_to_point(first_edit.old.start);
577 let old_end_point = old_snapshot.offset_to_point(last_edit.old.end);
578 let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
579 let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
580
581 const CONTEXT_LINES: u32 = 3;
582
583 let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
584 let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
585 let old_context_end_row =
586 (old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
587 let new_context_end_row =
588 (new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
589
590 let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
591 let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
592 let old_end_line_offset = old_snapshot
593 .point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
594 let new_end_line_offset = new_snapshot
595 .point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
596 let old_edit_range = old_start_line_offset..old_end_line_offset;
597 let new_edit_range = new_start_line_offset..new_end_line_offset;
598
599 let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
600 let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
601
602 let diff = language::unified_diff_with_offsets(
603 &old_region_text,
604 &new_region_text,
605 old_context_start_row,
606 new_context_start_row,
607 );
608
609 Some((diff, new_start_point..new_end_point))
610}
611
612fn buffer_path_with_id_fallback(
613 file: Option<&Arc<dyn File>>,
614 snapshot: &TextBufferSnapshot,
615 cx: &App,
616) -> Arc<Path> {
617 if let Some(file) = file {
618 file.full_path(cx).into()
619 } else {
620 Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
621 }
622}
623
624impl EditPredictionStore {
625 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
626 cx.try_global::<EditPredictionStoreGlobal>()
627 .map(|global| global.0.clone())
628 }
629
630 pub fn global(
631 client: &Arc<Client>,
632 user_store: &Entity<UserStore>,
633 cx: &mut App,
634 ) -> Entity<Self> {
635 cx.try_global::<EditPredictionStoreGlobal>()
636 .map(|global| global.0.clone())
637 .unwrap_or_else(|| {
638 let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
639 cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
640 ep_store
641 })
642 }
643
644 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
645 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
646 let data_collection_choice = Self::load_data_collection_choice();
647
648 let llm_token = LlmApiToken::default();
649
650 let (reject_tx, reject_rx) = mpsc::unbounded();
651 cx.background_spawn({
652 let client = client.clone();
653 let llm_token = llm_token.clone();
654 let app_version = AppVersion::global(cx);
655 let background_executor = cx.background_executor().clone();
656 async move {
657 Self::handle_rejected_predictions(
658 reject_rx,
659 client,
660 llm_token,
661 app_version,
662 background_executor,
663 )
664 .await
665 }
666 })
667 .detach();
668
669 let this = Self {
670 projects: HashMap::default(),
671 client,
672 user_store,
673 llm_token,
674 _llm_token_subscription: cx.subscribe(
675 &refresh_llm_token_listener,
676 |this, _listener, _event, cx| {
677 let client = this.client.clone();
678 let llm_token = this.llm_token.clone();
679 cx.spawn(async move |_this, _cx| {
680 llm_token.refresh(&client).await?;
681 anyhow::Ok(())
682 })
683 .detach_and_log_err(cx);
684 },
685 ),
686 update_required: false,
687 edit_prediction_model: EditPredictionModel::Zeta2,
688 zeta2_raw_config: Self::zeta2_raw_config_from_env(),
689 sweep_ai: SweepAi::new(cx),
690 mercury: Mercury::new(cx),
691 ollama: Ollama::new(),
692
693 data_collection_choice,
694 reject_predictions_tx: reject_tx,
695 rated_predictions: Default::default(),
696 shown_predictions: Default::default(),
697 };
698
699 this
700 }
701
702 fn zeta2_raw_config_from_env() -> Option<Zeta2RawConfig> {
703 let version_str = env::var("ZED_ZETA_FORMAT").ok()?;
704 let format = ZetaFormat::parse(&version_str).ok()?;
705 let model_id = env::var("ZED_ZETA_MODEL").ok();
706 Some(Zeta2RawConfig { model_id, format })
707 }
708
709 pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
710 self.edit_prediction_model = model;
711 }
712
713 pub fn set_zeta2_raw_config(&mut self, config: Zeta2RawConfig) {
714 self.zeta2_raw_config = Some(config);
715 }
716
717 pub fn zeta2_raw_config(&self) -> Option<&Zeta2RawConfig> {
718 self.zeta2_raw_config.as_ref()
719 }
720
721 pub fn icons(&self) -> edit_prediction_types::EditPredictionIconSet {
722 use ui::IconName;
723 match self.edit_prediction_model {
724 EditPredictionModel::Sweep => {
725 edit_prediction_types::EditPredictionIconSet::new(IconName::SweepAi)
726 .with_disabled(IconName::SweepAiDisabled)
727 .with_up(IconName::SweepAiUp)
728 .with_down(IconName::SweepAiDown)
729 .with_error(IconName::SweepAiError)
730 }
731 EditPredictionModel::Mercury => {
732 edit_prediction_types::EditPredictionIconSet::new(IconName::Inception)
733 }
734 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
735 edit_prediction_types::EditPredictionIconSet::new(IconName::ZedPredict)
736 .with_disabled(IconName::ZedPredictDisabled)
737 .with_up(IconName::ZedPredictUp)
738 .with_down(IconName::ZedPredictDown)
739 .with_error(IconName::ZedPredictError)
740 }
741 EditPredictionModel::Ollama => {
742 edit_prediction_types::EditPredictionIconSet::new(IconName::AiOllama)
743 }
744 }
745 }
746
747 pub fn has_sweep_api_token(&self, cx: &App) -> bool {
748 self.sweep_ai.api_token.read(cx).has_key()
749 }
750
751 pub fn has_mercury_api_token(&self, cx: &App) -> bool {
752 self.mercury.api_token.read(cx).has_key()
753 }
754
755 pub fn clear_history(&mut self) {
756 for project_state in self.projects.values_mut() {
757 project_state.events.clear();
758 project_state.last_event.take();
759 }
760 }
761
762 pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
763 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
764 project_state.events.clear();
765 project_state.last_event.take();
766 }
767 }
768
769 pub fn edit_history_for_project(
770 &self,
771 project: &Entity<Project>,
772 cx: &App,
773 ) -> Vec<StoredEvent> {
774 self.projects
775 .get(&project.entity_id())
776 .map(|project_state| project_state.events(cx))
777 .unwrap_or_default()
778 }
779
780 pub fn context_for_project<'a>(
781 &'a self,
782 project: &Entity<Project>,
783 cx: &'a mut App,
784 ) -> Vec<RelatedFile> {
785 self.projects
786 .get(&project.entity_id())
787 .map(|project_state| {
788 project_state.context.update(cx, |context, cx| {
789 context
790 .related_files_with_buffers(cx)
791 .map(|(mut related_file, buffer)| {
792 related_file.in_open_source_repo = buffer
793 .read(cx)
794 .file()
795 .map_or(false, |file| self.is_file_open_source(&project, file, cx));
796 related_file
797 })
798 .collect()
799 })
800 })
801 .unwrap_or_default()
802 }
803
804 pub fn copilot_for_project(&self, project: &Entity<Project>) -> Option<Entity<Copilot>> {
805 self.projects
806 .get(&project.entity_id())
807 .and_then(|project| project.copilot.clone())
808 }
809
810 pub fn start_copilot_for_project(
811 &mut self,
812 project: &Entity<Project>,
813 cx: &mut Context<Self>,
814 ) -> Option<Entity<Copilot>> {
815 if DisableAiSettings::get(None, cx).disable_ai {
816 return None;
817 }
818 let state = self.get_or_init_project(project, cx);
819
820 if state.copilot.is_some() {
821 return state.copilot.clone();
822 }
823 let _project = project.clone();
824 let project = project.read(cx);
825
826 let node = project.node_runtime().cloned();
827 if let Some(node) = node {
828 let next_id = project.languages().next_language_server_id();
829 let fs = project.fs().clone();
830
831 let copilot = cx.new(|cx| Copilot::new(Some(_project), next_id, fs, node, cx));
832 state.copilot = Some(copilot.clone());
833 Some(copilot)
834 } else {
835 None
836 }
837 }
838
839 pub fn context_for_project_with_buffers<'a>(
840 &'a self,
841 project: &Entity<Project>,
842 cx: &'a mut App,
843 ) -> Vec<(RelatedFile, Entity<Buffer>)> {
844 self.projects
845 .get(&project.entity_id())
846 .map(|project| {
847 project.context.update(cx, |context, cx| {
848 context.related_files_with_buffers(cx).collect()
849 })
850 })
851 .unwrap_or_default()
852 }
853
854 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
855 if matches!(self.edit_prediction_model, EditPredictionModel::Zeta2) {
856 self.user_store.read(cx).edit_prediction_usage()
857 } else {
858 None
859 }
860 }
861
862 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
863 self.get_or_init_project(project, cx);
864 }
865
866 pub fn register_buffer(
867 &mut self,
868 buffer: &Entity<Buffer>,
869 project: &Entity<Project>,
870 cx: &mut Context<Self>,
871 ) {
872 let project_state = self.get_or_init_project(project, cx);
873 Self::register_buffer_impl(project_state, buffer, project, cx);
874 }
875
876 fn get_or_init_project(
877 &mut self,
878 project: &Entity<Project>,
879 cx: &mut Context<Self>,
880 ) -> &mut ProjectState {
881 let entity_id = project.entity_id();
882 self.projects
883 .entry(entity_id)
884 .or_insert_with(|| ProjectState {
885 context: {
886 let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
887 cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
888 this.handle_excerpt_store_event(entity_id, event);
889 })
890 .detach();
891 related_excerpt_store
892 },
893 events: VecDeque::new(),
894 last_event: None,
895 recent_paths: VecDeque::new(),
896 debug_tx: None,
897 registered_buffers: HashMap::default(),
898 current_prediction: None,
899 cancelled_predictions: HashSet::default(),
900 pending_predictions: ArrayVec::new(),
901 next_pending_prediction_id: 0,
902 last_prediction_refresh: None,
903 license_detection_watchers: HashMap::default(),
904 user_actions: VecDeque::with_capacity(USER_ACTION_HISTORY_SIZE),
905 _subscriptions: [
906 cx.subscribe(&project, Self::handle_project_event),
907 cx.observe_release(&project, move |this, _, cx| {
908 this.projects.remove(&entity_id);
909 cx.notify();
910 }),
911 ],
912 copilot: None,
913 })
914 }
915
916 pub fn remove_project(&mut self, project: &Entity<Project>) {
917 self.projects.remove(&project.entity_id());
918 }
919
920 fn handle_excerpt_store_event(
921 &mut self,
922 project_entity_id: EntityId,
923 event: &RelatedExcerptStoreEvent,
924 ) {
925 if let Some(project_state) = self.projects.get(&project_entity_id) {
926 if let Some(debug_tx) = project_state.debug_tx.clone() {
927 match event {
928 RelatedExcerptStoreEvent::StartedRefresh => {
929 debug_tx
930 .unbounded_send(DebugEvent::ContextRetrievalStarted(
931 ContextRetrievalStartedDebugEvent {
932 project_entity_id: project_entity_id,
933 timestamp: Instant::now(),
934 search_prompt: String::new(),
935 },
936 ))
937 .ok();
938 }
939 RelatedExcerptStoreEvent::FinishedRefresh {
940 cache_hit_count,
941 cache_miss_count,
942 mean_definition_latency,
943 max_definition_latency,
944 } => {
945 debug_tx
946 .unbounded_send(DebugEvent::ContextRetrievalFinished(
947 ContextRetrievalFinishedDebugEvent {
948 project_entity_id: project_entity_id,
949 timestamp: Instant::now(),
950 metadata: vec![
951 (
952 "Cache Hits",
953 format!(
954 "{}/{}",
955 cache_hit_count,
956 cache_hit_count + cache_miss_count
957 )
958 .into(),
959 ),
960 (
961 "Max LSP Time",
962 format!("{} ms", max_definition_latency.as_millis())
963 .into(),
964 ),
965 (
966 "Mean LSP Time",
967 format!("{} ms", mean_definition_latency.as_millis())
968 .into(),
969 ),
970 ],
971 },
972 ))
973 .ok();
974 }
975 }
976 }
977 }
978 }
979
980 pub fn debug_info(
981 &mut self,
982 project: &Entity<Project>,
983 cx: &mut Context<Self>,
984 ) -> mpsc::UnboundedReceiver<DebugEvent> {
985 let project_state = self.get_or_init_project(project, cx);
986 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
987 project_state.debug_tx = Some(debug_watch_tx);
988 debug_watch_rx
989 }
990
991 fn handle_project_event(
992 &mut self,
993 project: Entity<Project>,
994 event: &project::Event,
995 cx: &mut Context<Self>,
996 ) {
997 // TODO [zeta2] init with recent paths
998 match event {
999 project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
1000 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1001 return;
1002 };
1003 let path = project.read(cx).path_for_entry(*active_entry_id, cx);
1004 if let Some(path) = path {
1005 if let Some(ix) = project_state
1006 .recent_paths
1007 .iter()
1008 .position(|probe| probe == &path)
1009 {
1010 project_state.recent_paths.remove(ix);
1011 }
1012 project_state.recent_paths.push_front(path);
1013 }
1014 }
1015 project::Event::DiagnosticsUpdated { .. } => {
1016 if cx.has_flag::<Zeta2FeatureFlag>() {
1017 self.refresh_prediction_from_diagnostics(project, cx);
1018 }
1019 }
1020 _ => (),
1021 }
1022 }
1023
1024 fn register_buffer_impl<'a>(
1025 project_state: &'a mut ProjectState,
1026 buffer: &Entity<Buffer>,
1027 project: &Entity<Project>,
1028 cx: &mut Context<Self>,
1029 ) -> &'a mut RegisteredBuffer {
1030 let buffer_id = buffer.entity_id();
1031
1032 if let Some(file) = buffer.read(cx).file() {
1033 let worktree_id = file.worktree_id(cx);
1034 if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
1035 project_state
1036 .license_detection_watchers
1037 .entry(worktree_id)
1038 .or_insert_with(|| {
1039 let project_entity_id = project.entity_id();
1040 cx.observe_release(&worktree, move |this, _worktree, _cx| {
1041 let Some(project_state) = this.projects.get_mut(&project_entity_id)
1042 else {
1043 return;
1044 };
1045 project_state
1046 .license_detection_watchers
1047 .remove(&worktree_id);
1048 })
1049 .detach();
1050 Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
1051 });
1052 }
1053 }
1054
1055 match project_state.registered_buffers.entry(buffer_id) {
1056 hash_map::Entry::Occupied(entry) => entry.into_mut(),
1057 hash_map::Entry::Vacant(entry) => {
1058 let buf = buffer.read(cx);
1059 let snapshot = buf.text_snapshot();
1060 let file = buf.file().cloned();
1061 let project_entity_id = project.entity_id();
1062 entry.insert(RegisteredBuffer {
1063 snapshot,
1064 file,
1065 last_position: None,
1066 _subscriptions: [
1067 cx.subscribe(buffer, {
1068 let project = project.downgrade();
1069 move |this, buffer, event, cx| {
1070 if let language::BufferEvent::Edited = event
1071 && let Some(project) = project.upgrade()
1072 {
1073 this.report_changes_for_buffer(&buffer, &project, false, cx);
1074 }
1075 }
1076 }),
1077 cx.observe_release(buffer, move |this, _buffer, _cx| {
1078 let Some(project_state) = this.projects.get_mut(&project_entity_id)
1079 else {
1080 return;
1081 };
1082 project_state.registered_buffers.remove(&buffer_id);
1083 }),
1084 ],
1085 })
1086 }
1087 }
1088 }
1089
1090 fn report_changes_for_buffer(
1091 &mut self,
1092 buffer: &Entity<Buffer>,
1093 project: &Entity<Project>,
1094 is_predicted: bool,
1095 cx: &mut Context<Self>,
1096 ) {
1097 let project_state = self.get_or_init_project(project, cx);
1098 let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
1099
1100 let buf = buffer.read(cx);
1101 let new_file = buf.file().cloned();
1102 let new_snapshot = buf.text_snapshot();
1103 if new_snapshot.version == registered_buffer.snapshot.version {
1104 return;
1105 }
1106
1107 let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
1108 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1109 let mut num_edits = 0usize;
1110 let mut total_deleted = 0usize;
1111 let mut total_inserted = 0usize;
1112 let mut edit_range: Option<Range<Anchor>> = None;
1113 let mut last_offset: Option<usize> = None;
1114
1115 for (edit, anchor_range) in
1116 new_snapshot.anchored_edits_since::<usize>(&old_snapshot.version)
1117 {
1118 num_edits += 1;
1119 total_deleted += edit.old.len();
1120 total_inserted += edit.new.len();
1121 edit_range = Some(match edit_range {
1122 None => anchor_range,
1123 Some(acc) => acc.start..anchor_range.end,
1124 });
1125 last_offset = Some(edit.new.end);
1126 }
1127
1128 let Some(edit_range) = edit_range else {
1129 return;
1130 };
1131
1132 let action_type = match (total_deleted, total_inserted, num_edits) {
1133 (0, ins, n) if ins == n => UserActionType::InsertChar,
1134 (0, _, _) => UserActionType::InsertSelection,
1135 (del, 0, n) if del == n => UserActionType::DeleteChar,
1136 (_, 0, _) => UserActionType::DeleteSelection,
1137 (_, ins, n) if ins == n => UserActionType::InsertChar,
1138 (_, _, _) => UserActionType::InsertSelection,
1139 };
1140
1141 if let Some(offset) = last_offset {
1142 let point = new_snapshot.offset_to_point(offset);
1143 let timestamp_epoch_ms = SystemTime::now()
1144 .duration_since(UNIX_EPOCH)
1145 .map(|d| d.as_millis() as u64)
1146 .unwrap_or(0);
1147 project_state.record_user_action(UserActionRecord {
1148 action_type,
1149 buffer_id: buffer.entity_id(),
1150 line_number: point.row,
1151 offset,
1152 timestamp_epoch_ms,
1153 });
1154 }
1155
1156 let events = &mut project_state.events;
1157
1158 let now = cx.background_executor().now();
1159 if let Some(last_event) = project_state.last_event.as_mut() {
1160 let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
1161 == last_event.new_snapshot.remote_id()
1162 && old_snapshot.version == last_event.new_snapshot.version;
1163
1164 let prediction_source_changed = is_predicted != last_event.predicted;
1165
1166 let should_coalesce = is_next_snapshot_of_same_buffer
1167 && !prediction_source_changed
1168 && last_event
1169 .edit_range
1170 .as_ref()
1171 .is_some_and(|last_edit_range| {
1172 lines_between_ranges(
1173 &edit_range.to_point(&new_snapshot),
1174 &last_edit_range.to_point(&new_snapshot),
1175 ) <= CHANGE_GROUPING_LINE_SPAN
1176 });
1177
1178 if should_coalesce {
1179 let pause_elapsed = last_event
1180 .last_edit_time
1181 .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
1182 .unwrap_or(false);
1183 if pause_elapsed {
1184 last_event.snapshot_after_last_editing_pause =
1185 Some(last_event.new_snapshot.clone());
1186 }
1187
1188 last_event.edit_range = Some(edit_range);
1189 last_event.new_snapshot = new_snapshot;
1190 last_event.last_edit_time = Some(now);
1191 return;
1192 }
1193 }
1194
1195 if let Some(event) = project_state.last_event.take() {
1196 if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1197 if events.len() + 1 >= EVENT_COUNT_MAX {
1198 events.pop_front();
1199 }
1200 events.push_back(event);
1201 }
1202 }
1203
1204 merge_trailing_events_if_needed(events, &old_snapshot, &new_snapshot, &edit_range);
1205
1206 project_state.last_event = Some(LastEvent {
1207 old_file,
1208 new_file,
1209 old_snapshot,
1210 new_snapshot,
1211 edit_range: Some(edit_range),
1212 predicted: is_predicted,
1213 snapshot_after_last_editing_pause: None,
1214 last_edit_time: Some(now),
1215 });
1216 }
1217
1218 fn prediction_at(
1219 &mut self,
1220 buffer: &Entity<Buffer>,
1221 position: Option<language::Anchor>,
1222 project: &Entity<Project>,
1223 cx: &App,
1224 ) -> Option<BufferEditPrediction<'_>> {
1225 let project_state = self.projects.get_mut(&project.entity_id())?;
1226 if let Some(position) = position
1227 && let Some(buffer) = project_state
1228 .registered_buffers
1229 .get_mut(&buffer.entity_id())
1230 {
1231 buffer.last_position = Some(position);
1232 }
1233
1234 let CurrentEditPrediction {
1235 requested_by,
1236 prediction,
1237 ..
1238 } = project_state.current_prediction.as_ref()?;
1239
1240 if prediction.targets_buffer(buffer.read(cx)) {
1241 Some(BufferEditPrediction::Local { prediction })
1242 } else {
1243 let show_jump = match requested_by {
1244 PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1245 requested_by_buffer_id == &buffer.entity_id()
1246 }
1247 PredictionRequestedBy::DiagnosticsUpdate => true,
1248 };
1249
1250 if show_jump {
1251 Some(BufferEditPrediction::Jump { prediction })
1252 } else {
1253 None
1254 }
1255 }
1256 }
1257
1258 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1259 let Some(current_prediction) = self
1260 .projects
1261 .get_mut(&project.entity_id())
1262 .and_then(|project_state| project_state.current_prediction.take())
1263 else {
1264 return;
1265 };
1266
1267 self.report_changes_for_buffer(¤t_prediction.prediction.buffer, project, true, cx);
1268
1269 // can't hold &mut project_state ref across report_changes_for_buffer_call
1270 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1271 return;
1272 };
1273
1274 for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1275 project_state.cancel_pending_prediction(pending_prediction, cx);
1276 }
1277
1278 match self.edit_prediction_model {
1279 EditPredictionModel::Sweep => {
1280 sweep_ai::edit_prediction_accepted(self, current_prediction, cx)
1281 }
1282 EditPredictionModel::Mercury => {
1283 mercury::edit_prediction_accepted(
1284 current_prediction.prediction.id,
1285 self.client.http_client(),
1286 cx,
1287 );
1288 }
1289 EditPredictionModel::Ollama => {}
1290 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
1291 zeta2::edit_prediction_accepted(self, current_prediction, cx)
1292 }
1293 }
1294 }
1295
1296 async fn handle_rejected_predictions(
1297 rx: UnboundedReceiver<EditPredictionRejection>,
1298 client: Arc<Client>,
1299 llm_token: LlmApiToken,
1300 app_version: Version,
1301 background_executor: BackgroundExecutor,
1302 ) {
1303 let mut rx = std::pin::pin!(rx.peekable());
1304 let mut batched = Vec::new();
1305
1306 while let Some(rejection) = rx.next().await {
1307 batched.push(rejection);
1308
1309 if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1310 select_biased! {
1311 next = rx.as_mut().peek().fuse() => {
1312 if next.is_some() {
1313 continue;
1314 }
1315 }
1316 () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1317 }
1318 }
1319
1320 let url = client
1321 .http_client()
1322 .build_zed_llm_url("/predict_edits/reject", &[])
1323 .unwrap();
1324
1325 let flush_count = batched
1326 .len()
1327 // in case items have accumulated after failure
1328 .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1329 let start = batched.len() - flush_count;
1330
1331 let body = RejectEditPredictionsBodyRef {
1332 rejections: &batched[start..],
1333 };
1334
1335 let result = Self::send_api_request::<()>(
1336 |builder| {
1337 let req = builder
1338 .uri(url.as_ref())
1339 .body(serde_json::to_string(&body)?.into());
1340 anyhow::Ok(req?)
1341 },
1342 client.clone(),
1343 llm_token.clone(),
1344 app_version.clone(),
1345 true,
1346 )
1347 .await;
1348
1349 if result.log_err().is_some() {
1350 batched.drain(start..);
1351 }
1352 }
1353 }
1354
1355 fn reject_current_prediction(
1356 &mut self,
1357 reason: EditPredictionRejectReason,
1358 project: &Entity<Project>,
1359 cx: &App,
1360 ) {
1361 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1362 project_state.pending_predictions.clear();
1363 if let Some(prediction) = project_state.current_prediction.take() {
1364 self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown, cx);
1365 }
1366 };
1367 }
1368
1369 fn did_show_current_prediction(
1370 &mut self,
1371 project: &Entity<Project>,
1372 display_type: edit_prediction_types::SuggestionDisplayType,
1373 cx: &mut Context<Self>,
1374 ) {
1375 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1376 return;
1377 };
1378
1379 let Some(current_prediction) = project_state.current_prediction.as_mut() else {
1380 return;
1381 };
1382
1383 let is_jump = display_type == edit_prediction_types::SuggestionDisplayType::Jump;
1384 let previous_shown_with = current_prediction.shown_with;
1385
1386 if previous_shown_with.is_none() || !is_jump {
1387 current_prediction.shown_with = Some(display_type);
1388 }
1389
1390 let is_first_non_jump_show = !current_prediction.was_shown && !is_jump;
1391
1392 if is_first_non_jump_show {
1393 current_prediction.was_shown = true;
1394 }
1395
1396 let display_type_changed = previous_shown_with != Some(display_type);
1397
1398 if self.edit_prediction_model == EditPredictionModel::Sweep && display_type_changed {
1399 sweep_ai::edit_prediction_shown(
1400 &self.sweep_ai,
1401 self.client.clone(),
1402 ¤t_prediction.prediction,
1403 display_type,
1404 cx,
1405 );
1406 }
1407
1408 if is_first_non_jump_show {
1409 self.shown_predictions
1410 .push_front(current_prediction.prediction.clone());
1411 if self.shown_predictions.len() > 50 {
1412 let completion = self.shown_predictions.pop_back().unwrap();
1413 self.rated_predictions.remove(&completion.id);
1414 }
1415 }
1416 }
1417
1418 fn reject_prediction(
1419 &mut self,
1420 prediction_id: EditPredictionId,
1421 reason: EditPredictionRejectReason,
1422 was_shown: bool,
1423 cx: &App,
1424 ) {
1425 match self.edit_prediction_model {
1426 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
1427 self.reject_predictions_tx
1428 .unbounded_send(EditPredictionRejection {
1429 request_id: prediction_id.to_string(),
1430 reason,
1431 was_shown,
1432 })
1433 .log_err();
1434 }
1435 EditPredictionModel::Sweep | EditPredictionModel::Ollama => {}
1436 EditPredictionModel::Mercury => {
1437 mercury::edit_prediction_rejected(
1438 prediction_id,
1439 was_shown,
1440 reason,
1441 self.client.http_client(),
1442 cx,
1443 );
1444 }
1445 }
1446 }
1447
1448 fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1449 self.projects
1450 .get(&project.entity_id())
1451 .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1452 }
1453
1454 pub fn refresh_prediction_from_buffer(
1455 &mut self,
1456 project: Entity<Project>,
1457 buffer: Entity<Buffer>,
1458 position: language::Anchor,
1459 cx: &mut Context<Self>,
1460 ) {
1461 self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
1462 let Some(request_task) = this
1463 .update(cx, |this, cx| {
1464 this.request_prediction(
1465 &project,
1466 &buffer,
1467 position,
1468 PredictEditsRequestTrigger::Other,
1469 cx,
1470 )
1471 })
1472 .log_err()
1473 else {
1474 return Task::ready(anyhow::Ok(None));
1475 };
1476
1477 cx.spawn(async move |_cx| {
1478 request_task.await.map(|prediction_result| {
1479 prediction_result.map(|prediction_result| {
1480 (
1481 prediction_result,
1482 PredictionRequestedBy::Buffer(buffer.entity_id()),
1483 )
1484 })
1485 })
1486 })
1487 })
1488 }
1489
1490 pub fn refresh_prediction_from_diagnostics(
1491 &mut self,
1492 project: Entity<Project>,
1493 cx: &mut Context<Self>,
1494 ) {
1495 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1496 return;
1497 };
1498
1499 // Prefer predictions from buffer
1500 if project_state.current_prediction.is_some() {
1501 return;
1502 };
1503
1504 self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
1505 let Some((active_buffer, snapshot, cursor_point)) = this
1506 .read_with(cx, |this, cx| {
1507 let project_state = this.projects.get(&project.entity_id())?;
1508 let (buffer, position) = project_state.active_buffer(&project, cx)?;
1509 let snapshot = buffer.read(cx).snapshot();
1510
1511 if !Self::predictions_enabled_at(&snapshot, position, cx) {
1512 return None;
1513 }
1514
1515 let cursor_point = position
1516 .map(|pos| pos.to_point(&snapshot))
1517 .unwrap_or_default();
1518
1519 Some((buffer, snapshot, cursor_point))
1520 })
1521 .log_err()
1522 .flatten()
1523 else {
1524 return Task::ready(anyhow::Ok(None));
1525 };
1526
1527 cx.spawn(async move |cx| {
1528 let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1529 active_buffer,
1530 &snapshot,
1531 Default::default(),
1532 cursor_point,
1533 &project,
1534 cx,
1535 )
1536 .await?
1537 else {
1538 return anyhow::Ok(None);
1539 };
1540
1541 let Some(prediction_result) = this
1542 .update(cx, |this, cx| {
1543 this.request_prediction(
1544 &project,
1545 &jump_buffer,
1546 jump_position,
1547 PredictEditsRequestTrigger::Diagnostics,
1548 cx,
1549 )
1550 })?
1551 .await?
1552 else {
1553 return anyhow::Ok(None);
1554 };
1555
1556 this.update(cx, |this, cx| {
1557 Some((
1558 if this
1559 .get_or_init_project(&project, cx)
1560 .current_prediction
1561 .is_none()
1562 {
1563 prediction_result
1564 } else {
1565 EditPredictionResult {
1566 id: prediction_result.id,
1567 prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1568 }
1569 },
1570 PredictionRequestedBy::DiagnosticsUpdate,
1571 ))
1572 })
1573 })
1574 });
1575 }
1576
1577 fn predictions_enabled_at(
1578 snapshot: &BufferSnapshot,
1579 position: Option<language::Anchor>,
1580 cx: &App,
1581 ) -> bool {
1582 let file = snapshot.file();
1583 let all_settings = all_language_settings(file, cx);
1584 if !all_settings.show_edit_predictions(snapshot.language(), cx)
1585 || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
1586 {
1587 return false;
1588 }
1589
1590 if let Some(last_position) = position {
1591 let settings = snapshot.settings_at(last_position, cx);
1592
1593 if !settings.edit_predictions_disabled_in.is_empty()
1594 && let Some(scope) = snapshot.language_scope_at(last_position)
1595 && let Some(scope_name) = scope.override_name()
1596 && settings
1597 .edit_predictions_disabled_in
1598 .iter()
1599 .any(|s| s == scope_name)
1600 {
1601 return false;
1602 }
1603 }
1604
1605 true
1606 }
1607
1608 #[cfg(not(test))]
1609 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1610 #[cfg(test)]
1611 pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
1612
1613 fn queue_prediction_refresh(
1614 &mut self,
1615 project: Entity<Project>,
1616 throttle_entity: EntityId,
1617 cx: &mut Context<Self>,
1618 do_refresh: impl FnOnce(
1619 WeakEntity<Self>,
1620 &mut AsyncApp,
1621 )
1622 -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1623 + 'static,
1624 ) {
1625 let is_ollama = self.edit_prediction_model == EditPredictionModel::Ollama;
1626 let drop_on_cancel = is_ollama;
1627 let max_pending_predictions = if is_ollama { 1 } else { 2 };
1628 let project_state = self.get_or_init_project(&project, cx);
1629 let pending_prediction_id = project_state.next_pending_prediction_id;
1630 project_state.next_pending_prediction_id += 1;
1631 let last_request = project_state.last_prediction_refresh;
1632
1633 let task = cx.spawn(async move |this, cx| {
1634 if let Some((last_entity, last_timestamp)) = last_request
1635 && throttle_entity == last_entity
1636 && let Some(timeout) =
1637 (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
1638 {
1639 cx.background_executor().timer(timeout).await;
1640 }
1641
1642 // If this task was cancelled before the throttle timeout expired,
1643 // do not perform a request.
1644 let mut is_cancelled = true;
1645 this.update(cx, |this, cx| {
1646 let project_state = this.get_or_init_project(&project, cx);
1647 if !project_state
1648 .cancelled_predictions
1649 .remove(&pending_prediction_id)
1650 {
1651 project_state.last_prediction_refresh = Some((throttle_entity, Instant::now()));
1652 is_cancelled = false;
1653 }
1654 })
1655 .ok();
1656 if is_cancelled {
1657 return None;
1658 }
1659
1660 let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
1661 let new_prediction_id = new_prediction_result
1662 .as_ref()
1663 .map(|(prediction, _)| prediction.id.clone());
1664
1665 // When a prediction completes, remove it from the pending list, and cancel
1666 // any pending predictions that were enqueued before it.
1667 this.update(cx, |this, cx| {
1668 let project_state = this.get_or_init_project(&project, cx);
1669
1670 let is_cancelled = project_state
1671 .cancelled_predictions
1672 .remove(&pending_prediction_id);
1673
1674 let new_current_prediction = if !is_cancelled
1675 && let Some((prediction_result, requested_by)) = new_prediction_result
1676 {
1677 match prediction_result.prediction {
1678 Ok(prediction) => {
1679 let new_prediction = CurrentEditPrediction {
1680 requested_by,
1681 prediction,
1682 was_shown: false,
1683 shown_with: None,
1684 };
1685
1686 if let Some(current_prediction) =
1687 project_state.current_prediction.as_ref()
1688 {
1689 if new_prediction.should_replace_prediction(¤t_prediction, cx)
1690 {
1691 this.reject_current_prediction(
1692 EditPredictionRejectReason::Replaced,
1693 &project,
1694 cx,
1695 );
1696
1697 Some(new_prediction)
1698 } else {
1699 this.reject_prediction(
1700 new_prediction.prediction.id,
1701 EditPredictionRejectReason::CurrentPreferred,
1702 false,
1703 cx,
1704 );
1705 None
1706 }
1707 } else {
1708 Some(new_prediction)
1709 }
1710 }
1711 Err(reject_reason) => {
1712 this.reject_prediction(prediction_result.id, reject_reason, false, cx);
1713 None
1714 }
1715 }
1716 } else {
1717 None
1718 };
1719
1720 let project_state = this.get_or_init_project(&project, cx);
1721
1722 if let Some(new_prediction) = new_current_prediction {
1723 project_state.current_prediction = Some(new_prediction);
1724 }
1725
1726 let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
1727 for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1728 if pending_prediction.id == pending_prediction_id {
1729 pending_predictions.remove(ix);
1730 for pending_prediction in pending_predictions.drain(0..ix) {
1731 project_state.cancel_pending_prediction(pending_prediction, cx)
1732 }
1733 break;
1734 }
1735 }
1736 this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
1737 cx.notify();
1738 })
1739 .ok();
1740
1741 new_prediction_id
1742 });
1743
1744 if project_state.pending_predictions.len() < max_pending_predictions {
1745 project_state.pending_predictions.push(PendingPrediction {
1746 id: pending_prediction_id,
1747 task,
1748 drop_on_cancel,
1749 });
1750 } else {
1751 let pending_prediction = project_state.pending_predictions.pop().unwrap();
1752 project_state.pending_predictions.push(PendingPrediction {
1753 id: pending_prediction_id,
1754 task,
1755 drop_on_cancel,
1756 });
1757 project_state.cancel_pending_prediction(pending_prediction, cx);
1758 }
1759 }
1760
1761 pub fn request_prediction(
1762 &mut self,
1763 project: &Entity<Project>,
1764 active_buffer: &Entity<Buffer>,
1765 position: language::Anchor,
1766 trigger: PredictEditsRequestTrigger,
1767 cx: &mut Context<Self>,
1768 ) -> Task<Result<Option<EditPredictionResult>>> {
1769 self.request_prediction_internal(
1770 project.clone(),
1771 active_buffer.clone(),
1772 position,
1773 trigger,
1774 cx.has_flag::<Zeta2FeatureFlag>(),
1775 cx,
1776 )
1777 }
1778
1779 fn request_prediction_internal(
1780 &mut self,
1781 project: Entity<Project>,
1782 active_buffer: Entity<Buffer>,
1783 position: language::Anchor,
1784 trigger: PredictEditsRequestTrigger,
1785 allow_jump: bool,
1786 cx: &mut Context<Self>,
1787 ) -> Task<Result<Option<EditPredictionResult>>> {
1788 const DIAGNOSTIC_LINES_RANGE: u32 = 20;
1789
1790 self.get_or_init_project(&project, cx);
1791 let project_state = self.projects.get(&project.entity_id()).unwrap();
1792 let stored_events = project_state.events(cx);
1793 let has_events = !stored_events.is_empty();
1794 let events: Vec<Arc<zeta_prompt::Event>> =
1795 stored_events.into_iter().map(|e| e.event).collect();
1796 let debug_tx = project_state.debug_tx.clone();
1797
1798 let snapshot = active_buffer.read(cx).snapshot();
1799 let cursor_point = position.to_point(&snapshot);
1800 let current_offset = position.to_offset(&snapshot);
1801
1802 let mut user_actions: Vec<UserActionRecord> =
1803 project_state.user_actions.iter().cloned().collect();
1804
1805 if let Some(last_action) = user_actions.last() {
1806 if last_action.buffer_id == active_buffer.entity_id()
1807 && current_offset != last_action.offset
1808 {
1809 let timestamp_epoch_ms = SystemTime::now()
1810 .duration_since(UNIX_EPOCH)
1811 .map(|d| d.as_millis() as u64)
1812 .unwrap_or(0);
1813 user_actions.push(UserActionRecord {
1814 action_type: UserActionType::CursorMovement,
1815 buffer_id: active_buffer.entity_id(),
1816 line_number: cursor_point.row,
1817 offset: current_offset,
1818 timestamp_epoch_ms,
1819 });
1820 }
1821 }
1822 let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1823 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1824 let diagnostic_search_range =
1825 Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1826
1827 let related_files = self.context_for_project(&project, cx);
1828
1829 let inputs = EditPredictionModelInput {
1830 project: project.clone(),
1831 buffer: active_buffer.clone(),
1832 snapshot: snapshot.clone(),
1833 position,
1834 events,
1835 related_files,
1836 recent_paths: project_state.recent_paths.clone(),
1837 trigger,
1838 diagnostic_search_range: diagnostic_search_range.clone(),
1839 debug_tx,
1840 user_actions,
1841 };
1842
1843 let task = match &self.edit_prediction_model {
1844 EditPredictionModel::Zeta1 => zeta2::request_prediction_with_zeta2(
1845 self,
1846 inputs,
1847 Some(zeta_prompt::EditPredictionModelKind::Zeta1),
1848 cx,
1849 ),
1850 EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2(
1851 self,
1852 inputs,
1853 Some(zeta_prompt::EditPredictionModelKind::Zeta2),
1854 cx,
1855 ),
1856 EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
1857 EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
1858 EditPredictionModel::Ollama => self.ollama.request_prediction(inputs, cx),
1859 };
1860
1861 cx.spawn(async move |this, cx| {
1862 let prediction = task.await?;
1863
1864 if prediction.is_none() && allow_jump {
1865 let cursor_point = position.to_point(&snapshot);
1866 if has_events
1867 && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1868 active_buffer.clone(),
1869 &snapshot,
1870 diagnostic_search_range,
1871 cursor_point,
1872 &project,
1873 cx,
1874 )
1875 .await?
1876 {
1877 return this
1878 .update(cx, |this, cx| {
1879 this.request_prediction_internal(
1880 project,
1881 jump_buffer,
1882 jump_position,
1883 trigger,
1884 false,
1885 cx,
1886 )
1887 })?
1888 .await;
1889 }
1890
1891 return anyhow::Ok(None);
1892 }
1893
1894 Ok(prediction)
1895 })
1896 }
1897
1898 async fn next_diagnostic_location(
1899 active_buffer: Entity<Buffer>,
1900 active_buffer_snapshot: &BufferSnapshot,
1901 active_buffer_diagnostic_search_range: Range<Point>,
1902 active_buffer_cursor_point: Point,
1903 project: &Entity<Project>,
1904 cx: &mut AsyncApp,
1905 ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1906 // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
1907 let mut jump_location = active_buffer_snapshot
1908 .diagnostic_groups(None)
1909 .into_iter()
1910 .filter_map(|(_, group)| {
1911 let range = &group.entries[group.primary_ix]
1912 .range
1913 .to_point(&active_buffer_snapshot);
1914 if range.overlaps(&active_buffer_diagnostic_search_range) {
1915 None
1916 } else {
1917 Some(range.start)
1918 }
1919 })
1920 .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
1921 .map(|position| {
1922 (
1923 active_buffer.clone(),
1924 active_buffer_snapshot.anchor_before(position),
1925 )
1926 });
1927
1928 if jump_location.is_none() {
1929 let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
1930 let file = buffer.file()?;
1931
1932 Some(ProjectPath {
1933 worktree_id: file.worktree_id(cx),
1934 path: file.path().clone(),
1935 })
1936 });
1937
1938 let buffer_task = project.update(cx, |project, cx| {
1939 let (path, _, _) = project
1940 .diagnostic_summaries(false, cx)
1941 .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
1942 .max_by_key(|(path, _, _)| {
1943 // find the buffer with errors that shares most parent directories
1944 path.path
1945 .components()
1946 .zip(
1947 active_buffer_path
1948 .as_ref()
1949 .map(|p| p.path.components())
1950 .unwrap_or_default(),
1951 )
1952 .take_while(|(a, b)| a == b)
1953 .count()
1954 })?;
1955
1956 Some(project.open_buffer(path, cx))
1957 });
1958
1959 if let Some(buffer_task) = buffer_task {
1960 let closest_buffer = buffer_task.await?;
1961
1962 jump_location = closest_buffer
1963 .read_with(cx, |buffer, _cx| {
1964 buffer
1965 .buffer_diagnostics(None)
1966 .into_iter()
1967 .min_by_key(|entry| entry.diagnostic.severity)
1968 .map(|entry| entry.range.start)
1969 })
1970 .map(|position| (closest_buffer, position));
1971 }
1972 }
1973
1974 anyhow::Ok(jump_location)
1975 }
1976
1977 async fn send_raw_llm_request(
1978 request: RawCompletionRequest,
1979 client: Arc<Client>,
1980 custom_url: Option<Arc<Url>>,
1981 llm_token: LlmApiToken,
1982 app_version: Version,
1983 ) -> Result<(RawCompletionResponse, Option<EditPredictionUsage>)> {
1984 let url = if let Some(custom_url) = custom_url {
1985 custom_url.as_ref().clone()
1986 } else {
1987 client
1988 .http_client()
1989 .build_zed_llm_url("/predict_edits/raw", &[])?
1990 };
1991
1992 Self::send_api_request(
1993 |builder| {
1994 let req = builder
1995 .uri(url.as_ref())
1996 .body(serde_json::to_string(&request)?.into());
1997 Ok(req?)
1998 },
1999 client,
2000 llm_token,
2001 app_version,
2002 true,
2003 )
2004 .await
2005 }
2006
2007 pub(crate) async fn send_v3_request(
2008 input: ZetaPromptInput,
2009 client: Arc<Client>,
2010 llm_token: LlmApiToken,
2011 app_version: Version,
2012 trigger: PredictEditsRequestTrigger,
2013 ) -> Result<(PredictEditsV3Response, Option<EditPredictionUsage>)> {
2014 let url = client
2015 .http_client()
2016 .build_zed_llm_url("/predict_edits/v3", &[])?;
2017
2018 let request = PredictEditsV3Request { input, trigger };
2019
2020 let json_bytes = serde_json::to_vec(&request)?;
2021 let compressed = zstd::encode_all(&json_bytes[..], 3)?;
2022
2023 Self::send_api_request(
2024 |builder| {
2025 let req = builder
2026 .uri(url.as_ref())
2027 .header("Content-Encoding", "zstd")
2028 .body(compressed.clone().into());
2029 Ok(req?)
2030 },
2031 client,
2032 llm_token,
2033 app_version,
2034 true,
2035 )
2036 .await
2037 }
2038
2039 fn handle_api_response<T>(
2040 this: &WeakEntity<Self>,
2041 response: Result<(T, Option<EditPredictionUsage>)>,
2042 cx: &mut gpui::AsyncApp,
2043 ) -> Result<T> {
2044 match response {
2045 Ok((data, usage)) => {
2046 if let Some(usage) = usage {
2047 this.update(cx, |this, cx| {
2048 this.user_store.update(cx, |user_store, cx| {
2049 user_store.update_edit_prediction_usage(usage, cx);
2050 });
2051 })
2052 .ok();
2053 }
2054 Ok(data)
2055 }
2056 Err(err) => {
2057 if err.is::<ZedUpdateRequiredError>() {
2058 cx.update(|cx| {
2059 this.update(cx, |this, _cx| {
2060 this.update_required = true;
2061 })
2062 .ok();
2063
2064 let error_message: SharedString = err.to_string().into();
2065 show_app_notification(
2066 NotificationId::unique::<ZedUpdateRequiredError>(),
2067 cx,
2068 move |cx| {
2069 cx.new(|cx| {
2070 ErrorMessagePrompt::new(error_message.clone(), cx)
2071 .with_link_button("Update Zed", "https://zed.dev/releases")
2072 })
2073 },
2074 );
2075 });
2076 }
2077 Err(err)
2078 }
2079 }
2080 }
2081
2082 async fn send_api_request<Res>(
2083 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
2084 client: Arc<Client>,
2085 llm_token: LlmApiToken,
2086 app_version: Version,
2087 require_auth: bool,
2088 ) -> Result<(Res, Option<EditPredictionUsage>)>
2089 where
2090 Res: DeserializeOwned,
2091 {
2092 let http_client = client.http_client();
2093
2094 let mut token = if let Ok(custom_token) = std::env::var("ZED_PREDICT_EDITS_TOKEN") {
2095 Some(custom_token)
2096 } else if require_auth {
2097 Some(llm_token.acquire(&client).await?)
2098 } else {
2099 llm_token.acquire(&client).await.ok()
2100 };
2101 let mut did_retry = false;
2102
2103 loop {
2104 let request_builder = http_client::Request::builder().method(Method::POST);
2105
2106 let mut request_builder = request_builder
2107 .header("Content-Type", "application/json")
2108 .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
2109
2110 // Only add Authorization header if we have a token
2111 if let Some(ref token_value) = token {
2112 request_builder =
2113 request_builder.header("Authorization", format!("Bearer {}", token_value));
2114 }
2115
2116 let request = build(request_builder)?;
2117
2118 let mut response = http_client.send(request).await?;
2119
2120 if let Some(minimum_required_version) = response
2121 .headers()
2122 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2123 .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2124 {
2125 anyhow::ensure!(
2126 app_version >= minimum_required_version,
2127 ZedUpdateRequiredError {
2128 minimum_version: minimum_required_version
2129 }
2130 );
2131 }
2132
2133 if response.status().is_success() {
2134 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2135
2136 let mut body = Vec::new();
2137 response.body_mut().read_to_end(&mut body).await?;
2138 return Ok((serde_json::from_slice(&body)?, usage));
2139 } else if !did_retry && token.is_some() && response.needs_llm_token_refresh() {
2140 did_retry = true;
2141 token = Some(llm_token.refresh(&client).await?);
2142 } else {
2143 let mut body = String::new();
2144 response.body_mut().read_to_string(&mut body).await?;
2145 anyhow::bail!(
2146 "Request failed with status: {:?}\nBody: {}",
2147 response.status(),
2148 body
2149 );
2150 }
2151 }
2152 }
2153
2154 pub fn refresh_context(
2155 &mut self,
2156 project: &Entity<Project>,
2157 buffer: &Entity<language::Buffer>,
2158 cursor_position: language::Anchor,
2159 cx: &mut Context<Self>,
2160 ) {
2161 self.get_or_init_project(project, cx)
2162 .context
2163 .update(cx, |store, cx| {
2164 store.refresh(buffer.clone(), cursor_position, cx);
2165 });
2166 }
2167
2168 #[cfg(feature = "cli-support")]
2169 pub fn set_context_for_buffer(
2170 &mut self,
2171 project: &Entity<Project>,
2172 related_files: Vec<RelatedFile>,
2173 cx: &mut Context<Self>,
2174 ) {
2175 self.get_or_init_project(project, cx)
2176 .context
2177 .update(cx, |store, cx| {
2178 store.set_related_files(related_files, cx);
2179 });
2180 }
2181
2182 #[cfg(feature = "cli-support")]
2183 pub fn set_recent_paths_for_project(
2184 &mut self,
2185 project: &Entity<Project>,
2186 paths: impl IntoIterator<Item = project::ProjectPath>,
2187 cx: &mut Context<Self>,
2188 ) {
2189 let project_state = self.get_or_init_project(project, cx);
2190 project_state.recent_paths = paths.into_iter().collect();
2191 }
2192
2193 fn is_file_open_source(
2194 &self,
2195 project: &Entity<Project>,
2196 file: &Arc<dyn File>,
2197 cx: &App,
2198 ) -> bool {
2199 if !file.is_local() || file.is_private() {
2200 return false;
2201 }
2202 let Some(project_state) = self.projects.get(&project.entity_id()) else {
2203 return false;
2204 };
2205 project_state
2206 .license_detection_watchers
2207 .get(&file.worktree_id(cx))
2208 .as_ref()
2209 .is_some_and(|watcher| watcher.is_project_open_source())
2210 }
2211
2212 pub(crate) fn is_data_collection_enabled(&self, cx: &App) -> bool {
2213 self.data_collection_choice.is_enabled(cx)
2214 }
2215
2216 fn load_data_collection_choice() -> DataCollectionChoice {
2217 let choice = KEY_VALUE_STORE
2218 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2219 .log_err()
2220 .flatten();
2221
2222 match choice.as_deref() {
2223 Some("true") => DataCollectionChoice::Enabled,
2224 Some("false") => DataCollectionChoice::Disabled,
2225 Some(_) => {
2226 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2227 DataCollectionChoice::NotAnswered
2228 }
2229 None => DataCollectionChoice::NotAnswered,
2230 }
2231 }
2232
2233 fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
2234 self.data_collection_choice = self.data_collection_choice.toggle();
2235 let new_choice = self.data_collection_choice;
2236 let is_enabled = new_choice.is_enabled(cx);
2237 db::write_and_log(cx, move || {
2238 KEY_VALUE_STORE.write_kvp(
2239 ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
2240 is_enabled.to_string(),
2241 )
2242 });
2243 }
2244
2245 pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2246 self.shown_predictions.iter()
2247 }
2248
2249 pub fn shown_completions_len(&self) -> usize {
2250 self.shown_predictions.len()
2251 }
2252
2253 pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2254 self.rated_predictions.contains(id)
2255 }
2256
2257 pub fn rate_prediction(
2258 &mut self,
2259 prediction: &EditPrediction,
2260 rating: EditPredictionRating,
2261 feedback: String,
2262 cx: &mut Context<Self>,
2263 ) {
2264 let organization = self.user_store.read(cx).current_organization();
2265
2266 self.rated_predictions.insert(prediction.id.clone());
2267
2268 cx.background_spawn({
2269 let client = self.client.clone();
2270 let prediction_id = prediction.id.to_string();
2271 let inputs = serde_json::to_value(&prediction.inputs);
2272 let output = prediction
2273 .edit_preview
2274 .as_unified_diff(prediction.snapshot.file(), &prediction.edits);
2275 async move {
2276 client
2277 .cloud_client()
2278 .submit_edit_prediction_feedback(SubmitEditPredictionFeedbackBody {
2279 organization_id: organization.map(|organization| organization.id.clone()),
2280 request_id: prediction_id,
2281 rating: match rating {
2282 EditPredictionRating::Positive => "positive".to_string(),
2283 EditPredictionRating::Negative => "negative".to_string(),
2284 },
2285 inputs: inputs?,
2286 output,
2287 feedback,
2288 })
2289 .await?;
2290
2291 anyhow::Ok(())
2292 }
2293 })
2294 .detach_and_log_err(cx);
2295
2296 cx.notify();
2297 }
2298}
2299
2300fn merge_trailing_events_if_needed(
2301 events: &mut VecDeque<StoredEvent>,
2302 end_snapshot: &TextBufferSnapshot,
2303 latest_snapshot: &TextBufferSnapshot,
2304 latest_edit_range: &Range<Anchor>,
2305) {
2306 if let Some(last_event) = events.back() {
2307 if last_event.old_snapshot.remote_id() != latest_snapshot.remote_id() {
2308 return;
2309 }
2310 }
2311
2312 let mut next_old_event = None;
2313 let mut mergeable_count = 0;
2314 for old_event in events.iter().rev() {
2315 if let Some(next_old_event) = &next_old_event
2316 && !old_event.can_merge(&next_old_event, latest_snapshot, latest_edit_range)
2317 {
2318 break;
2319 }
2320 mergeable_count += 1;
2321 next_old_event = Some(old_event);
2322 }
2323
2324 if mergeable_count <= 1 {
2325 return;
2326 }
2327
2328 let mut events_to_merge = events.range(events.len() - mergeable_count..).peekable();
2329 let oldest_event = events_to_merge.peek().unwrap();
2330 let oldest_snapshot = oldest_event.old_snapshot.clone();
2331
2332 if let Some((diff, edited_range)) =
2333 compute_diff_between_snapshots(&oldest_snapshot, end_snapshot)
2334 {
2335 let merged_event = match oldest_event.event.as_ref() {
2336 zeta_prompt::Event::BufferChange {
2337 old_path,
2338 path,
2339 in_open_source_repo,
2340 ..
2341 } => StoredEvent {
2342 event: Arc::new(zeta_prompt::Event::BufferChange {
2343 old_path: old_path.clone(),
2344 path: path.clone(),
2345 diff,
2346 in_open_source_repo: *in_open_source_repo,
2347 predicted: events_to_merge.all(|e| {
2348 matches!(
2349 e.event.as_ref(),
2350 zeta_prompt::Event::BufferChange {
2351 predicted: true,
2352 ..
2353 }
2354 )
2355 }),
2356 }),
2357 old_snapshot: oldest_snapshot.clone(),
2358 edit_range: end_snapshot.anchor_before(edited_range.start)
2359 ..end_snapshot.anchor_before(edited_range.end),
2360 },
2361 };
2362 events.truncate(events.len() - mergeable_count);
2363 events.push_back(merged_event);
2364 }
2365}
2366
2367pub(crate) fn filter_redundant_excerpts(
2368 mut related_files: Vec<RelatedFile>,
2369 cursor_path: &Path,
2370 cursor_row_range: Range<u32>,
2371) -> Vec<RelatedFile> {
2372 for file in &mut related_files {
2373 if file.path.as_ref() == cursor_path {
2374 file.excerpts.retain(|excerpt| {
2375 excerpt.row_range.start < cursor_row_range.start
2376 || excerpt.row_range.end > cursor_row_range.end
2377 });
2378 }
2379 }
2380 related_files.retain(|file| !file.excerpts.is_empty());
2381 related_files
2382}
2383
2384#[derive(Error, Debug)]
2385#[error(
2386 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2387)]
2388pub struct ZedUpdateRequiredError {
2389 minimum_version: Version,
2390}
2391
2392#[derive(Debug, Clone, Copy)]
2393pub enum DataCollectionChoice {
2394 NotAnswered,
2395 Enabled,
2396 Disabled,
2397}
2398
2399impl DataCollectionChoice {
2400 pub fn is_enabled(self, cx: &App) -> bool {
2401 if cx.is_staff() {
2402 return true;
2403 }
2404 match self {
2405 Self::Enabled => true,
2406 Self::NotAnswered | Self::Disabled => false,
2407 }
2408 }
2409
2410 #[must_use]
2411 pub fn toggle(&self) -> DataCollectionChoice {
2412 match self {
2413 Self::Enabled => Self::Disabled,
2414 Self::Disabled => Self::Enabled,
2415 Self::NotAnswered => Self::Enabled,
2416 }
2417 }
2418}
2419
2420impl From<bool> for DataCollectionChoice {
2421 fn from(value: bool) -> Self {
2422 match value {
2423 true => DataCollectionChoice::Enabled,
2424 false => DataCollectionChoice::Disabled,
2425 }
2426 }
2427}
2428
2429struct ZedPredictUpsell;
2430
2431impl Dismissable for ZedPredictUpsell {
2432 const KEY: &'static str = "dismissed-edit-predict-upsell";
2433
2434 fn dismissed() -> bool {
2435 // To make this backwards compatible with older versions of Zed, we
2436 // check if the user has seen the previous Edit Prediction Onboarding
2437 // before, by checking the data collection choice which was written to
2438 // the database once the user clicked on "Accept and Enable"
2439 if KEY_VALUE_STORE
2440 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2441 .log_err()
2442 .is_some_and(|s| s.is_some())
2443 {
2444 return true;
2445 }
2446
2447 KEY_VALUE_STORE
2448 .read_kvp(Self::KEY)
2449 .log_err()
2450 .is_some_and(|s| s.is_some())
2451 }
2452}
2453
2454pub fn should_show_upsell_modal() -> bool {
2455 !ZedPredictUpsell::dismissed()
2456}
2457
2458pub fn init(cx: &mut App) {
2459 cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2460 workspace.register_action(
2461 move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2462 ZedPredictModal::toggle(
2463 workspace,
2464 workspace.user_store().clone(),
2465 workspace.client().clone(),
2466 window,
2467 cx,
2468 )
2469 },
2470 );
2471
2472 workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2473 update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2474 settings
2475 .project
2476 .all_languages
2477 .edit_predictions
2478 .get_or_insert_default()
2479 .provider = Some(EditPredictionProvider::None)
2480 });
2481 });
2482 fn copilot_for_project(project: &Entity<Project>, cx: &mut App) -> Option<Entity<Copilot>> {
2483 EditPredictionStore::try_global(cx).and_then(|store| {
2484 store.update(cx, |this, cx| this.start_copilot_for_project(project, cx))
2485 })
2486 }
2487
2488 workspace.register_action(|workspace, _: &SignIn, window, cx| {
2489 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2490 copilot_ui::initiate_sign_in(copilot, window, cx);
2491 }
2492 });
2493 workspace.register_action(|workspace, _: &Reinstall, window, cx| {
2494 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2495 copilot_ui::reinstall_and_sign_in(copilot, window, cx);
2496 }
2497 });
2498 workspace.register_action(|workspace, _: &SignOut, window, cx| {
2499 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2500 copilot_ui::initiate_sign_out(copilot, window, cx);
2501 }
2502 });
2503 })
2504 .detach();
2505}