1use anyhow::{Context as _, Result, anyhow, bail};
2use arrayvec::ArrayVec;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat, Signature};
5use cloud_llm_client::{
6 AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason,
7 EditPredictionRejection, MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST,
8 MINIMUM_REQUIRED_VERSION_HEADER_NAME, PredictEditsRequestTrigger, RejectEditPredictionsBodyRef,
9 ZED_VERSION_HEADER_NAME,
10};
11use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
12use cloud_zeta2_prompt::{CURSOR_MARKER, DEFAULT_MAX_PROMPT_BYTES};
13use collections::{HashMap, HashSet};
14use command_palette_hooks::CommandPaletteFilter;
15use db::kvp::{Dismissable, KEY_VALUE_STORE};
16use edit_prediction_context::{
17 DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
18 EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionScoreOptions, Line,
19 SyntaxIndex, SyntaxIndexState,
20};
21use feature_flags::{FeatureFlag, FeatureFlagAppExt as _, PredictEditsRateCompletionsFeatureFlag};
22use futures::channel::mpsc::UnboundedReceiver;
23use futures::channel::{mpsc, oneshot};
24use futures::{AsyncReadExt as _, FutureExt as _, StreamExt as _, select_biased};
25use gpui::BackgroundExecutor;
26use gpui::{
27 App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
28 http_client::{self, AsyncBody, Method},
29 prelude::*,
30};
31use language::{
32 Anchor, Buffer, DiagnosticSet, File, LanguageServerId, Point, ToOffset as _, ToPoint,
33};
34use language::{BufferSnapshot, OffsetRangeExt};
35use language_model::{LlmApiToken, RefreshLlmTokenListener};
36use open_ai::FunctionDefinition;
37use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
38use release_channel::AppVersion;
39use semver::Version;
40use serde::de::DeserializeOwned;
41use settings::{EditPredictionProvider, Settings as _, SettingsStore, update_settings_file};
42use std::any::{Any as _, TypeId};
43use std::collections::{VecDeque, hash_map};
44use telemetry_events::EditPredictionRating;
45use workspace::Workspace;
46
47use std::ops::Range;
48use std::path::Path;
49use std::rc::Rc;
50use std::str::FromStr as _;
51use std::sync::{Arc, LazyLock};
52use std::time::{Duration, Instant};
53use std::{env, mem};
54use thiserror::Error;
55use util::rel_path::RelPathBuf;
56use util::{LogErrorFuture, RangeExt as _, ResultExt as _, TryFutureExt};
57use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
58
59pub mod assemble_excerpts;
60mod license_detection;
61mod onboarding_modal;
62mod prediction;
63mod provider;
64mod rate_prediction_modal;
65pub mod retrieval_search;
66pub mod sweep_ai;
67pub mod udiff;
68mod xml_edits;
69pub mod zeta1;
70
71#[cfg(test)]
72mod zeta_tests;
73
74use crate::assemble_excerpts::assemble_excerpts;
75use crate::license_detection::LicenseDetectionWatcher;
76use crate::onboarding_modal::ZedPredictModal;
77pub use crate::prediction::EditPrediction;
78pub use crate::prediction::EditPredictionId;
79pub use crate::prediction::EditPredictionInputs;
80use crate::prediction::EditPredictionResult;
81use crate::rate_prediction_modal::{
82 NextEdit, PreviousEdit, RatePredictionsModal, ThumbsDownActivePrediction,
83 ThumbsUpActivePrediction,
84};
85pub use crate::sweep_ai::SweepAi;
86use crate::zeta1::request_prediction_with_zeta1;
87pub use provider::ZetaEditPredictionProvider;
88
89actions!(
90 edit_prediction,
91 [
92 /// Resets the edit prediction onboarding state.
93 ResetOnboarding,
94 /// Opens the rate completions modal.
95 RateCompletions,
96 /// Clears the edit prediction history.
97 ClearHistory,
98 ]
99);
100
101/// Maximum number of events to track.
102const EVENT_COUNT_MAX: usize = 6;
103const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
104const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
105const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
106
107pub struct SweepFeatureFlag;
108
109impl FeatureFlag for SweepFeatureFlag {
110 const NAME: &str = "sweep-ai";
111}
112pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
113 max_bytes: 512,
114 min_bytes: 128,
115 target_before_cursor_over_total_bytes: 0.5,
116};
117
118pub const DEFAULT_CONTEXT_OPTIONS: ContextMode =
119 ContextMode::Agentic(DEFAULT_AGENTIC_CONTEXT_OPTIONS);
120
121pub const DEFAULT_AGENTIC_CONTEXT_OPTIONS: AgenticContextOptions = AgenticContextOptions {
122 excerpt: DEFAULT_EXCERPT_OPTIONS,
123};
124
125pub const DEFAULT_SYNTAX_CONTEXT_OPTIONS: EditPredictionContextOptions =
126 EditPredictionContextOptions {
127 use_imports: true,
128 max_retrieved_declarations: 0,
129 excerpt: DEFAULT_EXCERPT_OPTIONS,
130 score: EditPredictionScoreOptions {
131 omit_excerpt_overlaps: true,
132 },
133 };
134
135pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
136 context: DEFAULT_CONTEXT_OPTIONS,
137 max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
138 max_diagnostic_bytes: 2048,
139 prompt_format: PromptFormat::DEFAULT,
140 file_indexing_parallelism: 1,
141 buffer_change_grouping_interval: Duration::from_secs(1),
142};
143
144static USE_OLLAMA: LazyLock<bool> =
145 LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
146static CONTEXT_RETRIEVAL_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
147 env::var("ZED_ZETA2_CONTEXT_MODEL").unwrap_or(if *USE_OLLAMA {
148 "qwen3-coder:30b".to_string()
149 } else {
150 "yqvev8r3".to_string()
151 })
152});
153static EDIT_PREDICTIONS_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
154 match env::var("ZED_ZETA2_MODEL").as_deref() {
155 Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten
156 Ok(model) => model,
157 Err(_) if *USE_OLLAMA => "qwen3-coder:30b",
158 Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten
159 }
160 .to_string()
161});
162static PREDICT_EDITS_URL: LazyLock<Option<String>> = LazyLock::new(|| {
163 env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| {
164 if *USE_OLLAMA {
165 Some("http://localhost:11434/v1/chat/completions".into())
166 } else {
167 None
168 }
169 })
170});
171
172pub struct Zeta2FeatureFlag;
173
174impl FeatureFlag for Zeta2FeatureFlag {
175 const NAME: &'static str = "zeta2";
176
177 fn enabled_for_staff() -> bool {
178 true
179 }
180}
181
182#[derive(Clone)]
183struct ZetaGlobal(Entity<Zeta>);
184
185impl Global for ZetaGlobal {}
186
187pub struct Zeta {
188 client: Arc<Client>,
189 user_store: Entity<UserStore>,
190 llm_token: LlmApiToken,
191 _llm_token_subscription: Subscription,
192 projects: HashMap<EntityId, ZetaProject>,
193 options: ZetaOptions,
194 update_required: bool,
195 debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
196 #[cfg(feature = "eval-support")]
197 eval_cache: Option<Arc<dyn EvalCache>>,
198 edit_prediction_model: ZetaEditPredictionModel,
199 pub sweep_ai: SweepAi,
200 data_collection_choice: DataCollectionChoice,
201 reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
202 shown_predictions: VecDeque<EditPrediction>,
203 rated_predictions: HashSet<EditPredictionId>,
204}
205
206#[derive(Copy, Clone, Default, PartialEq, Eq)]
207pub enum ZetaEditPredictionModel {
208 #[default]
209 Zeta1,
210 Zeta2,
211 Sweep,
212}
213
214#[derive(Debug, Clone, PartialEq)]
215pub struct ZetaOptions {
216 pub context: ContextMode,
217 pub max_prompt_bytes: usize,
218 pub max_diagnostic_bytes: usize,
219 pub prompt_format: predict_edits_v3::PromptFormat,
220 pub file_indexing_parallelism: usize,
221 pub buffer_change_grouping_interval: Duration,
222}
223
224#[derive(Debug, Clone, PartialEq)]
225pub enum ContextMode {
226 Agentic(AgenticContextOptions),
227 Syntax(EditPredictionContextOptions),
228}
229
230#[derive(Debug, Clone, PartialEq)]
231pub struct AgenticContextOptions {
232 pub excerpt: EditPredictionExcerptOptions,
233}
234
235impl ContextMode {
236 pub fn excerpt(&self) -> &EditPredictionExcerptOptions {
237 match self {
238 ContextMode::Agentic(options) => &options.excerpt,
239 ContextMode::Syntax(options) => &options.excerpt,
240 }
241 }
242}
243
244#[derive(Debug)]
245pub enum ZetaDebugInfo {
246 ContextRetrievalStarted(ZetaContextRetrievalStartedDebugInfo),
247 SearchQueriesGenerated(ZetaSearchQueryDebugInfo),
248 SearchQueriesExecuted(ZetaContextRetrievalDebugInfo),
249 ContextRetrievalFinished(ZetaContextRetrievalDebugInfo),
250 EditPredictionRequested(ZetaEditPredictionDebugInfo),
251}
252
253#[derive(Debug)]
254pub struct ZetaContextRetrievalStartedDebugInfo {
255 pub project: Entity<Project>,
256 pub timestamp: Instant,
257 pub search_prompt: String,
258}
259
260#[derive(Debug)]
261pub struct ZetaContextRetrievalDebugInfo {
262 pub project: Entity<Project>,
263 pub timestamp: Instant,
264}
265
266#[derive(Debug)]
267pub struct ZetaEditPredictionDebugInfo {
268 pub inputs: EditPredictionInputs,
269 pub retrieval_time: Duration,
270 pub buffer: WeakEntity<Buffer>,
271 pub position: language::Anchor,
272 pub local_prompt: Result<String, String>,
273 pub response_rx: oneshot::Receiver<(Result<open_ai::Response, String>, Duration)>,
274}
275
276#[derive(Debug)]
277pub struct ZetaSearchQueryDebugInfo {
278 pub project: Entity<Project>,
279 pub timestamp: Instant,
280 pub search_queries: Vec<SearchToolQuery>,
281}
282
283pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
284
285struct ZetaProject {
286 syntax_index: Option<Entity<SyntaxIndex>>,
287 events: VecDeque<Arc<cloud_llm_client::predict_edits_v3::Event>>,
288 last_event: Option<LastEvent>,
289 recent_paths: VecDeque<ProjectPath>,
290 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
291 current_prediction: Option<CurrentEditPrediction>,
292 next_pending_prediction_id: usize,
293 pending_predictions: ArrayVec<PendingPrediction, 2>,
294 last_prediction_refresh: Option<(EntityId, Instant)>,
295 cancelled_predictions: HashSet<usize>,
296 context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
297 refresh_context_task: Option<LogErrorFuture<Task<Result<()>>>>,
298 refresh_context_debounce_task: Option<Task<Option<()>>>,
299 refresh_context_timestamp: Option<Instant>,
300 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
301 _subscription: gpui::Subscription,
302}
303
304impl ZetaProject {
305 pub fn events(&self, cx: &App) -> Vec<Arc<cloud_llm_client::predict_edits_v3::Event>> {
306 self.events
307 .iter()
308 .cloned()
309 .chain(
310 self.last_event
311 .as_ref()
312 .and_then(|event| event.finalize(&self.license_detection_watchers, cx)),
313 )
314 .collect()
315 }
316
317 fn cancel_pending_prediction(
318 &mut self,
319 pending_prediction: PendingPrediction,
320 cx: &mut Context<Zeta>,
321 ) {
322 self.cancelled_predictions.insert(pending_prediction.id);
323
324 cx.spawn(async move |this, cx| {
325 let Some(prediction_id) = pending_prediction.task.await else {
326 return;
327 };
328
329 this.update(cx, |this, _cx| {
330 this.reject_prediction(prediction_id, EditPredictionRejectReason::Canceled, false);
331 })
332 .ok();
333 })
334 .detach()
335 }
336}
337
338#[derive(Debug, Clone)]
339struct CurrentEditPrediction {
340 pub requested_by: PredictionRequestedBy,
341 pub prediction: EditPrediction,
342 pub was_shown: bool,
343}
344
345impl CurrentEditPrediction {
346 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
347 let Some(new_edits) = self
348 .prediction
349 .interpolate(&self.prediction.buffer.read(cx))
350 else {
351 return false;
352 };
353
354 if self.prediction.buffer != old_prediction.prediction.buffer {
355 return true;
356 }
357
358 let Some(old_edits) = old_prediction
359 .prediction
360 .interpolate(&old_prediction.prediction.buffer.read(cx))
361 else {
362 return true;
363 };
364
365 let requested_by_buffer_id = self.requested_by.buffer_id();
366
367 // This reduces the occurrence of UI thrash from replacing edits
368 //
369 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
370 if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
371 && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
372 && old_edits.len() == 1
373 && new_edits.len() == 1
374 {
375 let (old_range, old_text) = &old_edits[0];
376 let (new_range, new_text) = &new_edits[0];
377 new_range == old_range && new_text.starts_with(old_text.as_ref())
378 } else {
379 true
380 }
381 }
382}
383
384#[derive(Debug, Clone)]
385enum PredictionRequestedBy {
386 DiagnosticsUpdate,
387 Buffer(EntityId),
388}
389
390impl PredictionRequestedBy {
391 pub fn buffer_id(&self) -> Option<EntityId> {
392 match self {
393 PredictionRequestedBy::DiagnosticsUpdate => None,
394 PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
395 }
396 }
397}
398
399#[derive(Debug)]
400struct PendingPrediction {
401 id: usize,
402 task: Task<Option<EditPredictionId>>,
403}
404
405/// A prediction from the perspective of a buffer.
406#[derive(Debug)]
407enum BufferEditPrediction<'a> {
408 Local { prediction: &'a EditPrediction },
409 Jump { prediction: &'a EditPrediction },
410}
411
412#[cfg(test)]
413impl std::ops::Deref for BufferEditPrediction<'_> {
414 type Target = EditPrediction;
415
416 fn deref(&self) -> &Self::Target {
417 match self {
418 BufferEditPrediction::Local { prediction } => prediction,
419 BufferEditPrediction::Jump { prediction } => prediction,
420 }
421 }
422}
423
424struct RegisteredBuffer {
425 snapshot: BufferSnapshot,
426 _subscriptions: [gpui::Subscription; 2],
427}
428
429struct LastEvent {
430 old_snapshot: BufferSnapshot,
431 new_snapshot: BufferSnapshot,
432 end_edit_anchor: Option<Anchor>,
433}
434
435impl LastEvent {
436 pub fn finalize(
437 &self,
438 license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
439 cx: &App,
440 ) -> Option<Arc<predict_edits_v3::Event>> {
441 let path = buffer_path_with_id_fallback(&self.new_snapshot, cx);
442 let old_path = buffer_path_with_id_fallback(&self.old_snapshot, cx);
443
444 let file = self.new_snapshot.file();
445 let old_file = self.old_snapshot.file();
446
447 let in_open_source_repo = [file, old_file].iter().all(|file| {
448 file.is_some_and(|file| {
449 license_detection_watchers
450 .get(&file.worktree_id(cx))
451 .is_some_and(|watcher| watcher.is_project_open_source())
452 })
453 });
454
455 let diff = language::unified_diff(&self.old_snapshot.text(), &self.new_snapshot.text());
456
457 if path == old_path && diff.is_empty() {
458 None
459 } else {
460 Some(Arc::new(predict_edits_v3::Event::BufferChange {
461 old_path,
462 path,
463 diff,
464 in_open_source_repo,
465 // TODO: Actually detect if this edit was predicted or not
466 predicted: false,
467 }))
468 }
469 }
470}
471
472fn buffer_path_with_id_fallback(snapshot: &BufferSnapshot, cx: &App) -> Arc<Path> {
473 if let Some(file) = snapshot.file() {
474 file.full_path(cx).into()
475 } else {
476 Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
477 }
478}
479
480impl Zeta {
481 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
482 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
483 }
484
485 pub fn global(
486 client: &Arc<Client>,
487 user_store: &Entity<UserStore>,
488 cx: &mut App,
489 ) -> Entity<Self> {
490 cx.try_global::<ZetaGlobal>()
491 .map(|global| global.0.clone())
492 .unwrap_or_else(|| {
493 let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
494 cx.set_global(ZetaGlobal(zeta.clone()));
495 zeta
496 })
497 }
498
499 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
500 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
501 let data_collection_choice = Self::load_data_collection_choice();
502
503 let llm_token = LlmApiToken::default();
504
505 let (reject_tx, reject_rx) = mpsc::unbounded();
506 cx.background_spawn({
507 let client = client.clone();
508 let llm_token = llm_token.clone();
509 let app_version = AppVersion::global(cx);
510 let background_executor = cx.background_executor().clone();
511 async move {
512 Self::handle_rejected_predictions(
513 reject_rx,
514 client,
515 llm_token,
516 app_version,
517 background_executor,
518 )
519 .await
520 }
521 })
522 .detach();
523
524 Self {
525 projects: HashMap::default(),
526 client,
527 user_store,
528 options: DEFAULT_OPTIONS,
529 llm_token,
530 _llm_token_subscription: cx.subscribe(
531 &refresh_llm_token_listener,
532 |this, _listener, _event, cx| {
533 let client = this.client.clone();
534 let llm_token = this.llm_token.clone();
535 cx.spawn(async move |_this, _cx| {
536 llm_token.refresh(&client).await?;
537 anyhow::Ok(())
538 })
539 .detach_and_log_err(cx);
540 },
541 ),
542 update_required: false,
543 debug_tx: None,
544 #[cfg(feature = "eval-support")]
545 eval_cache: None,
546 edit_prediction_model: ZetaEditPredictionModel::Zeta2,
547 sweep_ai: SweepAi::new(cx),
548 data_collection_choice,
549 reject_predictions_tx: reject_tx,
550 rated_predictions: Default::default(),
551 shown_predictions: Default::default(),
552 }
553 }
554
555 pub fn set_edit_prediction_model(&mut self, model: ZetaEditPredictionModel) {
556 self.edit_prediction_model = model;
557 }
558
559 pub fn has_sweep_api_token(&self) -> bool {
560 self.sweep_ai
561 .api_token
562 .clone()
563 .now_or_never()
564 .flatten()
565 .is_some()
566 }
567
568 #[cfg(feature = "eval-support")]
569 pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
570 self.eval_cache = Some(cache);
571 }
572
573 pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<ZetaDebugInfo> {
574 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
575 self.debug_tx = Some(debug_watch_tx);
576 debug_watch_rx
577 }
578
579 pub fn options(&self) -> &ZetaOptions {
580 &self.options
581 }
582
583 pub fn set_options(&mut self, options: ZetaOptions) {
584 self.options = options;
585 }
586
587 pub fn clear_history(&mut self) {
588 for zeta_project in self.projects.values_mut() {
589 zeta_project.events.clear();
590 }
591 }
592
593 pub fn context_for_project(
594 &self,
595 project: &Entity<Project>,
596 ) -> impl Iterator<Item = (Entity<Buffer>, &[Range<Anchor>])> {
597 self.projects
598 .get(&project.entity_id())
599 .and_then(|project| {
600 Some(
601 project
602 .context
603 .as_ref()?
604 .iter()
605 .map(|(buffer, ranges)| (buffer.clone(), ranges.as_slice())),
606 )
607 })
608 .into_iter()
609 .flatten()
610 }
611
612 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
613 if self.edit_prediction_model == ZetaEditPredictionModel::Zeta2 {
614 self.user_store.read(cx).edit_prediction_usage()
615 } else {
616 None
617 }
618 }
619
620 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
621 self.get_or_init_zeta_project(project, cx);
622 }
623
624 pub fn register_buffer(
625 &mut self,
626 buffer: &Entity<Buffer>,
627 project: &Entity<Project>,
628 cx: &mut Context<Self>,
629 ) {
630 let zeta_project = self.get_or_init_zeta_project(project, cx);
631 Self::register_buffer_impl(zeta_project, buffer, project, cx);
632 }
633
634 fn get_or_init_zeta_project(
635 &mut self,
636 project: &Entity<Project>,
637 cx: &mut Context<Self>,
638 ) -> &mut ZetaProject {
639 self.projects
640 .entry(project.entity_id())
641 .or_insert_with(|| ZetaProject {
642 syntax_index: if let ContextMode::Syntax(_) = &self.options.context {
643 Some(cx.new(|cx| {
644 SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
645 }))
646 } else {
647 None
648 },
649 events: VecDeque::new(),
650 last_event: None,
651 recent_paths: VecDeque::new(),
652 registered_buffers: HashMap::default(),
653 current_prediction: None,
654 cancelled_predictions: HashSet::default(),
655 pending_predictions: ArrayVec::new(),
656 next_pending_prediction_id: 0,
657 last_prediction_refresh: None,
658 context: None,
659 refresh_context_task: None,
660 refresh_context_debounce_task: None,
661 refresh_context_timestamp: None,
662 license_detection_watchers: HashMap::default(),
663 _subscription: cx.subscribe(&project, Self::handle_project_event),
664 })
665 }
666
667 fn handle_project_event(
668 &mut self,
669 project: Entity<Project>,
670 event: &project::Event,
671 cx: &mut Context<Self>,
672 ) {
673 // TODO [zeta2] init with recent paths
674 match event {
675 project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
676 let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
677 return;
678 };
679 let path = project.read(cx).path_for_entry(*active_entry_id, cx);
680 if let Some(path) = path {
681 if let Some(ix) = zeta_project
682 .recent_paths
683 .iter()
684 .position(|probe| probe == &path)
685 {
686 zeta_project.recent_paths.remove(ix);
687 }
688 zeta_project.recent_paths.push_front(path);
689 }
690 }
691 project::Event::DiagnosticsUpdated { .. } => {
692 if cx.has_flag::<Zeta2FeatureFlag>() {
693 self.refresh_prediction_from_diagnostics(project, cx);
694 }
695 }
696 _ => (),
697 }
698 }
699
700 fn register_buffer_impl<'a>(
701 zeta_project: &'a mut ZetaProject,
702 buffer: &Entity<Buffer>,
703 project: &Entity<Project>,
704 cx: &mut Context<Self>,
705 ) -> &'a mut RegisteredBuffer {
706 let buffer_id = buffer.entity_id();
707
708 if let Some(file) = buffer.read(cx).file() {
709 let worktree_id = file.worktree_id(cx);
710 if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
711 zeta_project
712 .license_detection_watchers
713 .entry(worktree_id)
714 .or_insert_with(|| {
715 let project_entity_id = project.entity_id();
716 cx.observe_release(&worktree, move |this, _worktree, _cx| {
717 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
718 else {
719 return;
720 };
721 zeta_project.license_detection_watchers.remove(&worktree_id);
722 })
723 .detach();
724 Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
725 });
726 }
727 }
728
729 match zeta_project.registered_buffers.entry(buffer_id) {
730 hash_map::Entry::Occupied(entry) => entry.into_mut(),
731 hash_map::Entry::Vacant(entry) => {
732 let snapshot = buffer.read(cx).snapshot();
733 let project_entity_id = project.entity_id();
734 entry.insert(RegisteredBuffer {
735 snapshot,
736 _subscriptions: [
737 cx.subscribe(buffer, {
738 let project = project.downgrade();
739 move |this, buffer, event, cx| {
740 if let language::BufferEvent::Edited = event
741 && let Some(project) = project.upgrade()
742 {
743 this.report_changes_for_buffer(&buffer, &project, cx);
744 }
745 }
746 }),
747 cx.observe_release(buffer, move |this, _buffer, _cx| {
748 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
749 else {
750 return;
751 };
752 zeta_project.registered_buffers.remove(&buffer_id);
753 }),
754 ],
755 })
756 }
757 }
758 }
759
760 fn report_changes_for_buffer(
761 &mut self,
762 buffer: &Entity<Buffer>,
763 project: &Entity<Project>,
764 cx: &mut Context<Self>,
765 ) {
766 let project_state = self.get_or_init_zeta_project(project, cx);
767 let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
768
769 let new_snapshot = buffer.read(cx).snapshot();
770 if new_snapshot.version == registered_buffer.snapshot.version {
771 return;
772 }
773
774 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
775 let end_edit_anchor = new_snapshot
776 .anchored_edits_since::<Point>(&old_snapshot.version)
777 .last()
778 .map(|(_, range)| range.end);
779 let events = &mut project_state.events;
780
781 if let Some(LastEvent {
782 new_snapshot: last_new_snapshot,
783 end_edit_anchor: last_end_edit_anchor,
784 ..
785 }) = project_state.last_event.as_mut()
786 {
787 let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
788 == last_new_snapshot.remote_id()
789 && old_snapshot.version == last_new_snapshot.version;
790
791 let should_coalesce = is_next_snapshot_of_same_buffer
792 && end_edit_anchor
793 .as_ref()
794 .zip(last_end_edit_anchor.as_ref())
795 .is_some_and(|(a, b)| {
796 let a = a.to_point(&new_snapshot);
797 let b = b.to_point(&new_snapshot);
798 a.row.abs_diff(b.row) <= CHANGE_GROUPING_LINE_SPAN
799 });
800
801 if should_coalesce {
802 *last_end_edit_anchor = end_edit_anchor;
803 *last_new_snapshot = new_snapshot;
804 return;
805 }
806 }
807
808 if events.len() + 1 >= EVENT_COUNT_MAX {
809 events.pop_front();
810 }
811
812 if let Some(event) = project_state.last_event.take() {
813 events.extend(event.finalize(&project_state.license_detection_watchers, cx));
814 }
815
816 project_state.last_event = Some(LastEvent {
817 old_snapshot,
818 new_snapshot,
819 end_edit_anchor,
820 });
821 }
822
823 fn current_prediction_for_buffer(
824 &self,
825 buffer: &Entity<Buffer>,
826 project: &Entity<Project>,
827 cx: &App,
828 ) -> Option<BufferEditPrediction<'_>> {
829 let project_state = self.projects.get(&project.entity_id())?;
830
831 let CurrentEditPrediction {
832 requested_by,
833 prediction,
834 ..
835 } = project_state.current_prediction.as_ref()?;
836
837 if prediction.targets_buffer(buffer.read(cx)) {
838 Some(BufferEditPrediction::Local { prediction })
839 } else {
840 let show_jump = match requested_by {
841 PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
842 requested_by_buffer_id == &buffer.entity_id()
843 }
844 PredictionRequestedBy::DiagnosticsUpdate => true,
845 };
846
847 if show_jump {
848 Some(BufferEditPrediction::Jump { prediction })
849 } else {
850 None
851 }
852 }
853 }
854
855 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
856 match self.edit_prediction_model {
857 ZetaEditPredictionModel::Zeta1 | ZetaEditPredictionModel::Zeta2 => {}
858 ZetaEditPredictionModel::Sweep => return,
859 }
860
861 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
862 return;
863 };
864
865 let Some(prediction) = project_state.current_prediction.take() else {
866 return;
867 };
868 let request_id = prediction.prediction.id.to_string();
869 for pending_prediction in mem::take(&mut project_state.pending_predictions) {
870 project_state.cancel_pending_prediction(pending_prediction, cx);
871 }
872
873 let client = self.client.clone();
874 let llm_token = self.llm_token.clone();
875 let app_version = AppVersion::global(cx);
876 cx.spawn(async move |this, cx| {
877 let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") {
878 http_client::Url::parse(&predict_edits_url)?
879 } else {
880 client
881 .http_client()
882 .build_zed_llm_url("/predict_edits/accept", &[])?
883 };
884
885 let response = cx
886 .background_spawn(Self::send_api_request::<()>(
887 move |builder| {
888 let req = builder.uri(url.as_ref()).body(
889 serde_json::to_string(&AcceptEditPredictionBody {
890 request_id: request_id.clone(),
891 })?
892 .into(),
893 );
894 Ok(req?)
895 },
896 client,
897 llm_token,
898 app_version,
899 ))
900 .await;
901
902 Self::handle_api_response(&this, response, cx)?;
903 anyhow::Ok(())
904 })
905 .detach_and_log_err(cx);
906 }
907
908 async fn handle_rejected_predictions(
909 rx: UnboundedReceiver<EditPredictionRejection>,
910 client: Arc<Client>,
911 llm_token: LlmApiToken,
912 app_version: Version,
913 background_executor: BackgroundExecutor,
914 ) {
915 let mut rx = std::pin::pin!(rx.peekable());
916 let mut batched = Vec::new();
917
918 while let Some(rejection) = rx.next().await {
919 batched.push(rejection);
920
921 if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
922 select_biased! {
923 next = rx.as_mut().peek().fuse() => {
924 if next.is_some() {
925 continue;
926 }
927 }
928 () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
929 }
930 }
931
932 let url = client
933 .http_client()
934 .build_zed_llm_url("/predict_edits/reject", &[])
935 .unwrap();
936
937 let flush_count = batched
938 .len()
939 // in case items have accumulated after failure
940 .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
941 let start = batched.len() - flush_count;
942
943 let body = RejectEditPredictionsBodyRef {
944 rejections: &batched[start..],
945 };
946
947 let result = Self::send_api_request::<()>(
948 |builder| {
949 let req = builder
950 .uri(url.as_ref())
951 .body(serde_json::to_string(&body)?.into());
952 anyhow::Ok(req?)
953 },
954 client.clone(),
955 llm_token.clone(),
956 app_version.clone(),
957 )
958 .await;
959
960 if result.log_err().is_some() {
961 batched.drain(start..);
962 }
963 }
964 }
965
966 fn reject_current_prediction(
967 &mut self,
968 reason: EditPredictionRejectReason,
969 project: &Entity<Project>,
970 ) {
971 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
972 project_state.pending_predictions.clear();
973 if let Some(prediction) = project_state.current_prediction.take() {
974 self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown);
975 }
976 };
977 }
978
979 fn did_show_current_prediction(&mut self, project: &Entity<Project>, _cx: &mut Context<Self>) {
980 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
981 if let Some(current_prediction) = project_state.current_prediction.as_mut() {
982 if !current_prediction.was_shown {
983 current_prediction.was_shown = true;
984 self.shown_predictions
985 .push_front(current_prediction.prediction.clone());
986 if self.shown_predictions.len() > 50 {
987 let completion = self.shown_predictions.pop_back().unwrap();
988 self.rated_predictions.remove(&completion.id);
989 }
990 }
991 }
992 }
993 }
994
995 fn reject_prediction(
996 &mut self,
997 prediction_id: EditPredictionId,
998 reason: EditPredictionRejectReason,
999 was_shown: bool,
1000 ) {
1001 self.reject_predictions_tx
1002 .unbounded_send(EditPredictionRejection {
1003 request_id: prediction_id.to_string(),
1004 reason,
1005 was_shown,
1006 })
1007 .log_err();
1008 }
1009
1010 fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1011 self.projects
1012 .get(&project.entity_id())
1013 .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1014 }
1015
1016 pub fn refresh_prediction_from_buffer(
1017 &mut self,
1018 project: Entity<Project>,
1019 buffer: Entity<Buffer>,
1020 position: language::Anchor,
1021 cx: &mut Context<Self>,
1022 ) {
1023 self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
1024 let Some(request_task) = this
1025 .update(cx, |this, cx| {
1026 this.request_prediction(
1027 &project,
1028 &buffer,
1029 position,
1030 PredictEditsRequestTrigger::Other,
1031 cx,
1032 )
1033 })
1034 .log_err()
1035 else {
1036 return Task::ready(anyhow::Ok(None));
1037 };
1038
1039 cx.spawn(async move |_cx| {
1040 request_task.await.map(|prediction_result| {
1041 prediction_result.map(|prediction_result| {
1042 (
1043 prediction_result,
1044 PredictionRequestedBy::Buffer(buffer.entity_id()),
1045 )
1046 })
1047 })
1048 })
1049 })
1050 }
1051
1052 pub fn refresh_prediction_from_diagnostics(
1053 &mut self,
1054 project: Entity<Project>,
1055 cx: &mut Context<Self>,
1056 ) {
1057 let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1058 return;
1059 };
1060
1061 // Prefer predictions from buffer
1062 if zeta_project.current_prediction.is_some() {
1063 return;
1064 };
1065
1066 self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
1067 let Some(open_buffer_task) = project
1068 .update(cx, |project, cx| {
1069 project
1070 .active_entry()
1071 .and_then(|entry| project.path_for_entry(entry, cx))
1072 .map(|path| project.open_buffer(path, cx))
1073 })
1074 .log_err()
1075 .flatten()
1076 else {
1077 return Task::ready(anyhow::Ok(None));
1078 };
1079
1080 cx.spawn(async move |cx| {
1081 let active_buffer = open_buffer_task.await?;
1082 let snapshot = active_buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
1083
1084 let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1085 active_buffer,
1086 &snapshot,
1087 Default::default(),
1088 Default::default(),
1089 &project,
1090 cx,
1091 )
1092 .await?
1093 else {
1094 return anyhow::Ok(None);
1095 };
1096
1097 let Some(prediction_result) = this
1098 .update(cx, |this, cx| {
1099 this.request_prediction(
1100 &project,
1101 &jump_buffer,
1102 jump_position,
1103 PredictEditsRequestTrigger::Diagnostics,
1104 cx,
1105 )
1106 })?
1107 .await?
1108 else {
1109 return anyhow::Ok(None);
1110 };
1111
1112 this.update(cx, |this, cx| {
1113 Some((
1114 if this
1115 .get_or_init_zeta_project(&project, cx)
1116 .current_prediction
1117 .is_none()
1118 {
1119 prediction_result
1120 } else {
1121 EditPredictionResult {
1122 id: prediction_result.id,
1123 prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1124 }
1125 },
1126 PredictionRequestedBy::DiagnosticsUpdate,
1127 ))
1128 })
1129 })
1130 });
1131 }
1132
1133 #[cfg(not(test))]
1134 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1135 #[cfg(test)]
1136 pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
1137
1138 fn queue_prediction_refresh(
1139 &mut self,
1140 project: Entity<Project>,
1141 throttle_entity: EntityId,
1142 cx: &mut Context<Self>,
1143 do_refresh: impl FnOnce(
1144 WeakEntity<Self>,
1145 &mut AsyncApp,
1146 )
1147 -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1148 + 'static,
1149 ) {
1150 let zeta_project = self.get_or_init_zeta_project(&project, cx);
1151 let pending_prediction_id = zeta_project.next_pending_prediction_id;
1152 zeta_project.next_pending_prediction_id += 1;
1153 let last_request = zeta_project.last_prediction_refresh;
1154
1155 let task = cx.spawn(async move |this, cx| {
1156 if let Some((last_entity, last_timestamp)) = last_request
1157 && throttle_entity == last_entity
1158 && let Some(timeout) =
1159 (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
1160 {
1161 cx.background_executor().timer(timeout).await;
1162 }
1163
1164 // If this task was cancelled before the throttle timeout expired,
1165 // do not perform a request.
1166 let mut is_cancelled = true;
1167 this.update(cx, |this, cx| {
1168 let project_state = this.get_or_init_zeta_project(&project, cx);
1169 if !project_state
1170 .cancelled_predictions
1171 .remove(&pending_prediction_id)
1172 {
1173 project_state.last_prediction_refresh = Some((throttle_entity, Instant::now()));
1174 is_cancelled = false;
1175 }
1176 })
1177 .ok();
1178 if is_cancelled {
1179 return None;
1180 }
1181
1182 let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
1183 let new_prediction_id = new_prediction_result
1184 .as_ref()
1185 .map(|(prediction, _)| prediction.id.clone());
1186
1187 // When a prediction completes, remove it from the pending list, and cancel
1188 // any pending predictions that were enqueued before it.
1189 this.update(cx, |this, cx| {
1190 let zeta_project = this.get_or_init_zeta_project(&project, cx);
1191
1192 let is_cancelled = zeta_project
1193 .cancelled_predictions
1194 .remove(&pending_prediction_id);
1195
1196 let new_current_prediction = if !is_cancelled
1197 && let Some((prediction_result, requested_by)) = new_prediction_result
1198 {
1199 match prediction_result.prediction {
1200 Ok(prediction) => {
1201 let new_prediction = CurrentEditPrediction {
1202 requested_by,
1203 prediction,
1204 was_shown: false,
1205 };
1206
1207 if let Some(current_prediction) =
1208 zeta_project.current_prediction.as_ref()
1209 {
1210 if new_prediction.should_replace_prediction(¤t_prediction, cx)
1211 {
1212 this.reject_current_prediction(
1213 EditPredictionRejectReason::Replaced,
1214 &project,
1215 );
1216
1217 Some(new_prediction)
1218 } else {
1219 this.reject_prediction(
1220 new_prediction.prediction.id,
1221 EditPredictionRejectReason::CurrentPreferred,
1222 false,
1223 );
1224 None
1225 }
1226 } else {
1227 Some(new_prediction)
1228 }
1229 }
1230 Err(reject_reason) => {
1231 this.reject_prediction(prediction_result.id, reject_reason, false);
1232 None
1233 }
1234 }
1235 } else {
1236 None
1237 };
1238
1239 let zeta_project = this.get_or_init_zeta_project(&project, cx);
1240
1241 if let Some(new_prediction) = new_current_prediction {
1242 zeta_project.current_prediction = Some(new_prediction);
1243 }
1244
1245 let mut pending_predictions = mem::take(&mut zeta_project.pending_predictions);
1246 for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1247 if pending_prediction.id == pending_prediction_id {
1248 pending_predictions.remove(ix);
1249 for pending_prediction in pending_predictions.drain(0..ix) {
1250 zeta_project.cancel_pending_prediction(pending_prediction, cx)
1251 }
1252 break;
1253 }
1254 }
1255 this.get_or_init_zeta_project(&project, cx)
1256 .pending_predictions = pending_predictions;
1257 cx.notify();
1258 })
1259 .ok();
1260
1261 new_prediction_id
1262 });
1263
1264 if zeta_project.pending_predictions.len() <= 1 {
1265 zeta_project.pending_predictions.push(PendingPrediction {
1266 id: pending_prediction_id,
1267 task,
1268 });
1269 } else if zeta_project.pending_predictions.len() == 2 {
1270 let pending_prediction = zeta_project.pending_predictions.pop().unwrap();
1271 zeta_project.pending_predictions.push(PendingPrediction {
1272 id: pending_prediction_id,
1273 task,
1274 });
1275 zeta_project.cancel_pending_prediction(pending_prediction, cx);
1276 }
1277 }
1278
1279 pub fn request_prediction(
1280 &mut self,
1281 project: &Entity<Project>,
1282 active_buffer: &Entity<Buffer>,
1283 position: language::Anchor,
1284 trigger: PredictEditsRequestTrigger,
1285 cx: &mut Context<Self>,
1286 ) -> Task<Result<Option<EditPredictionResult>>> {
1287 self.request_prediction_internal(
1288 project.clone(),
1289 active_buffer.clone(),
1290 position,
1291 trigger,
1292 cx.has_flag::<Zeta2FeatureFlag>(),
1293 cx,
1294 )
1295 }
1296
1297 fn request_prediction_internal(
1298 &mut self,
1299 project: Entity<Project>,
1300 active_buffer: Entity<Buffer>,
1301 position: language::Anchor,
1302 trigger: PredictEditsRequestTrigger,
1303 allow_jump: bool,
1304 cx: &mut Context<Self>,
1305 ) -> Task<Result<Option<EditPredictionResult>>> {
1306 const DIAGNOSTIC_LINES_RANGE: u32 = 20;
1307
1308 self.get_or_init_zeta_project(&project, cx);
1309 let zeta_project = self.projects.get(&project.entity_id()).unwrap();
1310 let events = zeta_project.events(cx);
1311 let has_events = !events.is_empty();
1312
1313 let snapshot = active_buffer.read(cx).snapshot();
1314 let cursor_point = position.to_point(&snapshot);
1315 let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1316 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1317 let diagnostic_search_range =
1318 Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1319
1320 let task = match self.edit_prediction_model {
1321 ZetaEditPredictionModel::Zeta1 => request_prediction_with_zeta1(
1322 self,
1323 &project,
1324 &active_buffer,
1325 snapshot.clone(),
1326 position,
1327 events,
1328 trigger,
1329 cx,
1330 ),
1331 ZetaEditPredictionModel::Zeta2 => self.request_prediction_with_zeta2(
1332 &project,
1333 &active_buffer,
1334 snapshot.clone(),
1335 position,
1336 events,
1337 trigger,
1338 cx,
1339 ),
1340 ZetaEditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(
1341 &project,
1342 &active_buffer,
1343 snapshot.clone(),
1344 position,
1345 events,
1346 &zeta_project.recent_paths,
1347 diagnostic_search_range.clone(),
1348 cx,
1349 ),
1350 };
1351
1352 cx.spawn(async move |this, cx| {
1353 let prediction = task.await?;
1354
1355 if prediction.is_none() && allow_jump {
1356 let cursor_point = position.to_point(&snapshot);
1357 if has_events
1358 && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1359 active_buffer.clone(),
1360 &snapshot,
1361 diagnostic_search_range,
1362 cursor_point,
1363 &project,
1364 cx,
1365 )
1366 .await?
1367 {
1368 return this
1369 .update(cx, |this, cx| {
1370 this.request_prediction_internal(
1371 project,
1372 jump_buffer,
1373 jump_position,
1374 trigger,
1375 false,
1376 cx,
1377 )
1378 })?
1379 .await;
1380 }
1381
1382 return anyhow::Ok(None);
1383 }
1384
1385 Ok(prediction)
1386 })
1387 }
1388
1389 async fn next_diagnostic_location(
1390 active_buffer: Entity<Buffer>,
1391 active_buffer_snapshot: &BufferSnapshot,
1392 active_buffer_diagnostic_search_range: Range<Point>,
1393 active_buffer_cursor_point: Point,
1394 project: &Entity<Project>,
1395 cx: &mut AsyncApp,
1396 ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1397 // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
1398 let mut jump_location = active_buffer_snapshot
1399 .diagnostic_groups(None)
1400 .into_iter()
1401 .filter_map(|(_, group)| {
1402 let range = &group.entries[group.primary_ix]
1403 .range
1404 .to_point(&active_buffer_snapshot);
1405 if range.overlaps(&active_buffer_diagnostic_search_range) {
1406 None
1407 } else {
1408 Some(range.start)
1409 }
1410 })
1411 .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
1412 .map(|position| {
1413 (
1414 active_buffer.clone(),
1415 active_buffer_snapshot.anchor_before(position),
1416 )
1417 });
1418
1419 if jump_location.is_none() {
1420 let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
1421 let file = buffer.file()?;
1422
1423 Some(ProjectPath {
1424 worktree_id: file.worktree_id(cx),
1425 path: file.path().clone(),
1426 })
1427 })?;
1428
1429 let buffer_task = project.update(cx, |project, cx| {
1430 let (path, _, _) = project
1431 .diagnostic_summaries(false, cx)
1432 .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
1433 .max_by_key(|(path, _, _)| {
1434 // find the buffer with errors that shares most parent directories
1435 path.path
1436 .components()
1437 .zip(
1438 active_buffer_path
1439 .as_ref()
1440 .map(|p| p.path.components())
1441 .unwrap_or_default(),
1442 )
1443 .take_while(|(a, b)| a == b)
1444 .count()
1445 })?;
1446
1447 Some(project.open_buffer(path, cx))
1448 })?;
1449
1450 if let Some(buffer_task) = buffer_task {
1451 let closest_buffer = buffer_task.await?;
1452
1453 jump_location = closest_buffer
1454 .read_with(cx, |buffer, _cx| {
1455 buffer
1456 .buffer_diagnostics(None)
1457 .into_iter()
1458 .min_by_key(|entry| entry.diagnostic.severity)
1459 .map(|entry| entry.range.start)
1460 })?
1461 .map(|position| (closest_buffer, position));
1462 }
1463 }
1464
1465 anyhow::Ok(jump_location)
1466 }
1467
1468 fn request_prediction_with_zeta2(
1469 &mut self,
1470 project: &Entity<Project>,
1471 active_buffer: &Entity<Buffer>,
1472 active_snapshot: BufferSnapshot,
1473 position: language::Anchor,
1474 events: Vec<Arc<Event>>,
1475 trigger: PredictEditsRequestTrigger,
1476 cx: &mut Context<Self>,
1477 ) -> Task<Result<Option<EditPredictionResult>>> {
1478 let project_state = self.projects.get(&project.entity_id());
1479
1480 let index_state = project_state.and_then(|state| {
1481 state
1482 .syntax_index
1483 .as_ref()
1484 .map(|syntax_index| syntax_index.read_with(cx, |index, _cx| index.state().clone()))
1485 });
1486 let options = self.options.clone();
1487 let buffer_snapshotted_at = Instant::now();
1488 let Some(excerpt_path) = active_snapshot
1489 .file()
1490 .map(|path| -> Arc<Path> { path.full_path(cx).into() })
1491 else {
1492 return Task::ready(Err(anyhow!("No file path for excerpt")));
1493 };
1494 let client = self.client.clone();
1495 let llm_token = self.llm_token.clone();
1496 let app_version = AppVersion::global(cx);
1497 let worktree_snapshots = project
1498 .read(cx)
1499 .worktrees(cx)
1500 .map(|worktree| worktree.read(cx).snapshot())
1501 .collect::<Vec<_>>();
1502 let debug_tx = self.debug_tx.clone();
1503
1504 let diagnostics = active_snapshot.diagnostic_sets().clone();
1505
1506 let file = active_buffer.read(cx).file();
1507 let parent_abs_path = project::File::from_dyn(file).and_then(|f| {
1508 let mut path = f.worktree.read(cx).absolutize(&f.path);
1509 if path.pop() { Some(path) } else { None }
1510 });
1511
1512 // TODO data collection
1513 let can_collect_data = file
1514 .as_ref()
1515 .map_or(false, |file| self.can_collect_file(project, file, cx));
1516
1517 let empty_context_files = HashMap::default();
1518 let context_files = project_state
1519 .and_then(|project_state| project_state.context.as_ref())
1520 .unwrap_or(&empty_context_files);
1521
1522 #[cfg(feature = "eval-support")]
1523 let parsed_fut = futures::future::join_all(
1524 context_files
1525 .keys()
1526 .map(|buffer| buffer.read(cx).parsing_idle()),
1527 );
1528
1529 let mut included_files = context_files
1530 .iter()
1531 .filter_map(|(buffer_entity, ranges)| {
1532 let buffer = buffer_entity.read(cx);
1533 Some((
1534 buffer_entity.clone(),
1535 buffer.snapshot(),
1536 buffer.file()?.full_path(cx).into(),
1537 ranges.clone(),
1538 ))
1539 })
1540 .collect::<Vec<_>>();
1541
1542 included_files.sort_by(|(_, _, path_a, ranges_a), (_, _, path_b, ranges_b)| {
1543 (path_a, ranges_a.len()).cmp(&(path_b, ranges_b.len()))
1544 });
1545
1546 #[cfg(feature = "eval-support")]
1547 let eval_cache = self.eval_cache.clone();
1548
1549 let request_task = cx.background_spawn({
1550 let active_buffer = active_buffer.clone();
1551 async move {
1552 #[cfg(feature = "eval-support")]
1553 parsed_fut.await;
1554
1555 let index_state = if let Some(index_state) = index_state {
1556 Some(index_state.lock_owned().await)
1557 } else {
1558 None
1559 };
1560
1561 let cursor_offset = position.to_offset(&active_snapshot);
1562 let cursor_point = cursor_offset.to_point(&active_snapshot);
1563
1564 let before_retrieval = Instant::now();
1565
1566 let (diagnostic_groups, diagnostic_groups_truncated) =
1567 Self::gather_nearby_diagnostics(
1568 cursor_offset,
1569 &diagnostics,
1570 &active_snapshot,
1571 options.max_diagnostic_bytes,
1572 );
1573
1574 let cloud_request = match options.context {
1575 ContextMode::Agentic(context_options) => {
1576 let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
1577 cursor_point,
1578 &active_snapshot,
1579 &context_options.excerpt,
1580 index_state.as_deref(),
1581 ) else {
1582 return Ok((None, None));
1583 };
1584
1585 let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start)
1586 ..active_snapshot.anchor_before(excerpt.range.end);
1587
1588 if let Some(buffer_ix) =
1589 included_files.iter().position(|(_, snapshot, _, _)| {
1590 snapshot.remote_id() == active_snapshot.remote_id()
1591 })
1592 {
1593 let (_, buffer, _, ranges) = &mut included_files[buffer_ix];
1594 ranges.push(excerpt_anchor_range);
1595 retrieval_search::merge_anchor_ranges(ranges, buffer);
1596 let last_ix = included_files.len() - 1;
1597 included_files.swap(buffer_ix, last_ix);
1598 } else {
1599 included_files.push((
1600 active_buffer.clone(),
1601 active_snapshot.clone(),
1602 excerpt_path.clone(),
1603 vec![excerpt_anchor_range],
1604 ));
1605 }
1606
1607 let included_files = included_files
1608 .iter()
1609 .map(|(_, snapshot, path, ranges)| {
1610 let ranges = ranges
1611 .iter()
1612 .map(|range| {
1613 let point_range = range.to_point(&snapshot);
1614 Line(point_range.start.row)..Line(point_range.end.row)
1615 })
1616 .collect::<Vec<_>>();
1617 let excerpts = assemble_excerpts(&snapshot, ranges);
1618 predict_edits_v3::IncludedFile {
1619 path: path.clone(),
1620 max_row: Line(snapshot.max_point().row),
1621 excerpts,
1622 }
1623 })
1624 .collect::<Vec<_>>();
1625
1626 predict_edits_v3::PredictEditsRequest {
1627 excerpt_path,
1628 excerpt: String::new(),
1629 excerpt_line_range: Line(0)..Line(0),
1630 excerpt_range: 0..0,
1631 cursor_point: predict_edits_v3::Point {
1632 line: predict_edits_v3::Line(cursor_point.row),
1633 column: cursor_point.column,
1634 },
1635 included_files,
1636 referenced_declarations: vec![],
1637 events,
1638 can_collect_data,
1639 diagnostic_groups,
1640 diagnostic_groups_truncated,
1641 debug_info: debug_tx.is_some(),
1642 prompt_max_bytes: Some(options.max_prompt_bytes),
1643 prompt_format: options.prompt_format,
1644 // TODO [zeta2]
1645 signatures: vec![],
1646 excerpt_parent: None,
1647 git_info: None,
1648 trigger,
1649 }
1650 }
1651 ContextMode::Syntax(context_options) => {
1652 let Some(context) = EditPredictionContext::gather_context(
1653 cursor_point,
1654 &active_snapshot,
1655 parent_abs_path.as_deref(),
1656 &context_options,
1657 index_state.as_deref(),
1658 ) else {
1659 return Ok((None, None));
1660 };
1661
1662 make_syntax_context_cloud_request(
1663 excerpt_path,
1664 context,
1665 events,
1666 can_collect_data,
1667 diagnostic_groups,
1668 diagnostic_groups_truncated,
1669 None,
1670 debug_tx.is_some(),
1671 &worktree_snapshots,
1672 index_state.as_deref(),
1673 Some(options.max_prompt_bytes),
1674 options.prompt_format,
1675 trigger,
1676 )
1677 }
1678 };
1679
1680 let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request);
1681
1682 let inputs = EditPredictionInputs {
1683 included_files: cloud_request.included_files,
1684 events: cloud_request.events,
1685 cursor_point: cloud_request.cursor_point,
1686 cursor_path: cloud_request.excerpt_path,
1687 };
1688
1689 let retrieval_time = Instant::now() - before_retrieval;
1690
1691 let debug_response_tx = if let Some(debug_tx) = &debug_tx {
1692 let (response_tx, response_rx) = oneshot::channel();
1693
1694 debug_tx
1695 .unbounded_send(ZetaDebugInfo::EditPredictionRequested(
1696 ZetaEditPredictionDebugInfo {
1697 inputs: inputs.clone(),
1698 retrieval_time,
1699 buffer: active_buffer.downgrade(),
1700 local_prompt: match prompt_result.as_ref() {
1701 Ok((prompt, _)) => Ok(prompt.clone()),
1702 Err(err) => Err(err.to_string()),
1703 },
1704 position,
1705 response_rx,
1706 },
1707 ))
1708 .ok();
1709 Some(response_tx)
1710 } else {
1711 None
1712 };
1713
1714 if cfg!(debug_assertions) && env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
1715 if let Some(debug_response_tx) = debug_response_tx {
1716 debug_response_tx
1717 .send((Err("Request skipped".to_string()), Duration::ZERO))
1718 .ok();
1719 }
1720 anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
1721 }
1722
1723 let (prompt, _) = prompt_result?;
1724 let generation_params =
1725 cloud_zeta2_prompt::generation_params(cloud_request.prompt_format);
1726 let request = open_ai::Request {
1727 model: EDIT_PREDICTIONS_MODEL_ID.clone(),
1728 messages: vec![open_ai::RequestMessage::User {
1729 content: open_ai::MessageContent::Plain(prompt),
1730 }],
1731 stream: false,
1732 max_completion_tokens: None,
1733 stop: generation_params.stop.unwrap_or_default(),
1734 temperature: generation_params.temperature.unwrap_or(0.7),
1735 tool_choice: None,
1736 parallel_tool_calls: None,
1737 tools: vec![],
1738 prompt_cache_key: None,
1739 reasoning_effort: None,
1740 };
1741
1742 log::trace!("Sending edit prediction request");
1743
1744 let before_request = Instant::now();
1745 let response = Self::send_raw_llm_request(
1746 request,
1747 client,
1748 llm_token,
1749 app_version,
1750 #[cfg(feature = "eval-support")]
1751 eval_cache,
1752 #[cfg(feature = "eval-support")]
1753 EvalCacheEntryKind::Prediction,
1754 )
1755 .await;
1756 let received_response_at = Instant::now();
1757 let request_time = received_response_at - before_request;
1758
1759 log::trace!("Got edit prediction response");
1760
1761 if let Some(debug_response_tx) = debug_response_tx {
1762 debug_response_tx
1763 .send((
1764 response
1765 .as_ref()
1766 .map_err(|err| err.to_string())
1767 .map(|response| response.0.clone()),
1768 request_time,
1769 ))
1770 .ok();
1771 }
1772
1773 let (res, usage) = response?;
1774 let request_id = EditPredictionId(res.id.clone().into());
1775 let Some(mut output_text) = text_from_response(res) else {
1776 return Ok((Some((request_id, None)), usage));
1777 };
1778
1779 if output_text.contains(CURSOR_MARKER) {
1780 log::trace!("Stripping out {CURSOR_MARKER} from response");
1781 output_text = output_text.replace(CURSOR_MARKER, "");
1782 }
1783
1784 let get_buffer_from_context = |path: &Path| {
1785 included_files
1786 .iter()
1787 .find_map(|(_, buffer, probe_path, ranges)| {
1788 if probe_path.as_ref() == path {
1789 Some((buffer, ranges.as_slice()))
1790 } else {
1791 None
1792 }
1793 })
1794 };
1795
1796 let (edited_buffer_snapshot, edits) = match options.prompt_format {
1797 PromptFormat::NumLinesUniDiff => {
1798 // TODO: Implement parsing of multi-file diffs
1799 crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
1800 }
1801 PromptFormat::Minimal
1802 | PromptFormat::MinimalQwen
1803 | PromptFormat::SeedCoder1120 => {
1804 if output_text.contains("--- a/\n+++ b/\nNo edits") {
1805 let edits = vec![];
1806 (&active_snapshot, edits)
1807 } else {
1808 crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
1809 }
1810 }
1811 PromptFormat::OldTextNewText => {
1812 crate::xml_edits::parse_xml_edits(&output_text, get_buffer_from_context)
1813 .await?
1814 }
1815 _ => {
1816 bail!("unsupported prompt format {}", options.prompt_format)
1817 }
1818 };
1819
1820 let edited_buffer = included_files
1821 .iter()
1822 .find_map(|(buffer, snapshot, _, _)| {
1823 if snapshot.remote_id() == edited_buffer_snapshot.remote_id() {
1824 Some(buffer.clone())
1825 } else {
1826 None
1827 }
1828 })
1829 .context("Failed to find buffer in included_buffers")?;
1830
1831 anyhow::Ok((
1832 Some((
1833 request_id,
1834 Some((
1835 inputs,
1836 edited_buffer,
1837 edited_buffer_snapshot.clone(),
1838 edits,
1839 received_response_at,
1840 )),
1841 )),
1842 usage,
1843 ))
1844 }
1845 });
1846
1847 cx.spawn({
1848 async move |this, cx| {
1849 let Some((id, prediction)) =
1850 Self::handle_api_response(&this, request_task.await, cx)?
1851 else {
1852 return Ok(None);
1853 };
1854
1855 let Some((
1856 inputs,
1857 edited_buffer,
1858 edited_buffer_snapshot,
1859 edits,
1860 received_response_at,
1861 )) = prediction
1862 else {
1863 return Ok(Some(EditPredictionResult {
1864 id,
1865 prediction: Err(EditPredictionRejectReason::Empty),
1866 }));
1867 };
1868
1869 // TODO telemetry: duration, etc
1870 Ok(Some(
1871 EditPredictionResult::new(
1872 id,
1873 &edited_buffer,
1874 &edited_buffer_snapshot,
1875 edits.into(),
1876 buffer_snapshotted_at,
1877 received_response_at,
1878 inputs,
1879 cx,
1880 )
1881 .await,
1882 ))
1883 }
1884 })
1885 }
1886
1887 async fn send_raw_llm_request(
1888 request: open_ai::Request,
1889 client: Arc<Client>,
1890 llm_token: LlmApiToken,
1891 app_version: Version,
1892 #[cfg(feature = "eval-support")] eval_cache: Option<Arc<dyn EvalCache>>,
1893 #[cfg(feature = "eval-support")] eval_cache_kind: EvalCacheEntryKind,
1894 ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
1895 let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
1896 http_client::Url::parse(&predict_edits_url)?
1897 } else {
1898 client
1899 .http_client()
1900 .build_zed_llm_url("/predict_edits/raw", &[])?
1901 };
1902
1903 #[cfg(feature = "eval-support")]
1904 let cache_key = if let Some(cache) = eval_cache {
1905 use collections::FxHasher;
1906 use std::hash::{Hash, Hasher};
1907
1908 let mut hasher = FxHasher::default();
1909 url.hash(&mut hasher);
1910 let request_str = serde_json::to_string_pretty(&request)?;
1911 request_str.hash(&mut hasher);
1912 let hash = hasher.finish();
1913
1914 let key = (eval_cache_kind, hash);
1915 if let Some(response_str) = cache.read(key) {
1916 return Ok((serde_json::from_str(&response_str)?, None));
1917 }
1918
1919 Some((cache, request_str, key))
1920 } else {
1921 None
1922 };
1923
1924 let (response, usage) = Self::send_api_request(
1925 |builder| {
1926 let req = builder
1927 .uri(url.as_ref())
1928 .body(serde_json::to_string(&request)?.into());
1929 Ok(req?)
1930 },
1931 client,
1932 llm_token,
1933 app_version,
1934 )
1935 .await?;
1936
1937 #[cfg(feature = "eval-support")]
1938 if let Some((cache, request, key)) = cache_key {
1939 cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
1940 }
1941
1942 Ok((response, usage))
1943 }
1944
1945 fn handle_api_response<T>(
1946 this: &WeakEntity<Self>,
1947 response: Result<(T, Option<EditPredictionUsage>)>,
1948 cx: &mut gpui::AsyncApp,
1949 ) -> Result<T> {
1950 match response {
1951 Ok((data, usage)) => {
1952 if let Some(usage) = usage {
1953 this.update(cx, |this, cx| {
1954 this.user_store.update(cx, |user_store, cx| {
1955 user_store.update_edit_prediction_usage(usage, cx);
1956 });
1957 })
1958 .ok();
1959 }
1960 Ok(data)
1961 }
1962 Err(err) => {
1963 if err.is::<ZedUpdateRequiredError>() {
1964 cx.update(|cx| {
1965 this.update(cx, |this, _cx| {
1966 this.update_required = true;
1967 })
1968 .ok();
1969
1970 let error_message: SharedString = err.to_string().into();
1971 show_app_notification(
1972 NotificationId::unique::<ZedUpdateRequiredError>(),
1973 cx,
1974 move |cx| {
1975 cx.new(|cx| {
1976 ErrorMessagePrompt::new(error_message.clone(), cx)
1977 .with_link_button("Update Zed", "https://zed.dev/releases")
1978 })
1979 },
1980 );
1981 })
1982 .ok();
1983 }
1984 Err(err)
1985 }
1986 }
1987 }
1988
1989 async fn send_api_request<Res>(
1990 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1991 client: Arc<Client>,
1992 llm_token: LlmApiToken,
1993 app_version: Version,
1994 ) -> Result<(Res, Option<EditPredictionUsage>)>
1995 where
1996 Res: DeserializeOwned,
1997 {
1998 let http_client = client.http_client();
1999 let mut token = llm_token.acquire(&client).await?;
2000 let mut did_retry = false;
2001
2002 loop {
2003 let request_builder = http_client::Request::builder().method(Method::POST);
2004
2005 let request = build(
2006 request_builder
2007 .header("Content-Type", "application/json")
2008 .header("Authorization", format!("Bearer {}", token))
2009 .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
2010 )?;
2011
2012 let mut response = http_client.send(request).await?;
2013
2014 if let Some(minimum_required_version) = response
2015 .headers()
2016 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2017 .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2018 {
2019 anyhow::ensure!(
2020 app_version >= minimum_required_version,
2021 ZedUpdateRequiredError {
2022 minimum_version: minimum_required_version
2023 }
2024 );
2025 }
2026
2027 if response.status().is_success() {
2028 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2029
2030 let mut body = Vec::new();
2031 response.body_mut().read_to_end(&mut body).await?;
2032 return Ok((serde_json::from_slice(&body)?, usage));
2033 } else if !did_retry
2034 && response
2035 .headers()
2036 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
2037 .is_some()
2038 {
2039 did_retry = true;
2040 token = llm_token.refresh(&client).await?;
2041 } else {
2042 let mut body = String::new();
2043 response.body_mut().read_to_string(&mut body).await?;
2044 anyhow::bail!(
2045 "Request failed with status: {:?}\nBody: {}",
2046 response.status(),
2047 body
2048 );
2049 }
2050 }
2051 }
2052
2053 pub const CONTEXT_RETRIEVAL_IDLE_DURATION: Duration = Duration::from_secs(10);
2054 pub const CONTEXT_RETRIEVAL_DEBOUNCE_DURATION: Duration = Duration::from_secs(3);
2055
2056 // Refresh the related excerpts when the user just beguns editing after
2057 // an idle period, and after they pause editing.
2058 fn refresh_context_if_needed(
2059 &mut self,
2060 project: &Entity<Project>,
2061 buffer: &Entity<language::Buffer>,
2062 cursor_position: language::Anchor,
2063 cx: &mut Context<Self>,
2064 ) {
2065 if !matches!(self.edit_prediction_model, ZetaEditPredictionModel::Zeta2) {
2066 return;
2067 }
2068
2069 if !matches!(&self.options().context, ContextMode::Agentic { .. }) {
2070 return;
2071 }
2072
2073 let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
2074 return;
2075 };
2076
2077 let now = Instant::now();
2078 let was_idle = zeta_project
2079 .refresh_context_timestamp
2080 .map_or(true, |timestamp| {
2081 now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION
2082 });
2083 zeta_project.refresh_context_timestamp = Some(now);
2084 zeta_project.refresh_context_debounce_task = Some(cx.spawn({
2085 let buffer = buffer.clone();
2086 let project = project.clone();
2087 async move |this, cx| {
2088 if was_idle {
2089 log::debug!("refetching edit prediction context after idle");
2090 } else {
2091 cx.background_executor()
2092 .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION)
2093 .await;
2094 log::debug!("refetching edit prediction context after pause");
2095 }
2096 this.update(cx, |this, cx| {
2097 let task = this.refresh_context(project.clone(), buffer, cursor_position, cx);
2098
2099 if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
2100 zeta_project.refresh_context_task = Some(task.log_err());
2101 };
2102 })
2103 .ok()
2104 }
2105 }));
2106 }
2107
2108 // Refresh the related excerpts asynchronously. Ensure the task runs to completion,
2109 // and avoid spawning more than one concurrent task.
2110 pub fn refresh_context(
2111 &mut self,
2112 project: Entity<Project>,
2113 buffer: Entity<language::Buffer>,
2114 cursor_position: language::Anchor,
2115 cx: &mut Context<Self>,
2116 ) -> Task<Result<()>> {
2117 let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
2118 return Task::ready(anyhow::Ok(()));
2119 };
2120
2121 let ContextMode::Agentic(options) = &self.options().context else {
2122 return Task::ready(anyhow::Ok(()));
2123 };
2124
2125 let snapshot = buffer.read(cx).snapshot();
2126 let cursor_point = cursor_position.to_point(&snapshot);
2127 let Some(cursor_excerpt) = EditPredictionExcerpt::select_from_buffer(
2128 cursor_point,
2129 &snapshot,
2130 &options.excerpt,
2131 None,
2132 ) else {
2133 return Task::ready(Ok(()));
2134 };
2135
2136 let app_version = AppVersion::global(cx);
2137 let client = self.client.clone();
2138 let llm_token = self.llm_token.clone();
2139 let debug_tx = self.debug_tx.clone();
2140 let current_file_path: Arc<Path> = snapshot
2141 .file()
2142 .map(|f| f.full_path(cx).into())
2143 .unwrap_or_else(|| Path::new("untitled").into());
2144
2145 let prompt = match cloud_zeta2_prompt::retrieval_prompt::build_prompt(
2146 predict_edits_v3::PlanContextRetrievalRequest {
2147 excerpt: cursor_excerpt.text(&snapshot).body,
2148 excerpt_path: current_file_path,
2149 excerpt_line_range: cursor_excerpt.line_range,
2150 cursor_file_max_row: Line(snapshot.max_point().row),
2151 events: zeta_project.events(cx),
2152 },
2153 ) {
2154 Ok(prompt) => prompt,
2155 Err(err) => {
2156 return Task::ready(Err(err));
2157 }
2158 };
2159
2160 if let Some(debug_tx) = &debug_tx {
2161 debug_tx
2162 .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted(
2163 ZetaContextRetrievalStartedDebugInfo {
2164 project: project.clone(),
2165 timestamp: Instant::now(),
2166 search_prompt: prompt.clone(),
2167 },
2168 ))
2169 .ok();
2170 }
2171
2172 pub static TOOL_SCHEMA: LazyLock<(serde_json::Value, String)> = LazyLock::new(|| {
2173 let schema = language_model::tool_schema::root_schema_for::<SearchToolInput>(
2174 language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset,
2175 );
2176
2177 let description = schema
2178 .get("description")
2179 .and_then(|description| description.as_str())
2180 .unwrap()
2181 .to_string();
2182
2183 (schema.into(), description)
2184 });
2185
2186 let (tool_schema, tool_description) = TOOL_SCHEMA.clone();
2187
2188 let request = open_ai::Request {
2189 model: CONTEXT_RETRIEVAL_MODEL_ID.clone(),
2190 messages: vec![open_ai::RequestMessage::User {
2191 content: open_ai::MessageContent::Plain(prompt),
2192 }],
2193 stream: false,
2194 max_completion_tokens: None,
2195 stop: Default::default(),
2196 temperature: 0.7,
2197 tool_choice: None,
2198 parallel_tool_calls: None,
2199 tools: vec![open_ai::ToolDefinition::Function {
2200 function: FunctionDefinition {
2201 name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME.to_string(),
2202 description: Some(tool_description),
2203 parameters: Some(tool_schema),
2204 },
2205 }],
2206 prompt_cache_key: None,
2207 reasoning_effort: None,
2208 };
2209
2210 #[cfg(feature = "eval-support")]
2211 let eval_cache = self.eval_cache.clone();
2212
2213 cx.spawn(async move |this, cx| {
2214 log::trace!("Sending search planning request");
2215 let response = Self::send_raw_llm_request(
2216 request,
2217 client,
2218 llm_token,
2219 app_version,
2220 #[cfg(feature = "eval-support")]
2221 eval_cache.clone(),
2222 #[cfg(feature = "eval-support")]
2223 EvalCacheEntryKind::Context,
2224 )
2225 .await;
2226 let mut response = Self::handle_api_response(&this, response, cx)?;
2227 log::trace!("Got search planning response");
2228
2229 let choice = response
2230 .choices
2231 .pop()
2232 .context("No choices in retrieval response")?;
2233 let open_ai::RequestMessage::Assistant {
2234 content: _,
2235 tool_calls,
2236 } = choice.message
2237 else {
2238 anyhow::bail!("Retrieval response didn't include an assistant message");
2239 };
2240
2241 let mut queries: Vec<SearchToolQuery> = Vec::new();
2242 for tool_call in tool_calls {
2243 let open_ai::ToolCallContent::Function { function } = tool_call.content;
2244 if function.name != cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME {
2245 log::warn!(
2246 "Context retrieval response tried to call an unknown tool: {}",
2247 function.name
2248 );
2249
2250 continue;
2251 }
2252
2253 let input: SearchToolInput = serde_json::from_str(&function.arguments)
2254 .with_context(|| format!("invalid search json {}", &function.arguments))?;
2255 queries.extend(input.queries);
2256 }
2257
2258 if let Some(debug_tx) = &debug_tx {
2259 debug_tx
2260 .unbounded_send(ZetaDebugInfo::SearchQueriesGenerated(
2261 ZetaSearchQueryDebugInfo {
2262 project: project.clone(),
2263 timestamp: Instant::now(),
2264 search_queries: queries.clone(),
2265 },
2266 ))
2267 .ok();
2268 }
2269
2270 log::trace!("Running retrieval search: {queries:#?}");
2271
2272 let related_excerpts_result = retrieval_search::run_retrieval_searches(
2273 queries,
2274 project.clone(),
2275 #[cfg(feature = "eval-support")]
2276 eval_cache,
2277 cx,
2278 )
2279 .await;
2280
2281 log::trace!("Search queries executed");
2282
2283 if let Some(debug_tx) = &debug_tx {
2284 debug_tx
2285 .unbounded_send(ZetaDebugInfo::SearchQueriesExecuted(
2286 ZetaContextRetrievalDebugInfo {
2287 project: project.clone(),
2288 timestamp: Instant::now(),
2289 },
2290 ))
2291 .ok();
2292 }
2293
2294 this.update(cx, |this, _cx| {
2295 let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
2296 return Ok(());
2297 };
2298 zeta_project.refresh_context_task.take();
2299 if let Some(debug_tx) = &this.debug_tx {
2300 debug_tx
2301 .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished(
2302 ZetaContextRetrievalDebugInfo {
2303 project,
2304 timestamp: Instant::now(),
2305 },
2306 ))
2307 .ok();
2308 }
2309 match related_excerpts_result {
2310 Ok(excerpts) => {
2311 zeta_project.context = Some(excerpts);
2312 Ok(())
2313 }
2314 Err(error) => Err(error),
2315 }
2316 })?
2317 })
2318 }
2319
2320 pub fn set_context(
2321 &mut self,
2322 project: Entity<Project>,
2323 context: HashMap<Entity<Buffer>, Vec<Range<Anchor>>>,
2324 ) {
2325 if let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) {
2326 zeta_project.context = Some(context);
2327 }
2328 }
2329
2330 fn gather_nearby_diagnostics(
2331 cursor_offset: usize,
2332 diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
2333 snapshot: &BufferSnapshot,
2334 max_diagnostics_bytes: usize,
2335 ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
2336 // TODO: Could make this more efficient
2337 let mut diagnostic_groups = Vec::new();
2338 for (language_server_id, diagnostics) in diagnostic_sets {
2339 let mut groups = Vec::new();
2340 diagnostics.groups(*language_server_id, &mut groups, &snapshot);
2341 diagnostic_groups.extend(
2342 groups
2343 .into_iter()
2344 .map(|(_, group)| group.resolve::<usize>(&snapshot)),
2345 );
2346 }
2347
2348 // sort by proximity to cursor
2349 diagnostic_groups.sort_by_key(|group| {
2350 let range = &group.entries[group.primary_ix].range;
2351 if range.start >= cursor_offset {
2352 range.start - cursor_offset
2353 } else if cursor_offset >= range.end {
2354 cursor_offset - range.end
2355 } else {
2356 (cursor_offset - range.start).min(range.end - cursor_offset)
2357 }
2358 });
2359
2360 let mut results = Vec::new();
2361 let mut diagnostic_groups_truncated = false;
2362 let mut diagnostics_byte_count = 0;
2363 for group in diagnostic_groups {
2364 let raw_value = serde_json::value::to_raw_value(&group).unwrap();
2365 diagnostics_byte_count += raw_value.get().len();
2366 if diagnostics_byte_count > max_diagnostics_bytes {
2367 diagnostic_groups_truncated = true;
2368 break;
2369 }
2370 results.push(predict_edits_v3::DiagnosticGroup(raw_value));
2371 }
2372
2373 (results, diagnostic_groups_truncated)
2374 }
2375
2376 // TODO: Dedupe with similar code in request_prediction?
2377 pub fn cloud_request_for_zeta_cli(
2378 &mut self,
2379 project: &Entity<Project>,
2380 buffer: &Entity<Buffer>,
2381 position: language::Anchor,
2382 cx: &mut Context<Self>,
2383 ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
2384 let project_state = self.projects.get(&project.entity_id());
2385
2386 let index_state = project_state.and_then(|state| {
2387 state
2388 .syntax_index
2389 .as_ref()
2390 .map(|index| index.read_with(cx, |index, _cx| index.state().clone()))
2391 });
2392 let options = self.options.clone();
2393 let snapshot = buffer.read(cx).snapshot();
2394 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
2395 return Task::ready(Err(anyhow!("No file path for excerpt")));
2396 };
2397 let worktree_snapshots = project
2398 .read(cx)
2399 .worktrees(cx)
2400 .map(|worktree| worktree.read(cx).snapshot())
2401 .collect::<Vec<_>>();
2402
2403 let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
2404 let mut path = f.worktree.read(cx).absolutize(&f.path);
2405 if path.pop() { Some(path) } else { None }
2406 });
2407
2408 cx.background_spawn(async move {
2409 let index_state = if let Some(index_state) = index_state {
2410 Some(index_state.lock_owned().await)
2411 } else {
2412 None
2413 };
2414
2415 let cursor_point = position.to_point(&snapshot);
2416
2417 let debug_info = true;
2418 EditPredictionContext::gather_context(
2419 cursor_point,
2420 &snapshot,
2421 parent_abs_path.as_deref(),
2422 match &options.context {
2423 ContextMode::Agentic(_) => {
2424 // TODO
2425 panic!("Llm mode not supported in zeta cli yet");
2426 }
2427 ContextMode::Syntax(edit_prediction_context_options) => {
2428 edit_prediction_context_options
2429 }
2430 },
2431 index_state.as_deref(),
2432 )
2433 .context("Failed to select excerpt")
2434 .map(|context| {
2435 make_syntax_context_cloud_request(
2436 excerpt_path.into(),
2437 context,
2438 // TODO pass everything
2439 Vec::new(),
2440 false,
2441 Vec::new(),
2442 false,
2443 None,
2444 debug_info,
2445 &worktree_snapshots,
2446 index_state.as_deref(),
2447 Some(options.max_prompt_bytes),
2448 options.prompt_format,
2449 PredictEditsRequestTrigger::Other,
2450 )
2451 })
2452 })
2453 }
2454
2455 pub fn wait_for_initial_indexing(
2456 &mut self,
2457 project: &Entity<Project>,
2458 cx: &mut Context<Self>,
2459 ) -> Task<Result<()>> {
2460 let zeta_project = self.get_or_init_zeta_project(project, cx);
2461 if let Some(syntax_index) = &zeta_project.syntax_index {
2462 syntax_index.read(cx).wait_for_initial_file_indexing(cx)
2463 } else {
2464 Task::ready(Ok(()))
2465 }
2466 }
2467
2468 fn is_file_open_source(
2469 &self,
2470 project: &Entity<Project>,
2471 file: &Arc<dyn File>,
2472 cx: &App,
2473 ) -> bool {
2474 if !file.is_local() || file.is_private() {
2475 return false;
2476 }
2477 let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
2478 return false;
2479 };
2480 zeta_project
2481 .license_detection_watchers
2482 .get(&file.worktree_id(cx))
2483 .as_ref()
2484 .is_some_and(|watcher| watcher.is_project_open_source())
2485 }
2486
2487 fn can_collect_file(&self, project: &Entity<Project>, file: &Arc<dyn File>, cx: &App) -> bool {
2488 self.data_collection_choice.is_enabled() && self.is_file_open_source(project, file, cx)
2489 }
2490
2491 fn can_collect_events(&self, events: &[Arc<Event>]) -> bool {
2492 if !self.data_collection_choice.is_enabled() {
2493 return false;
2494 }
2495 events.iter().all(|event| {
2496 matches!(
2497 event.as_ref(),
2498 Event::BufferChange {
2499 in_open_source_repo: true,
2500 ..
2501 }
2502 )
2503 })
2504 }
2505
2506 fn load_data_collection_choice() -> DataCollectionChoice {
2507 let choice = KEY_VALUE_STORE
2508 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2509 .log_err()
2510 .flatten();
2511
2512 match choice.as_deref() {
2513 Some("true") => DataCollectionChoice::Enabled,
2514 Some("false") => DataCollectionChoice::Disabled,
2515 Some(_) => {
2516 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2517 DataCollectionChoice::NotAnswered
2518 }
2519 None => DataCollectionChoice::NotAnswered,
2520 }
2521 }
2522
2523 pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2524 self.shown_predictions.iter()
2525 }
2526
2527 pub fn shown_completions_len(&self) -> usize {
2528 self.shown_predictions.len()
2529 }
2530
2531 pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2532 self.rated_predictions.contains(id)
2533 }
2534
2535 pub fn rate_prediction(
2536 &mut self,
2537 prediction: &EditPrediction,
2538 rating: EditPredictionRating,
2539 feedback: String,
2540 cx: &mut Context<Self>,
2541 ) {
2542 self.rated_predictions.insert(prediction.id.clone());
2543 telemetry::event!(
2544 "Edit Prediction Rated",
2545 rating,
2546 inputs = prediction.inputs,
2547 output = prediction.edit_preview.as_unified_diff(&prediction.edits),
2548 feedback
2549 );
2550 self.client.telemetry().flush_events().detach();
2551 cx.notify();
2552 }
2553}
2554
2555pub fn text_from_response(mut res: open_ai::Response) -> Option<String> {
2556 let choice = res.choices.pop()?;
2557 let output_text = match choice.message {
2558 open_ai::RequestMessage::Assistant {
2559 content: Some(open_ai::MessageContent::Plain(content)),
2560 ..
2561 } => content,
2562 open_ai::RequestMessage::Assistant {
2563 content: Some(open_ai::MessageContent::Multipart(mut content)),
2564 ..
2565 } => {
2566 if content.is_empty() {
2567 log::error!("No output from Baseten completion response");
2568 return None;
2569 }
2570
2571 match content.remove(0) {
2572 open_ai::MessagePart::Text { text } => text,
2573 open_ai::MessagePart::Image { .. } => {
2574 log::error!("Expected text, got an image");
2575 return None;
2576 }
2577 }
2578 }
2579 _ => {
2580 log::error!("Invalid response message: {:?}", choice.message);
2581 return None;
2582 }
2583 };
2584 Some(output_text)
2585}
2586
2587#[derive(Error, Debug)]
2588#[error(
2589 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2590)]
2591pub struct ZedUpdateRequiredError {
2592 minimum_version: Version,
2593}
2594
2595fn make_syntax_context_cloud_request(
2596 excerpt_path: Arc<Path>,
2597 context: EditPredictionContext,
2598 events: Vec<Arc<predict_edits_v3::Event>>,
2599 can_collect_data: bool,
2600 diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
2601 diagnostic_groups_truncated: bool,
2602 git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
2603 debug_info: bool,
2604 worktrees: &Vec<worktree::Snapshot>,
2605 index_state: Option<&SyntaxIndexState>,
2606 prompt_max_bytes: Option<usize>,
2607 prompt_format: PromptFormat,
2608 trigger: PredictEditsRequestTrigger,
2609) -> predict_edits_v3::PredictEditsRequest {
2610 let mut signatures = Vec::new();
2611 let mut declaration_to_signature_index = HashMap::default();
2612 let mut referenced_declarations = Vec::new();
2613
2614 for snippet in context.declarations {
2615 let project_entry_id = snippet.declaration.project_entry_id();
2616 let Some(path) = worktrees.iter().find_map(|worktree| {
2617 worktree.entry_for_id(project_entry_id).map(|entry| {
2618 let mut full_path = RelPathBuf::new();
2619 full_path.push(worktree.root_name());
2620 full_path.push(&entry.path);
2621 full_path
2622 })
2623 }) else {
2624 continue;
2625 };
2626
2627 let parent_index = index_state.and_then(|index_state| {
2628 snippet.declaration.parent().and_then(|parent| {
2629 add_signature(
2630 parent,
2631 &mut declaration_to_signature_index,
2632 &mut signatures,
2633 index_state,
2634 )
2635 })
2636 });
2637
2638 let (text, text_is_truncated) = snippet.declaration.item_text();
2639 referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
2640 path: path.as_std_path().into(),
2641 text: text.into(),
2642 range: snippet.declaration.item_line_range(),
2643 text_is_truncated,
2644 signature_range: snippet.declaration.signature_range_in_item_text(),
2645 parent_index,
2646 signature_score: snippet.score(DeclarationStyle::Signature),
2647 declaration_score: snippet.score(DeclarationStyle::Declaration),
2648 score_components: snippet.components,
2649 });
2650 }
2651
2652 let excerpt_parent = index_state.and_then(|index_state| {
2653 context
2654 .excerpt
2655 .parent_declarations
2656 .last()
2657 .and_then(|(parent, _)| {
2658 add_signature(
2659 *parent,
2660 &mut declaration_to_signature_index,
2661 &mut signatures,
2662 index_state,
2663 )
2664 })
2665 });
2666
2667 predict_edits_v3::PredictEditsRequest {
2668 excerpt_path,
2669 excerpt: context.excerpt_text.body,
2670 excerpt_line_range: context.excerpt.line_range,
2671 excerpt_range: context.excerpt.range,
2672 cursor_point: predict_edits_v3::Point {
2673 line: predict_edits_v3::Line(context.cursor_point.row),
2674 column: context.cursor_point.column,
2675 },
2676 referenced_declarations,
2677 included_files: vec![],
2678 signatures,
2679 excerpt_parent,
2680 events,
2681 can_collect_data,
2682 diagnostic_groups,
2683 diagnostic_groups_truncated,
2684 git_info,
2685 debug_info,
2686 prompt_max_bytes,
2687 prompt_format,
2688 trigger,
2689 }
2690}
2691
2692fn add_signature(
2693 declaration_id: DeclarationId,
2694 declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
2695 signatures: &mut Vec<Signature>,
2696 index: &SyntaxIndexState,
2697) -> Option<usize> {
2698 if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
2699 return Some(*signature_index);
2700 }
2701 let Some(parent_declaration) = index.declaration(declaration_id) else {
2702 log::error!("bug: missing parent declaration");
2703 return None;
2704 };
2705 let parent_index = parent_declaration.parent().and_then(|parent| {
2706 add_signature(parent, declaration_to_signature_index, signatures, index)
2707 });
2708 let (text, text_is_truncated) = parent_declaration.signature_text();
2709 let signature_index = signatures.len();
2710 signatures.push(Signature {
2711 text: text.into(),
2712 text_is_truncated,
2713 parent_index,
2714 range: parent_declaration.signature_line_range(),
2715 });
2716 declaration_to_signature_index.insert(declaration_id, signature_index);
2717 Some(signature_index)
2718}
2719
2720#[cfg(feature = "eval-support")]
2721pub type EvalCacheKey = (EvalCacheEntryKind, u64);
2722
2723#[cfg(feature = "eval-support")]
2724#[derive(Debug, Clone, Copy, PartialEq)]
2725pub enum EvalCacheEntryKind {
2726 Context,
2727 Search,
2728 Prediction,
2729}
2730
2731#[cfg(feature = "eval-support")]
2732impl std::fmt::Display for EvalCacheEntryKind {
2733 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2734 match self {
2735 EvalCacheEntryKind::Search => write!(f, "search"),
2736 EvalCacheEntryKind::Context => write!(f, "context"),
2737 EvalCacheEntryKind::Prediction => write!(f, "prediction"),
2738 }
2739 }
2740}
2741
2742#[cfg(feature = "eval-support")]
2743pub trait EvalCache: Send + Sync {
2744 fn read(&self, key: EvalCacheKey) -> Option<String>;
2745 fn write(&self, key: EvalCacheKey, input: &str, value: &str);
2746}
2747
2748#[derive(Debug, Clone, Copy)]
2749pub enum DataCollectionChoice {
2750 NotAnswered,
2751 Enabled,
2752 Disabled,
2753}
2754
2755impl DataCollectionChoice {
2756 pub fn is_enabled(self) -> bool {
2757 match self {
2758 Self::Enabled => true,
2759 Self::NotAnswered | Self::Disabled => false,
2760 }
2761 }
2762
2763 pub fn is_answered(self) -> bool {
2764 match self {
2765 Self::Enabled | Self::Disabled => true,
2766 Self::NotAnswered => false,
2767 }
2768 }
2769
2770 #[must_use]
2771 pub fn toggle(&self) -> DataCollectionChoice {
2772 match self {
2773 Self::Enabled => Self::Disabled,
2774 Self::Disabled => Self::Enabled,
2775 Self::NotAnswered => Self::Enabled,
2776 }
2777 }
2778}
2779
2780impl From<bool> for DataCollectionChoice {
2781 fn from(value: bool) -> Self {
2782 match value {
2783 true => DataCollectionChoice::Enabled,
2784 false => DataCollectionChoice::Disabled,
2785 }
2786 }
2787}
2788
2789struct ZedPredictUpsell;
2790
2791impl Dismissable for ZedPredictUpsell {
2792 const KEY: &'static str = "dismissed-edit-predict-upsell";
2793
2794 fn dismissed() -> bool {
2795 // To make this backwards compatible with older versions of Zed, we
2796 // check if the user has seen the previous Edit Prediction Onboarding
2797 // before, by checking the data collection choice which was written to
2798 // the database once the user clicked on "Accept and Enable"
2799 if KEY_VALUE_STORE
2800 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2801 .log_err()
2802 .is_some_and(|s| s.is_some())
2803 {
2804 return true;
2805 }
2806
2807 KEY_VALUE_STORE
2808 .read_kvp(Self::KEY)
2809 .log_err()
2810 .is_some_and(|s| s.is_some())
2811 }
2812}
2813
2814pub fn should_show_upsell_modal() -> bool {
2815 !ZedPredictUpsell::dismissed()
2816}
2817
2818pub fn init(cx: &mut App) {
2819 feature_gate_predict_edits_actions(cx);
2820
2821 cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2822 workspace.register_action(|workspace, _: &RateCompletions, window, cx| {
2823 if cx.has_flag::<PredictEditsRateCompletionsFeatureFlag>() {
2824 RatePredictionsModal::toggle(workspace, window, cx);
2825 }
2826 });
2827
2828 workspace.register_action(
2829 move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2830 ZedPredictModal::toggle(
2831 workspace,
2832 workspace.user_store().clone(),
2833 workspace.client().clone(),
2834 window,
2835 cx,
2836 )
2837 },
2838 );
2839
2840 workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2841 update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2842 settings
2843 .project
2844 .all_languages
2845 .features
2846 .get_or_insert_default()
2847 .edit_prediction_provider = Some(EditPredictionProvider::None)
2848 });
2849 });
2850 })
2851 .detach();
2852}
2853
2854fn feature_gate_predict_edits_actions(cx: &mut App) {
2855 let rate_completion_action_types = [TypeId::of::<RateCompletions>()];
2856 let reset_onboarding_action_types = [TypeId::of::<ResetOnboarding>()];
2857 let zeta_all_action_types = [
2858 TypeId::of::<RateCompletions>(),
2859 TypeId::of::<ResetOnboarding>(),
2860 zed_actions::OpenZedPredictOnboarding.type_id(),
2861 TypeId::of::<ClearHistory>(),
2862 TypeId::of::<ThumbsUpActivePrediction>(),
2863 TypeId::of::<ThumbsDownActivePrediction>(),
2864 TypeId::of::<NextEdit>(),
2865 TypeId::of::<PreviousEdit>(),
2866 ];
2867
2868 CommandPaletteFilter::update_global(cx, |filter, _cx| {
2869 filter.hide_action_types(&rate_completion_action_types);
2870 filter.hide_action_types(&reset_onboarding_action_types);
2871 filter.hide_action_types(&[zed_actions::OpenZedPredictOnboarding.type_id()]);
2872 });
2873
2874 cx.observe_global::<SettingsStore>(move |cx| {
2875 let is_ai_disabled = DisableAiSettings::get_global(cx).disable_ai;
2876 let has_feature_flag = cx.has_flag::<PredictEditsRateCompletionsFeatureFlag>();
2877
2878 CommandPaletteFilter::update_global(cx, |filter, _cx| {
2879 if is_ai_disabled {
2880 filter.hide_action_types(&zeta_all_action_types);
2881 } else if has_feature_flag {
2882 filter.show_action_types(&rate_completion_action_types);
2883 } else {
2884 filter.hide_action_types(&rate_completion_action_types);
2885 }
2886 });
2887 })
2888 .detach();
2889
2890 cx.observe_flag::<PredictEditsRateCompletionsFeatureFlag, _>(move |is_enabled, cx| {
2891 if !DisableAiSettings::get_global(cx).disable_ai {
2892 if is_enabled {
2893 CommandPaletteFilter::update_global(cx, |filter, _cx| {
2894 filter.show_action_types(&rate_completion_action_types);
2895 });
2896 } else {
2897 CommandPaletteFilter::update_global(cx, |filter, _cx| {
2898 filter.hide_action_types(&rate_completion_action_types);
2899 });
2900 }
2901 }
2902 })
2903 .detach();
2904}
2905
2906#[cfg(test)]
2907mod tests {
2908 use std::{path::Path, sync::Arc, time::Duration};
2909
2910 use client::UserStore;
2911 use clock::FakeSystemClock;
2912 use cloud_llm_client::{
2913 EditPredictionRejectReason, EditPredictionRejection, RejectEditPredictionsBody,
2914 };
2915 use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
2916 use futures::{
2917 AsyncReadExt, StreamExt,
2918 channel::{mpsc, oneshot},
2919 };
2920 use gpui::{
2921 Entity, TestAppContext,
2922 http_client::{FakeHttpClient, Response},
2923 prelude::*,
2924 };
2925 use indoc::indoc;
2926 use language::OffsetRangeExt as _;
2927 use open_ai::Usage;
2928 use pretty_assertions::{assert_eq, assert_matches};
2929 use project::{FakeFs, Project};
2930 use serde_json::json;
2931 use settings::SettingsStore;
2932 use util::path;
2933 use uuid::Uuid;
2934
2935 use crate::{BufferEditPrediction, EditPredictionId, REJECT_REQUEST_DEBOUNCE, Zeta};
2936
2937 #[gpui::test]
2938 async fn test_current_state(cx: &mut TestAppContext) {
2939 let (zeta, mut requests) = init_test(cx);
2940 let fs = FakeFs::new(cx.executor());
2941 fs.insert_tree(
2942 "/root",
2943 json!({
2944 "1.txt": "Hello!\nHow\nBye\n",
2945 "2.txt": "Hola!\nComo\nAdios\n"
2946 }),
2947 )
2948 .await;
2949 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2950
2951 zeta.update(cx, |zeta, cx| {
2952 zeta.register_project(&project, cx);
2953 });
2954
2955 let buffer1 = project
2956 .update(cx, |project, cx| {
2957 let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
2958 project.open_buffer(path, cx)
2959 })
2960 .await
2961 .unwrap();
2962 let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
2963 let position = snapshot1.anchor_before(language::Point::new(1, 3));
2964
2965 // Prediction for current file
2966
2967 zeta.update(cx, |zeta, cx| {
2968 zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
2969 });
2970 let (_request, respond_tx) = requests.predict.next().await.unwrap();
2971
2972 respond_tx
2973 .send(model_response(indoc! {r"
2974 --- a/root/1.txt
2975 +++ b/root/1.txt
2976 @@ ... @@
2977 Hello!
2978 -How
2979 +How are you?
2980 Bye
2981 "}))
2982 .unwrap();
2983
2984 cx.run_until_parked();
2985
2986 zeta.read_with(cx, |zeta, cx| {
2987 let prediction = zeta
2988 .current_prediction_for_buffer(&buffer1, &project, cx)
2989 .unwrap();
2990 assert_matches!(prediction, BufferEditPrediction::Local { .. });
2991 });
2992
2993 // Context refresh
2994 let refresh_task = zeta.update(cx, |zeta, cx| {
2995 zeta.refresh_context(project.clone(), buffer1.clone(), position, cx)
2996 });
2997 let (_request, respond_tx) = requests.predict.next().await.unwrap();
2998 respond_tx
2999 .send(open_ai::Response {
3000 id: Uuid::new_v4().to_string(),
3001 object: "response".into(),
3002 created: 0,
3003 model: "model".into(),
3004 choices: vec![open_ai::Choice {
3005 index: 0,
3006 message: open_ai::RequestMessage::Assistant {
3007 content: None,
3008 tool_calls: vec![open_ai::ToolCall {
3009 id: "search".into(),
3010 content: open_ai::ToolCallContent::Function {
3011 function: open_ai::FunctionContent {
3012 name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME
3013 .to_string(),
3014 arguments: serde_json::to_string(&SearchToolInput {
3015 queries: Box::new([SearchToolQuery {
3016 glob: "root/2.txt".to_string(),
3017 syntax_node: vec![],
3018 content: Some(".".into()),
3019 }]),
3020 })
3021 .unwrap(),
3022 },
3023 },
3024 }],
3025 },
3026 finish_reason: None,
3027 }],
3028 usage: Usage {
3029 prompt_tokens: 0,
3030 completion_tokens: 0,
3031 total_tokens: 0,
3032 },
3033 })
3034 .unwrap();
3035 refresh_task.await.unwrap();
3036
3037 zeta.update(cx, |zeta, _cx| {
3038 zeta.reject_current_prediction(EditPredictionRejectReason::Discarded, &project);
3039 });
3040
3041 // Prediction for another file
3042 zeta.update(cx, |zeta, cx| {
3043 zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
3044 });
3045 let (_request, respond_tx) = requests.predict.next().await.unwrap();
3046 respond_tx
3047 .send(model_response(indoc! {r#"
3048 --- a/root/2.txt
3049 +++ b/root/2.txt
3050 Hola!
3051 -Como
3052 +Como estas?
3053 Adios
3054 "#}))
3055 .unwrap();
3056 cx.run_until_parked();
3057
3058 zeta.read_with(cx, |zeta, cx| {
3059 let prediction = zeta
3060 .current_prediction_for_buffer(&buffer1, &project, cx)
3061 .unwrap();
3062 assert_matches!(
3063 prediction,
3064 BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
3065 );
3066 });
3067
3068 let buffer2 = project
3069 .update(cx, |project, cx| {
3070 let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
3071 project.open_buffer(path, cx)
3072 })
3073 .await
3074 .unwrap();
3075
3076 zeta.read_with(cx, |zeta, cx| {
3077 let prediction = zeta
3078 .current_prediction_for_buffer(&buffer2, &project, cx)
3079 .unwrap();
3080 assert_matches!(prediction, BufferEditPrediction::Local { .. });
3081 });
3082 }
3083
3084 #[gpui::test]
3085 async fn test_simple_request(cx: &mut TestAppContext) {
3086 let (zeta, mut requests) = init_test(cx);
3087 let fs = FakeFs::new(cx.executor());
3088 fs.insert_tree(
3089 "/root",
3090 json!({
3091 "foo.md": "Hello!\nHow\nBye\n"
3092 }),
3093 )
3094 .await;
3095 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3096
3097 let buffer = project
3098 .update(cx, |project, cx| {
3099 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3100 project.open_buffer(path, cx)
3101 })
3102 .await
3103 .unwrap();
3104 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3105 let position = snapshot.anchor_before(language::Point::new(1, 3));
3106
3107 let prediction_task = zeta.update(cx, |zeta, cx| {
3108 zeta.request_prediction(&project, &buffer, position, Default::default(), cx)
3109 });
3110
3111 let (_, respond_tx) = requests.predict.next().await.unwrap();
3112
3113 // TODO Put back when we have a structured request again
3114 // assert_eq!(
3115 // request.excerpt_path.as_ref(),
3116 // Path::new(path!("root/foo.md"))
3117 // );
3118 // assert_eq!(
3119 // request.cursor_point,
3120 // Point {
3121 // line: Line(1),
3122 // column: 3
3123 // }
3124 // );
3125
3126 respond_tx
3127 .send(model_response(indoc! { r"
3128 --- a/root/foo.md
3129 +++ b/root/foo.md
3130 @@ ... @@
3131 Hello!
3132 -How
3133 +How are you?
3134 Bye
3135 "}))
3136 .unwrap();
3137
3138 let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
3139
3140 assert_eq!(prediction.edits.len(), 1);
3141 assert_eq!(
3142 prediction.edits[0].0.to_point(&snapshot).start,
3143 language::Point::new(1, 3)
3144 );
3145 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
3146 }
3147
3148 #[gpui::test]
3149 async fn test_request_events(cx: &mut TestAppContext) {
3150 let (zeta, mut requests) = init_test(cx);
3151 let fs = FakeFs::new(cx.executor());
3152 fs.insert_tree(
3153 "/root",
3154 json!({
3155 "foo.md": "Hello!\n\nBye\n"
3156 }),
3157 )
3158 .await;
3159 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3160
3161 let buffer = project
3162 .update(cx, |project, cx| {
3163 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3164 project.open_buffer(path, cx)
3165 })
3166 .await
3167 .unwrap();
3168
3169 zeta.update(cx, |zeta, cx| {
3170 zeta.register_buffer(&buffer, &project, cx);
3171 });
3172
3173 buffer.update(cx, |buffer, cx| {
3174 buffer.edit(vec![(7..7, "How")], None, cx);
3175 });
3176
3177 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3178 let position = snapshot.anchor_before(language::Point::new(1, 3));
3179
3180 let prediction_task = zeta.update(cx, |zeta, cx| {
3181 zeta.request_prediction(&project, &buffer, position, Default::default(), cx)
3182 });
3183
3184 let (request, respond_tx) = requests.predict.next().await.unwrap();
3185
3186 let prompt = prompt_from_request(&request);
3187 assert!(
3188 prompt.contains(indoc! {"
3189 --- a/root/foo.md
3190 +++ b/root/foo.md
3191 @@ -1,3 +1,3 @@
3192 Hello!
3193 -
3194 +How
3195 Bye
3196 "}),
3197 "{prompt}"
3198 );
3199
3200 respond_tx
3201 .send(model_response(indoc! {r#"
3202 --- a/root/foo.md
3203 +++ b/root/foo.md
3204 @@ ... @@
3205 Hello!
3206 -How
3207 +How are you?
3208 Bye
3209 "#}))
3210 .unwrap();
3211
3212 let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
3213
3214 assert_eq!(prediction.edits.len(), 1);
3215 assert_eq!(
3216 prediction.edits[0].0.to_point(&snapshot).start,
3217 language::Point::new(1, 3)
3218 );
3219 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
3220 }
3221
3222 #[gpui::test]
3223 async fn test_empty_prediction(cx: &mut TestAppContext) {
3224 let (zeta, mut requests) = init_test(cx);
3225 let fs = FakeFs::new(cx.executor());
3226 fs.insert_tree(
3227 "/root",
3228 json!({
3229 "foo.md": "Hello!\nHow\nBye\n"
3230 }),
3231 )
3232 .await;
3233 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3234
3235 let buffer = project
3236 .update(cx, |project, cx| {
3237 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3238 project.open_buffer(path, cx)
3239 })
3240 .await
3241 .unwrap();
3242 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3243 let position = snapshot.anchor_before(language::Point::new(1, 3));
3244
3245 zeta.update(cx, |zeta, cx| {
3246 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3247 });
3248
3249 const NO_OP_DIFF: &str = indoc! { r"
3250 --- a/root/foo.md
3251 +++ b/root/foo.md
3252 @@ ... @@
3253 Hello!
3254 -How
3255 +How
3256 Bye
3257 "};
3258
3259 let (_, respond_tx) = requests.predict.next().await.unwrap();
3260 let response = model_response(NO_OP_DIFF);
3261 let id = response.id.clone();
3262 respond_tx.send(response).unwrap();
3263
3264 cx.run_until_parked();
3265
3266 zeta.read_with(cx, |zeta, cx| {
3267 assert!(
3268 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3269 .is_none()
3270 );
3271 });
3272
3273 // prediction is reported as rejected
3274 let (reject_request, _) = requests.reject.next().await.unwrap();
3275
3276 assert_eq!(
3277 &reject_request.rejections,
3278 &[EditPredictionRejection {
3279 request_id: id,
3280 reason: EditPredictionRejectReason::Empty,
3281 was_shown: false
3282 }]
3283 );
3284 }
3285
3286 #[gpui::test]
3287 async fn test_interpolated_empty(cx: &mut TestAppContext) {
3288 let (zeta, mut requests) = init_test(cx);
3289 let fs = FakeFs::new(cx.executor());
3290 fs.insert_tree(
3291 "/root",
3292 json!({
3293 "foo.md": "Hello!\nHow\nBye\n"
3294 }),
3295 )
3296 .await;
3297 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3298
3299 let buffer = project
3300 .update(cx, |project, cx| {
3301 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3302 project.open_buffer(path, cx)
3303 })
3304 .await
3305 .unwrap();
3306 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3307 let position = snapshot.anchor_before(language::Point::new(1, 3));
3308
3309 zeta.update(cx, |zeta, cx| {
3310 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3311 });
3312
3313 let (_, respond_tx) = requests.predict.next().await.unwrap();
3314
3315 buffer.update(cx, |buffer, cx| {
3316 buffer.set_text("Hello!\nHow are you?\nBye", cx);
3317 });
3318
3319 let response = model_response(SIMPLE_DIFF);
3320 let id = response.id.clone();
3321 respond_tx.send(response).unwrap();
3322
3323 cx.run_until_parked();
3324
3325 zeta.read_with(cx, |zeta, cx| {
3326 assert!(
3327 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3328 .is_none()
3329 );
3330 });
3331
3332 // prediction is reported as rejected
3333 let (reject_request, _) = requests.reject.next().await.unwrap();
3334
3335 assert_eq!(
3336 &reject_request.rejections,
3337 &[EditPredictionRejection {
3338 request_id: id,
3339 reason: EditPredictionRejectReason::InterpolatedEmpty,
3340 was_shown: false
3341 }]
3342 );
3343 }
3344
3345 const SIMPLE_DIFF: &str = indoc! { r"
3346 --- a/root/foo.md
3347 +++ b/root/foo.md
3348 @@ ... @@
3349 Hello!
3350 -How
3351 +How are you?
3352 Bye
3353 "};
3354
3355 #[gpui::test]
3356 async fn test_replace_current(cx: &mut TestAppContext) {
3357 let (zeta, mut requests) = init_test(cx);
3358 let fs = FakeFs::new(cx.executor());
3359 fs.insert_tree(
3360 "/root",
3361 json!({
3362 "foo.md": "Hello!\nHow\nBye\n"
3363 }),
3364 )
3365 .await;
3366 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3367
3368 let buffer = project
3369 .update(cx, |project, cx| {
3370 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3371 project.open_buffer(path, cx)
3372 })
3373 .await
3374 .unwrap();
3375 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3376 let position = snapshot.anchor_before(language::Point::new(1, 3));
3377
3378 zeta.update(cx, |zeta, cx| {
3379 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3380 });
3381
3382 let (_, respond_tx) = requests.predict.next().await.unwrap();
3383 let first_response = model_response(SIMPLE_DIFF);
3384 let first_id = first_response.id.clone();
3385 respond_tx.send(first_response).unwrap();
3386
3387 cx.run_until_parked();
3388
3389 zeta.read_with(cx, |zeta, cx| {
3390 assert_eq!(
3391 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3392 .unwrap()
3393 .id
3394 .0,
3395 first_id
3396 );
3397 });
3398
3399 // a second request is triggered
3400 zeta.update(cx, |zeta, cx| {
3401 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3402 });
3403
3404 let (_, respond_tx) = requests.predict.next().await.unwrap();
3405 let second_response = model_response(SIMPLE_DIFF);
3406 let second_id = second_response.id.clone();
3407 respond_tx.send(second_response).unwrap();
3408
3409 cx.run_until_parked();
3410
3411 zeta.read_with(cx, |zeta, cx| {
3412 // second replaces first
3413 assert_eq!(
3414 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3415 .unwrap()
3416 .id
3417 .0,
3418 second_id
3419 );
3420 });
3421
3422 // first is reported as replaced
3423 let (reject_request, _) = requests.reject.next().await.unwrap();
3424
3425 assert_eq!(
3426 &reject_request.rejections,
3427 &[EditPredictionRejection {
3428 request_id: first_id,
3429 reason: EditPredictionRejectReason::Replaced,
3430 was_shown: false
3431 }]
3432 );
3433 }
3434
3435 #[gpui::test]
3436 async fn test_current_preferred(cx: &mut TestAppContext) {
3437 let (zeta, mut requests) = init_test(cx);
3438 let fs = FakeFs::new(cx.executor());
3439 fs.insert_tree(
3440 "/root",
3441 json!({
3442 "foo.md": "Hello!\nHow\nBye\n"
3443 }),
3444 )
3445 .await;
3446 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3447
3448 let buffer = project
3449 .update(cx, |project, cx| {
3450 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3451 project.open_buffer(path, cx)
3452 })
3453 .await
3454 .unwrap();
3455 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3456 let position = snapshot.anchor_before(language::Point::new(1, 3));
3457
3458 zeta.update(cx, |zeta, cx| {
3459 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3460 });
3461
3462 let (_, respond_tx) = requests.predict.next().await.unwrap();
3463 let first_response = model_response(SIMPLE_DIFF);
3464 let first_id = first_response.id.clone();
3465 respond_tx.send(first_response).unwrap();
3466
3467 cx.run_until_parked();
3468
3469 zeta.read_with(cx, |zeta, cx| {
3470 assert_eq!(
3471 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3472 .unwrap()
3473 .id
3474 .0,
3475 first_id
3476 );
3477 });
3478
3479 // a second request is triggered
3480 zeta.update(cx, |zeta, cx| {
3481 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3482 });
3483
3484 let (_, respond_tx) = requests.predict.next().await.unwrap();
3485 // worse than current prediction
3486 let second_response = model_response(indoc! { r"
3487 --- a/root/foo.md
3488 +++ b/root/foo.md
3489 @@ ... @@
3490 Hello!
3491 -How
3492 +How are
3493 Bye
3494 "});
3495 let second_id = second_response.id.clone();
3496 respond_tx.send(second_response).unwrap();
3497
3498 cx.run_until_parked();
3499
3500 zeta.read_with(cx, |zeta, cx| {
3501 // first is preferred over second
3502 assert_eq!(
3503 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3504 .unwrap()
3505 .id
3506 .0,
3507 first_id
3508 );
3509 });
3510
3511 // second is reported as rejected
3512 let (reject_request, _) = requests.reject.next().await.unwrap();
3513
3514 assert_eq!(
3515 &reject_request.rejections,
3516 &[EditPredictionRejection {
3517 request_id: second_id,
3518 reason: EditPredictionRejectReason::CurrentPreferred,
3519 was_shown: false
3520 }]
3521 );
3522 }
3523
3524 #[gpui::test]
3525 async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
3526 let (zeta, mut requests) = init_test(cx);
3527 let fs = FakeFs::new(cx.executor());
3528 fs.insert_tree(
3529 "/root",
3530 json!({
3531 "foo.md": "Hello!\nHow\nBye\n"
3532 }),
3533 )
3534 .await;
3535 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3536
3537 let buffer = project
3538 .update(cx, |project, cx| {
3539 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3540 project.open_buffer(path, cx)
3541 })
3542 .await
3543 .unwrap();
3544 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3545 let position = snapshot.anchor_before(language::Point::new(1, 3));
3546
3547 // start two refresh tasks
3548 zeta.update(cx, |zeta, cx| {
3549 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3550 });
3551
3552 let (_, respond_first) = requests.predict.next().await.unwrap();
3553
3554 zeta.update(cx, |zeta, cx| {
3555 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3556 });
3557
3558 let (_, respond_second) = requests.predict.next().await.unwrap();
3559
3560 // wait for throttle
3561 cx.run_until_parked();
3562
3563 // second responds first
3564 let second_response = model_response(SIMPLE_DIFF);
3565 let second_id = second_response.id.clone();
3566 respond_second.send(second_response).unwrap();
3567
3568 cx.run_until_parked();
3569
3570 zeta.read_with(cx, |zeta, cx| {
3571 // current prediction is second
3572 assert_eq!(
3573 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3574 .unwrap()
3575 .id
3576 .0,
3577 second_id
3578 );
3579 });
3580
3581 let first_response = model_response(SIMPLE_DIFF);
3582 let first_id = first_response.id.clone();
3583 respond_first.send(first_response).unwrap();
3584
3585 cx.run_until_parked();
3586
3587 zeta.read_with(cx, |zeta, cx| {
3588 // current prediction is still second, since first was cancelled
3589 assert_eq!(
3590 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3591 .unwrap()
3592 .id
3593 .0,
3594 second_id
3595 );
3596 });
3597
3598 // first is reported as rejected
3599 let (reject_request, _) = requests.reject.next().await.unwrap();
3600
3601 cx.run_until_parked();
3602
3603 assert_eq!(
3604 &reject_request.rejections,
3605 &[EditPredictionRejection {
3606 request_id: first_id,
3607 reason: EditPredictionRejectReason::Canceled,
3608 was_shown: false
3609 }]
3610 );
3611 }
3612
3613 #[gpui::test]
3614 async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
3615 let (zeta, mut requests) = init_test(cx);
3616 let fs = FakeFs::new(cx.executor());
3617 fs.insert_tree(
3618 "/root",
3619 json!({
3620 "foo.md": "Hello!\nHow\nBye\n"
3621 }),
3622 )
3623 .await;
3624 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3625
3626 let buffer = project
3627 .update(cx, |project, cx| {
3628 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3629 project.open_buffer(path, cx)
3630 })
3631 .await
3632 .unwrap();
3633 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3634 let position = snapshot.anchor_before(language::Point::new(1, 3));
3635
3636 // start two refresh tasks
3637 zeta.update(cx, |zeta, cx| {
3638 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3639 });
3640
3641 let (_, respond_first) = requests.predict.next().await.unwrap();
3642
3643 zeta.update(cx, |zeta, cx| {
3644 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3645 });
3646
3647 let (_, respond_second) = requests.predict.next().await.unwrap();
3648
3649 // wait for throttle, so requests are sent
3650 cx.run_until_parked();
3651
3652 zeta.update(cx, |zeta, cx| {
3653 // start a third request
3654 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3655
3656 // 2 are pending, so 2nd is cancelled
3657 assert_eq!(
3658 zeta.get_or_init_zeta_project(&project, cx)
3659 .cancelled_predictions
3660 .iter()
3661 .copied()
3662 .collect::<Vec<_>>(),
3663 [1]
3664 );
3665 });
3666
3667 // wait for throttle
3668 cx.run_until_parked();
3669
3670 let (_, respond_third) = requests.predict.next().await.unwrap();
3671
3672 let first_response = model_response(SIMPLE_DIFF);
3673 let first_id = first_response.id.clone();
3674 respond_first.send(first_response).unwrap();
3675
3676 cx.run_until_parked();
3677
3678 zeta.read_with(cx, |zeta, cx| {
3679 // current prediction is first
3680 assert_eq!(
3681 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3682 .unwrap()
3683 .id
3684 .0,
3685 first_id
3686 );
3687 });
3688
3689 let cancelled_response = model_response(SIMPLE_DIFF);
3690 let cancelled_id = cancelled_response.id.clone();
3691 respond_second.send(cancelled_response).unwrap();
3692
3693 cx.run_until_parked();
3694
3695 zeta.read_with(cx, |zeta, cx| {
3696 // current prediction is still first, since second was cancelled
3697 assert_eq!(
3698 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3699 .unwrap()
3700 .id
3701 .0,
3702 first_id
3703 );
3704 });
3705
3706 let third_response = model_response(SIMPLE_DIFF);
3707 let third_response_id = third_response.id.clone();
3708 respond_third.send(third_response).unwrap();
3709
3710 cx.run_until_parked();
3711
3712 zeta.read_with(cx, |zeta, cx| {
3713 // third completes and replaces first
3714 assert_eq!(
3715 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3716 .unwrap()
3717 .id
3718 .0,
3719 third_response_id
3720 );
3721 });
3722
3723 // second is reported as rejected
3724 let (reject_request, _) = requests.reject.next().await.unwrap();
3725
3726 cx.run_until_parked();
3727
3728 assert_eq!(
3729 &reject_request.rejections,
3730 &[
3731 EditPredictionRejection {
3732 request_id: cancelled_id,
3733 reason: EditPredictionRejectReason::Canceled,
3734 was_shown: false
3735 },
3736 EditPredictionRejection {
3737 request_id: first_id,
3738 reason: EditPredictionRejectReason::Replaced,
3739 was_shown: false
3740 }
3741 ]
3742 );
3743 }
3744
3745 #[gpui::test]
3746 async fn test_rejections_flushing(cx: &mut TestAppContext) {
3747 let (zeta, mut requests) = init_test(cx);
3748
3749 zeta.update(cx, |zeta, _cx| {
3750 zeta.reject_prediction(
3751 EditPredictionId("test-1".into()),
3752 EditPredictionRejectReason::Discarded,
3753 false,
3754 );
3755 zeta.reject_prediction(
3756 EditPredictionId("test-2".into()),
3757 EditPredictionRejectReason::Canceled,
3758 true,
3759 );
3760 });
3761
3762 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
3763 cx.run_until_parked();
3764
3765 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
3766 respond_tx.send(()).unwrap();
3767
3768 // batched
3769 assert_eq!(reject_request.rejections.len(), 2);
3770 assert_eq!(
3771 reject_request.rejections[0],
3772 EditPredictionRejection {
3773 request_id: "test-1".to_string(),
3774 reason: EditPredictionRejectReason::Discarded,
3775 was_shown: false
3776 }
3777 );
3778 assert_eq!(
3779 reject_request.rejections[1],
3780 EditPredictionRejection {
3781 request_id: "test-2".to_string(),
3782 reason: EditPredictionRejectReason::Canceled,
3783 was_shown: true
3784 }
3785 );
3786
3787 // Reaching batch size limit sends without debounce
3788 zeta.update(cx, |zeta, _cx| {
3789 for i in 0..70 {
3790 zeta.reject_prediction(
3791 EditPredictionId(format!("batch-{}", i).into()),
3792 EditPredictionRejectReason::Discarded,
3793 false,
3794 );
3795 }
3796 });
3797
3798 // First MAX/2 items are sent immediately
3799 cx.run_until_parked();
3800 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
3801 respond_tx.send(()).unwrap();
3802
3803 assert_eq!(reject_request.rejections.len(), 50);
3804 assert_eq!(reject_request.rejections[0].request_id, "batch-0");
3805 assert_eq!(reject_request.rejections[49].request_id, "batch-49");
3806
3807 // Remaining items are debounced with the next batch
3808 cx.executor().advance_clock(Duration::from_secs(15));
3809 cx.run_until_parked();
3810
3811 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
3812 respond_tx.send(()).unwrap();
3813
3814 assert_eq!(reject_request.rejections.len(), 20);
3815 assert_eq!(reject_request.rejections[0].request_id, "batch-50");
3816 assert_eq!(reject_request.rejections[19].request_id, "batch-69");
3817
3818 // Request failure
3819 zeta.update(cx, |zeta, _cx| {
3820 zeta.reject_prediction(
3821 EditPredictionId("retry-1".into()),
3822 EditPredictionRejectReason::Discarded,
3823 false,
3824 );
3825 });
3826
3827 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
3828 cx.run_until_parked();
3829
3830 let (reject_request, _respond_tx) = requests.reject.next().await.unwrap();
3831 assert_eq!(reject_request.rejections.len(), 1);
3832 assert_eq!(reject_request.rejections[0].request_id, "retry-1");
3833 // Simulate failure
3834 drop(_respond_tx);
3835
3836 // Add another rejection
3837 zeta.update(cx, |zeta, _cx| {
3838 zeta.reject_prediction(
3839 EditPredictionId("retry-2".into()),
3840 EditPredictionRejectReason::Discarded,
3841 false,
3842 );
3843 });
3844
3845 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
3846 cx.run_until_parked();
3847
3848 // Retry should include both the failed item and the new one
3849 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
3850 respond_tx.send(()).unwrap();
3851
3852 assert_eq!(reject_request.rejections.len(), 2);
3853 assert_eq!(reject_request.rejections[0].request_id, "retry-1");
3854 assert_eq!(reject_request.rejections[1].request_id, "retry-2");
3855 }
3856
3857 // Skipped until we start including diagnostics in prompt
3858 // #[gpui::test]
3859 // async fn test_request_diagnostics(cx: &mut TestAppContext) {
3860 // let (zeta, mut req_rx) = init_test(cx);
3861 // let fs = FakeFs::new(cx.executor());
3862 // fs.insert_tree(
3863 // "/root",
3864 // json!({
3865 // "foo.md": "Hello!\nBye"
3866 // }),
3867 // )
3868 // .await;
3869 // let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3870
3871 // let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
3872 // let diagnostic = lsp::Diagnostic {
3873 // range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
3874 // severity: Some(lsp::DiagnosticSeverity::ERROR),
3875 // message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
3876 // ..Default::default()
3877 // };
3878
3879 // project.update(cx, |project, cx| {
3880 // project.lsp_store().update(cx, |lsp_store, cx| {
3881 // // Create some diagnostics
3882 // lsp_store
3883 // .update_diagnostics(
3884 // LanguageServerId(0),
3885 // lsp::PublishDiagnosticsParams {
3886 // uri: path_to_buffer_uri.clone(),
3887 // diagnostics: vec![diagnostic],
3888 // version: None,
3889 // },
3890 // None,
3891 // language::DiagnosticSourceKind::Pushed,
3892 // &[],
3893 // cx,
3894 // )
3895 // .unwrap();
3896 // });
3897 // });
3898
3899 // let buffer = project
3900 // .update(cx, |project, cx| {
3901 // let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3902 // project.open_buffer(path, cx)
3903 // })
3904 // .await
3905 // .unwrap();
3906
3907 // let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3908 // let position = snapshot.anchor_before(language::Point::new(0, 0));
3909
3910 // let _prediction_task = zeta.update(cx, |zeta, cx| {
3911 // zeta.request_prediction(&project, &buffer, position, cx)
3912 // });
3913
3914 // let (request, _respond_tx) = req_rx.next().await.unwrap();
3915
3916 // assert_eq!(request.diagnostic_groups.len(), 1);
3917 // let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
3918 // .unwrap();
3919 // // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
3920 // assert_eq!(
3921 // value,
3922 // json!({
3923 // "entries": [{
3924 // "range": {
3925 // "start": 8,
3926 // "end": 10
3927 // },
3928 // "diagnostic": {
3929 // "source": null,
3930 // "code": null,
3931 // "code_description": null,
3932 // "severity": 1,
3933 // "message": "\"Hello\" deprecated. Use \"Hi\" instead",
3934 // "markdown": null,
3935 // "group_id": 0,
3936 // "is_primary": true,
3937 // "is_disk_based": false,
3938 // "is_unnecessary": false,
3939 // "source_kind": "Pushed",
3940 // "data": null,
3941 // "underline": true
3942 // }
3943 // }],
3944 // "primary_ix": 0
3945 // })
3946 // );
3947 // }
3948
3949 fn model_response(text: &str) -> open_ai::Response {
3950 open_ai::Response {
3951 id: Uuid::new_v4().to_string(),
3952 object: "response".into(),
3953 created: 0,
3954 model: "model".into(),
3955 choices: vec![open_ai::Choice {
3956 index: 0,
3957 message: open_ai::RequestMessage::Assistant {
3958 content: Some(open_ai::MessageContent::Plain(text.to_string())),
3959 tool_calls: vec![],
3960 },
3961 finish_reason: None,
3962 }],
3963 usage: Usage {
3964 prompt_tokens: 0,
3965 completion_tokens: 0,
3966 total_tokens: 0,
3967 },
3968 }
3969 }
3970
3971 fn prompt_from_request(request: &open_ai::Request) -> &str {
3972 assert_eq!(request.messages.len(), 1);
3973 let open_ai::RequestMessage::User {
3974 content: open_ai::MessageContent::Plain(content),
3975 ..
3976 } = &request.messages[0]
3977 else {
3978 panic!(
3979 "Request does not have single user message of type Plain. {:#?}",
3980 request
3981 );
3982 };
3983 content
3984 }
3985
3986 struct RequestChannels {
3987 predict: mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender<open_ai::Response>)>,
3988 reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>,
3989 }
3990
3991 fn init_test(cx: &mut TestAppContext) -> (Entity<Zeta>, RequestChannels) {
3992 cx.update(move |cx| {
3993 let settings_store = SettingsStore::test(cx);
3994 cx.set_global(settings_store);
3995 zlog::init_test();
3996
3997 let (predict_req_tx, predict_req_rx) = mpsc::unbounded();
3998 let (reject_req_tx, reject_req_rx) = mpsc::unbounded();
3999
4000 let http_client = FakeHttpClient::create({
4001 move |req| {
4002 let uri = req.uri().path().to_string();
4003 let mut body = req.into_body();
4004 let predict_req_tx = predict_req_tx.clone();
4005 let reject_req_tx = reject_req_tx.clone();
4006 async move {
4007 let resp = match uri.as_str() {
4008 "/client/llm_tokens" => serde_json::to_string(&json!({
4009 "token": "test"
4010 }))
4011 .unwrap(),
4012 "/predict_edits/raw" => {
4013 let mut buf = Vec::new();
4014 body.read_to_end(&mut buf).await.ok();
4015 let req = serde_json::from_slice(&buf).unwrap();
4016
4017 let (res_tx, res_rx) = oneshot::channel();
4018 predict_req_tx.unbounded_send((req, res_tx)).unwrap();
4019 serde_json::to_string(&res_rx.await?).unwrap()
4020 }
4021 "/predict_edits/reject" => {
4022 let mut buf = Vec::new();
4023 body.read_to_end(&mut buf).await.ok();
4024 let req = serde_json::from_slice(&buf).unwrap();
4025
4026 let (res_tx, res_rx) = oneshot::channel();
4027 reject_req_tx.unbounded_send((req, res_tx)).unwrap();
4028 serde_json::to_string(&res_rx.await?).unwrap()
4029 }
4030 _ => {
4031 panic!("Unexpected path: {}", uri)
4032 }
4033 };
4034
4035 Ok(Response::builder().body(resp.into()).unwrap())
4036 }
4037 }
4038 });
4039
4040 let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
4041 client.cloud_client().set_credentials(1, "test".into());
4042
4043 language_model::init(client.clone(), cx);
4044
4045 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
4046 let zeta = Zeta::global(&client, &user_store, cx);
4047
4048 (
4049 zeta,
4050 RequestChannels {
4051 predict: predict_req_rx,
4052 reject: reject_req_rx,
4053 },
4054 )
4055 })
4056 }
4057}