1use anyhow::{Context as _, Result, anyhow, bail};
2use arrayvec::ArrayVec;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat, Signature};
5use cloud_llm_client::{
6 AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason,
7 EditPredictionRejection, MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST,
8 MINIMUM_REQUIRED_VERSION_HEADER_NAME, PredictEditsRequestTrigger, RejectEditPredictionsBodyRef,
9 ZED_VERSION_HEADER_NAME,
10};
11use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
12use cloud_zeta2_prompt::{CURSOR_MARKER, DEFAULT_MAX_PROMPT_BYTES};
13use collections::{HashMap, HashSet};
14use command_palette_hooks::CommandPaletteFilter;
15use db::kvp::{Dismissable, KEY_VALUE_STORE};
16use edit_prediction_context::{
17 DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
18 EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionScoreOptions, Line,
19 SyntaxIndex, SyntaxIndexState,
20};
21use feature_flags::{FeatureFlag, FeatureFlagAppExt as _, PredictEditsRateCompletionsFeatureFlag};
22use futures::channel::mpsc::UnboundedReceiver;
23use futures::channel::{mpsc, oneshot};
24use futures::{AsyncReadExt as _, FutureExt as _, StreamExt as _, select_biased};
25use gpui::BackgroundExecutor;
26use gpui::{
27 App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
28 http_client::{self, AsyncBody, Method},
29 prelude::*,
30};
31use language::{
32 Anchor, Buffer, DiagnosticSet, File, LanguageServerId, Point, ToOffset as _, ToPoint,
33};
34use language::{BufferSnapshot, OffsetRangeExt};
35use language_model::{LlmApiToken, RefreshLlmTokenListener};
36use open_ai::FunctionDefinition;
37use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
38use release_channel::AppVersion;
39use semver::Version;
40use serde::de::DeserializeOwned;
41use settings::{EditPredictionProvider, Settings as _, SettingsStore, update_settings_file};
42use std::any::{Any as _, TypeId};
43use std::collections::{VecDeque, hash_map};
44use telemetry_events::EditPredictionRating;
45use workspace::Workspace;
46
47use std::ops::Range;
48use std::path::Path;
49use std::rc::Rc;
50use std::str::FromStr as _;
51use std::sync::{Arc, LazyLock};
52use std::time::{Duration, Instant};
53use std::{env, mem};
54use thiserror::Error;
55use util::rel_path::RelPathBuf;
56use util::{LogErrorFuture, RangeExt as _, ResultExt as _, TryFutureExt};
57use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
58
59pub mod assemble_excerpts;
60mod license_detection;
61mod onboarding_modal;
62mod prediction;
63mod provider;
64mod rate_prediction_modal;
65pub mod retrieval_search;
66pub mod sweep_ai;
67pub mod udiff;
68mod xml_edits;
69pub mod zeta1;
70
71#[cfg(test)]
72mod zeta_tests;
73
74use crate::assemble_excerpts::assemble_excerpts;
75use crate::license_detection::LicenseDetectionWatcher;
76use crate::onboarding_modal::ZedPredictModal;
77pub use crate::prediction::EditPrediction;
78pub use crate::prediction::EditPredictionId;
79pub use crate::prediction::EditPredictionInputs;
80use crate::prediction::EditPredictionResult;
81use crate::rate_prediction_modal::{
82 NextEdit, PreviousEdit, RatePredictionsModal, ThumbsDownActivePrediction,
83 ThumbsUpActivePrediction,
84};
85pub use crate::sweep_ai::SweepAi;
86use crate::zeta1::request_prediction_with_zeta1;
87pub use provider::ZetaEditPredictionProvider;
88
89actions!(
90 edit_prediction,
91 [
92 /// Resets the edit prediction onboarding state.
93 ResetOnboarding,
94 /// Opens the rate completions modal.
95 RateCompletions,
96 /// Clears the edit prediction history.
97 ClearHistory,
98 ]
99);
100
101/// Maximum number of events to track.
102const EVENT_COUNT_MAX: usize = 6;
103const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
104const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
105const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
106
107pub struct SweepFeatureFlag;
108
109impl FeatureFlag for SweepFeatureFlag {
110 const NAME: &str = "sweep-ai";
111}
112pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
113 max_bytes: 512,
114 min_bytes: 128,
115 target_before_cursor_over_total_bytes: 0.5,
116};
117
118pub const DEFAULT_CONTEXT_OPTIONS: ContextMode =
119 ContextMode::Agentic(DEFAULT_AGENTIC_CONTEXT_OPTIONS);
120
121pub const DEFAULT_AGENTIC_CONTEXT_OPTIONS: AgenticContextOptions = AgenticContextOptions {
122 excerpt: DEFAULT_EXCERPT_OPTIONS,
123};
124
125pub const DEFAULT_SYNTAX_CONTEXT_OPTIONS: EditPredictionContextOptions =
126 EditPredictionContextOptions {
127 use_imports: true,
128 max_retrieved_declarations: 0,
129 excerpt: DEFAULT_EXCERPT_OPTIONS,
130 score: EditPredictionScoreOptions {
131 omit_excerpt_overlaps: true,
132 },
133 };
134
135pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
136 context: DEFAULT_CONTEXT_OPTIONS,
137 max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
138 max_diagnostic_bytes: 2048,
139 prompt_format: PromptFormat::DEFAULT,
140 file_indexing_parallelism: 1,
141 buffer_change_grouping_interval: Duration::from_secs(1),
142};
143
144static USE_OLLAMA: LazyLock<bool> =
145 LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
146static CONTEXT_RETRIEVAL_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
147 env::var("ZED_ZETA2_CONTEXT_MODEL").unwrap_or(if *USE_OLLAMA {
148 "qwen3-coder:30b".to_string()
149 } else {
150 "yqvev8r3".to_string()
151 })
152});
153static EDIT_PREDICTIONS_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
154 match env::var("ZED_ZETA2_MODEL").as_deref() {
155 Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten
156 Ok(model) => model,
157 Err(_) if *USE_OLLAMA => "qwen3-coder:30b",
158 Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten
159 }
160 .to_string()
161});
162static PREDICT_EDITS_URL: LazyLock<Option<String>> = LazyLock::new(|| {
163 env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| {
164 if *USE_OLLAMA {
165 Some("http://localhost:11434/v1/chat/completions".into())
166 } else {
167 None
168 }
169 })
170});
171
172pub struct Zeta2FeatureFlag;
173
174impl FeatureFlag for Zeta2FeatureFlag {
175 const NAME: &'static str = "zeta2";
176
177 fn enabled_for_staff() -> bool {
178 true
179 }
180}
181
182#[derive(Clone)]
183struct ZetaGlobal(Entity<Zeta>);
184
185impl Global for ZetaGlobal {}
186
187pub struct Zeta {
188 client: Arc<Client>,
189 user_store: Entity<UserStore>,
190 llm_token: LlmApiToken,
191 _llm_token_subscription: Subscription,
192 projects: HashMap<EntityId, ZetaProject>,
193 options: ZetaOptions,
194 update_required: bool,
195 debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
196 #[cfg(feature = "eval-support")]
197 eval_cache: Option<Arc<dyn EvalCache>>,
198 edit_prediction_model: ZetaEditPredictionModel,
199 pub sweep_ai: SweepAi,
200 data_collection_choice: DataCollectionChoice,
201 reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
202 shown_predictions: VecDeque<EditPrediction>,
203 rated_predictions: HashSet<EditPredictionId>,
204}
205
206#[derive(Copy, Clone, Default, PartialEq, Eq)]
207pub enum ZetaEditPredictionModel {
208 #[default]
209 Zeta1,
210 Zeta2,
211 Sweep,
212}
213
214#[derive(Debug, Clone, PartialEq)]
215pub struct ZetaOptions {
216 pub context: ContextMode,
217 pub max_prompt_bytes: usize,
218 pub max_diagnostic_bytes: usize,
219 pub prompt_format: predict_edits_v3::PromptFormat,
220 pub file_indexing_parallelism: usize,
221 pub buffer_change_grouping_interval: Duration,
222}
223
224#[derive(Debug, Clone, PartialEq)]
225pub enum ContextMode {
226 Agentic(AgenticContextOptions),
227 Syntax(EditPredictionContextOptions),
228}
229
230#[derive(Debug, Clone, PartialEq)]
231pub struct AgenticContextOptions {
232 pub excerpt: EditPredictionExcerptOptions,
233}
234
235impl ContextMode {
236 pub fn excerpt(&self) -> &EditPredictionExcerptOptions {
237 match self {
238 ContextMode::Agentic(options) => &options.excerpt,
239 ContextMode::Syntax(options) => &options.excerpt,
240 }
241 }
242}
243
244#[derive(Debug)]
245pub enum ZetaDebugInfo {
246 ContextRetrievalStarted(ZetaContextRetrievalStartedDebugInfo),
247 SearchQueriesGenerated(ZetaSearchQueryDebugInfo),
248 SearchQueriesExecuted(ZetaContextRetrievalDebugInfo),
249 ContextRetrievalFinished(ZetaContextRetrievalDebugInfo),
250 EditPredictionRequested(ZetaEditPredictionDebugInfo),
251}
252
253#[derive(Debug)]
254pub struct ZetaContextRetrievalStartedDebugInfo {
255 pub project: Entity<Project>,
256 pub timestamp: Instant,
257 pub search_prompt: String,
258}
259
260#[derive(Debug)]
261pub struct ZetaContextRetrievalDebugInfo {
262 pub project: Entity<Project>,
263 pub timestamp: Instant,
264}
265
266#[derive(Debug)]
267pub struct ZetaEditPredictionDebugInfo {
268 pub inputs: EditPredictionInputs,
269 pub retrieval_time: Duration,
270 pub buffer: WeakEntity<Buffer>,
271 pub position: language::Anchor,
272 pub local_prompt: Result<String, String>,
273 pub response_rx: oneshot::Receiver<(Result<open_ai::Response, String>, Duration)>,
274}
275
276#[derive(Debug)]
277pub struct ZetaSearchQueryDebugInfo {
278 pub project: Entity<Project>,
279 pub timestamp: Instant,
280 pub search_queries: Vec<SearchToolQuery>,
281}
282
283pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
284
285struct ZetaProject {
286 syntax_index: Option<Entity<SyntaxIndex>>,
287 events: VecDeque<Arc<cloud_llm_client::predict_edits_v3::Event>>,
288 last_event: Option<LastEvent>,
289 recent_paths: VecDeque<ProjectPath>,
290 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
291 current_prediction: Option<CurrentEditPrediction>,
292 next_pending_prediction_id: usize,
293 pending_predictions: ArrayVec<PendingPrediction, 2>,
294 last_prediction_refresh: Option<(EntityId, Instant)>,
295 cancelled_predictions: HashSet<usize>,
296 context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
297 refresh_context_task: Option<LogErrorFuture<Task<Result<()>>>>,
298 refresh_context_debounce_task: Option<Task<Option<()>>>,
299 refresh_context_timestamp: Option<Instant>,
300 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
301 _subscription: gpui::Subscription,
302}
303
304impl ZetaProject {
305 pub fn events(&self, cx: &App) -> Vec<Arc<cloud_llm_client::predict_edits_v3::Event>> {
306 self.events
307 .iter()
308 .cloned()
309 .chain(
310 self.last_event
311 .as_ref()
312 .and_then(|event| event.finalize(&self.license_detection_watchers, cx)),
313 )
314 .collect()
315 }
316
317 fn cancel_pending_prediction(
318 &mut self,
319 pending_prediction: PendingPrediction,
320 cx: &mut Context<Zeta>,
321 ) {
322 self.cancelled_predictions.insert(pending_prediction.id);
323
324 cx.spawn(async move |this, cx| {
325 let Some(prediction_id) = pending_prediction.task.await else {
326 return;
327 };
328
329 this.update(cx, |this, _cx| {
330 this.reject_prediction(prediction_id, EditPredictionRejectReason::Canceled, false);
331 })
332 .ok();
333 })
334 .detach()
335 }
336}
337
338#[derive(Debug, Clone)]
339struct CurrentEditPrediction {
340 pub requested_by: PredictionRequestedBy,
341 pub prediction: EditPrediction,
342 pub was_shown: bool,
343}
344
345impl CurrentEditPrediction {
346 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
347 let Some(new_edits) = self
348 .prediction
349 .interpolate(&self.prediction.buffer.read(cx))
350 else {
351 return false;
352 };
353
354 if self.prediction.buffer != old_prediction.prediction.buffer {
355 return true;
356 }
357
358 let Some(old_edits) = old_prediction
359 .prediction
360 .interpolate(&old_prediction.prediction.buffer.read(cx))
361 else {
362 return true;
363 };
364
365 let requested_by_buffer_id = self.requested_by.buffer_id();
366
367 // This reduces the occurrence of UI thrash from replacing edits
368 //
369 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
370 if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
371 && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
372 && old_edits.len() == 1
373 && new_edits.len() == 1
374 {
375 let (old_range, old_text) = &old_edits[0];
376 let (new_range, new_text) = &new_edits[0];
377 new_range == old_range && new_text.starts_with(old_text.as_ref())
378 } else {
379 true
380 }
381 }
382}
383
384#[derive(Debug, Clone)]
385enum PredictionRequestedBy {
386 DiagnosticsUpdate,
387 Buffer(EntityId),
388}
389
390impl PredictionRequestedBy {
391 pub fn buffer_id(&self) -> Option<EntityId> {
392 match self {
393 PredictionRequestedBy::DiagnosticsUpdate => None,
394 PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
395 }
396 }
397}
398
399#[derive(Debug)]
400struct PendingPrediction {
401 id: usize,
402 task: Task<Option<EditPredictionId>>,
403}
404
405/// A prediction from the perspective of a buffer.
406#[derive(Debug)]
407enum BufferEditPrediction<'a> {
408 Local { prediction: &'a EditPrediction },
409 Jump { prediction: &'a EditPrediction },
410}
411
412#[cfg(test)]
413impl std::ops::Deref for BufferEditPrediction<'_> {
414 type Target = EditPrediction;
415
416 fn deref(&self) -> &Self::Target {
417 match self {
418 BufferEditPrediction::Local { prediction } => prediction,
419 BufferEditPrediction::Jump { prediction } => prediction,
420 }
421 }
422}
423
424struct RegisteredBuffer {
425 snapshot: BufferSnapshot,
426 _subscriptions: [gpui::Subscription; 2],
427}
428
429struct LastEvent {
430 old_snapshot: BufferSnapshot,
431 new_snapshot: BufferSnapshot,
432 end_edit_anchor: Option<Anchor>,
433}
434
435impl LastEvent {
436 pub fn finalize(
437 &self,
438 license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
439 cx: &App,
440 ) -> Option<Arc<predict_edits_v3::Event>> {
441 let path = buffer_path_with_id_fallback(&self.new_snapshot, cx);
442 let old_path = buffer_path_with_id_fallback(&self.old_snapshot, cx);
443
444 let file = self.new_snapshot.file();
445 let old_file = self.old_snapshot.file();
446
447 let in_open_source_repo = [file, old_file].iter().all(|file| {
448 file.is_some_and(|file| {
449 license_detection_watchers
450 .get(&file.worktree_id(cx))
451 .is_some_and(|watcher| watcher.is_project_open_source())
452 })
453 });
454
455 let diff = language::unified_diff(&self.old_snapshot.text(), &self.new_snapshot.text());
456
457 if path == old_path && diff.is_empty() {
458 None
459 } else {
460 Some(Arc::new(predict_edits_v3::Event::BufferChange {
461 old_path,
462 path,
463 diff,
464 in_open_source_repo,
465 // TODO: Actually detect if this edit was predicted or not
466 predicted: false,
467 }))
468 }
469 }
470}
471
472fn buffer_path_with_id_fallback(snapshot: &BufferSnapshot, cx: &App) -> Arc<Path> {
473 if let Some(file) = snapshot.file() {
474 file.full_path(cx).into()
475 } else {
476 Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
477 }
478}
479
480impl Zeta {
481 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
482 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
483 }
484
485 pub fn global(
486 client: &Arc<Client>,
487 user_store: &Entity<UserStore>,
488 cx: &mut App,
489 ) -> Entity<Self> {
490 cx.try_global::<ZetaGlobal>()
491 .map(|global| global.0.clone())
492 .unwrap_or_else(|| {
493 let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
494 cx.set_global(ZetaGlobal(zeta.clone()));
495 zeta
496 })
497 }
498
499 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
500 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
501 let data_collection_choice = Self::load_data_collection_choice();
502
503 let llm_token = LlmApiToken::default();
504
505 let (reject_tx, reject_rx) = mpsc::unbounded();
506 cx.background_spawn({
507 let client = client.clone();
508 let llm_token = llm_token.clone();
509 let app_version = AppVersion::global(cx);
510 let background_executor = cx.background_executor().clone();
511 async move {
512 Self::handle_rejected_predictions(
513 reject_rx,
514 client,
515 llm_token,
516 app_version,
517 background_executor,
518 )
519 .await
520 }
521 })
522 .detach();
523
524 Self {
525 projects: HashMap::default(),
526 client,
527 user_store,
528 options: DEFAULT_OPTIONS,
529 llm_token,
530 _llm_token_subscription: cx.subscribe(
531 &refresh_llm_token_listener,
532 |this, _listener, _event, cx| {
533 let client = this.client.clone();
534 let llm_token = this.llm_token.clone();
535 cx.spawn(async move |_this, _cx| {
536 llm_token.refresh(&client).await?;
537 anyhow::Ok(())
538 })
539 .detach_and_log_err(cx);
540 },
541 ),
542 update_required: false,
543 debug_tx: None,
544 #[cfg(feature = "eval-support")]
545 eval_cache: None,
546 edit_prediction_model: ZetaEditPredictionModel::Zeta2,
547 sweep_ai: SweepAi::new(cx),
548 data_collection_choice,
549 reject_predictions_tx: reject_tx,
550 rated_predictions: Default::default(),
551 shown_predictions: Default::default(),
552 }
553 }
554
555 pub fn set_edit_prediction_model(&mut self, model: ZetaEditPredictionModel) {
556 self.edit_prediction_model = model;
557 }
558
559 pub fn has_sweep_api_token(&self) -> bool {
560 self.sweep_ai
561 .api_token
562 .clone()
563 .now_or_never()
564 .flatten()
565 .is_some()
566 }
567
568 #[cfg(feature = "eval-support")]
569 pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
570 self.eval_cache = Some(cache);
571 }
572
573 pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<ZetaDebugInfo> {
574 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
575 self.debug_tx = Some(debug_watch_tx);
576 debug_watch_rx
577 }
578
579 pub fn options(&self) -> &ZetaOptions {
580 &self.options
581 }
582
583 pub fn set_options(&mut self, options: ZetaOptions) {
584 self.options = options;
585 }
586
587 pub fn clear_history(&mut self) {
588 for zeta_project in self.projects.values_mut() {
589 zeta_project.events.clear();
590 }
591 }
592
593 pub fn context_for_project(
594 &self,
595 project: &Entity<Project>,
596 ) -> impl Iterator<Item = (Entity<Buffer>, &[Range<Anchor>])> {
597 self.projects
598 .get(&project.entity_id())
599 .and_then(|project| {
600 Some(
601 project
602 .context
603 .as_ref()?
604 .iter()
605 .map(|(buffer, ranges)| (buffer.clone(), ranges.as_slice())),
606 )
607 })
608 .into_iter()
609 .flatten()
610 }
611
612 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
613 if self.edit_prediction_model == ZetaEditPredictionModel::Zeta2 {
614 self.user_store.read(cx).edit_prediction_usage()
615 } else {
616 None
617 }
618 }
619
620 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
621 self.get_or_init_zeta_project(project, cx);
622 }
623
624 pub fn register_buffer(
625 &mut self,
626 buffer: &Entity<Buffer>,
627 project: &Entity<Project>,
628 cx: &mut Context<Self>,
629 ) {
630 let zeta_project = self.get_or_init_zeta_project(project, cx);
631 Self::register_buffer_impl(zeta_project, buffer, project, cx);
632 }
633
634 fn get_or_init_zeta_project(
635 &mut self,
636 project: &Entity<Project>,
637 cx: &mut Context<Self>,
638 ) -> &mut ZetaProject {
639 self.projects
640 .entry(project.entity_id())
641 .or_insert_with(|| ZetaProject {
642 syntax_index: if let ContextMode::Syntax(_) = &self.options.context {
643 Some(cx.new(|cx| {
644 SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
645 }))
646 } else {
647 None
648 },
649 events: VecDeque::new(),
650 last_event: None,
651 recent_paths: VecDeque::new(),
652 registered_buffers: HashMap::default(),
653 current_prediction: None,
654 cancelled_predictions: HashSet::default(),
655 pending_predictions: ArrayVec::new(),
656 next_pending_prediction_id: 0,
657 last_prediction_refresh: None,
658 context: None,
659 refresh_context_task: None,
660 refresh_context_debounce_task: None,
661 refresh_context_timestamp: None,
662 license_detection_watchers: HashMap::default(),
663 _subscription: cx.subscribe(&project, Self::handle_project_event),
664 })
665 }
666
667 fn handle_project_event(
668 &mut self,
669 project: Entity<Project>,
670 event: &project::Event,
671 cx: &mut Context<Self>,
672 ) {
673 // TODO [zeta2] init with recent paths
674 match event {
675 project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
676 let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
677 return;
678 };
679 let path = project.read(cx).path_for_entry(*active_entry_id, cx);
680 if let Some(path) = path {
681 if let Some(ix) = zeta_project
682 .recent_paths
683 .iter()
684 .position(|probe| probe == &path)
685 {
686 zeta_project.recent_paths.remove(ix);
687 }
688 zeta_project.recent_paths.push_front(path);
689 }
690 }
691 project::Event::DiagnosticsUpdated { .. } => {
692 if cx.has_flag::<Zeta2FeatureFlag>() {
693 self.refresh_prediction_from_diagnostics(project, cx);
694 }
695 }
696 _ => (),
697 }
698 }
699
700 fn register_buffer_impl<'a>(
701 zeta_project: &'a mut ZetaProject,
702 buffer: &Entity<Buffer>,
703 project: &Entity<Project>,
704 cx: &mut Context<Self>,
705 ) -> &'a mut RegisteredBuffer {
706 let buffer_id = buffer.entity_id();
707
708 if let Some(file) = buffer.read(cx).file() {
709 let worktree_id = file.worktree_id(cx);
710 if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
711 zeta_project
712 .license_detection_watchers
713 .entry(worktree_id)
714 .or_insert_with(|| {
715 let project_entity_id = project.entity_id();
716 cx.observe_release(&worktree, move |this, _worktree, _cx| {
717 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
718 else {
719 return;
720 };
721 zeta_project.license_detection_watchers.remove(&worktree_id);
722 })
723 .detach();
724 Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
725 });
726 }
727 }
728
729 match zeta_project.registered_buffers.entry(buffer_id) {
730 hash_map::Entry::Occupied(entry) => entry.into_mut(),
731 hash_map::Entry::Vacant(entry) => {
732 let snapshot = buffer.read(cx).snapshot();
733 let project_entity_id = project.entity_id();
734 entry.insert(RegisteredBuffer {
735 snapshot,
736 _subscriptions: [
737 cx.subscribe(buffer, {
738 let project = project.downgrade();
739 move |this, buffer, event, cx| {
740 if let language::BufferEvent::Edited = event
741 && let Some(project) = project.upgrade()
742 {
743 this.report_changes_for_buffer(&buffer, &project, cx);
744 }
745 }
746 }),
747 cx.observe_release(buffer, move |this, _buffer, _cx| {
748 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
749 else {
750 return;
751 };
752 zeta_project.registered_buffers.remove(&buffer_id);
753 }),
754 ],
755 })
756 }
757 }
758 }
759
760 fn report_changes_for_buffer(
761 &mut self,
762 buffer: &Entity<Buffer>,
763 project: &Entity<Project>,
764 cx: &mut Context<Self>,
765 ) {
766 let project_state = self.get_or_init_zeta_project(project, cx);
767 let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
768
769 let new_snapshot = buffer.read(cx).snapshot();
770 if new_snapshot.version == registered_buffer.snapshot.version {
771 return;
772 }
773
774 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
775 let end_edit_anchor = new_snapshot
776 .anchored_edits_since::<Point>(&old_snapshot.version)
777 .last()
778 .map(|(_, range)| range.end);
779 let events = &mut project_state.events;
780
781 if let Some(LastEvent {
782 new_snapshot: last_new_snapshot,
783 end_edit_anchor: last_end_edit_anchor,
784 ..
785 }) = project_state.last_event.as_mut()
786 {
787 let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
788 == last_new_snapshot.remote_id()
789 && old_snapshot.version == last_new_snapshot.version;
790
791 let should_coalesce = is_next_snapshot_of_same_buffer
792 && end_edit_anchor
793 .as_ref()
794 .zip(last_end_edit_anchor.as_ref())
795 .is_some_and(|(a, b)| {
796 let a = a.to_point(&new_snapshot);
797 let b = b.to_point(&new_snapshot);
798 a.row.abs_diff(b.row) <= CHANGE_GROUPING_LINE_SPAN
799 });
800
801 if should_coalesce {
802 *last_end_edit_anchor = end_edit_anchor;
803 *last_new_snapshot = new_snapshot;
804 return;
805 }
806 }
807
808 if events.len() + 1 >= EVENT_COUNT_MAX {
809 events.pop_front();
810 }
811
812 if let Some(event) = project_state.last_event.take() {
813 events.extend(event.finalize(&project_state.license_detection_watchers, cx));
814 }
815
816 project_state.last_event = Some(LastEvent {
817 old_snapshot,
818 new_snapshot,
819 end_edit_anchor,
820 });
821 }
822
823 fn current_prediction_for_buffer(
824 &self,
825 buffer: &Entity<Buffer>,
826 project: &Entity<Project>,
827 cx: &App,
828 ) -> Option<BufferEditPrediction<'_>> {
829 let project_state = self.projects.get(&project.entity_id())?;
830
831 let CurrentEditPrediction {
832 requested_by,
833 prediction,
834 ..
835 } = project_state.current_prediction.as_ref()?;
836
837 if prediction.targets_buffer(buffer.read(cx)) {
838 Some(BufferEditPrediction::Local { prediction })
839 } else {
840 let show_jump = match requested_by {
841 PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
842 requested_by_buffer_id == &buffer.entity_id()
843 }
844 PredictionRequestedBy::DiagnosticsUpdate => true,
845 };
846
847 if show_jump {
848 Some(BufferEditPrediction::Jump { prediction })
849 } else {
850 None
851 }
852 }
853 }
854
855 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
856 match self.edit_prediction_model {
857 ZetaEditPredictionModel::Zeta1 | ZetaEditPredictionModel::Zeta2 => {}
858 ZetaEditPredictionModel::Sweep => return,
859 }
860
861 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
862 return;
863 };
864
865 let Some(prediction) = project_state.current_prediction.take() else {
866 return;
867 };
868 let request_id = prediction.prediction.id.to_string();
869 for pending_prediction in mem::take(&mut project_state.pending_predictions) {
870 project_state.cancel_pending_prediction(pending_prediction, cx);
871 }
872
873 let client = self.client.clone();
874 let llm_token = self.llm_token.clone();
875 let app_version = AppVersion::global(cx);
876 cx.spawn(async move |this, cx| {
877 let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") {
878 http_client::Url::parse(&predict_edits_url)?
879 } else {
880 client
881 .http_client()
882 .build_zed_llm_url("/predict_edits/accept", &[])?
883 };
884
885 let response = cx
886 .background_spawn(Self::send_api_request::<()>(
887 move |builder| {
888 let req = builder.uri(url.as_ref()).body(
889 serde_json::to_string(&AcceptEditPredictionBody {
890 request_id: request_id.clone(),
891 })?
892 .into(),
893 );
894 Ok(req?)
895 },
896 client,
897 llm_token,
898 app_version,
899 ))
900 .await;
901
902 Self::handle_api_response(&this, response, cx)?;
903 anyhow::Ok(())
904 })
905 .detach_and_log_err(cx);
906 }
907
908 async fn handle_rejected_predictions(
909 rx: UnboundedReceiver<EditPredictionRejection>,
910 client: Arc<Client>,
911 llm_token: LlmApiToken,
912 app_version: Version,
913 background_executor: BackgroundExecutor,
914 ) {
915 let mut rx = std::pin::pin!(rx.peekable());
916 let mut batched = Vec::new();
917
918 while let Some(rejection) = rx.next().await {
919 batched.push(rejection);
920
921 if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
922 select_biased! {
923 next = rx.as_mut().peek().fuse() => {
924 if next.is_some() {
925 continue;
926 }
927 }
928 () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
929 }
930 }
931
932 let url = client
933 .http_client()
934 .build_zed_llm_url("/predict_edits/reject", &[])
935 .unwrap();
936
937 let flush_count = batched
938 .len()
939 // in case items have accumulated after failure
940 .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
941 let start = batched.len() - flush_count;
942
943 let body = RejectEditPredictionsBodyRef {
944 rejections: &batched[start..],
945 };
946
947 let result = Self::send_api_request::<()>(
948 |builder| {
949 let req = builder
950 .uri(url.as_ref())
951 .body(serde_json::to_string(&body)?.into());
952 anyhow::Ok(req?)
953 },
954 client.clone(),
955 llm_token.clone(),
956 app_version.clone(),
957 )
958 .await;
959
960 if result.log_err().is_some() {
961 batched.drain(start..);
962 }
963 }
964 }
965
966 fn reject_current_prediction(
967 &mut self,
968 reason: EditPredictionRejectReason,
969 project: &Entity<Project>,
970 ) {
971 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
972 project_state.pending_predictions.clear();
973 if let Some(prediction) = project_state.current_prediction.take() {
974 self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown);
975 }
976 };
977 }
978
979 fn did_show_current_prediction(&mut self, project: &Entity<Project>, _cx: &mut Context<Self>) {
980 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
981 if let Some(current_prediction) = project_state.current_prediction.as_mut() {
982 if !current_prediction.was_shown {
983 current_prediction.was_shown = true;
984 self.shown_predictions
985 .push_front(current_prediction.prediction.clone());
986 if self.shown_predictions.len() > 50 {
987 let completion = self.shown_predictions.pop_back().unwrap();
988 self.rated_predictions.remove(&completion.id);
989 }
990 }
991 }
992 }
993 }
994
995 fn reject_prediction(
996 &mut self,
997 prediction_id: EditPredictionId,
998 reason: EditPredictionRejectReason,
999 was_shown: bool,
1000 ) {
1001 match self.edit_prediction_model {
1002 ZetaEditPredictionModel::Zeta1 | ZetaEditPredictionModel::Zeta2 => {}
1003 ZetaEditPredictionModel::Sweep => return,
1004 }
1005
1006 self.reject_predictions_tx
1007 .unbounded_send(EditPredictionRejection {
1008 request_id: prediction_id.to_string(),
1009 reason,
1010 was_shown,
1011 })
1012 .log_err();
1013 }
1014
1015 fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1016 self.projects
1017 .get(&project.entity_id())
1018 .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1019 }
1020
1021 pub fn refresh_prediction_from_buffer(
1022 &mut self,
1023 project: Entity<Project>,
1024 buffer: Entity<Buffer>,
1025 position: language::Anchor,
1026 cx: &mut Context<Self>,
1027 ) {
1028 self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
1029 let Some(request_task) = this
1030 .update(cx, |this, cx| {
1031 this.request_prediction(
1032 &project,
1033 &buffer,
1034 position,
1035 PredictEditsRequestTrigger::Other,
1036 cx,
1037 )
1038 })
1039 .log_err()
1040 else {
1041 return Task::ready(anyhow::Ok(None));
1042 };
1043
1044 cx.spawn(async move |_cx| {
1045 request_task.await.map(|prediction_result| {
1046 prediction_result.map(|prediction_result| {
1047 (
1048 prediction_result,
1049 PredictionRequestedBy::Buffer(buffer.entity_id()),
1050 )
1051 })
1052 })
1053 })
1054 })
1055 }
1056
1057 pub fn refresh_prediction_from_diagnostics(
1058 &mut self,
1059 project: Entity<Project>,
1060 cx: &mut Context<Self>,
1061 ) {
1062 let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1063 return;
1064 };
1065
1066 // Prefer predictions from buffer
1067 if zeta_project.current_prediction.is_some() {
1068 return;
1069 };
1070
1071 self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
1072 let Some(open_buffer_task) = project
1073 .update(cx, |project, cx| {
1074 project
1075 .active_entry()
1076 .and_then(|entry| project.path_for_entry(entry, cx))
1077 .map(|path| project.open_buffer(path, cx))
1078 })
1079 .log_err()
1080 .flatten()
1081 else {
1082 return Task::ready(anyhow::Ok(None));
1083 };
1084
1085 cx.spawn(async move |cx| {
1086 let active_buffer = open_buffer_task.await?;
1087 let snapshot = active_buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
1088
1089 let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1090 active_buffer,
1091 &snapshot,
1092 Default::default(),
1093 Default::default(),
1094 &project,
1095 cx,
1096 )
1097 .await?
1098 else {
1099 return anyhow::Ok(None);
1100 };
1101
1102 let Some(prediction_result) = this
1103 .update(cx, |this, cx| {
1104 this.request_prediction(
1105 &project,
1106 &jump_buffer,
1107 jump_position,
1108 PredictEditsRequestTrigger::Diagnostics,
1109 cx,
1110 )
1111 })?
1112 .await?
1113 else {
1114 return anyhow::Ok(None);
1115 };
1116
1117 this.update(cx, |this, cx| {
1118 Some((
1119 if this
1120 .get_or_init_zeta_project(&project, cx)
1121 .current_prediction
1122 .is_none()
1123 {
1124 prediction_result
1125 } else {
1126 EditPredictionResult {
1127 id: prediction_result.id,
1128 prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1129 }
1130 },
1131 PredictionRequestedBy::DiagnosticsUpdate,
1132 ))
1133 })
1134 })
1135 });
1136 }
1137
1138 #[cfg(not(test))]
1139 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1140 #[cfg(test)]
1141 pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
1142
1143 fn queue_prediction_refresh(
1144 &mut self,
1145 project: Entity<Project>,
1146 throttle_entity: EntityId,
1147 cx: &mut Context<Self>,
1148 do_refresh: impl FnOnce(
1149 WeakEntity<Self>,
1150 &mut AsyncApp,
1151 )
1152 -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1153 + 'static,
1154 ) {
1155 let zeta_project = self.get_or_init_zeta_project(&project, cx);
1156 let pending_prediction_id = zeta_project.next_pending_prediction_id;
1157 zeta_project.next_pending_prediction_id += 1;
1158 let last_request = zeta_project.last_prediction_refresh;
1159
1160 let task = cx.spawn(async move |this, cx| {
1161 if let Some((last_entity, last_timestamp)) = last_request
1162 && throttle_entity == last_entity
1163 && let Some(timeout) =
1164 (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
1165 {
1166 cx.background_executor().timer(timeout).await;
1167 }
1168
1169 // If this task was cancelled before the throttle timeout expired,
1170 // do not perform a request.
1171 let mut is_cancelled = true;
1172 this.update(cx, |this, cx| {
1173 let project_state = this.get_or_init_zeta_project(&project, cx);
1174 if !project_state
1175 .cancelled_predictions
1176 .remove(&pending_prediction_id)
1177 {
1178 project_state.last_prediction_refresh = Some((throttle_entity, Instant::now()));
1179 is_cancelled = false;
1180 }
1181 })
1182 .ok();
1183 if is_cancelled {
1184 return None;
1185 }
1186
1187 let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
1188 let new_prediction_id = new_prediction_result
1189 .as_ref()
1190 .map(|(prediction, _)| prediction.id.clone());
1191
1192 // When a prediction completes, remove it from the pending list, and cancel
1193 // any pending predictions that were enqueued before it.
1194 this.update(cx, |this, cx| {
1195 let zeta_project = this.get_or_init_zeta_project(&project, cx);
1196
1197 let is_cancelled = zeta_project
1198 .cancelled_predictions
1199 .remove(&pending_prediction_id);
1200
1201 let new_current_prediction = if !is_cancelled
1202 && let Some((prediction_result, requested_by)) = new_prediction_result
1203 {
1204 match prediction_result.prediction {
1205 Ok(prediction) => {
1206 let new_prediction = CurrentEditPrediction {
1207 requested_by,
1208 prediction,
1209 was_shown: false,
1210 };
1211
1212 if let Some(current_prediction) =
1213 zeta_project.current_prediction.as_ref()
1214 {
1215 if new_prediction.should_replace_prediction(¤t_prediction, cx)
1216 {
1217 this.reject_current_prediction(
1218 EditPredictionRejectReason::Replaced,
1219 &project,
1220 );
1221
1222 Some(new_prediction)
1223 } else {
1224 this.reject_prediction(
1225 new_prediction.prediction.id,
1226 EditPredictionRejectReason::CurrentPreferred,
1227 false,
1228 );
1229 None
1230 }
1231 } else {
1232 Some(new_prediction)
1233 }
1234 }
1235 Err(reject_reason) => {
1236 this.reject_prediction(prediction_result.id, reject_reason, false);
1237 None
1238 }
1239 }
1240 } else {
1241 None
1242 };
1243
1244 let zeta_project = this.get_or_init_zeta_project(&project, cx);
1245
1246 if let Some(new_prediction) = new_current_prediction {
1247 zeta_project.current_prediction = Some(new_prediction);
1248 }
1249
1250 let mut pending_predictions = mem::take(&mut zeta_project.pending_predictions);
1251 for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1252 if pending_prediction.id == pending_prediction_id {
1253 pending_predictions.remove(ix);
1254 for pending_prediction in pending_predictions.drain(0..ix) {
1255 zeta_project.cancel_pending_prediction(pending_prediction, cx)
1256 }
1257 break;
1258 }
1259 }
1260 this.get_or_init_zeta_project(&project, cx)
1261 .pending_predictions = pending_predictions;
1262 cx.notify();
1263 })
1264 .ok();
1265
1266 new_prediction_id
1267 });
1268
1269 if zeta_project.pending_predictions.len() <= 1 {
1270 zeta_project.pending_predictions.push(PendingPrediction {
1271 id: pending_prediction_id,
1272 task,
1273 });
1274 } else if zeta_project.pending_predictions.len() == 2 {
1275 let pending_prediction = zeta_project.pending_predictions.pop().unwrap();
1276 zeta_project.pending_predictions.push(PendingPrediction {
1277 id: pending_prediction_id,
1278 task,
1279 });
1280 zeta_project.cancel_pending_prediction(pending_prediction, cx);
1281 }
1282 }
1283
1284 pub fn request_prediction(
1285 &mut self,
1286 project: &Entity<Project>,
1287 active_buffer: &Entity<Buffer>,
1288 position: language::Anchor,
1289 trigger: PredictEditsRequestTrigger,
1290 cx: &mut Context<Self>,
1291 ) -> Task<Result<Option<EditPredictionResult>>> {
1292 self.request_prediction_internal(
1293 project.clone(),
1294 active_buffer.clone(),
1295 position,
1296 trigger,
1297 cx.has_flag::<Zeta2FeatureFlag>(),
1298 cx,
1299 )
1300 }
1301
1302 fn request_prediction_internal(
1303 &mut self,
1304 project: Entity<Project>,
1305 active_buffer: Entity<Buffer>,
1306 position: language::Anchor,
1307 trigger: PredictEditsRequestTrigger,
1308 allow_jump: bool,
1309 cx: &mut Context<Self>,
1310 ) -> Task<Result<Option<EditPredictionResult>>> {
1311 const DIAGNOSTIC_LINES_RANGE: u32 = 20;
1312
1313 self.get_or_init_zeta_project(&project, cx);
1314 let zeta_project = self.projects.get(&project.entity_id()).unwrap();
1315 let events = zeta_project.events(cx);
1316 let has_events = !events.is_empty();
1317
1318 let snapshot = active_buffer.read(cx).snapshot();
1319 let cursor_point = position.to_point(&snapshot);
1320 let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1321 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1322 let diagnostic_search_range =
1323 Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1324
1325 let task = match self.edit_prediction_model {
1326 ZetaEditPredictionModel::Zeta1 => request_prediction_with_zeta1(
1327 self,
1328 &project,
1329 &active_buffer,
1330 snapshot.clone(),
1331 position,
1332 events,
1333 trigger,
1334 cx,
1335 ),
1336 ZetaEditPredictionModel::Zeta2 => self.request_prediction_with_zeta2(
1337 &project,
1338 &active_buffer,
1339 snapshot.clone(),
1340 position,
1341 events,
1342 trigger,
1343 cx,
1344 ),
1345 ZetaEditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(
1346 &project,
1347 &active_buffer,
1348 snapshot.clone(),
1349 position,
1350 events,
1351 &zeta_project.recent_paths,
1352 diagnostic_search_range.clone(),
1353 cx,
1354 ),
1355 };
1356
1357 cx.spawn(async move |this, cx| {
1358 let prediction = task.await?;
1359
1360 if prediction.is_none() && allow_jump {
1361 let cursor_point = position.to_point(&snapshot);
1362 if has_events
1363 && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1364 active_buffer.clone(),
1365 &snapshot,
1366 diagnostic_search_range,
1367 cursor_point,
1368 &project,
1369 cx,
1370 )
1371 .await?
1372 {
1373 return this
1374 .update(cx, |this, cx| {
1375 this.request_prediction_internal(
1376 project,
1377 jump_buffer,
1378 jump_position,
1379 trigger,
1380 false,
1381 cx,
1382 )
1383 })?
1384 .await;
1385 }
1386
1387 return anyhow::Ok(None);
1388 }
1389
1390 Ok(prediction)
1391 })
1392 }
1393
1394 async fn next_diagnostic_location(
1395 active_buffer: Entity<Buffer>,
1396 active_buffer_snapshot: &BufferSnapshot,
1397 active_buffer_diagnostic_search_range: Range<Point>,
1398 active_buffer_cursor_point: Point,
1399 project: &Entity<Project>,
1400 cx: &mut AsyncApp,
1401 ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1402 // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
1403 let mut jump_location = active_buffer_snapshot
1404 .diagnostic_groups(None)
1405 .into_iter()
1406 .filter_map(|(_, group)| {
1407 let range = &group.entries[group.primary_ix]
1408 .range
1409 .to_point(&active_buffer_snapshot);
1410 if range.overlaps(&active_buffer_diagnostic_search_range) {
1411 None
1412 } else {
1413 Some(range.start)
1414 }
1415 })
1416 .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
1417 .map(|position| {
1418 (
1419 active_buffer.clone(),
1420 active_buffer_snapshot.anchor_before(position),
1421 )
1422 });
1423
1424 if jump_location.is_none() {
1425 let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
1426 let file = buffer.file()?;
1427
1428 Some(ProjectPath {
1429 worktree_id: file.worktree_id(cx),
1430 path: file.path().clone(),
1431 })
1432 })?;
1433
1434 let buffer_task = project.update(cx, |project, cx| {
1435 let (path, _, _) = project
1436 .diagnostic_summaries(false, cx)
1437 .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
1438 .max_by_key(|(path, _, _)| {
1439 // find the buffer with errors that shares most parent directories
1440 path.path
1441 .components()
1442 .zip(
1443 active_buffer_path
1444 .as_ref()
1445 .map(|p| p.path.components())
1446 .unwrap_or_default(),
1447 )
1448 .take_while(|(a, b)| a == b)
1449 .count()
1450 })?;
1451
1452 Some(project.open_buffer(path, cx))
1453 })?;
1454
1455 if let Some(buffer_task) = buffer_task {
1456 let closest_buffer = buffer_task.await?;
1457
1458 jump_location = closest_buffer
1459 .read_with(cx, |buffer, _cx| {
1460 buffer
1461 .buffer_diagnostics(None)
1462 .into_iter()
1463 .min_by_key(|entry| entry.diagnostic.severity)
1464 .map(|entry| entry.range.start)
1465 })?
1466 .map(|position| (closest_buffer, position));
1467 }
1468 }
1469
1470 anyhow::Ok(jump_location)
1471 }
1472
1473 fn request_prediction_with_zeta2(
1474 &mut self,
1475 project: &Entity<Project>,
1476 active_buffer: &Entity<Buffer>,
1477 active_snapshot: BufferSnapshot,
1478 position: language::Anchor,
1479 events: Vec<Arc<Event>>,
1480 trigger: PredictEditsRequestTrigger,
1481 cx: &mut Context<Self>,
1482 ) -> Task<Result<Option<EditPredictionResult>>> {
1483 let project_state = self.projects.get(&project.entity_id());
1484
1485 let index_state = project_state.and_then(|state| {
1486 state
1487 .syntax_index
1488 .as_ref()
1489 .map(|syntax_index| syntax_index.read_with(cx, |index, _cx| index.state().clone()))
1490 });
1491 let options = self.options.clone();
1492 let buffer_snapshotted_at = Instant::now();
1493 let Some(excerpt_path) = active_snapshot
1494 .file()
1495 .map(|path| -> Arc<Path> { path.full_path(cx).into() })
1496 else {
1497 return Task::ready(Err(anyhow!("No file path for excerpt")));
1498 };
1499 let client = self.client.clone();
1500 let llm_token = self.llm_token.clone();
1501 let app_version = AppVersion::global(cx);
1502 let worktree_snapshots = project
1503 .read(cx)
1504 .worktrees(cx)
1505 .map(|worktree| worktree.read(cx).snapshot())
1506 .collect::<Vec<_>>();
1507 let debug_tx = self.debug_tx.clone();
1508
1509 let diagnostics = active_snapshot.diagnostic_sets().clone();
1510
1511 let file = active_buffer.read(cx).file();
1512 let parent_abs_path = project::File::from_dyn(file).and_then(|f| {
1513 let mut path = f.worktree.read(cx).absolutize(&f.path);
1514 if path.pop() { Some(path) } else { None }
1515 });
1516
1517 // TODO data collection
1518 let can_collect_data = file
1519 .as_ref()
1520 .map_or(false, |file| self.can_collect_file(project, file, cx));
1521
1522 let empty_context_files = HashMap::default();
1523 let context_files = project_state
1524 .and_then(|project_state| project_state.context.as_ref())
1525 .unwrap_or(&empty_context_files);
1526
1527 #[cfg(feature = "eval-support")]
1528 let parsed_fut = futures::future::join_all(
1529 context_files
1530 .keys()
1531 .map(|buffer| buffer.read(cx).parsing_idle()),
1532 );
1533
1534 let mut included_files = context_files
1535 .iter()
1536 .filter_map(|(buffer_entity, ranges)| {
1537 let buffer = buffer_entity.read(cx);
1538 Some((
1539 buffer_entity.clone(),
1540 buffer.snapshot(),
1541 buffer.file()?.full_path(cx).into(),
1542 ranges.clone(),
1543 ))
1544 })
1545 .collect::<Vec<_>>();
1546
1547 included_files.sort_by(|(_, _, path_a, ranges_a), (_, _, path_b, ranges_b)| {
1548 (path_a, ranges_a.len()).cmp(&(path_b, ranges_b.len()))
1549 });
1550
1551 #[cfg(feature = "eval-support")]
1552 let eval_cache = self.eval_cache.clone();
1553
1554 let request_task = cx.background_spawn({
1555 let active_buffer = active_buffer.clone();
1556 async move {
1557 #[cfg(feature = "eval-support")]
1558 parsed_fut.await;
1559
1560 let index_state = if let Some(index_state) = index_state {
1561 Some(index_state.lock_owned().await)
1562 } else {
1563 None
1564 };
1565
1566 let cursor_offset = position.to_offset(&active_snapshot);
1567 let cursor_point = cursor_offset.to_point(&active_snapshot);
1568
1569 let before_retrieval = Instant::now();
1570
1571 let (diagnostic_groups, diagnostic_groups_truncated) =
1572 Self::gather_nearby_diagnostics(
1573 cursor_offset,
1574 &diagnostics,
1575 &active_snapshot,
1576 options.max_diagnostic_bytes,
1577 );
1578
1579 let cloud_request = match options.context {
1580 ContextMode::Agentic(context_options) => {
1581 let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
1582 cursor_point,
1583 &active_snapshot,
1584 &context_options.excerpt,
1585 index_state.as_deref(),
1586 ) else {
1587 return Ok((None, None));
1588 };
1589
1590 let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start)
1591 ..active_snapshot.anchor_before(excerpt.range.end);
1592
1593 if let Some(buffer_ix) =
1594 included_files.iter().position(|(_, snapshot, _, _)| {
1595 snapshot.remote_id() == active_snapshot.remote_id()
1596 })
1597 {
1598 let (_, buffer, _, ranges) = &mut included_files[buffer_ix];
1599 ranges.push(excerpt_anchor_range);
1600 retrieval_search::merge_anchor_ranges(ranges, buffer);
1601 let last_ix = included_files.len() - 1;
1602 included_files.swap(buffer_ix, last_ix);
1603 } else {
1604 included_files.push((
1605 active_buffer.clone(),
1606 active_snapshot.clone(),
1607 excerpt_path.clone(),
1608 vec![excerpt_anchor_range],
1609 ));
1610 }
1611
1612 let included_files = included_files
1613 .iter()
1614 .map(|(_, snapshot, path, ranges)| {
1615 let ranges = ranges
1616 .iter()
1617 .map(|range| {
1618 let point_range = range.to_point(&snapshot);
1619 Line(point_range.start.row)..Line(point_range.end.row)
1620 })
1621 .collect::<Vec<_>>();
1622 let excerpts = assemble_excerpts(&snapshot, ranges);
1623 predict_edits_v3::IncludedFile {
1624 path: path.clone(),
1625 max_row: Line(snapshot.max_point().row),
1626 excerpts,
1627 }
1628 })
1629 .collect::<Vec<_>>();
1630
1631 predict_edits_v3::PredictEditsRequest {
1632 excerpt_path,
1633 excerpt: String::new(),
1634 excerpt_line_range: Line(0)..Line(0),
1635 excerpt_range: 0..0,
1636 cursor_point: predict_edits_v3::Point {
1637 line: predict_edits_v3::Line(cursor_point.row),
1638 column: cursor_point.column,
1639 },
1640 included_files,
1641 referenced_declarations: vec![],
1642 events,
1643 can_collect_data,
1644 diagnostic_groups,
1645 diagnostic_groups_truncated,
1646 debug_info: debug_tx.is_some(),
1647 prompt_max_bytes: Some(options.max_prompt_bytes),
1648 prompt_format: options.prompt_format,
1649 // TODO [zeta2]
1650 signatures: vec![],
1651 excerpt_parent: None,
1652 git_info: None,
1653 trigger,
1654 }
1655 }
1656 ContextMode::Syntax(context_options) => {
1657 let Some(context) = EditPredictionContext::gather_context(
1658 cursor_point,
1659 &active_snapshot,
1660 parent_abs_path.as_deref(),
1661 &context_options,
1662 index_state.as_deref(),
1663 ) else {
1664 return Ok((None, None));
1665 };
1666
1667 make_syntax_context_cloud_request(
1668 excerpt_path,
1669 context,
1670 events,
1671 can_collect_data,
1672 diagnostic_groups,
1673 diagnostic_groups_truncated,
1674 None,
1675 debug_tx.is_some(),
1676 &worktree_snapshots,
1677 index_state.as_deref(),
1678 Some(options.max_prompt_bytes),
1679 options.prompt_format,
1680 trigger,
1681 )
1682 }
1683 };
1684
1685 let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request);
1686
1687 let inputs = EditPredictionInputs {
1688 included_files: cloud_request.included_files,
1689 events: cloud_request.events,
1690 cursor_point: cloud_request.cursor_point,
1691 cursor_path: cloud_request.excerpt_path,
1692 };
1693
1694 let retrieval_time = Instant::now() - before_retrieval;
1695
1696 let debug_response_tx = if let Some(debug_tx) = &debug_tx {
1697 let (response_tx, response_rx) = oneshot::channel();
1698
1699 debug_tx
1700 .unbounded_send(ZetaDebugInfo::EditPredictionRequested(
1701 ZetaEditPredictionDebugInfo {
1702 inputs: inputs.clone(),
1703 retrieval_time,
1704 buffer: active_buffer.downgrade(),
1705 local_prompt: match prompt_result.as_ref() {
1706 Ok((prompt, _)) => Ok(prompt.clone()),
1707 Err(err) => Err(err.to_string()),
1708 },
1709 position,
1710 response_rx,
1711 },
1712 ))
1713 .ok();
1714 Some(response_tx)
1715 } else {
1716 None
1717 };
1718
1719 if cfg!(debug_assertions) && env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
1720 if let Some(debug_response_tx) = debug_response_tx {
1721 debug_response_tx
1722 .send((Err("Request skipped".to_string()), Duration::ZERO))
1723 .ok();
1724 }
1725 anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
1726 }
1727
1728 let (prompt, _) = prompt_result?;
1729 let generation_params =
1730 cloud_zeta2_prompt::generation_params(cloud_request.prompt_format);
1731 let request = open_ai::Request {
1732 model: EDIT_PREDICTIONS_MODEL_ID.clone(),
1733 messages: vec![open_ai::RequestMessage::User {
1734 content: open_ai::MessageContent::Plain(prompt),
1735 }],
1736 stream: false,
1737 max_completion_tokens: None,
1738 stop: generation_params.stop.unwrap_or_default(),
1739 temperature: generation_params.temperature.unwrap_or(0.7),
1740 tool_choice: None,
1741 parallel_tool_calls: None,
1742 tools: vec![],
1743 prompt_cache_key: None,
1744 reasoning_effort: None,
1745 };
1746
1747 log::trace!("Sending edit prediction request");
1748
1749 let before_request = Instant::now();
1750 let response = Self::send_raw_llm_request(
1751 request,
1752 client,
1753 llm_token,
1754 app_version,
1755 #[cfg(feature = "eval-support")]
1756 eval_cache,
1757 #[cfg(feature = "eval-support")]
1758 EvalCacheEntryKind::Prediction,
1759 )
1760 .await;
1761 let received_response_at = Instant::now();
1762 let request_time = received_response_at - before_request;
1763
1764 log::trace!("Got edit prediction response");
1765
1766 if let Some(debug_response_tx) = debug_response_tx {
1767 debug_response_tx
1768 .send((
1769 response
1770 .as_ref()
1771 .map_err(|err| err.to_string())
1772 .map(|response| response.0.clone()),
1773 request_time,
1774 ))
1775 .ok();
1776 }
1777
1778 let (res, usage) = response?;
1779 let request_id = EditPredictionId(res.id.clone().into());
1780 let Some(mut output_text) = text_from_response(res) else {
1781 return Ok((Some((request_id, None)), usage));
1782 };
1783
1784 if output_text.contains(CURSOR_MARKER) {
1785 log::trace!("Stripping out {CURSOR_MARKER} from response");
1786 output_text = output_text.replace(CURSOR_MARKER, "");
1787 }
1788
1789 let get_buffer_from_context = |path: &Path| {
1790 included_files
1791 .iter()
1792 .find_map(|(_, buffer, probe_path, ranges)| {
1793 if probe_path.as_ref() == path {
1794 Some((buffer, ranges.as_slice()))
1795 } else {
1796 None
1797 }
1798 })
1799 };
1800
1801 let (edited_buffer_snapshot, edits) = match options.prompt_format {
1802 PromptFormat::NumLinesUniDiff => {
1803 // TODO: Implement parsing of multi-file diffs
1804 crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
1805 }
1806 PromptFormat::Minimal
1807 | PromptFormat::MinimalQwen
1808 | PromptFormat::SeedCoder1120 => {
1809 if output_text.contains("--- a/\n+++ b/\nNo edits") {
1810 let edits = vec![];
1811 (&active_snapshot, edits)
1812 } else {
1813 crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
1814 }
1815 }
1816 PromptFormat::OldTextNewText => {
1817 crate::xml_edits::parse_xml_edits(&output_text, get_buffer_from_context)
1818 .await?
1819 }
1820 _ => {
1821 bail!("unsupported prompt format {}", options.prompt_format)
1822 }
1823 };
1824
1825 let edited_buffer = included_files
1826 .iter()
1827 .find_map(|(buffer, snapshot, _, _)| {
1828 if snapshot.remote_id() == edited_buffer_snapshot.remote_id() {
1829 Some(buffer.clone())
1830 } else {
1831 None
1832 }
1833 })
1834 .context("Failed to find buffer in included_buffers")?;
1835
1836 anyhow::Ok((
1837 Some((
1838 request_id,
1839 Some((
1840 inputs,
1841 edited_buffer,
1842 edited_buffer_snapshot.clone(),
1843 edits,
1844 received_response_at,
1845 )),
1846 )),
1847 usage,
1848 ))
1849 }
1850 });
1851
1852 cx.spawn({
1853 async move |this, cx| {
1854 let Some((id, prediction)) =
1855 Self::handle_api_response(&this, request_task.await, cx)?
1856 else {
1857 return Ok(None);
1858 };
1859
1860 let Some((
1861 inputs,
1862 edited_buffer,
1863 edited_buffer_snapshot,
1864 edits,
1865 received_response_at,
1866 )) = prediction
1867 else {
1868 return Ok(Some(EditPredictionResult {
1869 id,
1870 prediction: Err(EditPredictionRejectReason::Empty),
1871 }));
1872 };
1873
1874 // TODO telemetry: duration, etc
1875 Ok(Some(
1876 EditPredictionResult::new(
1877 id,
1878 &edited_buffer,
1879 &edited_buffer_snapshot,
1880 edits.into(),
1881 buffer_snapshotted_at,
1882 received_response_at,
1883 inputs,
1884 cx,
1885 )
1886 .await,
1887 ))
1888 }
1889 })
1890 }
1891
1892 async fn send_raw_llm_request(
1893 request: open_ai::Request,
1894 client: Arc<Client>,
1895 llm_token: LlmApiToken,
1896 app_version: Version,
1897 #[cfg(feature = "eval-support")] eval_cache: Option<Arc<dyn EvalCache>>,
1898 #[cfg(feature = "eval-support")] eval_cache_kind: EvalCacheEntryKind,
1899 ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
1900 let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
1901 http_client::Url::parse(&predict_edits_url)?
1902 } else {
1903 client
1904 .http_client()
1905 .build_zed_llm_url("/predict_edits/raw", &[])?
1906 };
1907
1908 #[cfg(feature = "eval-support")]
1909 let cache_key = if let Some(cache) = eval_cache {
1910 use collections::FxHasher;
1911 use std::hash::{Hash, Hasher};
1912
1913 let mut hasher = FxHasher::default();
1914 url.hash(&mut hasher);
1915 let request_str = serde_json::to_string_pretty(&request)?;
1916 request_str.hash(&mut hasher);
1917 let hash = hasher.finish();
1918
1919 let key = (eval_cache_kind, hash);
1920 if let Some(response_str) = cache.read(key) {
1921 return Ok((serde_json::from_str(&response_str)?, None));
1922 }
1923
1924 Some((cache, request_str, key))
1925 } else {
1926 None
1927 };
1928
1929 let (response, usage) = Self::send_api_request(
1930 |builder| {
1931 let req = builder
1932 .uri(url.as_ref())
1933 .body(serde_json::to_string(&request)?.into());
1934 Ok(req?)
1935 },
1936 client,
1937 llm_token,
1938 app_version,
1939 )
1940 .await?;
1941
1942 #[cfg(feature = "eval-support")]
1943 if let Some((cache, request, key)) = cache_key {
1944 cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
1945 }
1946
1947 Ok((response, usage))
1948 }
1949
1950 fn handle_api_response<T>(
1951 this: &WeakEntity<Self>,
1952 response: Result<(T, Option<EditPredictionUsage>)>,
1953 cx: &mut gpui::AsyncApp,
1954 ) -> Result<T> {
1955 match response {
1956 Ok((data, usage)) => {
1957 if let Some(usage) = usage {
1958 this.update(cx, |this, cx| {
1959 this.user_store.update(cx, |user_store, cx| {
1960 user_store.update_edit_prediction_usage(usage, cx);
1961 });
1962 })
1963 .ok();
1964 }
1965 Ok(data)
1966 }
1967 Err(err) => {
1968 if err.is::<ZedUpdateRequiredError>() {
1969 cx.update(|cx| {
1970 this.update(cx, |this, _cx| {
1971 this.update_required = true;
1972 })
1973 .ok();
1974
1975 let error_message: SharedString = err.to_string().into();
1976 show_app_notification(
1977 NotificationId::unique::<ZedUpdateRequiredError>(),
1978 cx,
1979 move |cx| {
1980 cx.new(|cx| {
1981 ErrorMessagePrompt::new(error_message.clone(), cx)
1982 .with_link_button("Update Zed", "https://zed.dev/releases")
1983 })
1984 },
1985 );
1986 })
1987 .ok();
1988 }
1989 Err(err)
1990 }
1991 }
1992 }
1993
1994 async fn send_api_request<Res>(
1995 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1996 client: Arc<Client>,
1997 llm_token: LlmApiToken,
1998 app_version: Version,
1999 ) -> Result<(Res, Option<EditPredictionUsage>)>
2000 where
2001 Res: DeserializeOwned,
2002 {
2003 let http_client = client.http_client();
2004 let mut token = llm_token.acquire(&client).await?;
2005 let mut did_retry = false;
2006
2007 loop {
2008 let request_builder = http_client::Request::builder().method(Method::POST);
2009
2010 let request = build(
2011 request_builder
2012 .header("Content-Type", "application/json")
2013 .header("Authorization", format!("Bearer {}", token))
2014 .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
2015 )?;
2016
2017 let mut response = http_client.send(request).await?;
2018
2019 if let Some(minimum_required_version) = response
2020 .headers()
2021 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2022 .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2023 {
2024 anyhow::ensure!(
2025 app_version >= minimum_required_version,
2026 ZedUpdateRequiredError {
2027 minimum_version: minimum_required_version
2028 }
2029 );
2030 }
2031
2032 if response.status().is_success() {
2033 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2034
2035 let mut body = Vec::new();
2036 response.body_mut().read_to_end(&mut body).await?;
2037 return Ok((serde_json::from_slice(&body)?, usage));
2038 } else if !did_retry
2039 && response
2040 .headers()
2041 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
2042 .is_some()
2043 {
2044 did_retry = true;
2045 token = llm_token.refresh(&client).await?;
2046 } else {
2047 let mut body = String::new();
2048 response.body_mut().read_to_string(&mut body).await?;
2049 anyhow::bail!(
2050 "Request failed with status: {:?}\nBody: {}",
2051 response.status(),
2052 body
2053 );
2054 }
2055 }
2056 }
2057
2058 pub const CONTEXT_RETRIEVAL_IDLE_DURATION: Duration = Duration::from_secs(10);
2059 pub const CONTEXT_RETRIEVAL_DEBOUNCE_DURATION: Duration = Duration::from_secs(3);
2060
2061 // Refresh the related excerpts when the user just beguns editing after
2062 // an idle period, and after they pause editing.
2063 fn refresh_context_if_needed(
2064 &mut self,
2065 project: &Entity<Project>,
2066 buffer: &Entity<language::Buffer>,
2067 cursor_position: language::Anchor,
2068 cx: &mut Context<Self>,
2069 ) {
2070 if !matches!(self.edit_prediction_model, ZetaEditPredictionModel::Zeta2) {
2071 return;
2072 }
2073
2074 if !matches!(&self.options().context, ContextMode::Agentic { .. }) {
2075 return;
2076 }
2077
2078 let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
2079 return;
2080 };
2081
2082 let now = Instant::now();
2083 let was_idle = zeta_project
2084 .refresh_context_timestamp
2085 .map_or(true, |timestamp| {
2086 now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION
2087 });
2088 zeta_project.refresh_context_timestamp = Some(now);
2089 zeta_project.refresh_context_debounce_task = Some(cx.spawn({
2090 let buffer = buffer.clone();
2091 let project = project.clone();
2092 async move |this, cx| {
2093 if was_idle {
2094 log::debug!("refetching edit prediction context after idle");
2095 } else {
2096 cx.background_executor()
2097 .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION)
2098 .await;
2099 log::debug!("refetching edit prediction context after pause");
2100 }
2101 this.update(cx, |this, cx| {
2102 let task = this.refresh_context(project.clone(), buffer, cursor_position, cx);
2103
2104 if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
2105 zeta_project.refresh_context_task = Some(task.log_err());
2106 };
2107 })
2108 .ok()
2109 }
2110 }));
2111 }
2112
2113 // Refresh the related excerpts asynchronously. Ensure the task runs to completion,
2114 // and avoid spawning more than one concurrent task.
2115 pub fn refresh_context(
2116 &mut self,
2117 project: Entity<Project>,
2118 buffer: Entity<language::Buffer>,
2119 cursor_position: language::Anchor,
2120 cx: &mut Context<Self>,
2121 ) -> Task<Result<()>> {
2122 let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
2123 return Task::ready(anyhow::Ok(()));
2124 };
2125
2126 let ContextMode::Agentic(options) = &self.options().context else {
2127 return Task::ready(anyhow::Ok(()));
2128 };
2129
2130 let snapshot = buffer.read(cx).snapshot();
2131 let cursor_point = cursor_position.to_point(&snapshot);
2132 let Some(cursor_excerpt) = EditPredictionExcerpt::select_from_buffer(
2133 cursor_point,
2134 &snapshot,
2135 &options.excerpt,
2136 None,
2137 ) else {
2138 return Task::ready(Ok(()));
2139 };
2140
2141 let app_version = AppVersion::global(cx);
2142 let client = self.client.clone();
2143 let llm_token = self.llm_token.clone();
2144 let debug_tx = self.debug_tx.clone();
2145 let current_file_path: Arc<Path> = snapshot
2146 .file()
2147 .map(|f| f.full_path(cx).into())
2148 .unwrap_or_else(|| Path::new("untitled").into());
2149
2150 let prompt = match cloud_zeta2_prompt::retrieval_prompt::build_prompt(
2151 predict_edits_v3::PlanContextRetrievalRequest {
2152 excerpt: cursor_excerpt.text(&snapshot).body,
2153 excerpt_path: current_file_path,
2154 excerpt_line_range: cursor_excerpt.line_range,
2155 cursor_file_max_row: Line(snapshot.max_point().row),
2156 events: zeta_project.events(cx),
2157 },
2158 ) {
2159 Ok(prompt) => prompt,
2160 Err(err) => {
2161 return Task::ready(Err(err));
2162 }
2163 };
2164
2165 if let Some(debug_tx) = &debug_tx {
2166 debug_tx
2167 .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted(
2168 ZetaContextRetrievalStartedDebugInfo {
2169 project: project.clone(),
2170 timestamp: Instant::now(),
2171 search_prompt: prompt.clone(),
2172 },
2173 ))
2174 .ok();
2175 }
2176
2177 pub static TOOL_SCHEMA: LazyLock<(serde_json::Value, String)> = LazyLock::new(|| {
2178 let schema = language_model::tool_schema::root_schema_for::<SearchToolInput>(
2179 language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset,
2180 );
2181
2182 let description = schema
2183 .get("description")
2184 .and_then(|description| description.as_str())
2185 .unwrap()
2186 .to_string();
2187
2188 (schema.into(), description)
2189 });
2190
2191 let (tool_schema, tool_description) = TOOL_SCHEMA.clone();
2192
2193 let request = open_ai::Request {
2194 model: CONTEXT_RETRIEVAL_MODEL_ID.clone(),
2195 messages: vec![open_ai::RequestMessage::User {
2196 content: open_ai::MessageContent::Plain(prompt),
2197 }],
2198 stream: false,
2199 max_completion_tokens: None,
2200 stop: Default::default(),
2201 temperature: 0.7,
2202 tool_choice: None,
2203 parallel_tool_calls: None,
2204 tools: vec![open_ai::ToolDefinition::Function {
2205 function: FunctionDefinition {
2206 name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME.to_string(),
2207 description: Some(tool_description),
2208 parameters: Some(tool_schema),
2209 },
2210 }],
2211 prompt_cache_key: None,
2212 reasoning_effort: None,
2213 };
2214
2215 #[cfg(feature = "eval-support")]
2216 let eval_cache = self.eval_cache.clone();
2217
2218 cx.spawn(async move |this, cx| {
2219 log::trace!("Sending search planning request");
2220 let response = Self::send_raw_llm_request(
2221 request,
2222 client,
2223 llm_token,
2224 app_version,
2225 #[cfg(feature = "eval-support")]
2226 eval_cache.clone(),
2227 #[cfg(feature = "eval-support")]
2228 EvalCacheEntryKind::Context,
2229 )
2230 .await;
2231 let mut response = Self::handle_api_response(&this, response, cx)?;
2232 log::trace!("Got search planning response");
2233
2234 let choice = response
2235 .choices
2236 .pop()
2237 .context("No choices in retrieval response")?;
2238 let open_ai::RequestMessage::Assistant {
2239 content: _,
2240 tool_calls,
2241 } = choice.message
2242 else {
2243 anyhow::bail!("Retrieval response didn't include an assistant message");
2244 };
2245
2246 let mut queries: Vec<SearchToolQuery> = Vec::new();
2247 for tool_call in tool_calls {
2248 let open_ai::ToolCallContent::Function { function } = tool_call.content;
2249 if function.name != cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME {
2250 log::warn!(
2251 "Context retrieval response tried to call an unknown tool: {}",
2252 function.name
2253 );
2254
2255 continue;
2256 }
2257
2258 let input: SearchToolInput = serde_json::from_str(&function.arguments)
2259 .with_context(|| format!("invalid search json {}", &function.arguments))?;
2260 queries.extend(input.queries);
2261 }
2262
2263 if let Some(debug_tx) = &debug_tx {
2264 debug_tx
2265 .unbounded_send(ZetaDebugInfo::SearchQueriesGenerated(
2266 ZetaSearchQueryDebugInfo {
2267 project: project.clone(),
2268 timestamp: Instant::now(),
2269 search_queries: queries.clone(),
2270 },
2271 ))
2272 .ok();
2273 }
2274
2275 log::trace!("Running retrieval search: {queries:#?}");
2276
2277 let related_excerpts_result = retrieval_search::run_retrieval_searches(
2278 queries,
2279 project.clone(),
2280 #[cfg(feature = "eval-support")]
2281 eval_cache,
2282 cx,
2283 )
2284 .await;
2285
2286 log::trace!("Search queries executed");
2287
2288 if let Some(debug_tx) = &debug_tx {
2289 debug_tx
2290 .unbounded_send(ZetaDebugInfo::SearchQueriesExecuted(
2291 ZetaContextRetrievalDebugInfo {
2292 project: project.clone(),
2293 timestamp: Instant::now(),
2294 },
2295 ))
2296 .ok();
2297 }
2298
2299 this.update(cx, |this, _cx| {
2300 let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
2301 return Ok(());
2302 };
2303 zeta_project.refresh_context_task.take();
2304 if let Some(debug_tx) = &this.debug_tx {
2305 debug_tx
2306 .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished(
2307 ZetaContextRetrievalDebugInfo {
2308 project,
2309 timestamp: Instant::now(),
2310 },
2311 ))
2312 .ok();
2313 }
2314 match related_excerpts_result {
2315 Ok(excerpts) => {
2316 zeta_project.context = Some(excerpts);
2317 Ok(())
2318 }
2319 Err(error) => Err(error),
2320 }
2321 })?
2322 })
2323 }
2324
2325 pub fn set_context(
2326 &mut self,
2327 project: Entity<Project>,
2328 context: HashMap<Entity<Buffer>, Vec<Range<Anchor>>>,
2329 ) {
2330 if let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) {
2331 zeta_project.context = Some(context);
2332 }
2333 }
2334
2335 fn gather_nearby_diagnostics(
2336 cursor_offset: usize,
2337 diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
2338 snapshot: &BufferSnapshot,
2339 max_diagnostics_bytes: usize,
2340 ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
2341 // TODO: Could make this more efficient
2342 let mut diagnostic_groups = Vec::new();
2343 for (language_server_id, diagnostics) in diagnostic_sets {
2344 let mut groups = Vec::new();
2345 diagnostics.groups(*language_server_id, &mut groups, &snapshot);
2346 diagnostic_groups.extend(
2347 groups
2348 .into_iter()
2349 .map(|(_, group)| group.resolve::<usize>(&snapshot)),
2350 );
2351 }
2352
2353 // sort by proximity to cursor
2354 diagnostic_groups.sort_by_key(|group| {
2355 let range = &group.entries[group.primary_ix].range;
2356 if range.start >= cursor_offset {
2357 range.start - cursor_offset
2358 } else if cursor_offset >= range.end {
2359 cursor_offset - range.end
2360 } else {
2361 (cursor_offset - range.start).min(range.end - cursor_offset)
2362 }
2363 });
2364
2365 let mut results = Vec::new();
2366 let mut diagnostic_groups_truncated = false;
2367 let mut diagnostics_byte_count = 0;
2368 for group in diagnostic_groups {
2369 let raw_value = serde_json::value::to_raw_value(&group).unwrap();
2370 diagnostics_byte_count += raw_value.get().len();
2371 if diagnostics_byte_count > max_diagnostics_bytes {
2372 diagnostic_groups_truncated = true;
2373 break;
2374 }
2375 results.push(predict_edits_v3::DiagnosticGroup(raw_value));
2376 }
2377
2378 (results, diagnostic_groups_truncated)
2379 }
2380
2381 // TODO: Dedupe with similar code in request_prediction?
2382 pub fn cloud_request_for_zeta_cli(
2383 &mut self,
2384 project: &Entity<Project>,
2385 buffer: &Entity<Buffer>,
2386 position: language::Anchor,
2387 cx: &mut Context<Self>,
2388 ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
2389 let project_state = self.projects.get(&project.entity_id());
2390
2391 let index_state = project_state.and_then(|state| {
2392 state
2393 .syntax_index
2394 .as_ref()
2395 .map(|index| index.read_with(cx, |index, _cx| index.state().clone()))
2396 });
2397 let options = self.options.clone();
2398 let snapshot = buffer.read(cx).snapshot();
2399 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
2400 return Task::ready(Err(anyhow!("No file path for excerpt")));
2401 };
2402 let worktree_snapshots = project
2403 .read(cx)
2404 .worktrees(cx)
2405 .map(|worktree| worktree.read(cx).snapshot())
2406 .collect::<Vec<_>>();
2407
2408 let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
2409 let mut path = f.worktree.read(cx).absolutize(&f.path);
2410 if path.pop() { Some(path) } else { None }
2411 });
2412
2413 cx.background_spawn(async move {
2414 let index_state = if let Some(index_state) = index_state {
2415 Some(index_state.lock_owned().await)
2416 } else {
2417 None
2418 };
2419
2420 let cursor_point = position.to_point(&snapshot);
2421
2422 let debug_info = true;
2423 EditPredictionContext::gather_context(
2424 cursor_point,
2425 &snapshot,
2426 parent_abs_path.as_deref(),
2427 match &options.context {
2428 ContextMode::Agentic(_) => {
2429 // TODO
2430 panic!("Llm mode not supported in zeta cli yet");
2431 }
2432 ContextMode::Syntax(edit_prediction_context_options) => {
2433 edit_prediction_context_options
2434 }
2435 },
2436 index_state.as_deref(),
2437 )
2438 .context("Failed to select excerpt")
2439 .map(|context| {
2440 make_syntax_context_cloud_request(
2441 excerpt_path.into(),
2442 context,
2443 // TODO pass everything
2444 Vec::new(),
2445 false,
2446 Vec::new(),
2447 false,
2448 None,
2449 debug_info,
2450 &worktree_snapshots,
2451 index_state.as_deref(),
2452 Some(options.max_prompt_bytes),
2453 options.prompt_format,
2454 PredictEditsRequestTrigger::Other,
2455 )
2456 })
2457 })
2458 }
2459
2460 pub fn wait_for_initial_indexing(
2461 &mut self,
2462 project: &Entity<Project>,
2463 cx: &mut Context<Self>,
2464 ) -> Task<Result<()>> {
2465 let zeta_project = self.get_or_init_zeta_project(project, cx);
2466 if let Some(syntax_index) = &zeta_project.syntax_index {
2467 syntax_index.read(cx).wait_for_initial_file_indexing(cx)
2468 } else {
2469 Task::ready(Ok(()))
2470 }
2471 }
2472
2473 fn is_file_open_source(
2474 &self,
2475 project: &Entity<Project>,
2476 file: &Arc<dyn File>,
2477 cx: &App,
2478 ) -> bool {
2479 if !file.is_local() || file.is_private() {
2480 return false;
2481 }
2482 let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
2483 return false;
2484 };
2485 zeta_project
2486 .license_detection_watchers
2487 .get(&file.worktree_id(cx))
2488 .as_ref()
2489 .is_some_and(|watcher| watcher.is_project_open_source())
2490 }
2491
2492 fn can_collect_file(&self, project: &Entity<Project>, file: &Arc<dyn File>, cx: &App) -> bool {
2493 self.data_collection_choice.is_enabled() && self.is_file_open_source(project, file, cx)
2494 }
2495
2496 fn can_collect_events(&self, events: &[Arc<Event>]) -> bool {
2497 if !self.data_collection_choice.is_enabled() {
2498 return false;
2499 }
2500 events.iter().all(|event| {
2501 matches!(
2502 event.as_ref(),
2503 Event::BufferChange {
2504 in_open_source_repo: true,
2505 ..
2506 }
2507 )
2508 })
2509 }
2510
2511 fn load_data_collection_choice() -> DataCollectionChoice {
2512 let choice = KEY_VALUE_STORE
2513 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2514 .log_err()
2515 .flatten();
2516
2517 match choice.as_deref() {
2518 Some("true") => DataCollectionChoice::Enabled,
2519 Some("false") => DataCollectionChoice::Disabled,
2520 Some(_) => {
2521 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2522 DataCollectionChoice::NotAnswered
2523 }
2524 None => DataCollectionChoice::NotAnswered,
2525 }
2526 }
2527
2528 pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2529 self.shown_predictions.iter()
2530 }
2531
2532 pub fn shown_completions_len(&self) -> usize {
2533 self.shown_predictions.len()
2534 }
2535
2536 pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2537 self.rated_predictions.contains(id)
2538 }
2539
2540 pub fn rate_prediction(
2541 &mut self,
2542 prediction: &EditPrediction,
2543 rating: EditPredictionRating,
2544 feedback: String,
2545 cx: &mut Context<Self>,
2546 ) {
2547 self.rated_predictions.insert(prediction.id.clone());
2548 telemetry::event!(
2549 "Edit Prediction Rated",
2550 rating,
2551 inputs = prediction.inputs,
2552 output = prediction.edit_preview.as_unified_diff(&prediction.edits),
2553 feedback
2554 );
2555 self.client.telemetry().flush_events().detach();
2556 cx.notify();
2557 }
2558}
2559
2560pub fn text_from_response(mut res: open_ai::Response) -> Option<String> {
2561 let choice = res.choices.pop()?;
2562 let output_text = match choice.message {
2563 open_ai::RequestMessage::Assistant {
2564 content: Some(open_ai::MessageContent::Plain(content)),
2565 ..
2566 } => content,
2567 open_ai::RequestMessage::Assistant {
2568 content: Some(open_ai::MessageContent::Multipart(mut content)),
2569 ..
2570 } => {
2571 if content.is_empty() {
2572 log::error!("No output from Baseten completion response");
2573 return None;
2574 }
2575
2576 match content.remove(0) {
2577 open_ai::MessagePart::Text { text } => text,
2578 open_ai::MessagePart::Image { .. } => {
2579 log::error!("Expected text, got an image");
2580 return None;
2581 }
2582 }
2583 }
2584 _ => {
2585 log::error!("Invalid response message: {:?}", choice.message);
2586 return None;
2587 }
2588 };
2589 Some(output_text)
2590}
2591
2592#[derive(Error, Debug)]
2593#[error(
2594 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2595)]
2596pub struct ZedUpdateRequiredError {
2597 minimum_version: Version,
2598}
2599
2600fn make_syntax_context_cloud_request(
2601 excerpt_path: Arc<Path>,
2602 context: EditPredictionContext,
2603 events: Vec<Arc<predict_edits_v3::Event>>,
2604 can_collect_data: bool,
2605 diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
2606 diagnostic_groups_truncated: bool,
2607 git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
2608 debug_info: bool,
2609 worktrees: &Vec<worktree::Snapshot>,
2610 index_state: Option<&SyntaxIndexState>,
2611 prompt_max_bytes: Option<usize>,
2612 prompt_format: PromptFormat,
2613 trigger: PredictEditsRequestTrigger,
2614) -> predict_edits_v3::PredictEditsRequest {
2615 let mut signatures = Vec::new();
2616 let mut declaration_to_signature_index = HashMap::default();
2617 let mut referenced_declarations = Vec::new();
2618
2619 for snippet in context.declarations {
2620 let project_entry_id = snippet.declaration.project_entry_id();
2621 let Some(path) = worktrees.iter().find_map(|worktree| {
2622 worktree.entry_for_id(project_entry_id).map(|entry| {
2623 let mut full_path = RelPathBuf::new();
2624 full_path.push(worktree.root_name());
2625 full_path.push(&entry.path);
2626 full_path
2627 })
2628 }) else {
2629 continue;
2630 };
2631
2632 let parent_index = index_state.and_then(|index_state| {
2633 snippet.declaration.parent().and_then(|parent| {
2634 add_signature(
2635 parent,
2636 &mut declaration_to_signature_index,
2637 &mut signatures,
2638 index_state,
2639 )
2640 })
2641 });
2642
2643 let (text, text_is_truncated) = snippet.declaration.item_text();
2644 referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
2645 path: path.as_std_path().into(),
2646 text: text.into(),
2647 range: snippet.declaration.item_line_range(),
2648 text_is_truncated,
2649 signature_range: snippet.declaration.signature_range_in_item_text(),
2650 parent_index,
2651 signature_score: snippet.score(DeclarationStyle::Signature),
2652 declaration_score: snippet.score(DeclarationStyle::Declaration),
2653 score_components: snippet.components,
2654 });
2655 }
2656
2657 let excerpt_parent = index_state.and_then(|index_state| {
2658 context
2659 .excerpt
2660 .parent_declarations
2661 .last()
2662 .and_then(|(parent, _)| {
2663 add_signature(
2664 *parent,
2665 &mut declaration_to_signature_index,
2666 &mut signatures,
2667 index_state,
2668 )
2669 })
2670 });
2671
2672 predict_edits_v3::PredictEditsRequest {
2673 excerpt_path,
2674 excerpt: context.excerpt_text.body,
2675 excerpt_line_range: context.excerpt.line_range,
2676 excerpt_range: context.excerpt.range,
2677 cursor_point: predict_edits_v3::Point {
2678 line: predict_edits_v3::Line(context.cursor_point.row),
2679 column: context.cursor_point.column,
2680 },
2681 referenced_declarations,
2682 included_files: vec![],
2683 signatures,
2684 excerpt_parent,
2685 events,
2686 can_collect_data,
2687 diagnostic_groups,
2688 diagnostic_groups_truncated,
2689 git_info,
2690 debug_info,
2691 prompt_max_bytes,
2692 prompt_format,
2693 trigger,
2694 }
2695}
2696
2697fn add_signature(
2698 declaration_id: DeclarationId,
2699 declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
2700 signatures: &mut Vec<Signature>,
2701 index: &SyntaxIndexState,
2702) -> Option<usize> {
2703 if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
2704 return Some(*signature_index);
2705 }
2706 let Some(parent_declaration) = index.declaration(declaration_id) else {
2707 log::error!("bug: missing parent declaration");
2708 return None;
2709 };
2710 let parent_index = parent_declaration.parent().and_then(|parent| {
2711 add_signature(parent, declaration_to_signature_index, signatures, index)
2712 });
2713 let (text, text_is_truncated) = parent_declaration.signature_text();
2714 let signature_index = signatures.len();
2715 signatures.push(Signature {
2716 text: text.into(),
2717 text_is_truncated,
2718 parent_index,
2719 range: parent_declaration.signature_line_range(),
2720 });
2721 declaration_to_signature_index.insert(declaration_id, signature_index);
2722 Some(signature_index)
2723}
2724
2725#[cfg(feature = "eval-support")]
2726pub type EvalCacheKey = (EvalCacheEntryKind, u64);
2727
2728#[cfg(feature = "eval-support")]
2729#[derive(Debug, Clone, Copy, PartialEq)]
2730pub enum EvalCacheEntryKind {
2731 Context,
2732 Search,
2733 Prediction,
2734}
2735
2736#[cfg(feature = "eval-support")]
2737impl std::fmt::Display for EvalCacheEntryKind {
2738 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2739 match self {
2740 EvalCacheEntryKind::Search => write!(f, "search"),
2741 EvalCacheEntryKind::Context => write!(f, "context"),
2742 EvalCacheEntryKind::Prediction => write!(f, "prediction"),
2743 }
2744 }
2745}
2746
2747#[cfg(feature = "eval-support")]
2748pub trait EvalCache: Send + Sync {
2749 fn read(&self, key: EvalCacheKey) -> Option<String>;
2750 fn write(&self, key: EvalCacheKey, input: &str, value: &str);
2751}
2752
2753#[derive(Debug, Clone, Copy)]
2754pub enum DataCollectionChoice {
2755 NotAnswered,
2756 Enabled,
2757 Disabled,
2758}
2759
2760impl DataCollectionChoice {
2761 pub fn is_enabled(self) -> bool {
2762 match self {
2763 Self::Enabled => true,
2764 Self::NotAnswered | Self::Disabled => false,
2765 }
2766 }
2767
2768 pub fn is_answered(self) -> bool {
2769 match self {
2770 Self::Enabled | Self::Disabled => true,
2771 Self::NotAnswered => false,
2772 }
2773 }
2774
2775 #[must_use]
2776 pub fn toggle(&self) -> DataCollectionChoice {
2777 match self {
2778 Self::Enabled => Self::Disabled,
2779 Self::Disabled => Self::Enabled,
2780 Self::NotAnswered => Self::Enabled,
2781 }
2782 }
2783}
2784
2785impl From<bool> for DataCollectionChoice {
2786 fn from(value: bool) -> Self {
2787 match value {
2788 true => DataCollectionChoice::Enabled,
2789 false => DataCollectionChoice::Disabled,
2790 }
2791 }
2792}
2793
2794struct ZedPredictUpsell;
2795
2796impl Dismissable for ZedPredictUpsell {
2797 const KEY: &'static str = "dismissed-edit-predict-upsell";
2798
2799 fn dismissed() -> bool {
2800 // To make this backwards compatible with older versions of Zed, we
2801 // check if the user has seen the previous Edit Prediction Onboarding
2802 // before, by checking the data collection choice which was written to
2803 // the database once the user clicked on "Accept and Enable"
2804 if KEY_VALUE_STORE
2805 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2806 .log_err()
2807 .is_some_and(|s| s.is_some())
2808 {
2809 return true;
2810 }
2811
2812 KEY_VALUE_STORE
2813 .read_kvp(Self::KEY)
2814 .log_err()
2815 .is_some_and(|s| s.is_some())
2816 }
2817}
2818
2819pub fn should_show_upsell_modal() -> bool {
2820 !ZedPredictUpsell::dismissed()
2821}
2822
2823pub fn init(cx: &mut App) {
2824 feature_gate_predict_edits_actions(cx);
2825
2826 cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2827 workspace.register_action(|workspace, _: &RateCompletions, window, cx| {
2828 if cx.has_flag::<PredictEditsRateCompletionsFeatureFlag>() {
2829 RatePredictionsModal::toggle(workspace, window, cx);
2830 }
2831 });
2832
2833 workspace.register_action(
2834 move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2835 ZedPredictModal::toggle(
2836 workspace,
2837 workspace.user_store().clone(),
2838 workspace.client().clone(),
2839 window,
2840 cx,
2841 )
2842 },
2843 );
2844
2845 workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2846 update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2847 settings
2848 .project
2849 .all_languages
2850 .features
2851 .get_or_insert_default()
2852 .edit_prediction_provider = Some(EditPredictionProvider::None)
2853 });
2854 });
2855 })
2856 .detach();
2857}
2858
2859fn feature_gate_predict_edits_actions(cx: &mut App) {
2860 let rate_completion_action_types = [TypeId::of::<RateCompletions>()];
2861 let reset_onboarding_action_types = [TypeId::of::<ResetOnboarding>()];
2862 let zeta_all_action_types = [
2863 TypeId::of::<RateCompletions>(),
2864 TypeId::of::<ResetOnboarding>(),
2865 zed_actions::OpenZedPredictOnboarding.type_id(),
2866 TypeId::of::<ClearHistory>(),
2867 TypeId::of::<ThumbsUpActivePrediction>(),
2868 TypeId::of::<ThumbsDownActivePrediction>(),
2869 TypeId::of::<NextEdit>(),
2870 TypeId::of::<PreviousEdit>(),
2871 ];
2872
2873 CommandPaletteFilter::update_global(cx, |filter, _cx| {
2874 filter.hide_action_types(&rate_completion_action_types);
2875 filter.hide_action_types(&reset_onboarding_action_types);
2876 filter.hide_action_types(&[zed_actions::OpenZedPredictOnboarding.type_id()]);
2877 });
2878
2879 cx.observe_global::<SettingsStore>(move |cx| {
2880 let is_ai_disabled = DisableAiSettings::get_global(cx).disable_ai;
2881 let has_feature_flag = cx.has_flag::<PredictEditsRateCompletionsFeatureFlag>();
2882
2883 CommandPaletteFilter::update_global(cx, |filter, _cx| {
2884 if is_ai_disabled {
2885 filter.hide_action_types(&zeta_all_action_types);
2886 } else if has_feature_flag {
2887 filter.show_action_types(&rate_completion_action_types);
2888 } else {
2889 filter.hide_action_types(&rate_completion_action_types);
2890 }
2891 });
2892 })
2893 .detach();
2894
2895 cx.observe_flag::<PredictEditsRateCompletionsFeatureFlag, _>(move |is_enabled, cx| {
2896 if !DisableAiSettings::get_global(cx).disable_ai {
2897 if is_enabled {
2898 CommandPaletteFilter::update_global(cx, |filter, _cx| {
2899 filter.show_action_types(&rate_completion_action_types);
2900 });
2901 } else {
2902 CommandPaletteFilter::update_global(cx, |filter, _cx| {
2903 filter.hide_action_types(&rate_completion_action_types);
2904 });
2905 }
2906 }
2907 })
2908 .detach();
2909}
2910
2911#[cfg(test)]
2912mod tests {
2913 use std::{path::Path, sync::Arc, time::Duration};
2914
2915 use client::UserStore;
2916 use clock::FakeSystemClock;
2917 use cloud_llm_client::{
2918 EditPredictionRejectReason, EditPredictionRejection, RejectEditPredictionsBody,
2919 };
2920 use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
2921 use futures::{
2922 AsyncReadExt, StreamExt,
2923 channel::{mpsc, oneshot},
2924 };
2925 use gpui::{
2926 Entity, TestAppContext,
2927 http_client::{FakeHttpClient, Response},
2928 prelude::*,
2929 };
2930 use indoc::indoc;
2931 use language::OffsetRangeExt as _;
2932 use open_ai::Usage;
2933 use pretty_assertions::{assert_eq, assert_matches};
2934 use project::{FakeFs, Project};
2935 use serde_json::json;
2936 use settings::SettingsStore;
2937 use util::path;
2938 use uuid::Uuid;
2939
2940 use crate::{BufferEditPrediction, EditPredictionId, REJECT_REQUEST_DEBOUNCE, Zeta};
2941
2942 #[gpui::test]
2943 async fn test_current_state(cx: &mut TestAppContext) {
2944 let (zeta, mut requests) = init_test(cx);
2945 let fs = FakeFs::new(cx.executor());
2946 fs.insert_tree(
2947 "/root",
2948 json!({
2949 "1.txt": "Hello!\nHow\nBye\n",
2950 "2.txt": "Hola!\nComo\nAdios\n"
2951 }),
2952 )
2953 .await;
2954 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2955
2956 zeta.update(cx, |zeta, cx| {
2957 zeta.register_project(&project, cx);
2958 });
2959
2960 let buffer1 = project
2961 .update(cx, |project, cx| {
2962 let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
2963 project.open_buffer(path, cx)
2964 })
2965 .await
2966 .unwrap();
2967 let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
2968 let position = snapshot1.anchor_before(language::Point::new(1, 3));
2969
2970 // Prediction for current file
2971
2972 zeta.update(cx, |zeta, cx| {
2973 zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
2974 });
2975 let (_request, respond_tx) = requests.predict.next().await.unwrap();
2976
2977 respond_tx
2978 .send(model_response(indoc! {r"
2979 --- a/root/1.txt
2980 +++ b/root/1.txt
2981 @@ ... @@
2982 Hello!
2983 -How
2984 +How are you?
2985 Bye
2986 "}))
2987 .unwrap();
2988
2989 cx.run_until_parked();
2990
2991 zeta.read_with(cx, |zeta, cx| {
2992 let prediction = zeta
2993 .current_prediction_for_buffer(&buffer1, &project, cx)
2994 .unwrap();
2995 assert_matches!(prediction, BufferEditPrediction::Local { .. });
2996 });
2997
2998 // Context refresh
2999 let refresh_task = zeta.update(cx, |zeta, cx| {
3000 zeta.refresh_context(project.clone(), buffer1.clone(), position, cx)
3001 });
3002 let (_request, respond_tx) = requests.predict.next().await.unwrap();
3003 respond_tx
3004 .send(open_ai::Response {
3005 id: Uuid::new_v4().to_string(),
3006 object: "response".into(),
3007 created: 0,
3008 model: "model".into(),
3009 choices: vec![open_ai::Choice {
3010 index: 0,
3011 message: open_ai::RequestMessage::Assistant {
3012 content: None,
3013 tool_calls: vec![open_ai::ToolCall {
3014 id: "search".into(),
3015 content: open_ai::ToolCallContent::Function {
3016 function: open_ai::FunctionContent {
3017 name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME
3018 .to_string(),
3019 arguments: serde_json::to_string(&SearchToolInput {
3020 queries: Box::new([SearchToolQuery {
3021 glob: "root/2.txt".to_string(),
3022 syntax_node: vec![],
3023 content: Some(".".into()),
3024 }]),
3025 })
3026 .unwrap(),
3027 },
3028 },
3029 }],
3030 },
3031 finish_reason: None,
3032 }],
3033 usage: Usage {
3034 prompt_tokens: 0,
3035 completion_tokens: 0,
3036 total_tokens: 0,
3037 },
3038 })
3039 .unwrap();
3040 refresh_task.await.unwrap();
3041
3042 zeta.update(cx, |zeta, _cx| {
3043 zeta.reject_current_prediction(EditPredictionRejectReason::Discarded, &project);
3044 });
3045
3046 // Prediction for another file
3047 zeta.update(cx, |zeta, cx| {
3048 zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
3049 });
3050 let (_request, respond_tx) = requests.predict.next().await.unwrap();
3051 respond_tx
3052 .send(model_response(indoc! {r#"
3053 --- a/root/2.txt
3054 +++ b/root/2.txt
3055 Hola!
3056 -Como
3057 +Como estas?
3058 Adios
3059 "#}))
3060 .unwrap();
3061 cx.run_until_parked();
3062
3063 zeta.read_with(cx, |zeta, cx| {
3064 let prediction = zeta
3065 .current_prediction_for_buffer(&buffer1, &project, cx)
3066 .unwrap();
3067 assert_matches!(
3068 prediction,
3069 BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
3070 );
3071 });
3072
3073 let buffer2 = project
3074 .update(cx, |project, cx| {
3075 let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
3076 project.open_buffer(path, cx)
3077 })
3078 .await
3079 .unwrap();
3080
3081 zeta.read_with(cx, |zeta, cx| {
3082 let prediction = zeta
3083 .current_prediction_for_buffer(&buffer2, &project, cx)
3084 .unwrap();
3085 assert_matches!(prediction, BufferEditPrediction::Local { .. });
3086 });
3087 }
3088
3089 #[gpui::test]
3090 async fn test_simple_request(cx: &mut TestAppContext) {
3091 let (zeta, mut requests) = init_test(cx);
3092 let fs = FakeFs::new(cx.executor());
3093 fs.insert_tree(
3094 "/root",
3095 json!({
3096 "foo.md": "Hello!\nHow\nBye\n"
3097 }),
3098 )
3099 .await;
3100 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3101
3102 let buffer = project
3103 .update(cx, |project, cx| {
3104 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3105 project.open_buffer(path, cx)
3106 })
3107 .await
3108 .unwrap();
3109 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3110 let position = snapshot.anchor_before(language::Point::new(1, 3));
3111
3112 let prediction_task = zeta.update(cx, |zeta, cx| {
3113 zeta.request_prediction(&project, &buffer, position, Default::default(), cx)
3114 });
3115
3116 let (_, respond_tx) = requests.predict.next().await.unwrap();
3117
3118 // TODO Put back when we have a structured request again
3119 // assert_eq!(
3120 // request.excerpt_path.as_ref(),
3121 // Path::new(path!("root/foo.md"))
3122 // );
3123 // assert_eq!(
3124 // request.cursor_point,
3125 // Point {
3126 // line: Line(1),
3127 // column: 3
3128 // }
3129 // );
3130
3131 respond_tx
3132 .send(model_response(indoc! { r"
3133 --- a/root/foo.md
3134 +++ b/root/foo.md
3135 @@ ... @@
3136 Hello!
3137 -How
3138 +How are you?
3139 Bye
3140 "}))
3141 .unwrap();
3142
3143 let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
3144
3145 assert_eq!(prediction.edits.len(), 1);
3146 assert_eq!(
3147 prediction.edits[0].0.to_point(&snapshot).start,
3148 language::Point::new(1, 3)
3149 );
3150 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
3151 }
3152
3153 #[gpui::test]
3154 async fn test_request_events(cx: &mut TestAppContext) {
3155 let (zeta, mut requests) = init_test(cx);
3156 let fs = FakeFs::new(cx.executor());
3157 fs.insert_tree(
3158 "/root",
3159 json!({
3160 "foo.md": "Hello!\n\nBye\n"
3161 }),
3162 )
3163 .await;
3164 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3165
3166 let buffer = project
3167 .update(cx, |project, cx| {
3168 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3169 project.open_buffer(path, cx)
3170 })
3171 .await
3172 .unwrap();
3173
3174 zeta.update(cx, |zeta, cx| {
3175 zeta.register_buffer(&buffer, &project, cx);
3176 });
3177
3178 buffer.update(cx, |buffer, cx| {
3179 buffer.edit(vec![(7..7, "How")], None, cx);
3180 });
3181
3182 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3183 let position = snapshot.anchor_before(language::Point::new(1, 3));
3184
3185 let prediction_task = zeta.update(cx, |zeta, cx| {
3186 zeta.request_prediction(&project, &buffer, position, Default::default(), cx)
3187 });
3188
3189 let (request, respond_tx) = requests.predict.next().await.unwrap();
3190
3191 let prompt = prompt_from_request(&request);
3192 assert!(
3193 prompt.contains(indoc! {"
3194 --- a/root/foo.md
3195 +++ b/root/foo.md
3196 @@ -1,3 +1,3 @@
3197 Hello!
3198 -
3199 +How
3200 Bye
3201 "}),
3202 "{prompt}"
3203 );
3204
3205 respond_tx
3206 .send(model_response(indoc! {r#"
3207 --- a/root/foo.md
3208 +++ b/root/foo.md
3209 @@ ... @@
3210 Hello!
3211 -How
3212 +How are you?
3213 Bye
3214 "#}))
3215 .unwrap();
3216
3217 let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
3218
3219 assert_eq!(prediction.edits.len(), 1);
3220 assert_eq!(
3221 prediction.edits[0].0.to_point(&snapshot).start,
3222 language::Point::new(1, 3)
3223 );
3224 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
3225 }
3226
3227 #[gpui::test]
3228 async fn test_empty_prediction(cx: &mut TestAppContext) {
3229 let (zeta, mut requests) = init_test(cx);
3230 let fs = FakeFs::new(cx.executor());
3231 fs.insert_tree(
3232 "/root",
3233 json!({
3234 "foo.md": "Hello!\nHow\nBye\n"
3235 }),
3236 )
3237 .await;
3238 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3239
3240 let buffer = project
3241 .update(cx, |project, cx| {
3242 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3243 project.open_buffer(path, cx)
3244 })
3245 .await
3246 .unwrap();
3247 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3248 let position = snapshot.anchor_before(language::Point::new(1, 3));
3249
3250 zeta.update(cx, |zeta, cx| {
3251 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3252 });
3253
3254 const NO_OP_DIFF: &str = indoc! { r"
3255 --- a/root/foo.md
3256 +++ b/root/foo.md
3257 @@ ... @@
3258 Hello!
3259 -How
3260 +How
3261 Bye
3262 "};
3263
3264 let (_, respond_tx) = requests.predict.next().await.unwrap();
3265 let response = model_response(NO_OP_DIFF);
3266 let id = response.id.clone();
3267 respond_tx.send(response).unwrap();
3268
3269 cx.run_until_parked();
3270
3271 zeta.read_with(cx, |zeta, cx| {
3272 assert!(
3273 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3274 .is_none()
3275 );
3276 });
3277
3278 // prediction is reported as rejected
3279 let (reject_request, _) = requests.reject.next().await.unwrap();
3280
3281 assert_eq!(
3282 &reject_request.rejections,
3283 &[EditPredictionRejection {
3284 request_id: id,
3285 reason: EditPredictionRejectReason::Empty,
3286 was_shown: false
3287 }]
3288 );
3289 }
3290
3291 #[gpui::test]
3292 async fn test_interpolated_empty(cx: &mut TestAppContext) {
3293 let (zeta, mut requests) = init_test(cx);
3294 let fs = FakeFs::new(cx.executor());
3295 fs.insert_tree(
3296 "/root",
3297 json!({
3298 "foo.md": "Hello!\nHow\nBye\n"
3299 }),
3300 )
3301 .await;
3302 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3303
3304 let buffer = project
3305 .update(cx, |project, cx| {
3306 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3307 project.open_buffer(path, cx)
3308 })
3309 .await
3310 .unwrap();
3311 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3312 let position = snapshot.anchor_before(language::Point::new(1, 3));
3313
3314 zeta.update(cx, |zeta, cx| {
3315 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3316 });
3317
3318 let (_, respond_tx) = requests.predict.next().await.unwrap();
3319
3320 buffer.update(cx, |buffer, cx| {
3321 buffer.set_text("Hello!\nHow are you?\nBye", cx);
3322 });
3323
3324 let response = model_response(SIMPLE_DIFF);
3325 let id = response.id.clone();
3326 respond_tx.send(response).unwrap();
3327
3328 cx.run_until_parked();
3329
3330 zeta.read_with(cx, |zeta, cx| {
3331 assert!(
3332 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3333 .is_none()
3334 );
3335 });
3336
3337 // prediction is reported as rejected
3338 let (reject_request, _) = requests.reject.next().await.unwrap();
3339
3340 assert_eq!(
3341 &reject_request.rejections,
3342 &[EditPredictionRejection {
3343 request_id: id,
3344 reason: EditPredictionRejectReason::InterpolatedEmpty,
3345 was_shown: false
3346 }]
3347 );
3348 }
3349
3350 const SIMPLE_DIFF: &str = indoc! { r"
3351 --- a/root/foo.md
3352 +++ b/root/foo.md
3353 @@ ... @@
3354 Hello!
3355 -How
3356 +How are you?
3357 Bye
3358 "};
3359
3360 #[gpui::test]
3361 async fn test_replace_current(cx: &mut TestAppContext) {
3362 let (zeta, mut requests) = init_test(cx);
3363 let fs = FakeFs::new(cx.executor());
3364 fs.insert_tree(
3365 "/root",
3366 json!({
3367 "foo.md": "Hello!\nHow\nBye\n"
3368 }),
3369 )
3370 .await;
3371 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3372
3373 let buffer = project
3374 .update(cx, |project, cx| {
3375 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3376 project.open_buffer(path, cx)
3377 })
3378 .await
3379 .unwrap();
3380 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3381 let position = snapshot.anchor_before(language::Point::new(1, 3));
3382
3383 zeta.update(cx, |zeta, cx| {
3384 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3385 });
3386
3387 let (_, respond_tx) = requests.predict.next().await.unwrap();
3388 let first_response = model_response(SIMPLE_DIFF);
3389 let first_id = first_response.id.clone();
3390 respond_tx.send(first_response).unwrap();
3391
3392 cx.run_until_parked();
3393
3394 zeta.read_with(cx, |zeta, cx| {
3395 assert_eq!(
3396 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3397 .unwrap()
3398 .id
3399 .0,
3400 first_id
3401 );
3402 });
3403
3404 // a second request is triggered
3405 zeta.update(cx, |zeta, cx| {
3406 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3407 });
3408
3409 let (_, respond_tx) = requests.predict.next().await.unwrap();
3410 let second_response = model_response(SIMPLE_DIFF);
3411 let second_id = second_response.id.clone();
3412 respond_tx.send(second_response).unwrap();
3413
3414 cx.run_until_parked();
3415
3416 zeta.read_with(cx, |zeta, cx| {
3417 // second replaces first
3418 assert_eq!(
3419 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3420 .unwrap()
3421 .id
3422 .0,
3423 second_id
3424 );
3425 });
3426
3427 // first is reported as replaced
3428 let (reject_request, _) = requests.reject.next().await.unwrap();
3429
3430 assert_eq!(
3431 &reject_request.rejections,
3432 &[EditPredictionRejection {
3433 request_id: first_id,
3434 reason: EditPredictionRejectReason::Replaced,
3435 was_shown: false
3436 }]
3437 );
3438 }
3439
3440 #[gpui::test]
3441 async fn test_current_preferred(cx: &mut TestAppContext) {
3442 let (zeta, mut requests) = init_test(cx);
3443 let fs = FakeFs::new(cx.executor());
3444 fs.insert_tree(
3445 "/root",
3446 json!({
3447 "foo.md": "Hello!\nHow\nBye\n"
3448 }),
3449 )
3450 .await;
3451 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3452
3453 let buffer = project
3454 .update(cx, |project, cx| {
3455 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3456 project.open_buffer(path, cx)
3457 })
3458 .await
3459 .unwrap();
3460 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3461 let position = snapshot.anchor_before(language::Point::new(1, 3));
3462
3463 zeta.update(cx, |zeta, cx| {
3464 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3465 });
3466
3467 let (_, respond_tx) = requests.predict.next().await.unwrap();
3468 let first_response = model_response(SIMPLE_DIFF);
3469 let first_id = first_response.id.clone();
3470 respond_tx.send(first_response).unwrap();
3471
3472 cx.run_until_parked();
3473
3474 zeta.read_with(cx, |zeta, cx| {
3475 assert_eq!(
3476 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3477 .unwrap()
3478 .id
3479 .0,
3480 first_id
3481 );
3482 });
3483
3484 // a second request is triggered
3485 zeta.update(cx, |zeta, cx| {
3486 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3487 });
3488
3489 let (_, respond_tx) = requests.predict.next().await.unwrap();
3490 // worse than current prediction
3491 let second_response = model_response(indoc! { r"
3492 --- a/root/foo.md
3493 +++ b/root/foo.md
3494 @@ ... @@
3495 Hello!
3496 -How
3497 +How are
3498 Bye
3499 "});
3500 let second_id = second_response.id.clone();
3501 respond_tx.send(second_response).unwrap();
3502
3503 cx.run_until_parked();
3504
3505 zeta.read_with(cx, |zeta, cx| {
3506 // first is preferred over second
3507 assert_eq!(
3508 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3509 .unwrap()
3510 .id
3511 .0,
3512 first_id
3513 );
3514 });
3515
3516 // second is reported as rejected
3517 let (reject_request, _) = requests.reject.next().await.unwrap();
3518
3519 assert_eq!(
3520 &reject_request.rejections,
3521 &[EditPredictionRejection {
3522 request_id: second_id,
3523 reason: EditPredictionRejectReason::CurrentPreferred,
3524 was_shown: false
3525 }]
3526 );
3527 }
3528
3529 #[gpui::test]
3530 async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
3531 let (zeta, mut requests) = init_test(cx);
3532 let fs = FakeFs::new(cx.executor());
3533 fs.insert_tree(
3534 "/root",
3535 json!({
3536 "foo.md": "Hello!\nHow\nBye\n"
3537 }),
3538 )
3539 .await;
3540 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3541
3542 let buffer = project
3543 .update(cx, |project, cx| {
3544 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3545 project.open_buffer(path, cx)
3546 })
3547 .await
3548 .unwrap();
3549 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3550 let position = snapshot.anchor_before(language::Point::new(1, 3));
3551
3552 // start two refresh tasks
3553 zeta.update(cx, |zeta, cx| {
3554 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3555 });
3556
3557 let (_, respond_first) = requests.predict.next().await.unwrap();
3558
3559 zeta.update(cx, |zeta, cx| {
3560 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3561 });
3562
3563 let (_, respond_second) = requests.predict.next().await.unwrap();
3564
3565 // wait for throttle
3566 cx.run_until_parked();
3567
3568 // second responds first
3569 let second_response = model_response(SIMPLE_DIFF);
3570 let second_id = second_response.id.clone();
3571 respond_second.send(second_response).unwrap();
3572
3573 cx.run_until_parked();
3574
3575 zeta.read_with(cx, |zeta, cx| {
3576 // current prediction is second
3577 assert_eq!(
3578 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3579 .unwrap()
3580 .id
3581 .0,
3582 second_id
3583 );
3584 });
3585
3586 let first_response = model_response(SIMPLE_DIFF);
3587 let first_id = first_response.id.clone();
3588 respond_first.send(first_response).unwrap();
3589
3590 cx.run_until_parked();
3591
3592 zeta.read_with(cx, |zeta, cx| {
3593 // current prediction is still second, since first was cancelled
3594 assert_eq!(
3595 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3596 .unwrap()
3597 .id
3598 .0,
3599 second_id
3600 );
3601 });
3602
3603 // first is reported as rejected
3604 let (reject_request, _) = requests.reject.next().await.unwrap();
3605
3606 cx.run_until_parked();
3607
3608 assert_eq!(
3609 &reject_request.rejections,
3610 &[EditPredictionRejection {
3611 request_id: first_id,
3612 reason: EditPredictionRejectReason::Canceled,
3613 was_shown: false
3614 }]
3615 );
3616 }
3617
3618 #[gpui::test]
3619 async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
3620 let (zeta, mut requests) = init_test(cx);
3621 let fs = FakeFs::new(cx.executor());
3622 fs.insert_tree(
3623 "/root",
3624 json!({
3625 "foo.md": "Hello!\nHow\nBye\n"
3626 }),
3627 )
3628 .await;
3629 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3630
3631 let buffer = project
3632 .update(cx, |project, cx| {
3633 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3634 project.open_buffer(path, cx)
3635 })
3636 .await
3637 .unwrap();
3638 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3639 let position = snapshot.anchor_before(language::Point::new(1, 3));
3640
3641 // start two refresh tasks
3642 zeta.update(cx, |zeta, cx| {
3643 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3644 });
3645
3646 let (_, respond_first) = requests.predict.next().await.unwrap();
3647
3648 zeta.update(cx, |zeta, cx| {
3649 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3650 });
3651
3652 let (_, respond_second) = requests.predict.next().await.unwrap();
3653
3654 // wait for throttle, so requests are sent
3655 cx.run_until_parked();
3656
3657 zeta.update(cx, |zeta, cx| {
3658 // start a third request
3659 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3660
3661 // 2 are pending, so 2nd is cancelled
3662 assert_eq!(
3663 zeta.get_or_init_zeta_project(&project, cx)
3664 .cancelled_predictions
3665 .iter()
3666 .copied()
3667 .collect::<Vec<_>>(),
3668 [1]
3669 );
3670 });
3671
3672 // wait for throttle
3673 cx.run_until_parked();
3674
3675 let (_, respond_third) = requests.predict.next().await.unwrap();
3676
3677 let first_response = model_response(SIMPLE_DIFF);
3678 let first_id = first_response.id.clone();
3679 respond_first.send(first_response).unwrap();
3680
3681 cx.run_until_parked();
3682
3683 zeta.read_with(cx, |zeta, cx| {
3684 // current prediction is first
3685 assert_eq!(
3686 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3687 .unwrap()
3688 .id
3689 .0,
3690 first_id
3691 );
3692 });
3693
3694 let cancelled_response = model_response(SIMPLE_DIFF);
3695 let cancelled_id = cancelled_response.id.clone();
3696 respond_second.send(cancelled_response).unwrap();
3697
3698 cx.run_until_parked();
3699
3700 zeta.read_with(cx, |zeta, cx| {
3701 // current prediction is still first, since second was cancelled
3702 assert_eq!(
3703 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3704 .unwrap()
3705 .id
3706 .0,
3707 first_id
3708 );
3709 });
3710
3711 let third_response = model_response(SIMPLE_DIFF);
3712 let third_response_id = third_response.id.clone();
3713 respond_third.send(third_response).unwrap();
3714
3715 cx.run_until_parked();
3716
3717 zeta.read_with(cx, |zeta, cx| {
3718 // third completes and replaces first
3719 assert_eq!(
3720 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3721 .unwrap()
3722 .id
3723 .0,
3724 third_response_id
3725 );
3726 });
3727
3728 // second is reported as rejected
3729 let (reject_request, _) = requests.reject.next().await.unwrap();
3730
3731 cx.run_until_parked();
3732
3733 assert_eq!(
3734 &reject_request.rejections,
3735 &[
3736 EditPredictionRejection {
3737 request_id: cancelled_id,
3738 reason: EditPredictionRejectReason::Canceled,
3739 was_shown: false
3740 },
3741 EditPredictionRejection {
3742 request_id: first_id,
3743 reason: EditPredictionRejectReason::Replaced,
3744 was_shown: false
3745 }
3746 ]
3747 );
3748 }
3749
3750 #[gpui::test]
3751 async fn test_rejections_flushing(cx: &mut TestAppContext) {
3752 let (zeta, mut requests) = init_test(cx);
3753
3754 zeta.update(cx, |zeta, _cx| {
3755 zeta.reject_prediction(
3756 EditPredictionId("test-1".into()),
3757 EditPredictionRejectReason::Discarded,
3758 false,
3759 );
3760 zeta.reject_prediction(
3761 EditPredictionId("test-2".into()),
3762 EditPredictionRejectReason::Canceled,
3763 true,
3764 );
3765 });
3766
3767 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
3768 cx.run_until_parked();
3769
3770 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
3771 respond_tx.send(()).unwrap();
3772
3773 // batched
3774 assert_eq!(reject_request.rejections.len(), 2);
3775 assert_eq!(
3776 reject_request.rejections[0],
3777 EditPredictionRejection {
3778 request_id: "test-1".to_string(),
3779 reason: EditPredictionRejectReason::Discarded,
3780 was_shown: false
3781 }
3782 );
3783 assert_eq!(
3784 reject_request.rejections[1],
3785 EditPredictionRejection {
3786 request_id: "test-2".to_string(),
3787 reason: EditPredictionRejectReason::Canceled,
3788 was_shown: true
3789 }
3790 );
3791
3792 // Reaching batch size limit sends without debounce
3793 zeta.update(cx, |zeta, _cx| {
3794 for i in 0..70 {
3795 zeta.reject_prediction(
3796 EditPredictionId(format!("batch-{}", i).into()),
3797 EditPredictionRejectReason::Discarded,
3798 false,
3799 );
3800 }
3801 });
3802
3803 // First MAX/2 items are sent immediately
3804 cx.run_until_parked();
3805 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
3806 respond_tx.send(()).unwrap();
3807
3808 assert_eq!(reject_request.rejections.len(), 50);
3809 assert_eq!(reject_request.rejections[0].request_id, "batch-0");
3810 assert_eq!(reject_request.rejections[49].request_id, "batch-49");
3811
3812 // Remaining items are debounced with the next batch
3813 cx.executor().advance_clock(Duration::from_secs(15));
3814 cx.run_until_parked();
3815
3816 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
3817 respond_tx.send(()).unwrap();
3818
3819 assert_eq!(reject_request.rejections.len(), 20);
3820 assert_eq!(reject_request.rejections[0].request_id, "batch-50");
3821 assert_eq!(reject_request.rejections[19].request_id, "batch-69");
3822
3823 // Request failure
3824 zeta.update(cx, |zeta, _cx| {
3825 zeta.reject_prediction(
3826 EditPredictionId("retry-1".into()),
3827 EditPredictionRejectReason::Discarded,
3828 false,
3829 );
3830 });
3831
3832 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
3833 cx.run_until_parked();
3834
3835 let (reject_request, _respond_tx) = requests.reject.next().await.unwrap();
3836 assert_eq!(reject_request.rejections.len(), 1);
3837 assert_eq!(reject_request.rejections[0].request_id, "retry-1");
3838 // Simulate failure
3839 drop(_respond_tx);
3840
3841 // Add another rejection
3842 zeta.update(cx, |zeta, _cx| {
3843 zeta.reject_prediction(
3844 EditPredictionId("retry-2".into()),
3845 EditPredictionRejectReason::Discarded,
3846 false,
3847 );
3848 });
3849
3850 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
3851 cx.run_until_parked();
3852
3853 // Retry should include both the failed item and the new one
3854 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
3855 respond_tx.send(()).unwrap();
3856
3857 assert_eq!(reject_request.rejections.len(), 2);
3858 assert_eq!(reject_request.rejections[0].request_id, "retry-1");
3859 assert_eq!(reject_request.rejections[1].request_id, "retry-2");
3860 }
3861
3862 // Skipped until we start including diagnostics in prompt
3863 // #[gpui::test]
3864 // async fn test_request_diagnostics(cx: &mut TestAppContext) {
3865 // let (zeta, mut req_rx) = init_test(cx);
3866 // let fs = FakeFs::new(cx.executor());
3867 // fs.insert_tree(
3868 // "/root",
3869 // json!({
3870 // "foo.md": "Hello!\nBye"
3871 // }),
3872 // )
3873 // .await;
3874 // let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3875
3876 // let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
3877 // let diagnostic = lsp::Diagnostic {
3878 // range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
3879 // severity: Some(lsp::DiagnosticSeverity::ERROR),
3880 // message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
3881 // ..Default::default()
3882 // };
3883
3884 // project.update(cx, |project, cx| {
3885 // project.lsp_store().update(cx, |lsp_store, cx| {
3886 // // Create some diagnostics
3887 // lsp_store
3888 // .update_diagnostics(
3889 // LanguageServerId(0),
3890 // lsp::PublishDiagnosticsParams {
3891 // uri: path_to_buffer_uri.clone(),
3892 // diagnostics: vec![diagnostic],
3893 // version: None,
3894 // },
3895 // None,
3896 // language::DiagnosticSourceKind::Pushed,
3897 // &[],
3898 // cx,
3899 // )
3900 // .unwrap();
3901 // });
3902 // });
3903
3904 // let buffer = project
3905 // .update(cx, |project, cx| {
3906 // let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3907 // project.open_buffer(path, cx)
3908 // })
3909 // .await
3910 // .unwrap();
3911
3912 // let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3913 // let position = snapshot.anchor_before(language::Point::new(0, 0));
3914
3915 // let _prediction_task = zeta.update(cx, |zeta, cx| {
3916 // zeta.request_prediction(&project, &buffer, position, cx)
3917 // });
3918
3919 // let (request, _respond_tx) = req_rx.next().await.unwrap();
3920
3921 // assert_eq!(request.diagnostic_groups.len(), 1);
3922 // let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
3923 // .unwrap();
3924 // // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
3925 // assert_eq!(
3926 // value,
3927 // json!({
3928 // "entries": [{
3929 // "range": {
3930 // "start": 8,
3931 // "end": 10
3932 // },
3933 // "diagnostic": {
3934 // "source": null,
3935 // "code": null,
3936 // "code_description": null,
3937 // "severity": 1,
3938 // "message": "\"Hello\" deprecated. Use \"Hi\" instead",
3939 // "markdown": null,
3940 // "group_id": 0,
3941 // "is_primary": true,
3942 // "is_disk_based": false,
3943 // "is_unnecessary": false,
3944 // "source_kind": "Pushed",
3945 // "data": null,
3946 // "underline": true
3947 // }
3948 // }],
3949 // "primary_ix": 0
3950 // })
3951 // );
3952 // }
3953
3954 fn model_response(text: &str) -> open_ai::Response {
3955 open_ai::Response {
3956 id: Uuid::new_v4().to_string(),
3957 object: "response".into(),
3958 created: 0,
3959 model: "model".into(),
3960 choices: vec![open_ai::Choice {
3961 index: 0,
3962 message: open_ai::RequestMessage::Assistant {
3963 content: Some(open_ai::MessageContent::Plain(text.to_string())),
3964 tool_calls: vec![],
3965 },
3966 finish_reason: None,
3967 }],
3968 usage: Usage {
3969 prompt_tokens: 0,
3970 completion_tokens: 0,
3971 total_tokens: 0,
3972 },
3973 }
3974 }
3975
3976 fn prompt_from_request(request: &open_ai::Request) -> &str {
3977 assert_eq!(request.messages.len(), 1);
3978 let open_ai::RequestMessage::User {
3979 content: open_ai::MessageContent::Plain(content),
3980 ..
3981 } = &request.messages[0]
3982 else {
3983 panic!(
3984 "Request does not have single user message of type Plain. {:#?}",
3985 request
3986 );
3987 };
3988 content
3989 }
3990
3991 struct RequestChannels {
3992 predict: mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender<open_ai::Response>)>,
3993 reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>,
3994 }
3995
3996 fn init_test(cx: &mut TestAppContext) -> (Entity<Zeta>, RequestChannels) {
3997 cx.update(move |cx| {
3998 let settings_store = SettingsStore::test(cx);
3999 cx.set_global(settings_store);
4000 zlog::init_test();
4001
4002 let (predict_req_tx, predict_req_rx) = mpsc::unbounded();
4003 let (reject_req_tx, reject_req_rx) = mpsc::unbounded();
4004
4005 let http_client = FakeHttpClient::create({
4006 move |req| {
4007 let uri = req.uri().path().to_string();
4008 let mut body = req.into_body();
4009 let predict_req_tx = predict_req_tx.clone();
4010 let reject_req_tx = reject_req_tx.clone();
4011 async move {
4012 let resp = match uri.as_str() {
4013 "/client/llm_tokens" => serde_json::to_string(&json!({
4014 "token": "test"
4015 }))
4016 .unwrap(),
4017 "/predict_edits/raw" => {
4018 let mut buf = Vec::new();
4019 body.read_to_end(&mut buf).await.ok();
4020 let req = serde_json::from_slice(&buf).unwrap();
4021
4022 let (res_tx, res_rx) = oneshot::channel();
4023 predict_req_tx.unbounded_send((req, res_tx)).unwrap();
4024 serde_json::to_string(&res_rx.await?).unwrap()
4025 }
4026 "/predict_edits/reject" => {
4027 let mut buf = Vec::new();
4028 body.read_to_end(&mut buf).await.ok();
4029 let req = serde_json::from_slice(&buf).unwrap();
4030
4031 let (res_tx, res_rx) = oneshot::channel();
4032 reject_req_tx.unbounded_send((req, res_tx)).unwrap();
4033 serde_json::to_string(&res_rx.await?).unwrap()
4034 }
4035 _ => {
4036 panic!("Unexpected path: {}", uri)
4037 }
4038 };
4039
4040 Ok(Response::builder().body(resp.into()).unwrap())
4041 }
4042 }
4043 });
4044
4045 let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
4046 client.cloud_client().set_credentials(1, "test".into());
4047
4048 language_model::init(client.clone(), cx);
4049
4050 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
4051 let zeta = Zeta::global(&client, &user_store, cx);
4052
4053 (
4054 zeta,
4055 RequestChannels {
4056 predict: predict_req_rx,
4057 reject: reject_req_rx,
4058 },
4059 )
4060 })
4061 }
4062}