1use anyhow::Result;
2use arrayvec::ArrayVec;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_api_types::SubmitEditPredictionFeedbackBody;
5use cloud_llm_client::predict_edits_v3::{
6 PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
7};
8use cloud_llm_client::{
9 EditPredictionRejectReason, EditPredictionRejection,
10 MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
11 PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME,
12};
13use collections::{HashMap, HashSet};
14use copilot::{Copilot, Reinstall, SignIn, SignOut};
15use db::kvp::{Dismissable, KEY_VALUE_STORE};
16use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
17use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
18use futures::{
19 AsyncReadExt as _, FutureExt as _, StreamExt as _,
20 channel::mpsc::{self, UnboundedReceiver},
21 select_biased,
22};
23use gpui::BackgroundExecutor;
24use gpui::http_client::Url;
25use gpui::{
26 App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
27 http_client::{self, AsyncBody, Method},
28 prelude::*,
29};
30use language::language_settings::all_language_settings;
31use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
32use language::{BufferSnapshot, OffsetRangeExt};
33use language_model::{LlmApiToken, NeedsLlmTokenRefresh, RefreshLlmTokenListener};
34use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
35use release_channel::AppVersion;
36use semver::Version;
37use serde::de::DeserializeOwned;
38use settings::{
39 EditPredictionPromptFormat, EditPredictionProvider, Settings as _, update_settings_file,
40};
41use std::collections::{VecDeque, hash_map};
42use std::env;
43use text::Edit;
44use workspace::Workspace;
45use zeta_prompt::{ZetaFormat, ZetaPromptInput};
46
47use std::mem;
48use std::ops::Range;
49use std::path::Path;
50use std::rc::Rc;
51use std::str::FromStr as _;
52use std::sync::Arc;
53use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
54use thiserror::Error;
55use util::{RangeExt as _, ResultExt as _};
56use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
57
58pub mod cursor_excerpt;
59pub mod example_spec;
60pub mod fim;
61mod license_detection;
62pub mod mercury;
63pub mod ollama;
64mod onboarding_modal;
65pub mod open_ai_response;
66mod prediction;
67pub mod sweep_ai;
68
69pub mod udiff;
70
71mod capture_example;
72mod zed_edit_prediction_delegate;
73pub mod zeta;
74
75#[cfg(test)]
76mod edit_prediction_tests;
77
78use crate::license_detection::LicenseDetectionWatcher;
79use crate::mercury::Mercury;
80use crate::onboarding_modal::ZedPredictModal;
81pub use crate::prediction::EditPrediction;
82pub use crate::prediction::EditPredictionId;
83use crate::prediction::EditPredictionResult;
84pub use crate::sweep_ai::SweepAi;
85pub use capture_example::capture_example;
86pub use language_model::ApiKeyState;
87pub use telemetry_events::EditPredictionRating;
88pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
89
90actions!(
91 edit_prediction,
92 [
93 /// Resets the edit prediction onboarding state.
94 ResetOnboarding,
95 /// Clears the edit prediction history.
96 ClearHistory,
97 ]
98);
99
100/// Maximum number of events to track.
101const EVENT_COUNT_MAX: usize = 6;
102const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
103const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
104const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
105const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
106
107pub struct Zeta2FeatureFlag;
108
109impl FeatureFlag for Zeta2FeatureFlag {
110 const NAME: &'static str = "zeta2";
111
112 fn enabled_for_staff() -> bool {
113 true
114 }
115}
116
117#[derive(Clone)]
118struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
119
120impl Global for EditPredictionStoreGlobal {}
121
122/// Configuration for using the raw Zeta2 endpoint.
123/// When set, the client uses the raw endpoint and constructs the prompt itself.
124/// The version is also used as the Baseten environment name (lowercased).
125#[derive(Clone)]
126pub struct Zeta2RawConfig {
127 pub model_id: Option<String>,
128 pub format: ZetaFormat,
129}
130
131pub struct EditPredictionStore {
132 client: Arc<Client>,
133 user_store: Entity<UserStore>,
134 llm_token: LlmApiToken,
135 _llm_token_subscription: Subscription,
136 projects: HashMap<EntityId, ProjectState>,
137 update_required: bool,
138 edit_prediction_model: EditPredictionModel,
139 zeta2_raw_config: Option<Zeta2RawConfig>,
140 pub sweep_ai: SweepAi,
141 pub mercury: Mercury,
142 data_collection_choice: DataCollectionChoice,
143 reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
144 shown_predictions: VecDeque<EditPrediction>,
145 rated_predictions: HashSet<EditPredictionId>,
146}
147
148#[derive(Copy, Clone, PartialEq, Eq)]
149pub enum EditPredictionModel {
150 Zeta1,
151 Zeta2,
152 Fim { format: EditPredictionPromptFormat },
153 Sweep,
154 Mercury,
155}
156
157#[derive(Clone)]
158pub struct EditPredictionModelInput {
159 project: Entity<Project>,
160 buffer: Entity<Buffer>,
161 snapshot: BufferSnapshot,
162 position: Anchor,
163 events: Vec<Arc<zeta_prompt::Event>>,
164 related_files: Vec<RelatedFile>,
165 recent_paths: VecDeque<ProjectPath>,
166 trigger: PredictEditsRequestTrigger,
167 diagnostic_search_range: Range<Point>,
168 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
169 pub user_actions: Vec<UserActionRecord>,
170}
171
172#[derive(Debug)]
173pub enum DebugEvent {
174 ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
175 ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
176 EditPredictionStarted(EditPredictionStartedDebugEvent),
177 EditPredictionFinished(EditPredictionFinishedDebugEvent),
178}
179
180#[derive(Debug)]
181pub struct ContextRetrievalStartedDebugEvent {
182 pub project_entity_id: EntityId,
183 pub timestamp: Instant,
184 pub search_prompt: String,
185}
186
187#[derive(Debug)]
188pub struct ContextRetrievalFinishedDebugEvent {
189 pub project_entity_id: EntityId,
190 pub timestamp: Instant,
191 pub metadata: Vec<(&'static str, SharedString)>,
192}
193
194#[derive(Debug)]
195pub struct EditPredictionStartedDebugEvent {
196 pub buffer: WeakEntity<Buffer>,
197 pub position: Anchor,
198 pub prompt: Option<String>,
199}
200
201#[derive(Debug)]
202pub struct EditPredictionFinishedDebugEvent {
203 pub buffer: WeakEntity<Buffer>,
204 pub position: Anchor,
205 pub model_output: Option<String>,
206}
207
208const USER_ACTION_HISTORY_SIZE: usize = 16;
209
210#[derive(Clone, Debug)]
211pub struct UserActionRecord {
212 pub action_type: UserActionType,
213 pub buffer_id: EntityId,
214 pub line_number: u32,
215 pub offset: usize,
216 pub timestamp_epoch_ms: u64,
217}
218
219#[derive(Clone, Copy, Debug, PartialEq, Eq)]
220pub enum UserActionType {
221 InsertChar,
222 InsertSelection,
223 DeleteChar,
224 DeleteSelection,
225 CursorMovement,
226}
227
228/// An event with associated metadata for reconstructing buffer state.
229#[derive(Clone)]
230pub struct StoredEvent {
231 pub event: Arc<zeta_prompt::Event>,
232 pub old_snapshot: TextBufferSnapshot,
233 pub edit_range: Range<Anchor>,
234}
235
236impl StoredEvent {
237 fn can_merge(
238 &self,
239 next_old_event: &&&StoredEvent,
240 new_snapshot: &TextBufferSnapshot,
241 last_edit_range: &Range<Anchor>,
242 ) -> bool {
243 // Events must be for the same buffer
244 if self.old_snapshot.remote_id() != next_old_event.old_snapshot.remote_id() {
245 return false;
246 }
247 if self.old_snapshot.remote_id() != new_snapshot.remote_id() {
248 return false;
249 }
250
251 let a_is_predicted = matches!(
252 self.event.as_ref(),
253 zeta_prompt::Event::BufferChange {
254 predicted: true,
255 ..
256 }
257 );
258 let b_is_predicted = matches!(
259 next_old_event.event.as_ref(),
260 zeta_prompt::Event::BufferChange {
261 predicted: true,
262 ..
263 }
264 );
265
266 // If events come from the same source (both predicted or both manual) then
267 // we would have coalesced them already.
268 if a_is_predicted == b_is_predicted {
269 return false;
270 }
271
272 let left_range = self.edit_range.to_point(new_snapshot);
273 let right_range = next_old_event.edit_range.to_point(new_snapshot);
274 let latest_range = last_edit_range.to_point(&new_snapshot);
275
276 // Events near to the latest edit are not merged if their sources differ.
277 if lines_between_ranges(&left_range, &latest_range)
278 .min(lines_between_ranges(&right_range, &latest_range))
279 <= CHANGE_GROUPING_LINE_SPAN
280 {
281 return false;
282 }
283
284 // Events that are distant from each other are not merged.
285 if lines_between_ranges(&left_range, &right_range) > CHANGE_GROUPING_LINE_SPAN {
286 return false;
287 }
288
289 true
290 }
291}
292
293fn lines_between_ranges(left: &Range<Point>, right: &Range<Point>) -> u32 {
294 if left.start > right.end {
295 return left.start.row - right.end.row;
296 }
297 if right.start > left.end {
298 return right.start.row - left.end.row;
299 }
300 0
301}
302
303struct ProjectState {
304 events: VecDeque<StoredEvent>,
305 last_event: Option<LastEvent>,
306 recent_paths: VecDeque<ProjectPath>,
307 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
308 current_prediction: Option<CurrentEditPrediction>,
309 next_pending_prediction_id: usize,
310 pending_predictions: ArrayVec<PendingPrediction, 2>,
311 debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
312 last_edit_prediction_refresh: Option<(EntityId, Instant)>,
313 last_jump_prediction_refresh: Option<(EntityId, Instant)>,
314 cancelled_predictions: HashSet<usize>,
315 context: Entity<RelatedExcerptStore>,
316 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
317 user_actions: VecDeque<UserActionRecord>,
318 _subscriptions: [gpui::Subscription; 2],
319 copilot: Option<Entity<Copilot>>,
320}
321
322impl ProjectState {
323 fn record_user_action(&mut self, action: UserActionRecord) {
324 if self.user_actions.len() >= USER_ACTION_HISTORY_SIZE {
325 self.user_actions.pop_front();
326 }
327 self.user_actions.push_back(action);
328 }
329
330 pub fn events(&self, cx: &App) -> Vec<StoredEvent> {
331 self.events
332 .iter()
333 .cloned()
334 .chain(self.last_event.as_ref().iter().flat_map(|event| {
335 let (one, two) = event.split_by_pause();
336 let one = one.finalize(&self.license_detection_watchers, cx);
337 let two = two.and_then(|two| two.finalize(&self.license_detection_watchers, cx));
338 one.into_iter().chain(two)
339 }))
340 .collect()
341 }
342
343 fn cancel_pending_prediction(
344 &mut self,
345 pending_prediction: PendingPrediction,
346 cx: &mut Context<EditPredictionStore>,
347 ) {
348 self.cancelled_predictions.insert(pending_prediction.id);
349
350 if pending_prediction.drop_on_cancel {
351 drop(pending_prediction.task);
352 } else {
353 cx.spawn(async move |this, cx| {
354 let Some(prediction_id) = pending_prediction.task.await else {
355 return;
356 };
357
358 this.update(cx, |this, cx| {
359 this.reject_prediction(
360 prediction_id,
361 EditPredictionRejectReason::Canceled,
362 false,
363 cx,
364 );
365 })
366 .ok();
367 })
368 .detach()
369 }
370 }
371
372 fn active_buffer(
373 &self,
374 project: &Entity<Project>,
375 cx: &App,
376 ) -> Option<(Entity<Buffer>, Option<Anchor>)> {
377 let project = project.read(cx);
378 let active_path = project.path_for_entry(project.active_entry()?, cx)?;
379 let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
380 let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
381 Some((active_buffer, registered_buffer.last_position))
382 }
383}
384
385#[derive(Debug, Clone)]
386struct CurrentEditPrediction {
387 pub requested_by: PredictionRequestedBy,
388 pub prediction: EditPrediction,
389 pub was_shown: bool,
390 pub shown_with: Option<edit_prediction_types::SuggestionDisplayType>,
391}
392
393impl CurrentEditPrediction {
394 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
395 let Some(new_edits) = self
396 .prediction
397 .interpolate(&self.prediction.buffer.read(cx))
398 else {
399 return false;
400 };
401
402 if self.prediction.buffer != old_prediction.prediction.buffer {
403 return true;
404 }
405
406 let Some(old_edits) = old_prediction
407 .prediction
408 .interpolate(&old_prediction.prediction.buffer.read(cx))
409 else {
410 return true;
411 };
412
413 let requested_by_buffer_id = self.requested_by.buffer_id();
414
415 // This reduces the occurrence of UI thrash from replacing edits
416 //
417 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
418 if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
419 && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
420 && old_edits.len() == 1
421 && new_edits.len() == 1
422 {
423 let (old_range, old_text) = &old_edits[0];
424 let (new_range, new_text) = &new_edits[0];
425 new_range == old_range && new_text.starts_with(old_text.as_ref())
426 } else {
427 true
428 }
429 }
430}
431
432#[derive(Debug, Clone)]
433enum PredictionRequestedBy {
434 DiagnosticsUpdate,
435 Buffer(EntityId),
436}
437
438impl PredictionRequestedBy {
439 pub fn buffer_id(&self) -> Option<EntityId> {
440 match self {
441 PredictionRequestedBy::DiagnosticsUpdate => None,
442 PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
443 }
444 }
445}
446
447const DIAGNOSTIC_LINES_RANGE: u32 = 20;
448
449#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
450pub enum DiagnosticSearchScope {
451 Local,
452 Global,
453}
454
455#[derive(Debug)]
456struct PendingPrediction {
457 id: usize,
458 task: Task<Option<EditPredictionId>>,
459 /// If true, the task is dropped immediately on cancel (cancelling the HTTP request).
460 /// If false, the task is awaited to completion so rejection can be reported.
461 drop_on_cancel: bool,
462}
463
464/// A prediction from the perspective of a buffer.
465#[derive(Debug)]
466enum BufferEditPrediction<'a> {
467 Local { prediction: &'a EditPrediction },
468 Jump { prediction: &'a EditPrediction },
469}
470
471#[cfg(test)]
472impl std::ops::Deref for BufferEditPrediction<'_> {
473 type Target = EditPrediction;
474
475 fn deref(&self) -> &Self::Target {
476 match self {
477 BufferEditPrediction::Local { prediction } => prediction,
478 BufferEditPrediction::Jump { prediction } => prediction,
479 }
480 }
481}
482
483struct RegisteredBuffer {
484 file: Option<Arc<dyn File>>,
485 snapshot: TextBufferSnapshot,
486 last_position: Option<Anchor>,
487 _subscriptions: [gpui::Subscription; 2],
488}
489
490#[derive(Clone)]
491struct LastEvent {
492 old_snapshot: TextBufferSnapshot,
493 new_snapshot: TextBufferSnapshot,
494 old_file: Option<Arc<dyn File>>,
495 new_file: Option<Arc<dyn File>>,
496 edit_range: Option<Range<Anchor>>,
497 predicted: bool,
498 snapshot_after_last_editing_pause: Option<TextBufferSnapshot>,
499 last_edit_time: Option<Instant>,
500}
501
502impl LastEvent {
503 pub fn finalize(
504 &self,
505 license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
506 cx: &App,
507 ) -> Option<StoredEvent> {
508 let path = buffer_path_with_id_fallback(self.new_file.as_ref(), &self.new_snapshot, cx);
509 let old_path = buffer_path_with_id_fallback(self.old_file.as_ref(), &self.old_snapshot, cx);
510
511 let in_open_source_repo =
512 [self.new_file.as_ref(), self.old_file.as_ref()]
513 .iter()
514 .all(|file| {
515 file.is_some_and(|file| {
516 license_detection_watchers
517 .get(&file.worktree_id(cx))
518 .is_some_and(|watcher| watcher.is_project_open_source())
519 })
520 });
521
522 let (diff, edit_range) =
523 compute_diff_between_snapshots(&self.old_snapshot, &self.new_snapshot)?;
524
525 if path == old_path && diff.is_empty() {
526 None
527 } else {
528 Some(StoredEvent {
529 event: Arc::new(zeta_prompt::Event::BufferChange {
530 old_path,
531 path,
532 diff,
533 in_open_source_repo,
534 predicted: self.predicted,
535 }),
536 edit_range: self.new_snapshot.anchor_before(edit_range.start)
537 ..self.new_snapshot.anchor_before(edit_range.end),
538 old_snapshot: self.old_snapshot.clone(),
539 })
540 }
541 }
542
543 pub fn split_by_pause(&self) -> (LastEvent, Option<LastEvent>) {
544 let Some(boundary_snapshot) = self.snapshot_after_last_editing_pause.as_ref() else {
545 return (self.clone(), None);
546 };
547
548 let before = LastEvent {
549 old_snapshot: self.old_snapshot.clone(),
550 new_snapshot: boundary_snapshot.clone(),
551 old_file: self.old_file.clone(),
552 new_file: self.new_file.clone(),
553 edit_range: None,
554 predicted: self.predicted,
555 snapshot_after_last_editing_pause: None,
556 last_edit_time: self.last_edit_time,
557 };
558
559 let after = LastEvent {
560 old_snapshot: boundary_snapshot.clone(),
561 new_snapshot: self.new_snapshot.clone(),
562 old_file: self.old_file.clone(),
563 new_file: self.new_file.clone(),
564 edit_range: None,
565 predicted: self.predicted,
566 snapshot_after_last_editing_pause: None,
567 last_edit_time: self.last_edit_time,
568 };
569
570 (before, Some(after))
571 }
572}
573
574pub(crate) fn compute_diff_between_snapshots(
575 old_snapshot: &TextBufferSnapshot,
576 new_snapshot: &TextBufferSnapshot,
577) -> Option<(String, Range<Point>)> {
578 let edits: Vec<Edit<usize>> = new_snapshot
579 .edits_since::<usize>(&old_snapshot.version)
580 .collect();
581
582 let (first_edit, last_edit) = edits.first().zip(edits.last())?;
583
584 let old_start_point = old_snapshot.offset_to_point(first_edit.old.start);
585 let old_end_point = old_snapshot.offset_to_point(last_edit.old.end);
586 let new_start_point = new_snapshot.offset_to_point(first_edit.new.start);
587 let new_end_point = new_snapshot.offset_to_point(last_edit.new.end);
588
589 const CONTEXT_LINES: u32 = 3;
590
591 let old_context_start_row = old_start_point.row.saturating_sub(CONTEXT_LINES);
592 let new_context_start_row = new_start_point.row.saturating_sub(CONTEXT_LINES);
593 let old_context_end_row =
594 (old_end_point.row + 1 + CONTEXT_LINES).min(old_snapshot.max_point().row);
595 let new_context_end_row =
596 (new_end_point.row + 1 + CONTEXT_LINES).min(new_snapshot.max_point().row);
597
598 let old_start_line_offset = old_snapshot.point_to_offset(Point::new(old_context_start_row, 0));
599 let new_start_line_offset = new_snapshot.point_to_offset(Point::new(new_context_start_row, 0));
600 let old_end_line_offset = old_snapshot
601 .point_to_offset(Point::new(old_context_end_row + 1, 0).min(old_snapshot.max_point()));
602 let new_end_line_offset = new_snapshot
603 .point_to_offset(Point::new(new_context_end_row + 1, 0).min(new_snapshot.max_point()));
604 let old_edit_range = old_start_line_offset..old_end_line_offset;
605 let new_edit_range = new_start_line_offset..new_end_line_offset;
606
607 let old_region_text: String = old_snapshot.text_for_range(old_edit_range).collect();
608 let new_region_text: String = new_snapshot.text_for_range(new_edit_range).collect();
609
610 let diff = language::unified_diff_with_offsets(
611 &old_region_text,
612 &new_region_text,
613 old_context_start_row,
614 new_context_start_row,
615 );
616
617 Some((diff, new_start_point..new_end_point))
618}
619
620fn buffer_path_with_id_fallback(
621 file: Option<&Arc<dyn File>>,
622 snapshot: &TextBufferSnapshot,
623 cx: &App,
624) -> Arc<Path> {
625 if let Some(file) = file {
626 file.full_path(cx).into()
627 } else {
628 Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
629 }
630}
631
632impl EditPredictionStore {
633 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
634 cx.try_global::<EditPredictionStoreGlobal>()
635 .map(|global| global.0.clone())
636 }
637
638 pub fn global(
639 client: &Arc<Client>,
640 user_store: &Entity<UserStore>,
641 cx: &mut App,
642 ) -> Entity<Self> {
643 cx.try_global::<EditPredictionStoreGlobal>()
644 .map(|global| global.0.clone())
645 .unwrap_or_else(|| {
646 let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
647 cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
648 ep_store
649 })
650 }
651
652 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
653 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
654 let data_collection_choice = Self::load_data_collection_choice();
655
656 let llm_token = LlmApiToken::default();
657
658 let (reject_tx, reject_rx) = mpsc::unbounded();
659 cx.background_spawn({
660 let client = client.clone();
661 let llm_token = llm_token.clone();
662 let app_version = AppVersion::global(cx);
663 let background_executor = cx.background_executor().clone();
664 async move {
665 Self::handle_rejected_predictions(
666 reject_rx,
667 client,
668 llm_token,
669 app_version,
670 background_executor,
671 )
672 .await
673 }
674 })
675 .detach();
676
677 let this = Self {
678 projects: HashMap::default(),
679 client,
680 user_store,
681 llm_token,
682 _llm_token_subscription: cx.subscribe(
683 &refresh_llm_token_listener,
684 |this, _listener, _event, cx| {
685 let client = this.client.clone();
686 let llm_token = this.llm_token.clone();
687 cx.spawn(async move |_this, _cx| {
688 llm_token.refresh(&client).await?;
689 anyhow::Ok(())
690 })
691 .detach_and_log_err(cx);
692 },
693 ),
694 update_required: false,
695 edit_prediction_model: EditPredictionModel::Zeta2,
696 zeta2_raw_config: Self::zeta2_raw_config_from_env(),
697 sweep_ai: SweepAi::new(cx),
698 mercury: Mercury::new(cx),
699
700 data_collection_choice,
701 reject_predictions_tx: reject_tx,
702 rated_predictions: Default::default(),
703 shown_predictions: Default::default(),
704 };
705
706 this
707 }
708
709 fn zeta2_raw_config_from_env() -> Option<Zeta2RawConfig> {
710 let version_str = env::var("ZED_ZETA_FORMAT").ok()?;
711 let format = ZetaFormat::parse(&version_str).ok()?;
712 let model_id = env::var("ZED_ZETA_MODEL").ok();
713 Some(Zeta2RawConfig { model_id, format })
714 }
715
716 pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
717 self.edit_prediction_model = model;
718 }
719
720 pub fn set_zeta2_raw_config(&mut self, config: Zeta2RawConfig) {
721 self.zeta2_raw_config = Some(config);
722 }
723
724 pub fn zeta2_raw_config(&self) -> Option<&Zeta2RawConfig> {
725 self.zeta2_raw_config.as_ref()
726 }
727
728 pub fn icons(&self, cx: &App) -> edit_prediction_types::EditPredictionIconSet {
729 use ui::IconName;
730 match self.edit_prediction_model {
731 EditPredictionModel::Sweep => {
732 edit_prediction_types::EditPredictionIconSet::new(IconName::SweepAi)
733 .with_disabled(IconName::SweepAiDisabled)
734 .with_up(IconName::SweepAiUp)
735 .with_down(IconName::SweepAiDown)
736 .with_error(IconName::SweepAiError)
737 }
738 EditPredictionModel::Mercury => {
739 edit_prediction_types::EditPredictionIconSet::new(IconName::Inception)
740 }
741 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
742 edit_prediction_types::EditPredictionIconSet::new(IconName::ZedPredict)
743 .with_disabled(IconName::ZedPredictDisabled)
744 .with_up(IconName::ZedPredictUp)
745 .with_down(IconName::ZedPredictDown)
746 .with_error(IconName::ZedPredictError)
747 }
748 EditPredictionModel::Fim { .. } => {
749 let settings = &all_language_settings(None, cx).edit_predictions;
750 match settings.provider {
751 EditPredictionProvider::Ollama => {
752 edit_prediction_types::EditPredictionIconSet::new(IconName::AiOllama)
753 }
754 _ => {
755 edit_prediction_types::EditPredictionIconSet::new(IconName::AiOpenAiCompat)
756 }
757 }
758 }
759 }
760 }
761
762 pub fn has_sweep_api_token(&self, cx: &App) -> bool {
763 self.sweep_ai.api_token.read(cx).has_key()
764 }
765
766 pub fn has_mercury_api_token(&self, cx: &App) -> bool {
767 self.mercury.api_token.read(cx).has_key()
768 }
769
770 pub fn clear_history(&mut self) {
771 for project_state in self.projects.values_mut() {
772 project_state.events.clear();
773 project_state.last_event.take();
774 }
775 }
776
777 pub fn clear_history_for_project(&mut self, project: &Entity<Project>) {
778 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
779 project_state.events.clear();
780 project_state.last_event.take();
781 }
782 }
783
784 pub fn edit_history_for_project(
785 &self,
786 project: &Entity<Project>,
787 cx: &App,
788 ) -> Vec<StoredEvent> {
789 self.projects
790 .get(&project.entity_id())
791 .map(|project_state| project_state.events(cx))
792 .unwrap_or_default()
793 }
794
795 pub fn context_for_project<'a>(
796 &'a self,
797 project: &Entity<Project>,
798 cx: &'a mut App,
799 ) -> Vec<RelatedFile> {
800 self.projects
801 .get(&project.entity_id())
802 .map(|project_state| {
803 project_state.context.update(cx, |context, cx| {
804 context
805 .related_files_with_buffers(cx)
806 .map(|(mut related_file, buffer)| {
807 related_file.in_open_source_repo = buffer
808 .read(cx)
809 .file()
810 .map_or(false, |file| self.is_file_open_source(&project, file, cx));
811 related_file
812 })
813 .collect()
814 })
815 })
816 .unwrap_or_default()
817 }
818
819 pub fn copilot_for_project(&self, project: &Entity<Project>) -> Option<Entity<Copilot>> {
820 self.projects
821 .get(&project.entity_id())
822 .and_then(|project| project.copilot.clone())
823 }
824
825 pub fn start_copilot_for_project(
826 &mut self,
827 project: &Entity<Project>,
828 cx: &mut Context<Self>,
829 ) -> Option<Entity<Copilot>> {
830 if DisableAiSettings::get(None, cx).disable_ai {
831 return None;
832 }
833 let state = self.get_or_init_project(project, cx);
834
835 if state.copilot.is_some() {
836 return state.copilot.clone();
837 }
838 let _project = project.clone();
839 let project = project.read(cx);
840
841 let node = project.node_runtime().cloned();
842 if let Some(node) = node {
843 let next_id = project.languages().next_language_server_id();
844 let fs = project.fs().clone();
845
846 let copilot = cx.new(|cx| Copilot::new(Some(_project), next_id, fs, node, cx));
847 state.copilot = Some(copilot.clone());
848 Some(copilot)
849 } else {
850 None
851 }
852 }
853
854 pub fn context_for_project_with_buffers<'a>(
855 &'a self,
856 project: &Entity<Project>,
857 cx: &'a mut App,
858 ) -> Vec<(RelatedFile, Entity<Buffer>)> {
859 self.projects
860 .get(&project.entity_id())
861 .map(|project| {
862 project.context.update(cx, |context, cx| {
863 context.related_files_with_buffers(cx).collect()
864 })
865 })
866 .unwrap_or_default()
867 }
868
869 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
870 if matches!(
871 self.edit_prediction_model,
872 EditPredictionModel::Zeta2 | EditPredictionModel::Zeta1
873 ) {
874 self.user_store.read(cx).edit_prediction_usage()
875 } else {
876 None
877 }
878 }
879
880 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
881 self.get_or_init_project(project, cx);
882 }
883
884 pub fn register_buffer(
885 &mut self,
886 buffer: &Entity<Buffer>,
887 project: &Entity<Project>,
888 cx: &mut Context<Self>,
889 ) {
890 let project_state = self.get_or_init_project(project, cx);
891 Self::register_buffer_impl(project_state, buffer, project, cx);
892 }
893
894 fn get_or_init_project(
895 &mut self,
896 project: &Entity<Project>,
897 cx: &mut Context<Self>,
898 ) -> &mut ProjectState {
899 let entity_id = project.entity_id();
900 self.projects
901 .entry(entity_id)
902 .or_insert_with(|| ProjectState {
903 context: {
904 let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
905 cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
906 this.handle_excerpt_store_event(entity_id, event);
907 })
908 .detach();
909 related_excerpt_store
910 },
911 events: VecDeque::new(),
912 last_event: None,
913 recent_paths: VecDeque::new(),
914 debug_tx: None,
915 registered_buffers: HashMap::default(),
916 current_prediction: None,
917 cancelled_predictions: HashSet::default(),
918 pending_predictions: ArrayVec::new(),
919 next_pending_prediction_id: 0,
920 last_edit_prediction_refresh: None,
921 last_jump_prediction_refresh: None,
922 license_detection_watchers: HashMap::default(),
923 user_actions: VecDeque::with_capacity(USER_ACTION_HISTORY_SIZE),
924 _subscriptions: [
925 cx.subscribe(&project, Self::handle_project_event),
926 cx.observe_release(&project, move |this, _, cx| {
927 this.projects.remove(&entity_id);
928 cx.notify();
929 }),
930 ],
931 copilot: None,
932 })
933 }
934
935 pub fn remove_project(&mut self, project: &Entity<Project>) {
936 self.projects.remove(&project.entity_id());
937 }
938
939 fn handle_excerpt_store_event(
940 &mut self,
941 project_entity_id: EntityId,
942 event: &RelatedExcerptStoreEvent,
943 ) {
944 if let Some(project_state) = self.projects.get(&project_entity_id) {
945 if let Some(debug_tx) = project_state.debug_tx.clone() {
946 match event {
947 RelatedExcerptStoreEvent::StartedRefresh => {
948 debug_tx
949 .unbounded_send(DebugEvent::ContextRetrievalStarted(
950 ContextRetrievalStartedDebugEvent {
951 project_entity_id: project_entity_id,
952 timestamp: Instant::now(),
953 search_prompt: String::new(),
954 },
955 ))
956 .ok();
957 }
958 RelatedExcerptStoreEvent::FinishedRefresh {
959 cache_hit_count,
960 cache_miss_count,
961 mean_definition_latency,
962 max_definition_latency,
963 } => {
964 debug_tx
965 .unbounded_send(DebugEvent::ContextRetrievalFinished(
966 ContextRetrievalFinishedDebugEvent {
967 project_entity_id: project_entity_id,
968 timestamp: Instant::now(),
969 metadata: vec![
970 (
971 "Cache Hits",
972 format!(
973 "{}/{}",
974 cache_hit_count,
975 cache_hit_count + cache_miss_count
976 )
977 .into(),
978 ),
979 (
980 "Max LSP Time",
981 format!("{} ms", max_definition_latency.as_millis())
982 .into(),
983 ),
984 (
985 "Mean LSP Time",
986 format!("{} ms", mean_definition_latency.as_millis())
987 .into(),
988 ),
989 ],
990 },
991 ))
992 .ok();
993 }
994 }
995 }
996 }
997 }
998
999 pub fn debug_info(
1000 &mut self,
1001 project: &Entity<Project>,
1002 cx: &mut Context<Self>,
1003 ) -> mpsc::UnboundedReceiver<DebugEvent> {
1004 let project_state = self.get_or_init_project(project, cx);
1005 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
1006 project_state.debug_tx = Some(debug_watch_tx);
1007 debug_watch_rx
1008 }
1009
1010 fn handle_project_event(
1011 &mut self,
1012 project: Entity<Project>,
1013 event: &project::Event,
1014 cx: &mut Context<Self>,
1015 ) {
1016 if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1017 return;
1018 }
1019 // TODO [zeta2] init with recent paths
1020 match event {
1021 project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
1022 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1023 return;
1024 };
1025 let path = project.read(cx).path_for_entry(*active_entry_id, cx);
1026 if let Some(path) = path {
1027 if let Some(ix) = project_state
1028 .recent_paths
1029 .iter()
1030 .position(|probe| probe == &path)
1031 {
1032 project_state.recent_paths.remove(ix);
1033 }
1034 project_state.recent_paths.push_front(path);
1035 }
1036 }
1037 project::Event::DiagnosticsUpdated { .. } => {
1038 if cx.has_flag::<Zeta2FeatureFlag>() {
1039 self.refresh_prediction_from_diagnostics(
1040 project,
1041 DiagnosticSearchScope::Global,
1042 cx,
1043 );
1044 }
1045 }
1046 _ => (),
1047 }
1048 }
1049
1050 fn register_buffer_impl<'a>(
1051 project_state: &'a mut ProjectState,
1052 buffer: &Entity<Buffer>,
1053 project: &Entity<Project>,
1054 cx: &mut Context<Self>,
1055 ) -> &'a mut RegisteredBuffer {
1056 let buffer_id = buffer.entity_id();
1057
1058 if let Some(file) = buffer.read(cx).file() {
1059 let worktree_id = file.worktree_id(cx);
1060 if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
1061 project_state
1062 .license_detection_watchers
1063 .entry(worktree_id)
1064 .or_insert_with(|| {
1065 let project_entity_id = project.entity_id();
1066 cx.observe_release(&worktree, move |this, _worktree, _cx| {
1067 let Some(project_state) = this.projects.get_mut(&project_entity_id)
1068 else {
1069 return;
1070 };
1071 project_state
1072 .license_detection_watchers
1073 .remove(&worktree_id);
1074 })
1075 .detach();
1076 Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
1077 });
1078 }
1079 }
1080
1081 match project_state.registered_buffers.entry(buffer_id) {
1082 hash_map::Entry::Occupied(entry) => entry.into_mut(),
1083 hash_map::Entry::Vacant(entry) => {
1084 let buf = buffer.read(cx);
1085 let snapshot = buf.text_snapshot();
1086 let file = buf.file().cloned();
1087 let project_entity_id = project.entity_id();
1088 entry.insert(RegisteredBuffer {
1089 snapshot,
1090 file,
1091 last_position: None,
1092 _subscriptions: [
1093 cx.subscribe(buffer, {
1094 let project = project.downgrade();
1095 move |this, buffer, event, cx| {
1096 if let language::BufferEvent::Edited = event
1097 && let Some(project) = project.upgrade()
1098 {
1099 this.report_changes_for_buffer(&buffer, &project, false, cx);
1100 }
1101 }
1102 }),
1103 cx.observe_release(buffer, move |this, _buffer, _cx| {
1104 let Some(project_state) = this.projects.get_mut(&project_entity_id)
1105 else {
1106 return;
1107 };
1108 project_state.registered_buffers.remove(&buffer_id);
1109 }),
1110 ],
1111 })
1112 }
1113 }
1114 }
1115
1116 fn report_changes_for_buffer(
1117 &mut self,
1118 buffer: &Entity<Buffer>,
1119 project: &Entity<Project>,
1120 is_predicted: bool,
1121 cx: &mut Context<Self>,
1122 ) {
1123 let project_state = self.get_or_init_project(project, cx);
1124 let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
1125
1126 let buf = buffer.read(cx);
1127 let new_file = buf.file().cloned();
1128 let new_snapshot = buf.text_snapshot();
1129 if new_snapshot.version == registered_buffer.snapshot.version {
1130 return;
1131 }
1132
1133 let old_file = mem::replace(&mut registered_buffer.file, new_file.clone());
1134 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
1135 let mut num_edits = 0usize;
1136 let mut total_deleted = 0usize;
1137 let mut total_inserted = 0usize;
1138 let mut edit_range: Option<Range<Anchor>> = None;
1139 let mut last_offset: Option<usize> = None;
1140
1141 for (edit, anchor_range) in
1142 new_snapshot.anchored_edits_since::<usize>(&old_snapshot.version)
1143 {
1144 num_edits += 1;
1145 total_deleted += edit.old.len();
1146 total_inserted += edit.new.len();
1147 edit_range = Some(match edit_range {
1148 None => anchor_range,
1149 Some(acc) => acc.start..anchor_range.end,
1150 });
1151 last_offset = Some(edit.new.end);
1152 }
1153
1154 let Some(edit_range) = edit_range else {
1155 return;
1156 };
1157
1158 let action_type = match (total_deleted, total_inserted, num_edits) {
1159 (0, ins, n) if ins == n => UserActionType::InsertChar,
1160 (0, _, _) => UserActionType::InsertSelection,
1161 (del, 0, n) if del == n => UserActionType::DeleteChar,
1162 (_, 0, _) => UserActionType::DeleteSelection,
1163 (_, ins, n) if ins == n => UserActionType::InsertChar,
1164 (_, _, _) => UserActionType::InsertSelection,
1165 };
1166
1167 if let Some(offset) = last_offset {
1168 let point = new_snapshot.offset_to_point(offset);
1169 let timestamp_epoch_ms = SystemTime::now()
1170 .duration_since(UNIX_EPOCH)
1171 .map(|d| d.as_millis() as u64)
1172 .unwrap_or(0);
1173 project_state.record_user_action(UserActionRecord {
1174 action_type,
1175 buffer_id: buffer.entity_id(),
1176 line_number: point.row,
1177 offset,
1178 timestamp_epoch_ms,
1179 });
1180 }
1181
1182 let events = &mut project_state.events;
1183
1184 let now = cx.background_executor().now();
1185 if let Some(last_event) = project_state.last_event.as_mut() {
1186 let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
1187 == last_event.new_snapshot.remote_id()
1188 && old_snapshot.version == last_event.new_snapshot.version;
1189
1190 let prediction_source_changed = is_predicted != last_event.predicted;
1191
1192 let should_coalesce = is_next_snapshot_of_same_buffer
1193 && !prediction_source_changed
1194 && last_event
1195 .edit_range
1196 .as_ref()
1197 .is_some_and(|last_edit_range| {
1198 lines_between_ranges(
1199 &edit_range.to_point(&new_snapshot),
1200 &last_edit_range.to_point(&new_snapshot),
1201 ) <= CHANGE_GROUPING_LINE_SPAN
1202 });
1203
1204 if should_coalesce {
1205 let pause_elapsed = last_event
1206 .last_edit_time
1207 .map(|t| now.duration_since(t) >= LAST_CHANGE_GROUPING_TIME)
1208 .unwrap_or(false);
1209 if pause_elapsed {
1210 last_event.snapshot_after_last_editing_pause =
1211 Some(last_event.new_snapshot.clone());
1212 }
1213
1214 last_event.edit_range = Some(edit_range);
1215 last_event.new_snapshot = new_snapshot;
1216 last_event.last_edit_time = Some(now);
1217 return;
1218 }
1219 }
1220
1221 if let Some(event) = project_state.last_event.take() {
1222 if let Some(event) = event.finalize(&project_state.license_detection_watchers, cx) {
1223 if events.len() + 1 >= EVENT_COUNT_MAX {
1224 events.pop_front();
1225 }
1226 events.push_back(event);
1227 }
1228 }
1229
1230 merge_trailing_events_if_needed(events, &old_snapshot, &new_snapshot, &edit_range);
1231
1232 project_state.last_event = Some(LastEvent {
1233 old_file,
1234 new_file,
1235 old_snapshot,
1236 new_snapshot,
1237 edit_range: Some(edit_range),
1238 predicted: is_predicted,
1239 snapshot_after_last_editing_pause: None,
1240 last_edit_time: Some(now),
1241 });
1242 }
1243
1244 fn prediction_at(
1245 &mut self,
1246 buffer: &Entity<Buffer>,
1247 position: Option<language::Anchor>,
1248 project: &Entity<Project>,
1249 cx: &App,
1250 ) -> Option<BufferEditPrediction<'_>> {
1251 let project_state = self.projects.get_mut(&project.entity_id())?;
1252 if let Some(position) = position
1253 && let Some(buffer) = project_state
1254 .registered_buffers
1255 .get_mut(&buffer.entity_id())
1256 {
1257 buffer.last_position = Some(position);
1258 }
1259
1260 let CurrentEditPrediction {
1261 requested_by,
1262 prediction,
1263 ..
1264 } = project_state.current_prediction.as_ref()?;
1265
1266 if prediction.targets_buffer(buffer.read(cx)) {
1267 Some(BufferEditPrediction::Local { prediction })
1268 } else {
1269 let show_jump = match requested_by {
1270 PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
1271 requested_by_buffer_id == &buffer.entity_id()
1272 }
1273 PredictionRequestedBy::DiagnosticsUpdate => true,
1274 };
1275
1276 if show_jump {
1277 Some(BufferEditPrediction::Jump { prediction })
1278 } else {
1279 None
1280 }
1281 }
1282 }
1283
1284 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
1285 let Some(current_prediction) = self
1286 .projects
1287 .get_mut(&project.entity_id())
1288 .and_then(|project_state| project_state.current_prediction.take())
1289 else {
1290 return;
1291 };
1292
1293 self.report_changes_for_buffer(¤t_prediction.prediction.buffer, project, true, cx);
1294
1295 // can't hold &mut project_state ref across report_changes_for_buffer_call
1296 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1297 return;
1298 };
1299
1300 for pending_prediction in mem::take(&mut project_state.pending_predictions) {
1301 project_state.cancel_pending_prediction(pending_prediction, cx);
1302 }
1303
1304 match self.edit_prediction_model {
1305 EditPredictionModel::Sweep => {
1306 sweep_ai::edit_prediction_accepted(self, current_prediction, cx)
1307 }
1308 EditPredictionModel::Mercury => {
1309 mercury::edit_prediction_accepted(
1310 current_prediction.prediction.id,
1311 self.client.http_client(),
1312 cx,
1313 );
1314 }
1315 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
1316 let is_cloud = !matches!(
1317 all_language_settings(None, cx).edit_predictions.provider,
1318 EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1319 );
1320 if is_cloud {
1321 zeta::edit_prediction_accepted(self, current_prediction, cx)
1322 }
1323 }
1324 EditPredictionModel::Fim { .. } => {}
1325 }
1326 }
1327
1328 async fn handle_rejected_predictions(
1329 rx: UnboundedReceiver<EditPredictionRejection>,
1330 client: Arc<Client>,
1331 llm_token: LlmApiToken,
1332 app_version: Version,
1333 background_executor: BackgroundExecutor,
1334 ) {
1335 let mut rx = std::pin::pin!(rx.peekable());
1336 let mut batched = Vec::new();
1337
1338 while let Some(rejection) = rx.next().await {
1339 batched.push(rejection);
1340
1341 if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
1342 select_biased! {
1343 next = rx.as_mut().peek().fuse() => {
1344 if next.is_some() {
1345 continue;
1346 }
1347 }
1348 () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
1349 }
1350 }
1351
1352 let url = client
1353 .http_client()
1354 .build_zed_llm_url("/predict_edits/reject", &[])
1355 .unwrap();
1356
1357 let flush_count = batched
1358 .len()
1359 // in case items have accumulated after failure
1360 .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
1361 let start = batched.len() - flush_count;
1362
1363 let body = RejectEditPredictionsBodyRef {
1364 rejections: &batched[start..],
1365 };
1366
1367 let result = Self::send_api_request::<()>(
1368 |builder| {
1369 let req = builder
1370 .uri(url.as_ref())
1371 .body(serde_json::to_string(&body)?.into());
1372 anyhow::Ok(req?)
1373 },
1374 client.clone(),
1375 llm_token.clone(),
1376 app_version.clone(),
1377 true,
1378 )
1379 .await;
1380
1381 if result.log_err().is_some() {
1382 batched.drain(start..);
1383 }
1384 }
1385 }
1386
1387 fn reject_current_prediction(
1388 &mut self,
1389 reason: EditPredictionRejectReason,
1390 project: &Entity<Project>,
1391 cx: &App,
1392 ) {
1393 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
1394 project_state.pending_predictions.clear();
1395 if let Some(prediction) = project_state.current_prediction.take() {
1396 self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown, cx);
1397 }
1398 };
1399 }
1400
1401 fn did_show_current_prediction(
1402 &mut self,
1403 project: &Entity<Project>,
1404 display_type: edit_prediction_types::SuggestionDisplayType,
1405 cx: &mut Context<Self>,
1406 ) {
1407 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1408 return;
1409 };
1410
1411 let Some(current_prediction) = project_state.current_prediction.as_mut() else {
1412 return;
1413 };
1414
1415 let is_jump = display_type == edit_prediction_types::SuggestionDisplayType::Jump;
1416 let previous_shown_with = current_prediction.shown_with;
1417
1418 if previous_shown_with.is_none() || !is_jump {
1419 current_prediction.shown_with = Some(display_type);
1420 }
1421
1422 let is_first_non_jump_show = !current_prediction.was_shown && !is_jump;
1423
1424 if is_first_non_jump_show {
1425 current_prediction.was_shown = true;
1426 }
1427
1428 let display_type_changed = previous_shown_with != Some(display_type);
1429
1430 if self.edit_prediction_model == EditPredictionModel::Sweep && display_type_changed {
1431 sweep_ai::edit_prediction_shown(
1432 &self.sweep_ai,
1433 self.client.clone(),
1434 ¤t_prediction.prediction,
1435 display_type,
1436 cx,
1437 );
1438 }
1439
1440 if is_first_non_jump_show {
1441 self.shown_predictions
1442 .push_front(current_prediction.prediction.clone());
1443 if self.shown_predictions.len() > 50 {
1444 let completion = self.shown_predictions.pop_back().unwrap();
1445 self.rated_predictions.remove(&completion.id);
1446 }
1447 }
1448 }
1449
1450 fn reject_prediction(
1451 &mut self,
1452 prediction_id: EditPredictionId,
1453 reason: EditPredictionRejectReason,
1454 was_shown: bool,
1455 cx: &App,
1456 ) {
1457 match self.edit_prediction_model {
1458 EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
1459 let is_cloud = !matches!(
1460 all_language_settings(None, cx).edit_predictions.provider,
1461 EditPredictionProvider::Ollama | EditPredictionProvider::OpenAiCompatibleApi
1462 );
1463 if is_cloud {
1464 self.reject_predictions_tx
1465 .unbounded_send(EditPredictionRejection {
1466 request_id: prediction_id.to_string(),
1467 reason,
1468 was_shown,
1469 })
1470 .log_err();
1471 }
1472 }
1473 EditPredictionModel::Mercury => {
1474 mercury::edit_prediction_rejected(
1475 prediction_id,
1476 was_shown,
1477 reason,
1478 self.client.http_client(),
1479 cx,
1480 );
1481 }
1482 EditPredictionModel::Sweep | EditPredictionModel::Fim { .. } => {}
1483 }
1484 }
1485
1486 fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1487 self.projects
1488 .get(&project.entity_id())
1489 .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1490 }
1491
1492 pub fn refresh_prediction_from_buffer(
1493 &mut self,
1494 project: Entity<Project>,
1495 buffer: Entity<Buffer>,
1496 position: language::Anchor,
1497 cx: &mut Context<Self>,
1498 ) {
1499 self.queue_prediction_refresh(
1500 project.clone(),
1501 PredictEditsRequestTrigger::Other,
1502 buffer.entity_id(),
1503 cx,
1504 move |this, cx| {
1505 let Some(request_task) = this
1506 .update(cx, |this, cx| {
1507 this.request_prediction(
1508 &project,
1509 &buffer,
1510 position,
1511 PredictEditsRequestTrigger::Other,
1512 cx,
1513 )
1514 })
1515 .log_err()
1516 else {
1517 return Task::ready(anyhow::Ok(None));
1518 };
1519
1520 cx.spawn(async move |_cx| {
1521 request_task.await.map(|prediction_result| {
1522 prediction_result.map(|prediction_result| {
1523 (
1524 prediction_result,
1525 PredictionRequestedBy::Buffer(buffer.entity_id()),
1526 )
1527 })
1528 })
1529 })
1530 },
1531 )
1532 }
1533
1534 pub fn refresh_prediction_from_diagnostics(
1535 &mut self,
1536 project: Entity<Project>,
1537 scope: DiagnosticSearchScope,
1538 cx: &mut Context<Self>,
1539 ) {
1540 if !is_ep_store_provider(all_language_settings(None, cx).edit_predictions.provider) {
1541 return;
1542 }
1543
1544 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
1545 return;
1546 };
1547
1548 // Prefer predictions from buffer
1549 if project_state.current_prediction.is_some() {
1550 return;
1551 }
1552
1553 self.queue_prediction_refresh(
1554 project.clone(),
1555 PredictEditsRequestTrigger::Diagnostics,
1556 project.entity_id(),
1557 cx,
1558 move |this, cx| {
1559 let Some((active_buffer, snapshot, cursor_point)) = this
1560 .read_with(cx, |this, cx| {
1561 let project_state = this.projects.get(&project.entity_id())?;
1562 let (buffer, position) = project_state.active_buffer(&project, cx)?;
1563 let snapshot = buffer.read(cx).snapshot();
1564
1565 if !Self::predictions_enabled_at(&snapshot, position, cx) {
1566 return None;
1567 }
1568
1569 let cursor_point = position
1570 .map(|pos| pos.to_point(&snapshot))
1571 .unwrap_or_default();
1572
1573 Some((buffer, snapshot, cursor_point))
1574 })
1575 .log_err()
1576 .flatten()
1577 else {
1578 return Task::ready(anyhow::Ok(None));
1579 };
1580
1581 cx.spawn(async move |cx| {
1582 let diagnostic_search_range = match scope {
1583 DiagnosticSearchScope::Local => {
1584 let diagnostic_search_start =
1585 cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1586 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1587 Point::new(diagnostic_search_start, 0)
1588 ..Point::new(diagnostic_search_end, 0)
1589 }
1590 DiagnosticSearchScope::Global => Default::default(),
1591 };
1592
1593 let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1594 active_buffer,
1595 &snapshot,
1596 diagnostic_search_range,
1597 cursor_point,
1598 &project,
1599 cx,
1600 )
1601 .await?
1602 else {
1603 return anyhow::Ok(None);
1604 };
1605
1606 let Some(prediction_result) = this
1607 .update(cx, |this, cx| {
1608 this.request_prediction(
1609 &project,
1610 &jump_buffer,
1611 jump_position,
1612 PredictEditsRequestTrigger::Diagnostics,
1613 cx,
1614 )
1615 })?
1616 .await?
1617 else {
1618 return anyhow::Ok(None);
1619 };
1620
1621 this.update(cx, |this, cx| {
1622 Some((
1623 if this
1624 .get_or_init_project(&project, cx)
1625 .current_prediction
1626 .is_none()
1627 {
1628 prediction_result
1629 } else {
1630 EditPredictionResult {
1631 id: prediction_result.id,
1632 prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1633 }
1634 },
1635 PredictionRequestedBy::DiagnosticsUpdate,
1636 ))
1637 })
1638 })
1639 },
1640 );
1641 }
1642
1643 fn predictions_enabled_at(
1644 snapshot: &BufferSnapshot,
1645 position: Option<language::Anchor>,
1646 cx: &App,
1647 ) -> bool {
1648 let file = snapshot.file();
1649 let all_settings = all_language_settings(file, cx);
1650 if !all_settings.show_edit_predictions(snapshot.language(), cx)
1651 || file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
1652 {
1653 return false;
1654 }
1655
1656 if let Some(last_position) = position {
1657 let settings = snapshot.settings_at(last_position, cx);
1658
1659 if !settings.edit_predictions_disabled_in.is_empty()
1660 && let Some(scope) = snapshot.language_scope_at(last_position)
1661 && let Some(scope_name) = scope.override_name()
1662 && settings
1663 .edit_predictions_disabled_in
1664 .iter()
1665 .any(|s| s == scope_name)
1666 {
1667 return false;
1668 }
1669 }
1670
1671 true
1672 }
1673
1674 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1675}
1676
1677fn is_ep_store_provider(provider: EditPredictionProvider) -> bool {
1678 match provider {
1679 EditPredictionProvider::Zed
1680 | EditPredictionProvider::Sweep
1681 | EditPredictionProvider::Mercury
1682 | EditPredictionProvider::Ollama
1683 | EditPredictionProvider::OpenAiCompatibleApi
1684 | EditPredictionProvider::Experimental(_) => true,
1685 EditPredictionProvider::None
1686 | EditPredictionProvider::Copilot
1687 | EditPredictionProvider::Supermaven
1688 | EditPredictionProvider::Codestral => false,
1689 }
1690}
1691
1692impl EditPredictionStore {
1693 fn queue_prediction_refresh(
1694 &mut self,
1695 project: Entity<Project>,
1696 request_trigger: PredictEditsRequestTrigger,
1697 throttle_entity: EntityId,
1698 cx: &mut Context<Self>,
1699 do_refresh: impl FnOnce(
1700 WeakEntity<Self>,
1701 &mut AsyncApp,
1702 )
1703 -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1704 + 'static,
1705 ) {
1706 fn select_throttle(
1707 project_state: &mut ProjectState,
1708 request_trigger: PredictEditsRequestTrigger,
1709 ) -> &mut Option<(EntityId, Instant)> {
1710 match request_trigger {
1711 PredictEditsRequestTrigger::Diagnostics => {
1712 &mut project_state.last_jump_prediction_refresh
1713 }
1714 _ => &mut project_state.last_edit_prediction_refresh,
1715 }
1716 }
1717
1718 let (needs_acceptance_tracking, max_pending_predictions) =
1719 match all_language_settings(None, cx).edit_predictions.provider {
1720 EditPredictionProvider::Zed
1721 | EditPredictionProvider::Sweep
1722 | EditPredictionProvider::Mercury
1723 | EditPredictionProvider::Experimental(_) => (true, 2),
1724 EditPredictionProvider::Ollama => (false, 1),
1725 EditPredictionProvider::OpenAiCompatibleApi => (false, 2),
1726 EditPredictionProvider::None
1727 | EditPredictionProvider::Copilot
1728 | EditPredictionProvider::Supermaven
1729 | EditPredictionProvider::Codestral => {
1730 log::error!("queue_prediction_refresh called with non-store provider");
1731 return;
1732 }
1733 };
1734
1735 let drop_on_cancel = !needs_acceptance_tracking;
1736 let throttle_timeout = Self::THROTTLE_TIMEOUT;
1737 let project_state = self.get_or_init_project(&project, cx);
1738 let pending_prediction_id = project_state.next_pending_prediction_id;
1739 project_state.next_pending_prediction_id += 1;
1740 let last_request = *select_throttle(project_state, request_trigger);
1741
1742 let task = cx.spawn(async move |this, cx| {
1743 if let Some(timeout) = last_request.and_then(|(last_entity, last_timestamp)| {
1744 if throttle_entity != last_entity {
1745 return None;
1746 }
1747 (last_timestamp + throttle_timeout).checked_duration_since(Instant::now())
1748 }) {
1749 cx.background_executor().timer(timeout).await;
1750 }
1751
1752 // If this task was cancelled before the throttle timeout expired,
1753 // do not perform a request.
1754 let mut is_cancelled = true;
1755 this.update(cx, |this, cx| {
1756 let project_state = this.get_or_init_project(&project, cx);
1757 let was_cancelled = project_state
1758 .cancelled_predictions
1759 .remove(&pending_prediction_id);
1760 if !was_cancelled {
1761 let new_refresh = (throttle_entity, Instant::now());
1762 *select_throttle(project_state, request_trigger) = Some(new_refresh);
1763 is_cancelled = false;
1764 }
1765 })
1766 .ok();
1767 if is_cancelled {
1768 return None;
1769 }
1770
1771 let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
1772 let new_prediction_id = new_prediction_result
1773 .as_ref()
1774 .map(|(prediction, _)| prediction.id.clone());
1775
1776 // When a prediction completes, remove it from the pending list, and cancel
1777 // any pending predictions that were enqueued before it.
1778 this.update(cx, |this, cx| {
1779 let project_state = this.get_or_init_project(&project, cx);
1780
1781 let is_cancelled = project_state
1782 .cancelled_predictions
1783 .remove(&pending_prediction_id);
1784
1785 let new_current_prediction = if !is_cancelled
1786 && let Some((prediction_result, requested_by)) = new_prediction_result
1787 {
1788 match prediction_result.prediction {
1789 Ok(prediction) => {
1790 let new_prediction = CurrentEditPrediction {
1791 requested_by,
1792 prediction,
1793 was_shown: false,
1794 shown_with: None,
1795 };
1796
1797 if let Some(current_prediction) =
1798 project_state.current_prediction.as_ref()
1799 {
1800 if new_prediction.should_replace_prediction(¤t_prediction, cx)
1801 {
1802 this.reject_current_prediction(
1803 EditPredictionRejectReason::Replaced,
1804 &project,
1805 cx,
1806 );
1807
1808 Some(new_prediction)
1809 } else {
1810 this.reject_prediction(
1811 new_prediction.prediction.id,
1812 EditPredictionRejectReason::CurrentPreferred,
1813 false,
1814 cx,
1815 );
1816 None
1817 }
1818 } else {
1819 Some(new_prediction)
1820 }
1821 }
1822 Err(reject_reason) => {
1823 this.reject_prediction(prediction_result.id, reject_reason, false, cx);
1824 None
1825 }
1826 }
1827 } else {
1828 None
1829 };
1830
1831 let project_state = this.get_or_init_project(&project, cx);
1832
1833 if let Some(new_prediction) = new_current_prediction {
1834 project_state.current_prediction = Some(new_prediction);
1835 }
1836
1837 let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
1838 for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1839 if pending_prediction.id == pending_prediction_id {
1840 pending_predictions.remove(ix);
1841 for pending_prediction in pending_predictions.drain(0..ix) {
1842 project_state.cancel_pending_prediction(pending_prediction, cx)
1843 }
1844 break;
1845 }
1846 }
1847 this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
1848 cx.notify();
1849 })
1850 .ok();
1851
1852 new_prediction_id
1853 });
1854
1855 if project_state.pending_predictions.len() < max_pending_predictions {
1856 project_state.pending_predictions.push(PendingPrediction {
1857 id: pending_prediction_id,
1858 task,
1859 drop_on_cancel,
1860 });
1861 } else {
1862 let pending_prediction = project_state.pending_predictions.pop().unwrap();
1863 project_state.pending_predictions.push(PendingPrediction {
1864 id: pending_prediction_id,
1865 task,
1866 drop_on_cancel,
1867 });
1868 project_state.cancel_pending_prediction(pending_prediction, cx);
1869 }
1870 }
1871
1872 pub fn request_prediction(
1873 &mut self,
1874 project: &Entity<Project>,
1875 active_buffer: &Entity<Buffer>,
1876 position: language::Anchor,
1877 trigger: PredictEditsRequestTrigger,
1878 cx: &mut Context<Self>,
1879 ) -> Task<Result<Option<EditPredictionResult>>> {
1880 self.request_prediction_internal(
1881 project.clone(),
1882 active_buffer.clone(),
1883 position,
1884 trigger,
1885 cx.has_flag::<Zeta2FeatureFlag>(),
1886 cx,
1887 )
1888 }
1889
1890 fn request_prediction_internal(
1891 &mut self,
1892 project: Entity<Project>,
1893 active_buffer: Entity<Buffer>,
1894 position: language::Anchor,
1895 trigger: PredictEditsRequestTrigger,
1896 allow_jump: bool,
1897 cx: &mut Context<Self>,
1898 ) -> Task<Result<Option<EditPredictionResult>>> {
1899 self.get_or_init_project(&project, cx);
1900 let project_state = self.projects.get(&project.entity_id()).unwrap();
1901 let stored_events = project_state.events(cx);
1902 let has_events = !stored_events.is_empty();
1903 let events: Vec<Arc<zeta_prompt::Event>> =
1904 stored_events.into_iter().map(|e| e.event).collect();
1905 let debug_tx = project_state.debug_tx.clone();
1906
1907 let snapshot = active_buffer.read(cx).snapshot();
1908 let cursor_point = position.to_point(&snapshot);
1909 let current_offset = position.to_offset(&snapshot);
1910
1911 let mut user_actions: Vec<UserActionRecord> =
1912 project_state.user_actions.iter().cloned().collect();
1913
1914 if let Some(last_action) = user_actions.last() {
1915 if last_action.buffer_id == active_buffer.entity_id()
1916 && current_offset != last_action.offset
1917 {
1918 let timestamp_epoch_ms = SystemTime::now()
1919 .duration_since(UNIX_EPOCH)
1920 .map(|d| d.as_millis() as u64)
1921 .unwrap_or(0);
1922 user_actions.push(UserActionRecord {
1923 action_type: UserActionType::CursorMovement,
1924 buffer_id: active_buffer.entity_id(),
1925 line_number: cursor_point.row,
1926 offset: current_offset,
1927 timestamp_epoch_ms,
1928 });
1929 }
1930 }
1931 let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1932 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1933 let diagnostic_search_range =
1934 Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1935
1936 let related_files = self.context_for_project(&project, cx);
1937
1938 let inputs = EditPredictionModelInput {
1939 project: project.clone(),
1940 buffer: active_buffer,
1941 snapshot: snapshot,
1942 position,
1943 events,
1944 related_files,
1945 recent_paths: project_state.recent_paths.clone(),
1946 trigger,
1947 diagnostic_search_range: diagnostic_search_range,
1948 debug_tx,
1949 user_actions,
1950 };
1951
1952 let task = match self.edit_prediction_model {
1953 EditPredictionModel::Zeta1 => zeta::request_prediction_with_zeta(
1954 self,
1955 inputs,
1956 Some(zeta_prompt::EditPredictionModelKind::Zeta1),
1957 cx,
1958 ),
1959 EditPredictionModel::Zeta2 => zeta::request_prediction_with_zeta(
1960 self,
1961 inputs,
1962 Some(zeta_prompt::EditPredictionModelKind::Zeta2),
1963 cx,
1964 ),
1965 EditPredictionModel::Fim { format } => fim::request_prediction(inputs, format, cx),
1966 EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
1967 EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
1968 };
1969
1970 cx.spawn(async move |this, cx| {
1971 let prediction = task.await?;
1972
1973 if prediction.is_none() && allow_jump && has_events {
1974 this.update(cx, |this, cx| {
1975 this.refresh_prediction_from_diagnostics(
1976 project,
1977 DiagnosticSearchScope::Local,
1978 cx,
1979 );
1980 })?;
1981 return anyhow::Ok(None);
1982 }
1983
1984 Ok(prediction)
1985 })
1986 }
1987
1988 pub(crate) async fn next_diagnostic_location(
1989 active_buffer: Entity<Buffer>,
1990 active_buffer_snapshot: &BufferSnapshot,
1991 active_buffer_diagnostic_search_range: Range<Point>,
1992 active_buffer_cursor_point: Point,
1993 project: &Entity<Project>,
1994 cx: &mut AsyncApp,
1995 ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1996 let collaborator_cursor_rows: Vec<u32> = active_buffer_snapshot
1997 .selections_in_range(Anchor::MIN..Anchor::MAX, false)
1998 .flat_map(|(_, _, _, selections)| {
1999 selections.map(|s| s.head().to_point(active_buffer_snapshot).row)
2000 })
2001 .collect();
2002
2003 let mut jump_location = active_buffer_snapshot
2004 .diagnostic_groups(None)
2005 .into_iter()
2006 .filter_map(|(_, group)| {
2007 let range = &group.entries[group.primary_ix]
2008 .range
2009 .to_point(&active_buffer_snapshot);
2010 if range.overlaps(&active_buffer_diagnostic_search_range) {
2011 return None;
2012 }
2013 let near_collaborator = collaborator_cursor_rows.iter().any(|&collab_row| {
2014 range.start.row.abs_diff(collab_row) <= DIAGNOSTIC_LINES_RANGE
2015 });
2016 let near_local = active_buffer_cursor_point.row.abs_diff(range.start.row)
2017 <= DIAGNOSTIC_LINES_RANGE;
2018 if near_collaborator && !near_local {
2019 return None;
2020 }
2021 Some(range.start)
2022 })
2023 .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
2024 .map(|position| {
2025 (
2026 active_buffer.clone(),
2027 active_buffer_snapshot.anchor_before(position),
2028 )
2029 });
2030
2031 if jump_location.is_none() {
2032 let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
2033 let file = buffer.file()?;
2034
2035 Some(ProjectPath {
2036 worktree_id: file.worktree_id(cx),
2037 path: file.path().clone(),
2038 })
2039 });
2040
2041 let mut candidates: Vec<(ProjectPath, usize)> = project.read_with(cx, |project, cx| {
2042 project
2043 .diagnostic_summaries(false, cx)
2044 .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
2045 .map(|(path, _, _)| {
2046 let shared_prefix = path
2047 .path
2048 .components()
2049 .zip(
2050 active_buffer_path
2051 .as_ref()
2052 .map(|p| p.path.components())
2053 .unwrap_or_default(),
2054 )
2055 .take_while(|(a, b)| a == b)
2056 .count();
2057 (path, shared_prefix)
2058 })
2059 .collect()
2060 });
2061
2062 candidates.sort_by(|a, b| b.1.cmp(&a.1));
2063
2064 for (path, _) in candidates {
2065 let candidate_buffer = project
2066 .update(cx, |project, cx| project.open_buffer(path, cx))
2067 .await?;
2068
2069 let (has_collaborators, diagnostic_position) =
2070 candidate_buffer.read_with(cx, |buffer, _cx| {
2071 let snapshot = buffer.snapshot();
2072 let has_collaborators = snapshot
2073 .selections_in_range(Anchor::MIN..Anchor::MAX, false)
2074 .next()
2075 .is_some();
2076 let position = buffer
2077 .buffer_diagnostics(None)
2078 .into_iter()
2079 .min_by_key(|entry| entry.diagnostic.severity)
2080 .map(|entry| entry.range.start);
2081 (has_collaborators, position)
2082 });
2083
2084 if has_collaborators {
2085 continue;
2086 }
2087
2088 if let Some(position) = diagnostic_position {
2089 jump_location = Some((candidate_buffer, position));
2090 break;
2091 }
2092 }
2093 }
2094
2095 anyhow::Ok(jump_location)
2096 }
2097
2098 async fn send_raw_llm_request(
2099 request: RawCompletionRequest,
2100 client: Arc<Client>,
2101 custom_url: Option<Arc<Url>>,
2102 llm_token: LlmApiToken,
2103 app_version: Version,
2104 ) -> Result<(RawCompletionResponse, Option<EditPredictionUsage>)> {
2105 let url = if let Some(custom_url) = custom_url {
2106 custom_url.as_ref().clone()
2107 } else {
2108 client
2109 .http_client()
2110 .build_zed_llm_url("/predict_edits/raw", &[])?
2111 };
2112
2113 Self::send_api_request(
2114 |builder| {
2115 let req = builder
2116 .uri(url.as_ref())
2117 .body(serde_json::to_string(&request)?.into());
2118 Ok(req?)
2119 },
2120 client,
2121 llm_token,
2122 app_version,
2123 true,
2124 )
2125 .await
2126 }
2127
2128 pub(crate) async fn send_v3_request(
2129 input: ZetaPromptInput,
2130 client: Arc<Client>,
2131 llm_token: LlmApiToken,
2132 app_version: Version,
2133 trigger: PredictEditsRequestTrigger,
2134 ) -> Result<(PredictEditsV3Response, Option<EditPredictionUsage>)> {
2135 let url = client
2136 .http_client()
2137 .build_zed_llm_url("/predict_edits/v3", &[])?;
2138
2139 let request = PredictEditsV3Request { input, trigger };
2140
2141 let json_bytes = serde_json::to_vec(&request)?;
2142 let compressed = zstd::encode_all(&json_bytes[..], 3)?;
2143
2144 Self::send_api_request(
2145 |builder| {
2146 let req = builder
2147 .uri(url.as_ref())
2148 .header("Content-Encoding", "zstd")
2149 .body(compressed.clone().into());
2150 Ok(req?)
2151 },
2152 client,
2153 llm_token,
2154 app_version,
2155 true,
2156 )
2157 .await
2158 }
2159
2160 fn handle_api_response<T>(
2161 this: &WeakEntity<Self>,
2162 response: Result<(T, Option<EditPredictionUsage>)>,
2163 cx: &mut gpui::AsyncApp,
2164 ) -> Result<T> {
2165 match response {
2166 Ok((data, usage)) => {
2167 if let Some(usage) = usage {
2168 this.update(cx, |this, cx| {
2169 this.user_store.update(cx, |user_store, cx| {
2170 user_store.update_edit_prediction_usage(usage, cx);
2171 });
2172 })
2173 .ok();
2174 }
2175 Ok(data)
2176 }
2177 Err(err) => {
2178 if err.is::<ZedUpdateRequiredError>() {
2179 cx.update(|cx| {
2180 this.update(cx, |this, _cx| {
2181 this.update_required = true;
2182 })
2183 .ok();
2184
2185 let error_message: SharedString = err.to_string().into();
2186 show_app_notification(
2187 NotificationId::unique::<ZedUpdateRequiredError>(),
2188 cx,
2189 move |cx| {
2190 cx.new(|cx| {
2191 ErrorMessagePrompt::new(error_message.clone(), cx)
2192 .with_link_button("Update Zed", "https://zed.dev/releases")
2193 })
2194 },
2195 );
2196 });
2197 }
2198 Err(err)
2199 }
2200 }
2201 }
2202
2203 async fn send_api_request<Res>(
2204 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
2205 client: Arc<Client>,
2206 llm_token: LlmApiToken,
2207 app_version: Version,
2208 require_auth: bool,
2209 ) -> Result<(Res, Option<EditPredictionUsage>)>
2210 where
2211 Res: DeserializeOwned,
2212 {
2213 let http_client = client.http_client();
2214
2215 let mut token = if let Ok(custom_token) = std::env::var("ZED_PREDICT_EDITS_TOKEN") {
2216 Some(custom_token)
2217 } else if require_auth {
2218 Some(llm_token.acquire(&client).await?)
2219 } else {
2220 llm_token.acquire(&client).await.ok()
2221 };
2222 let mut did_retry = false;
2223
2224 loop {
2225 let request_builder = http_client::Request::builder().method(Method::POST);
2226
2227 let mut request_builder = request_builder
2228 .header("Content-Type", "application/json")
2229 .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
2230
2231 // Only add Authorization header if we have a token
2232 if let Some(ref token_value) = token {
2233 request_builder =
2234 request_builder.header("Authorization", format!("Bearer {}", token_value));
2235 }
2236
2237 let request = build(request_builder)?;
2238
2239 let mut response = http_client.send(request).await?;
2240
2241 if let Some(minimum_required_version) = response
2242 .headers()
2243 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2244 .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2245 {
2246 anyhow::ensure!(
2247 app_version >= minimum_required_version,
2248 ZedUpdateRequiredError {
2249 minimum_version: minimum_required_version
2250 }
2251 );
2252 }
2253
2254 if response.status().is_success() {
2255 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2256
2257 let mut body = Vec::new();
2258 response.body_mut().read_to_end(&mut body).await?;
2259 return Ok((serde_json::from_slice(&body)?, usage));
2260 } else if !did_retry && token.is_some() && response.needs_llm_token_refresh() {
2261 did_retry = true;
2262 token = Some(llm_token.refresh(&client).await?);
2263 } else {
2264 let mut body = String::new();
2265 response.body_mut().read_to_string(&mut body).await?;
2266 anyhow::bail!(
2267 "Request failed with status: {:?}\nBody: {}",
2268 response.status(),
2269 body
2270 );
2271 }
2272 }
2273 }
2274
2275 pub fn refresh_context(
2276 &mut self,
2277 project: &Entity<Project>,
2278 buffer: &Entity<language::Buffer>,
2279 cursor_position: language::Anchor,
2280 cx: &mut Context<Self>,
2281 ) {
2282 self.get_or_init_project(project, cx)
2283 .context
2284 .update(cx, |store, cx| {
2285 store.refresh(buffer.clone(), cursor_position, cx);
2286 });
2287 }
2288
2289 #[cfg(feature = "cli-support")]
2290 pub fn set_context_for_buffer(
2291 &mut self,
2292 project: &Entity<Project>,
2293 related_files: Vec<RelatedFile>,
2294 cx: &mut Context<Self>,
2295 ) {
2296 self.get_or_init_project(project, cx)
2297 .context
2298 .update(cx, |store, cx| {
2299 store.set_related_files(related_files, cx);
2300 });
2301 }
2302
2303 #[cfg(feature = "cli-support")]
2304 pub fn set_recent_paths_for_project(
2305 &mut self,
2306 project: &Entity<Project>,
2307 paths: impl IntoIterator<Item = project::ProjectPath>,
2308 cx: &mut Context<Self>,
2309 ) {
2310 let project_state = self.get_or_init_project(project, cx);
2311 project_state.recent_paths = paths.into_iter().collect();
2312 }
2313
2314 fn is_file_open_source(
2315 &self,
2316 project: &Entity<Project>,
2317 file: &Arc<dyn File>,
2318 cx: &App,
2319 ) -> bool {
2320 if !file.is_local() || file.is_private() {
2321 return false;
2322 }
2323 let Some(project_state) = self.projects.get(&project.entity_id()) else {
2324 return false;
2325 };
2326 project_state
2327 .license_detection_watchers
2328 .get(&file.worktree_id(cx))
2329 .as_ref()
2330 .is_some_and(|watcher| watcher.is_project_open_source())
2331 }
2332
2333 pub(crate) fn is_data_collection_enabled(&self, cx: &App) -> bool {
2334 self.data_collection_choice.is_enabled(cx)
2335 }
2336
2337 fn load_data_collection_choice() -> DataCollectionChoice {
2338 let choice = KEY_VALUE_STORE
2339 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2340 .log_err()
2341 .flatten();
2342
2343 match choice.as_deref() {
2344 Some("true") => DataCollectionChoice::Enabled,
2345 Some("false") => DataCollectionChoice::Disabled,
2346 Some(_) => {
2347 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2348 DataCollectionChoice::NotAnswered
2349 }
2350 None => DataCollectionChoice::NotAnswered,
2351 }
2352 }
2353
2354 fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
2355 self.data_collection_choice = self.data_collection_choice.toggle();
2356 let new_choice = self.data_collection_choice;
2357 let is_enabled = new_choice.is_enabled(cx);
2358 db::write_and_log(cx, move || {
2359 KEY_VALUE_STORE.write_kvp(
2360 ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
2361 is_enabled.to_string(),
2362 )
2363 });
2364 }
2365
2366 pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2367 self.shown_predictions.iter()
2368 }
2369
2370 pub fn shown_completions_len(&self) -> usize {
2371 self.shown_predictions.len()
2372 }
2373
2374 pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2375 self.rated_predictions.contains(id)
2376 }
2377
2378 pub fn rate_prediction(
2379 &mut self,
2380 prediction: &EditPrediction,
2381 rating: EditPredictionRating,
2382 feedback: String,
2383 cx: &mut Context<Self>,
2384 ) {
2385 let organization = self.user_store.read(cx).current_organization();
2386
2387 self.rated_predictions.insert(prediction.id.clone());
2388
2389 cx.background_spawn({
2390 let client = self.client.clone();
2391 let prediction_id = prediction.id.to_string();
2392 let inputs = serde_json::to_value(&prediction.inputs);
2393 let output = prediction
2394 .edit_preview
2395 .as_unified_diff(prediction.snapshot.file(), &prediction.edits);
2396 async move {
2397 client
2398 .cloud_client()
2399 .submit_edit_prediction_feedback(SubmitEditPredictionFeedbackBody {
2400 organization_id: organization.map(|organization| organization.id.clone()),
2401 request_id: prediction_id,
2402 rating: match rating {
2403 EditPredictionRating::Positive => "positive".to_string(),
2404 EditPredictionRating::Negative => "negative".to_string(),
2405 },
2406 inputs: inputs?,
2407 output,
2408 feedback,
2409 })
2410 .await?;
2411
2412 anyhow::Ok(())
2413 }
2414 })
2415 .detach_and_log_err(cx);
2416
2417 cx.notify();
2418 }
2419}
2420
2421fn merge_trailing_events_if_needed(
2422 events: &mut VecDeque<StoredEvent>,
2423 end_snapshot: &TextBufferSnapshot,
2424 latest_snapshot: &TextBufferSnapshot,
2425 latest_edit_range: &Range<Anchor>,
2426) {
2427 if let Some(last_event) = events.back() {
2428 if last_event.old_snapshot.remote_id() != latest_snapshot.remote_id() {
2429 return;
2430 }
2431 }
2432
2433 let mut next_old_event = None;
2434 let mut mergeable_count = 0;
2435 for old_event in events.iter().rev() {
2436 if let Some(next_old_event) = &next_old_event
2437 && !old_event.can_merge(&next_old_event, latest_snapshot, latest_edit_range)
2438 {
2439 break;
2440 }
2441 mergeable_count += 1;
2442 next_old_event = Some(old_event);
2443 }
2444
2445 if mergeable_count <= 1 {
2446 return;
2447 }
2448
2449 let mut events_to_merge = events.range(events.len() - mergeable_count..).peekable();
2450 let oldest_event = events_to_merge.peek().unwrap();
2451 let oldest_snapshot = oldest_event.old_snapshot.clone();
2452
2453 if let Some((diff, edited_range)) =
2454 compute_diff_between_snapshots(&oldest_snapshot, end_snapshot)
2455 {
2456 let merged_event = match oldest_event.event.as_ref() {
2457 zeta_prompt::Event::BufferChange {
2458 old_path,
2459 path,
2460 in_open_source_repo,
2461 ..
2462 } => StoredEvent {
2463 event: Arc::new(zeta_prompt::Event::BufferChange {
2464 old_path: old_path.clone(),
2465 path: path.clone(),
2466 diff,
2467 in_open_source_repo: *in_open_source_repo,
2468 predicted: events_to_merge.all(|e| {
2469 matches!(
2470 e.event.as_ref(),
2471 zeta_prompt::Event::BufferChange {
2472 predicted: true,
2473 ..
2474 }
2475 )
2476 }),
2477 }),
2478 old_snapshot: oldest_snapshot.clone(),
2479 edit_range: end_snapshot.anchor_before(edited_range.start)
2480 ..end_snapshot.anchor_before(edited_range.end),
2481 },
2482 };
2483 events.truncate(events.len() - mergeable_count);
2484 events.push_back(merged_event);
2485 }
2486}
2487
2488pub(crate) fn filter_redundant_excerpts(
2489 mut related_files: Vec<RelatedFile>,
2490 cursor_path: &Path,
2491 cursor_row_range: Range<u32>,
2492) -> Vec<RelatedFile> {
2493 for file in &mut related_files {
2494 if file.path.as_ref() == cursor_path {
2495 file.excerpts.retain(|excerpt| {
2496 excerpt.row_range.start < cursor_row_range.start
2497 || excerpt.row_range.end > cursor_row_range.end
2498 });
2499 }
2500 }
2501 related_files.retain(|file| !file.excerpts.is_empty());
2502 related_files
2503}
2504
2505#[derive(Error, Debug)]
2506#[error(
2507 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2508)]
2509pub struct ZedUpdateRequiredError {
2510 minimum_version: Version,
2511}
2512
2513#[derive(Debug, Clone, Copy)]
2514pub enum DataCollectionChoice {
2515 NotAnswered,
2516 Enabled,
2517 Disabled,
2518}
2519
2520impl DataCollectionChoice {
2521 pub fn is_enabled(self, cx: &App) -> bool {
2522 if cx.is_staff() {
2523 return true;
2524 }
2525 match self {
2526 Self::Enabled => true,
2527 Self::NotAnswered | Self::Disabled => false,
2528 }
2529 }
2530
2531 #[must_use]
2532 pub fn toggle(&self) -> DataCollectionChoice {
2533 match self {
2534 Self::Enabled => Self::Disabled,
2535 Self::Disabled => Self::Enabled,
2536 Self::NotAnswered => Self::Enabled,
2537 }
2538 }
2539}
2540
2541impl From<bool> for DataCollectionChoice {
2542 fn from(value: bool) -> Self {
2543 match value {
2544 true => DataCollectionChoice::Enabled,
2545 false => DataCollectionChoice::Disabled,
2546 }
2547 }
2548}
2549
2550struct ZedPredictUpsell;
2551
2552impl Dismissable for ZedPredictUpsell {
2553 const KEY: &'static str = "dismissed-edit-predict-upsell";
2554
2555 fn dismissed() -> bool {
2556 // To make this backwards compatible with older versions of Zed, we
2557 // check if the user has seen the previous Edit Prediction Onboarding
2558 // before, by checking the data collection choice which was written to
2559 // the database once the user clicked on "Accept and Enable"
2560 if KEY_VALUE_STORE
2561 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2562 .log_err()
2563 .is_some_and(|s| s.is_some())
2564 {
2565 return true;
2566 }
2567
2568 KEY_VALUE_STORE
2569 .read_kvp(Self::KEY)
2570 .log_err()
2571 .is_some_and(|s| s.is_some())
2572 }
2573}
2574
2575pub fn should_show_upsell_modal() -> bool {
2576 !ZedPredictUpsell::dismissed()
2577}
2578
2579pub fn init(cx: &mut App) {
2580 cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2581 workspace.register_action(
2582 move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2583 ZedPredictModal::toggle(
2584 workspace,
2585 workspace.user_store().clone(),
2586 workspace.client().clone(),
2587 window,
2588 cx,
2589 )
2590 },
2591 );
2592
2593 workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2594 update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2595 settings
2596 .project
2597 .all_languages
2598 .edit_predictions
2599 .get_or_insert_default()
2600 .provider = Some(EditPredictionProvider::None)
2601 });
2602 });
2603 fn copilot_for_project(project: &Entity<Project>, cx: &mut App) -> Option<Entity<Copilot>> {
2604 EditPredictionStore::try_global(cx).and_then(|store| {
2605 store.update(cx, |this, cx| this.start_copilot_for_project(project, cx))
2606 })
2607 }
2608
2609 workspace.register_action(|workspace, _: &SignIn, window, cx| {
2610 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2611 copilot_ui::initiate_sign_in(copilot, window, cx);
2612 }
2613 });
2614 workspace.register_action(|workspace, _: &Reinstall, window, cx| {
2615 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2616 copilot_ui::reinstall_and_sign_in(copilot, window, cx);
2617 }
2618 });
2619 workspace.register_action(|workspace, _: &SignOut, window, cx| {
2620 if let Some(copilot) = copilot_for_project(workspace.project(), cx) {
2621 copilot_ui::initiate_sign_out(copilot, window, cx);
2622 }
2623 });
2624 })
2625 .detach();
2626}