1use anyhow::{Context as _, Result, anyhow, bail};
2use arrayvec::ArrayVec;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat, Signature};
5use cloud_llm_client::{
6 AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason,
7 EditPredictionRejection, MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST,
8 MINIMUM_REQUIRED_VERSION_HEADER_NAME, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME,
9};
10use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
11use cloud_zeta2_prompt::{CURSOR_MARKER, DEFAULT_MAX_PROMPT_BYTES};
12use collections::{HashMap, HashSet};
13use command_palette_hooks::CommandPaletteFilter;
14use db::kvp::{Dismissable, KEY_VALUE_STORE};
15use edit_prediction_context::{
16 DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
17 EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionScoreOptions, Line,
18 SyntaxIndex, SyntaxIndexState,
19};
20use feature_flags::{FeatureFlag, FeatureFlagAppExt as _, PredictEditsRateCompletionsFeatureFlag};
21use futures::channel::mpsc::UnboundedReceiver;
22use futures::channel::{mpsc, oneshot};
23use futures::{AsyncReadExt as _, FutureExt as _, StreamExt as _, select_biased};
24use gpui::{
25 App, AsyncApp, BackgroundExecutor, Entity, EntityId, Global, SharedString, Subscription, Task,
26 WeakEntity, actions,
27 http_client::{self, AsyncBody, Method},
28 prelude::*,
29};
30use language::{
31 Anchor, Buffer, DiagnosticSet, File, LanguageServerId, Point, ToOffset as _, ToPoint,
32};
33use language::{BufferSnapshot, OffsetRangeExt};
34use language_model::{LlmApiToken, RefreshLlmTokenListener};
35use open_ai::FunctionDefinition;
36use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
37use release_channel::AppVersion;
38use semver::Version;
39use serde::de::DeserializeOwned;
40use settings::{EditPredictionProvider, Settings as _, SettingsStore, update_settings_file};
41use std::any::{Any as _, TypeId};
42use std::collections::{VecDeque, hash_map};
43use telemetry_events::EditPredictionRating;
44use workspace::Workspace;
45
46use std::ops::Range;
47use std::path::Path;
48use std::rc::Rc;
49use std::str::FromStr as _;
50use std::sync::{Arc, LazyLock};
51use std::time::{Duration, Instant};
52use std::{env, mem};
53use thiserror::Error;
54use util::rel_path::RelPathBuf;
55use util::{LogErrorFuture, RangeExt as _, ResultExt as _, TryFutureExt};
56use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
57
58pub mod assemble_excerpts;
59mod license_detection;
60mod onboarding_modal;
61mod prediction;
62mod provider;
63mod rate_prediction_modal;
64pub mod retrieval_search;
65mod sweep_ai;
66pub mod udiff;
67mod xml_edits;
68pub mod zeta1;
69
70#[cfg(test)]
71mod zeta_tests;
72
73use crate::assemble_excerpts::assemble_excerpts;
74use crate::license_detection::LicenseDetectionWatcher;
75use crate::onboarding_modal::ZedPredictModal;
76pub use crate::prediction::EditPrediction;
77pub use crate::prediction::EditPredictionId;
78pub use crate::prediction::EditPredictionInputs;
79use crate::prediction::EditPredictionResult;
80use crate::rate_prediction_modal::{
81 NextEdit, PreviousEdit, RatePredictionsModal, ThumbsDownActivePrediction,
82 ThumbsUpActivePrediction,
83};
84use crate::sweep_ai::SweepAi;
85use crate::zeta1::request_prediction_with_zeta1;
86pub use provider::ZetaEditPredictionProvider;
87
88actions!(
89 edit_prediction,
90 [
91 /// Resets the edit prediction onboarding state.
92 ResetOnboarding,
93 /// Opens the rate completions modal.
94 RateCompletions,
95 /// Clears the edit prediction history.
96 ClearHistory,
97 ]
98);
99
100/// Maximum number of events to track.
101const EVENT_COUNT_MAX: usize = 6;
102const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
103const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
104const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
105
106pub struct SweepFeatureFlag;
107
108impl FeatureFlag for SweepFeatureFlag {
109 const NAME: &str = "sweep-ai";
110}
111pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
112 max_bytes: 512,
113 min_bytes: 128,
114 target_before_cursor_over_total_bytes: 0.5,
115};
116
117pub const DEFAULT_CONTEXT_OPTIONS: ContextMode =
118 ContextMode::Agentic(DEFAULT_AGENTIC_CONTEXT_OPTIONS);
119
120pub const DEFAULT_AGENTIC_CONTEXT_OPTIONS: AgenticContextOptions = AgenticContextOptions {
121 excerpt: DEFAULT_EXCERPT_OPTIONS,
122};
123
124pub const DEFAULT_SYNTAX_CONTEXT_OPTIONS: EditPredictionContextOptions =
125 EditPredictionContextOptions {
126 use_imports: true,
127 max_retrieved_declarations: 0,
128 excerpt: DEFAULT_EXCERPT_OPTIONS,
129 score: EditPredictionScoreOptions {
130 omit_excerpt_overlaps: true,
131 },
132 };
133
134pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
135 context: DEFAULT_CONTEXT_OPTIONS,
136 max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
137 max_diagnostic_bytes: 2048,
138 prompt_format: PromptFormat::DEFAULT,
139 file_indexing_parallelism: 1,
140 buffer_change_grouping_interval: Duration::from_secs(1),
141};
142
143static USE_OLLAMA: LazyLock<bool> =
144 LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
145static CONTEXT_RETRIEVAL_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
146 env::var("ZED_ZETA2_CONTEXT_MODEL").unwrap_or(if *USE_OLLAMA {
147 "qwen3-coder:30b".to_string()
148 } else {
149 "yqvev8r3".to_string()
150 })
151});
152static EDIT_PREDICTIONS_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
153 match env::var("ZED_ZETA2_MODEL").as_deref() {
154 Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten
155 Ok(model) => model,
156 Err(_) if *USE_OLLAMA => "qwen3-coder:30b",
157 Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten
158 }
159 .to_string()
160});
161static PREDICT_EDITS_URL: LazyLock<Option<String>> = LazyLock::new(|| {
162 env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| {
163 if *USE_OLLAMA {
164 Some("http://localhost:11434/v1/chat/completions".into())
165 } else {
166 None
167 }
168 })
169});
170
171pub struct Zeta2FeatureFlag;
172
173impl FeatureFlag for Zeta2FeatureFlag {
174 const NAME: &'static str = "zeta2";
175
176 fn enabled_for_staff() -> bool {
177 true
178 }
179}
180
181#[derive(Clone)]
182struct ZetaGlobal(Entity<Zeta>);
183
184impl Global for ZetaGlobal {}
185
186pub struct Zeta {
187 client: Arc<Client>,
188 user_store: Entity<UserStore>,
189 llm_token: LlmApiToken,
190 _llm_token_subscription: Subscription,
191 projects: HashMap<EntityId, ZetaProject>,
192 options: ZetaOptions,
193 update_required: bool,
194 debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
195 #[cfg(feature = "eval-support")]
196 eval_cache: Option<Arc<dyn EvalCache>>,
197 edit_prediction_model: ZetaEditPredictionModel,
198 sweep_ai: SweepAi,
199 data_collection_choice: DataCollectionChoice,
200 reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
201 shown_predictions: VecDeque<EditPrediction>,
202 rated_predictions: HashSet<EditPredictionId>,
203}
204
205#[derive(Copy, Clone, Default, PartialEq, Eq)]
206pub enum ZetaEditPredictionModel {
207 #[default]
208 Zeta1,
209 Zeta2,
210 Sweep,
211}
212
213#[derive(Debug, Clone, PartialEq)]
214pub struct ZetaOptions {
215 pub context: ContextMode,
216 pub max_prompt_bytes: usize,
217 pub max_diagnostic_bytes: usize,
218 pub prompt_format: predict_edits_v3::PromptFormat,
219 pub file_indexing_parallelism: usize,
220 pub buffer_change_grouping_interval: Duration,
221}
222
223#[derive(Debug, Clone, PartialEq)]
224pub enum ContextMode {
225 Agentic(AgenticContextOptions),
226 Syntax(EditPredictionContextOptions),
227}
228
229#[derive(Debug, Clone, PartialEq)]
230pub struct AgenticContextOptions {
231 pub excerpt: EditPredictionExcerptOptions,
232}
233
234impl ContextMode {
235 pub fn excerpt(&self) -> &EditPredictionExcerptOptions {
236 match self {
237 ContextMode::Agentic(options) => &options.excerpt,
238 ContextMode::Syntax(options) => &options.excerpt,
239 }
240 }
241}
242
243#[derive(Debug)]
244pub enum ZetaDebugInfo {
245 ContextRetrievalStarted(ZetaContextRetrievalStartedDebugInfo),
246 SearchQueriesGenerated(ZetaSearchQueryDebugInfo),
247 SearchQueriesExecuted(ZetaContextRetrievalDebugInfo),
248 ContextRetrievalFinished(ZetaContextRetrievalDebugInfo),
249 EditPredictionRequested(ZetaEditPredictionDebugInfo),
250}
251
252#[derive(Debug)]
253pub struct ZetaContextRetrievalStartedDebugInfo {
254 pub project: Entity<Project>,
255 pub timestamp: Instant,
256 pub search_prompt: String,
257}
258
259#[derive(Debug)]
260pub struct ZetaContextRetrievalDebugInfo {
261 pub project: Entity<Project>,
262 pub timestamp: Instant,
263}
264
265#[derive(Debug)]
266pub struct ZetaEditPredictionDebugInfo {
267 pub inputs: EditPredictionInputs,
268 pub retrieval_time: Duration,
269 pub buffer: WeakEntity<Buffer>,
270 pub position: language::Anchor,
271 pub local_prompt: Result<String, String>,
272 pub response_rx: oneshot::Receiver<(Result<open_ai::Response, String>, Duration)>,
273}
274
275#[derive(Debug)]
276pub struct ZetaSearchQueryDebugInfo {
277 pub project: Entity<Project>,
278 pub timestamp: Instant,
279 pub search_queries: Vec<SearchToolQuery>,
280}
281
282pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
283
284struct ZetaProject {
285 syntax_index: Option<Entity<SyntaxIndex>>,
286 events: VecDeque<Arc<cloud_llm_client::predict_edits_v3::Event>>,
287 last_event: Option<LastEvent>,
288 recent_paths: VecDeque<ProjectPath>,
289 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
290 current_prediction: Option<CurrentEditPrediction>,
291 next_pending_prediction_id: usize,
292 pending_predictions: ArrayVec<PendingPrediction, 2>,
293 last_prediction_refresh: Option<(EntityId, Instant)>,
294 cancelled_predictions: HashSet<usize>,
295 context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
296 refresh_context_task: Option<LogErrorFuture<Task<Result<()>>>>,
297 refresh_context_debounce_task: Option<Task<Option<()>>>,
298 refresh_context_timestamp: Option<Instant>,
299 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
300 _subscription: gpui::Subscription,
301}
302
303impl ZetaProject {
304 pub fn events(&self, cx: &App) -> Vec<Arc<cloud_llm_client::predict_edits_v3::Event>> {
305 self.events
306 .iter()
307 .cloned()
308 .chain(
309 self.last_event
310 .as_ref()
311 .and_then(|event| event.finalize(&self.license_detection_watchers, cx)),
312 )
313 .collect()
314 }
315
316 fn cancel_pending_prediction(
317 &mut self,
318 pending_prediction: PendingPrediction,
319 cx: &mut Context<Zeta>,
320 ) {
321 self.cancelled_predictions.insert(pending_prediction.id);
322
323 cx.spawn(async move |this, cx| {
324 let Some(prediction_id) = pending_prediction.task.await else {
325 return;
326 };
327
328 this.update(cx, |this, _cx| {
329 this.reject_prediction(prediction_id, EditPredictionRejectReason::Canceled, false);
330 })
331 .ok();
332 })
333 .detach()
334 }
335}
336
337#[derive(Debug, Clone)]
338struct CurrentEditPrediction {
339 pub requested_by: PredictionRequestedBy,
340 pub prediction: EditPrediction,
341 pub was_shown: bool,
342}
343
344impl CurrentEditPrediction {
345 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
346 let Some(new_edits) = self
347 .prediction
348 .interpolate(&self.prediction.buffer.read(cx))
349 else {
350 return false;
351 };
352
353 if self.prediction.buffer != old_prediction.prediction.buffer {
354 return true;
355 }
356
357 let Some(old_edits) = old_prediction
358 .prediction
359 .interpolate(&old_prediction.prediction.buffer.read(cx))
360 else {
361 return true;
362 };
363
364 let requested_by_buffer_id = self.requested_by.buffer_id();
365
366 // This reduces the occurrence of UI thrash from replacing edits
367 //
368 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
369 if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
370 && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
371 && old_edits.len() == 1
372 && new_edits.len() == 1
373 {
374 let (old_range, old_text) = &old_edits[0];
375 let (new_range, new_text) = &new_edits[0];
376 new_range == old_range && new_text.starts_with(old_text.as_ref())
377 } else {
378 true
379 }
380 }
381}
382
383#[derive(Debug, Clone)]
384enum PredictionRequestedBy {
385 DiagnosticsUpdate,
386 Buffer(EntityId),
387}
388
389impl PredictionRequestedBy {
390 pub fn buffer_id(&self) -> Option<EntityId> {
391 match self {
392 PredictionRequestedBy::DiagnosticsUpdate => None,
393 PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
394 }
395 }
396}
397
398#[derive(Debug)]
399struct PendingPrediction {
400 id: usize,
401 task: Task<Option<EditPredictionId>>,
402}
403
404/// A prediction from the perspective of a buffer.
405#[derive(Debug)]
406enum BufferEditPrediction<'a> {
407 Local { prediction: &'a EditPrediction },
408 Jump { prediction: &'a EditPrediction },
409}
410
411#[cfg(test)]
412impl std::ops::Deref for BufferEditPrediction<'_> {
413 type Target = EditPrediction;
414
415 fn deref(&self) -> &Self::Target {
416 match self {
417 BufferEditPrediction::Local { prediction } => prediction,
418 BufferEditPrediction::Jump { prediction } => prediction,
419 }
420 }
421}
422
423struct RegisteredBuffer {
424 snapshot: BufferSnapshot,
425 _subscriptions: [gpui::Subscription; 2],
426}
427
428struct LastEvent {
429 old_snapshot: BufferSnapshot,
430 new_snapshot: BufferSnapshot,
431 end_edit_anchor: Option<Anchor>,
432}
433
434impl LastEvent {
435 pub fn finalize(
436 &self,
437 license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
438 cx: &App,
439 ) -> Option<Arc<predict_edits_v3::Event>> {
440 let path = buffer_path_with_id_fallback(&self.new_snapshot, cx);
441 let old_path = buffer_path_with_id_fallback(&self.old_snapshot, cx);
442
443 let file = self.new_snapshot.file();
444 let old_file = self.old_snapshot.file();
445
446 let in_open_source_repo = [file, old_file].iter().all(|file| {
447 file.is_some_and(|file| {
448 license_detection_watchers
449 .get(&file.worktree_id(cx))
450 .is_some_and(|watcher| watcher.is_project_open_source())
451 })
452 });
453
454 let diff = language::unified_diff(&self.old_snapshot.text(), &self.new_snapshot.text());
455
456 if path == old_path && diff.is_empty() {
457 None
458 } else {
459 Some(Arc::new(predict_edits_v3::Event::BufferChange {
460 old_path,
461 path,
462 diff,
463 in_open_source_repo,
464 // TODO: Actually detect if this edit was predicted or not
465 predicted: false,
466 }))
467 }
468 }
469}
470
471fn buffer_path_with_id_fallback(snapshot: &BufferSnapshot, cx: &App) -> Arc<Path> {
472 if let Some(file) = snapshot.file() {
473 file.full_path(cx).into()
474 } else {
475 Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
476 }
477}
478
479impl Zeta {
480 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
481 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
482 }
483
484 pub fn global(
485 client: &Arc<Client>,
486 user_store: &Entity<UserStore>,
487 cx: &mut App,
488 ) -> Entity<Self> {
489 cx.try_global::<ZetaGlobal>()
490 .map(|global| global.0.clone())
491 .unwrap_or_else(|| {
492 let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
493 cx.set_global(ZetaGlobal(zeta.clone()));
494 zeta
495 })
496 }
497
498 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
499 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
500 let data_collection_choice = Self::load_data_collection_choice();
501
502 let llm_token = LlmApiToken::default();
503
504 let (reject_tx, reject_rx) = mpsc::unbounded();
505 cx.background_spawn({
506 let client = client.clone();
507 let llm_token = llm_token.clone();
508 let app_version = AppVersion::global(cx);
509 let background_executor = cx.background_executor().clone();
510 async move {
511 Self::handle_rejected_predictions(
512 reject_rx,
513 client,
514 llm_token,
515 app_version,
516 background_executor,
517 )
518 .await
519 }
520 })
521 .detach();
522
523 Self {
524 projects: HashMap::default(),
525 client,
526 user_store,
527 options: DEFAULT_OPTIONS,
528 llm_token,
529 _llm_token_subscription: cx.subscribe(
530 &refresh_llm_token_listener,
531 |this, _listener, _event, cx| {
532 let client = this.client.clone();
533 let llm_token = this.llm_token.clone();
534 cx.spawn(async move |_this, _cx| {
535 llm_token.refresh(&client).await?;
536 anyhow::Ok(())
537 })
538 .detach_and_log_err(cx);
539 },
540 ),
541 update_required: false,
542 debug_tx: None,
543 #[cfg(feature = "eval-support")]
544 eval_cache: None,
545 edit_prediction_model: ZetaEditPredictionModel::Zeta2,
546 sweep_ai: SweepAi::new(cx),
547 data_collection_choice,
548 reject_predictions_tx: reject_tx,
549 rated_predictions: Default::default(),
550 shown_predictions: Default::default(),
551 }
552 }
553
554 pub fn set_edit_prediction_model(&mut self, model: ZetaEditPredictionModel) {
555 self.edit_prediction_model = model;
556 }
557
558 pub fn has_sweep_api_token(&self) -> bool {
559 self.sweep_ai.api_token.is_some()
560 }
561
562 #[cfg(feature = "eval-support")]
563 pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
564 self.eval_cache = Some(cache);
565 }
566
567 pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<ZetaDebugInfo> {
568 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
569 self.debug_tx = Some(debug_watch_tx);
570 debug_watch_rx
571 }
572
573 pub fn options(&self) -> &ZetaOptions {
574 &self.options
575 }
576
577 pub fn set_options(&mut self, options: ZetaOptions) {
578 self.options = options;
579 }
580
581 pub fn clear_history(&mut self) {
582 for zeta_project in self.projects.values_mut() {
583 zeta_project.events.clear();
584 }
585 }
586
587 pub fn context_for_project(
588 &self,
589 project: &Entity<Project>,
590 ) -> impl Iterator<Item = (Entity<Buffer>, &[Range<Anchor>])> {
591 self.projects
592 .get(&project.entity_id())
593 .and_then(|project| {
594 Some(
595 project
596 .context
597 .as_ref()?
598 .iter()
599 .map(|(buffer, ranges)| (buffer.clone(), ranges.as_slice())),
600 )
601 })
602 .into_iter()
603 .flatten()
604 }
605
606 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
607 if self.edit_prediction_model == ZetaEditPredictionModel::Zeta2 {
608 self.user_store.read(cx).edit_prediction_usage()
609 } else {
610 None
611 }
612 }
613
614 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
615 self.get_or_init_zeta_project(project, cx);
616 }
617
618 pub fn register_buffer(
619 &mut self,
620 buffer: &Entity<Buffer>,
621 project: &Entity<Project>,
622 cx: &mut Context<Self>,
623 ) {
624 let zeta_project = self.get_or_init_zeta_project(project, cx);
625 Self::register_buffer_impl(zeta_project, buffer, project, cx);
626 }
627
628 fn get_or_init_zeta_project(
629 &mut self,
630 project: &Entity<Project>,
631 cx: &mut Context<Self>,
632 ) -> &mut ZetaProject {
633 self.projects
634 .entry(project.entity_id())
635 .or_insert_with(|| ZetaProject {
636 syntax_index: if let ContextMode::Syntax(_) = &self.options.context {
637 Some(cx.new(|cx| {
638 SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
639 }))
640 } else {
641 None
642 },
643 events: VecDeque::new(),
644 last_event: None,
645 recent_paths: VecDeque::new(),
646 registered_buffers: HashMap::default(),
647 current_prediction: None,
648 cancelled_predictions: HashSet::default(),
649 pending_predictions: ArrayVec::new(),
650 next_pending_prediction_id: 0,
651 last_prediction_refresh: None,
652 context: None,
653 refresh_context_task: None,
654 refresh_context_debounce_task: None,
655 refresh_context_timestamp: None,
656 license_detection_watchers: HashMap::default(),
657 _subscription: cx.subscribe(&project, Self::handle_project_event),
658 })
659 }
660
661 fn handle_project_event(
662 &mut self,
663 project: Entity<Project>,
664 event: &project::Event,
665 cx: &mut Context<Self>,
666 ) {
667 // TODO [zeta2] init with recent paths
668 match event {
669 project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
670 let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
671 return;
672 };
673 let path = project.read(cx).path_for_entry(*active_entry_id, cx);
674 if let Some(path) = path {
675 if let Some(ix) = zeta_project
676 .recent_paths
677 .iter()
678 .position(|probe| probe == &path)
679 {
680 zeta_project.recent_paths.remove(ix);
681 }
682 zeta_project.recent_paths.push_front(path);
683 }
684 }
685 project::Event::DiagnosticsUpdated { .. } => {
686 if cx.has_flag::<Zeta2FeatureFlag>() {
687 self.refresh_prediction_from_diagnostics(project, cx);
688 }
689 }
690 _ => (),
691 }
692 }
693
694 fn register_buffer_impl<'a>(
695 zeta_project: &'a mut ZetaProject,
696 buffer: &Entity<Buffer>,
697 project: &Entity<Project>,
698 cx: &mut Context<Self>,
699 ) -> &'a mut RegisteredBuffer {
700 let buffer_id = buffer.entity_id();
701
702 if let Some(file) = buffer.read(cx).file() {
703 let worktree_id = file.worktree_id(cx);
704 if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
705 zeta_project
706 .license_detection_watchers
707 .entry(worktree_id)
708 .or_insert_with(|| {
709 let project_entity_id = project.entity_id();
710 cx.observe_release(&worktree, move |this, _worktree, _cx| {
711 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
712 else {
713 return;
714 };
715 zeta_project.license_detection_watchers.remove(&worktree_id);
716 })
717 .detach();
718 Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
719 });
720 }
721 }
722
723 match zeta_project.registered_buffers.entry(buffer_id) {
724 hash_map::Entry::Occupied(entry) => entry.into_mut(),
725 hash_map::Entry::Vacant(entry) => {
726 let snapshot = buffer.read(cx).snapshot();
727 let project_entity_id = project.entity_id();
728 entry.insert(RegisteredBuffer {
729 snapshot,
730 _subscriptions: [
731 cx.subscribe(buffer, {
732 let project = project.downgrade();
733 move |this, buffer, event, cx| {
734 if let language::BufferEvent::Edited = event
735 && let Some(project) = project.upgrade()
736 {
737 this.report_changes_for_buffer(&buffer, &project, cx);
738 }
739 }
740 }),
741 cx.observe_release(buffer, move |this, _buffer, _cx| {
742 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
743 else {
744 return;
745 };
746 zeta_project.registered_buffers.remove(&buffer_id);
747 }),
748 ],
749 })
750 }
751 }
752 }
753
754 fn report_changes_for_buffer(
755 &mut self,
756 buffer: &Entity<Buffer>,
757 project: &Entity<Project>,
758 cx: &mut Context<Self>,
759 ) {
760 let project_state = self.get_or_init_zeta_project(project, cx);
761 let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
762
763 let new_snapshot = buffer.read(cx).snapshot();
764 if new_snapshot.version == registered_buffer.snapshot.version {
765 return;
766 }
767
768 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
769 let end_edit_anchor = new_snapshot
770 .anchored_edits_since::<Point>(&old_snapshot.version)
771 .last()
772 .map(|(_, range)| range.end);
773 let events = &mut project_state.events;
774
775 if let Some(LastEvent {
776 new_snapshot: last_new_snapshot,
777 end_edit_anchor: last_end_edit_anchor,
778 ..
779 }) = project_state.last_event.as_mut()
780 {
781 let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
782 == last_new_snapshot.remote_id()
783 && old_snapshot.version == last_new_snapshot.version;
784
785 let should_coalesce = is_next_snapshot_of_same_buffer
786 && end_edit_anchor
787 .as_ref()
788 .zip(last_end_edit_anchor.as_ref())
789 .is_some_and(|(a, b)| {
790 let a = a.to_point(&new_snapshot);
791 let b = b.to_point(&new_snapshot);
792 a.row.abs_diff(b.row) <= CHANGE_GROUPING_LINE_SPAN
793 });
794
795 if should_coalesce {
796 *last_end_edit_anchor = end_edit_anchor;
797 *last_new_snapshot = new_snapshot;
798 return;
799 }
800 }
801
802 if events.len() + 1 >= EVENT_COUNT_MAX {
803 events.pop_front();
804 }
805
806 if let Some(event) = project_state.last_event.take() {
807 events.extend(event.finalize(&project_state.license_detection_watchers, cx));
808 }
809
810 project_state.last_event = Some(LastEvent {
811 old_snapshot,
812 new_snapshot,
813 end_edit_anchor,
814 });
815 }
816
817 fn current_prediction_for_buffer(
818 &self,
819 buffer: &Entity<Buffer>,
820 project: &Entity<Project>,
821 cx: &App,
822 ) -> Option<BufferEditPrediction<'_>> {
823 let project_state = self.projects.get(&project.entity_id())?;
824
825 let CurrentEditPrediction {
826 requested_by,
827 prediction,
828 ..
829 } = project_state.current_prediction.as_ref()?;
830
831 if prediction.targets_buffer(buffer.read(cx)) {
832 Some(BufferEditPrediction::Local { prediction })
833 } else {
834 let show_jump = match requested_by {
835 PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
836 requested_by_buffer_id == &buffer.entity_id()
837 }
838 PredictionRequestedBy::DiagnosticsUpdate => true,
839 };
840
841 if show_jump {
842 Some(BufferEditPrediction::Jump { prediction })
843 } else {
844 None
845 }
846 }
847 }
848
849 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
850 match self.edit_prediction_model {
851 ZetaEditPredictionModel::Zeta1 | ZetaEditPredictionModel::Zeta2 => {}
852 ZetaEditPredictionModel::Sweep => return,
853 }
854
855 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
856 return;
857 };
858
859 let Some(prediction) = project_state.current_prediction.take() else {
860 return;
861 };
862 let request_id = prediction.prediction.id.to_string();
863 for pending_prediction in mem::take(&mut project_state.pending_predictions) {
864 project_state.cancel_pending_prediction(pending_prediction, cx);
865 }
866
867 let client = self.client.clone();
868 let llm_token = self.llm_token.clone();
869 let app_version = AppVersion::global(cx);
870 cx.spawn(async move |this, cx| {
871 let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") {
872 http_client::Url::parse(&predict_edits_url)?
873 } else {
874 client
875 .http_client()
876 .build_zed_llm_url("/predict_edits/accept", &[])?
877 };
878
879 let response = cx
880 .background_spawn(Self::send_api_request::<()>(
881 move |builder| {
882 let req = builder.uri(url.as_ref()).body(
883 serde_json::to_string(&AcceptEditPredictionBody {
884 request_id: request_id.clone(),
885 })?
886 .into(),
887 );
888 Ok(req?)
889 },
890 client,
891 llm_token,
892 app_version,
893 ))
894 .await;
895
896 Self::handle_api_response(&this, response, cx)?;
897 anyhow::Ok(())
898 })
899 .detach_and_log_err(cx);
900 }
901
902 async fn handle_rejected_predictions(
903 rx: UnboundedReceiver<EditPredictionRejection>,
904 client: Arc<Client>,
905 llm_token: LlmApiToken,
906 app_version: Version,
907 background_executor: BackgroundExecutor,
908 ) {
909 let mut rx = std::pin::pin!(rx.peekable());
910 let mut batched = Vec::new();
911
912 while let Some(rejection) = rx.next().await {
913 batched.push(rejection);
914
915 if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
916 select_biased! {
917 next = rx.as_mut().peek().fuse() => {
918 if next.is_some() {
919 continue;
920 }
921 }
922 () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
923 }
924 }
925
926 let url = client
927 .http_client()
928 .build_zed_llm_url("/predict_edits/reject", &[])
929 .unwrap();
930
931 let flush_count = batched
932 .len()
933 // in case items have accumulated after failure
934 .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
935 let start = batched.len() - flush_count;
936
937 let body = RejectEditPredictionsBodyRef {
938 rejections: &batched[start..],
939 };
940
941 let result = Self::send_api_request::<()>(
942 |builder| {
943 let req = builder
944 .uri(url.as_ref())
945 .body(serde_json::to_string(&body)?.into());
946 anyhow::Ok(req?)
947 },
948 client.clone(),
949 llm_token.clone(),
950 app_version.clone(),
951 )
952 .await;
953
954 if result.log_err().is_some() {
955 batched.drain(start..);
956 }
957 }
958 }
959
960 fn reject_current_prediction(
961 &mut self,
962 reason: EditPredictionRejectReason,
963 project: &Entity<Project>,
964 ) {
965 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
966 project_state.pending_predictions.clear();
967 if let Some(prediction) = project_state.current_prediction.take() {
968 self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown);
969 }
970 };
971 }
972
973 fn did_show_current_prediction(&mut self, project: &Entity<Project>, _cx: &mut Context<Self>) {
974 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
975 if let Some(current_prediction) = project_state.current_prediction.as_mut() {
976 if !current_prediction.was_shown {
977 current_prediction.was_shown = true;
978 self.shown_predictions
979 .push_front(current_prediction.prediction.clone());
980 if self.shown_predictions.len() > 50 {
981 let completion = self.shown_predictions.pop_back().unwrap();
982 self.rated_predictions.remove(&completion.id);
983 }
984 }
985 }
986 }
987 }
988
989 fn reject_prediction(
990 &mut self,
991 prediction_id: EditPredictionId,
992 reason: EditPredictionRejectReason,
993 was_shown: bool,
994 ) {
995 match self.edit_prediction_model {
996 ZetaEditPredictionModel::Zeta1 | ZetaEditPredictionModel::Zeta2 => {}
997 ZetaEditPredictionModel::Sweep => return,
998 }
999
1000 self.reject_predictions_tx
1001 .unbounded_send(EditPredictionRejection {
1002 request_id: prediction_id.to_string(),
1003 reason,
1004 was_shown,
1005 })
1006 .log_err();
1007 }
1008
1009 fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1010 self.projects
1011 .get(&project.entity_id())
1012 .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1013 }
1014
1015 pub fn refresh_prediction_from_buffer(
1016 &mut self,
1017 project: Entity<Project>,
1018 buffer: Entity<Buffer>,
1019 position: language::Anchor,
1020 cx: &mut Context<Self>,
1021 ) {
1022 self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
1023 let Some(request_task) = this
1024 .update(cx, |this, cx| {
1025 this.request_prediction(&project, &buffer, position, cx)
1026 })
1027 .log_err()
1028 else {
1029 return Task::ready(anyhow::Ok(None));
1030 };
1031
1032 cx.spawn(async move |_cx| {
1033 request_task.await.map(|prediction_result| {
1034 prediction_result.map(|prediction_result| {
1035 (
1036 prediction_result,
1037 PredictionRequestedBy::Buffer(buffer.entity_id()),
1038 )
1039 })
1040 })
1041 })
1042 })
1043 }
1044
1045 pub fn refresh_prediction_from_diagnostics(
1046 &mut self,
1047 project: Entity<Project>,
1048 cx: &mut Context<Self>,
1049 ) {
1050 let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1051 return;
1052 };
1053
1054 // Prefer predictions from buffer
1055 if zeta_project.current_prediction.is_some() {
1056 return;
1057 };
1058
1059 self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
1060 let Some(open_buffer_task) = project
1061 .update(cx, |project, cx| {
1062 project
1063 .active_entry()
1064 .and_then(|entry| project.path_for_entry(entry, cx))
1065 .map(|path| project.open_buffer(path, cx))
1066 })
1067 .log_err()
1068 .flatten()
1069 else {
1070 return Task::ready(anyhow::Ok(None));
1071 };
1072
1073 cx.spawn(async move |cx| {
1074 let active_buffer = open_buffer_task.await?;
1075 let snapshot = active_buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
1076
1077 let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1078 active_buffer,
1079 &snapshot,
1080 Default::default(),
1081 Default::default(),
1082 &project,
1083 cx,
1084 )
1085 .await?
1086 else {
1087 return anyhow::Ok(None);
1088 };
1089
1090 let Some(prediction_result) = this
1091 .update(cx, |this, cx| {
1092 this.request_prediction(&project, &jump_buffer, jump_position, cx)
1093 })?
1094 .await?
1095 else {
1096 return anyhow::Ok(None);
1097 };
1098
1099 this.update(cx, |this, cx| {
1100 Some((
1101 if this
1102 .get_or_init_zeta_project(&project, cx)
1103 .current_prediction
1104 .is_none()
1105 {
1106 prediction_result
1107 } else {
1108 EditPredictionResult {
1109 id: prediction_result.id,
1110 prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1111 }
1112 },
1113 PredictionRequestedBy::DiagnosticsUpdate,
1114 ))
1115 })
1116 })
1117 });
1118 }
1119
1120 #[cfg(not(test))]
1121 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1122 #[cfg(test)]
1123 pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
1124
1125 fn queue_prediction_refresh(
1126 &mut self,
1127 project: Entity<Project>,
1128 throttle_entity: EntityId,
1129 cx: &mut Context<Self>,
1130 do_refresh: impl FnOnce(
1131 WeakEntity<Self>,
1132 &mut AsyncApp,
1133 )
1134 -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1135 + 'static,
1136 ) {
1137 let zeta_project = self.get_or_init_zeta_project(&project, cx);
1138 let pending_prediction_id = zeta_project.next_pending_prediction_id;
1139 zeta_project.next_pending_prediction_id += 1;
1140 let last_request = zeta_project.last_prediction_refresh;
1141
1142 let task = cx.spawn(async move |this, cx| {
1143 if let Some((last_entity, last_timestamp)) = last_request
1144 && throttle_entity == last_entity
1145 && let Some(timeout) =
1146 (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
1147 {
1148 cx.background_executor().timer(timeout).await;
1149 }
1150
1151 // If this task was cancelled before the throttle timeout expired,
1152 // do not perform a request.
1153 let mut is_cancelled = true;
1154 this.update(cx, |this, cx| {
1155 let project_state = this.get_or_init_zeta_project(&project, cx);
1156 if !project_state
1157 .cancelled_predictions
1158 .remove(&pending_prediction_id)
1159 {
1160 project_state.last_prediction_refresh = Some((throttle_entity, Instant::now()));
1161 is_cancelled = false;
1162 }
1163 })
1164 .ok();
1165 if is_cancelled {
1166 return None;
1167 }
1168
1169 let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
1170 let new_prediction_id = new_prediction_result
1171 .as_ref()
1172 .map(|(prediction, _)| prediction.id.clone());
1173
1174 // When a prediction completes, remove it from the pending list, and cancel
1175 // any pending predictions that were enqueued before it.
1176 this.update(cx, |this, cx| {
1177 let zeta_project = this.get_or_init_zeta_project(&project, cx);
1178
1179 let is_cancelled = zeta_project
1180 .cancelled_predictions
1181 .remove(&pending_prediction_id);
1182
1183 let new_current_prediction = if !is_cancelled
1184 && let Some((prediction_result, requested_by)) = new_prediction_result
1185 {
1186 match prediction_result.prediction {
1187 Ok(prediction) => {
1188 let new_prediction = CurrentEditPrediction {
1189 requested_by,
1190 prediction,
1191 was_shown: false,
1192 };
1193
1194 if let Some(current_prediction) =
1195 zeta_project.current_prediction.as_ref()
1196 {
1197 if new_prediction.should_replace_prediction(¤t_prediction, cx)
1198 {
1199 this.reject_current_prediction(
1200 EditPredictionRejectReason::Replaced,
1201 &project,
1202 );
1203
1204 Some(new_prediction)
1205 } else {
1206 this.reject_prediction(
1207 new_prediction.prediction.id,
1208 EditPredictionRejectReason::CurrentPreferred,
1209 false,
1210 );
1211 None
1212 }
1213 } else {
1214 Some(new_prediction)
1215 }
1216 }
1217 Err(reject_reason) => {
1218 this.reject_prediction(prediction_result.id, reject_reason, false);
1219 None
1220 }
1221 }
1222 } else {
1223 None
1224 };
1225
1226 let zeta_project = this.get_or_init_zeta_project(&project, cx);
1227
1228 if let Some(new_prediction) = new_current_prediction {
1229 zeta_project.current_prediction = Some(new_prediction);
1230 }
1231
1232 let mut pending_predictions = mem::take(&mut zeta_project.pending_predictions);
1233 for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1234 if pending_prediction.id == pending_prediction_id {
1235 pending_predictions.remove(ix);
1236 for pending_prediction in pending_predictions.drain(0..ix) {
1237 zeta_project.cancel_pending_prediction(pending_prediction, cx)
1238 }
1239 break;
1240 }
1241 }
1242 this.get_or_init_zeta_project(&project, cx)
1243 .pending_predictions = pending_predictions;
1244 cx.notify();
1245 })
1246 .ok();
1247
1248 new_prediction_id
1249 });
1250
1251 if zeta_project.pending_predictions.len() <= 1 {
1252 zeta_project.pending_predictions.push(PendingPrediction {
1253 id: pending_prediction_id,
1254 task,
1255 });
1256 } else if zeta_project.pending_predictions.len() == 2 {
1257 let pending_prediction = zeta_project.pending_predictions.pop().unwrap();
1258 zeta_project.pending_predictions.push(PendingPrediction {
1259 id: pending_prediction_id,
1260 task,
1261 });
1262 zeta_project.cancel_pending_prediction(pending_prediction, cx);
1263 }
1264 }
1265
1266 pub fn request_prediction(
1267 &mut self,
1268 project: &Entity<Project>,
1269 active_buffer: &Entity<Buffer>,
1270 position: language::Anchor,
1271 cx: &mut Context<Self>,
1272 ) -> Task<Result<Option<EditPredictionResult>>> {
1273 self.request_prediction_internal(
1274 project.clone(),
1275 active_buffer.clone(),
1276 position,
1277 cx.has_flag::<Zeta2FeatureFlag>(),
1278 cx,
1279 )
1280 }
1281
1282 fn request_prediction_internal(
1283 &mut self,
1284 project: Entity<Project>,
1285 active_buffer: Entity<Buffer>,
1286 position: language::Anchor,
1287 allow_jump: bool,
1288 cx: &mut Context<Self>,
1289 ) -> Task<Result<Option<EditPredictionResult>>> {
1290 const DIAGNOSTIC_LINES_RANGE: u32 = 20;
1291
1292 self.get_or_init_zeta_project(&project, cx);
1293 let zeta_project = self.projects.get(&project.entity_id()).unwrap();
1294 let events = zeta_project.events(cx);
1295 let has_events = !events.is_empty();
1296
1297 let snapshot = active_buffer.read(cx).snapshot();
1298 let cursor_point = position.to_point(&snapshot);
1299 let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1300 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1301 let diagnostic_search_range =
1302 Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1303
1304 let task = match self.edit_prediction_model {
1305 ZetaEditPredictionModel::Zeta1 => request_prediction_with_zeta1(
1306 self,
1307 &project,
1308 &active_buffer,
1309 snapshot.clone(),
1310 position,
1311 events,
1312 cx,
1313 ),
1314 ZetaEditPredictionModel::Zeta2 => self.request_prediction_with_zeta2(
1315 &project,
1316 &active_buffer,
1317 snapshot.clone(),
1318 position,
1319 events,
1320 cx,
1321 ),
1322 ZetaEditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(
1323 &project,
1324 &active_buffer,
1325 snapshot.clone(),
1326 position,
1327 events,
1328 &zeta_project.recent_paths,
1329 diagnostic_search_range.clone(),
1330 cx,
1331 ),
1332 };
1333
1334 cx.spawn(async move |this, cx| {
1335 let prediction = task.await?;
1336
1337 if prediction.is_none() && allow_jump {
1338 let cursor_point = position.to_point(&snapshot);
1339 if has_events
1340 && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1341 active_buffer.clone(),
1342 &snapshot,
1343 diagnostic_search_range,
1344 cursor_point,
1345 &project,
1346 cx,
1347 )
1348 .await?
1349 {
1350 return this
1351 .update(cx, |this, cx| {
1352 this.request_prediction_internal(
1353 project,
1354 jump_buffer,
1355 jump_position,
1356 false,
1357 cx,
1358 )
1359 })?
1360 .await;
1361 }
1362
1363 return anyhow::Ok(None);
1364 }
1365
1366 Ok(prediction)
1367 })
1368 }
1369
1370 async fn next_diagnostic_location(
1371 active_buffer: Entity<Buffer>,
1372 active_buffer_snapshot: &BufferSnapshot,
1373 active_buffer_diagnostic_search_range: Range<Point>,
1374 active_buffer_cursor_point: Point,
1375 project: &Entity<Project>,
1376 cx: &mut AsyncApp,
1377 ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1378 // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
1379 let mut jump_location = active_buffer_snapshot
1380 .diagnostic_groups(None)
1381 .into_iter()
1382 .filter_map(|(_, group)| {
1383 let range = &group.entries[group.primary_ix]
1384 .range
1385 .to_point(&active_buffer_snapshot);
1386 if range.overlaps(&active_buffer_diagnostic_search_range) {
1387 None
1388 } else {
1389 Some(range.start)
1390 }
1391 })
1392 .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
1393 .map(|position| {
1394 (
1395 active_buffer.clone(),
1396 active_buffer_snapshot.anchor_before(position),
1397 )
1398 });
1399
1400 if jump_location.is_none() {
1401 let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
1402 let file = buffer.file()?;
1403
1404 Some(ProjectPath {
1405 worktree_id: file.worktree_id(cx),
1406 path: file.path().clone(),
1407 })
1408 })?;
1409
1410 let buffer_task = project.update(cx, |project, cx| {
1411 let (path, _, _) = project
1412 .diagnostic_summaries(false, cx)
1413 .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
1414 .max_by_key(|(path, _, _)| {
1415 // find the buffer with errors that shares most parent directories
1416 path.path
1417 .components()
1418 .zip(
1419 active_buffer_path
1420 .as_ref()
1421 .map(|p| p.path.components())
1422 .unwrap_or_default(),
1423 )
1424 .take_while(|(a, b)| a == b)
1425 .count()
1426 })?;
1427
1428 Some(project.open_buffer(path, cx))
1429 })?;
1430
1431 if let Some(buffer_task) = buffer_task {
1432 let closest_buffer = buffer_task.await?;
1433
1434 jump_location = closest_buffer
1435 .read_with(cx, |buffer, _cx| {
1436 buffer
1437 .buffer_diagnostics(None)
1438 .into_iter()
1439 .min_by_key(|entry| entry.diagnostic.severity)
1440 .map(|entry| entry.range.start)
1441 })?
1442 .map(|position| (closest_buffer, position));
1443 }
1444 }
1445
1446 anyhow::Ok(jump_location)
1447 }
1448
1449 fn request_prediction_with_zeta2(
1450 &mut self,
1451 project: &Entity<Project>,
1452 active_buffer: &Entity<Buffer>,
1453 active_snapshot: BufferSnapshot,
1454 position: language::Anchor,
1455 events: Vec<Arc<Event>>,
1456 cx: &mut Context<Self>,
1457 ) -> Task<Result<Option<EditPredictionResult>>> {
1458 let project_state = self.projects.get(&project.entity_id());
1459
1460 let index_state = project_state.and_then(|state| {
1461 state
1462 .syntax_index
1463 .as_ref()
1464 .map(|syntax_index| syntax_index.read_with(cx, |index, _cx| index.state().clone()))
1465 });
1466 let options = self.options.clone();
1467 let buffer_snapshotted_at = Instant::now();
1468 let Some(excerpt_path) = active_snapshot
1469 .file()
1470 .map(|path| -> Arc<Path> { path.full_path(cx).into() })
1471 else {
1472 return Task::ready(Err(anyhow!("No file path for excerpt")));
1473 };
1474 let client = self.client.clone();
1475 let llm_token = self.llm_token.clone();
1476 let app_version = AppVersion::global(cx);
1477 let worktree_snapshots = project
1478 .read(cx)
1479 .worktrees(cx)
1480 .map(|worktree| worktree.read(cx).snapshot())
1481 .collect::<Vec<_>>();
1482 let debug_tx = self.debug_tx.clone();
1483
1484 let diagnostics = active_snapshot.diagnostic_sets().clone();
1485
1486 let file = active_buffer.read(cx).file();
1487 let parent_abs_path = project::File::from_dyn(file).and_then(|f| {
1488 let mut path = f.worktree.read(cx).absolutize(&f.path);
1489 if path.pop() { Some(path) } else { None }
1490 });
1491
1492 // TODO data collection
1493 let can_collect_data = file
1494 .as_ref()
1495 .map_or(false, |file| self.can_collect_file(project, file, cx));
1496
1497 let empty_context_files = HashMap::default();
1498 let context_files = project_state
1499 .and_then(|project_state| project_state.context.as_ref())
1500 .unwrap_or(&empty_context_files);
1501
1502 #[cfg(feature = "eval-support")]
1503 let parsed_fut = futures::future::join_all(
1504 context_files
1505 .keys()
1506 .map(|buffer| buffer.read(cx).parsing_idle()),
1507 );
1508
1509 let mut included_files = context_files
1510 .iter()
1511 .filter_map(|(buffer_entity, ranges)| {
1512 let buffer = buffer_entity.read(cx);
1513 Some((
1514 buffer_entity.clone(),
1515 buffer.snapshot(),
1516 buffer.file()?.full_path(cx).into(),
1517 ranges.clone(),
1518 ))
1519 })
1520 .collect::<Vec<_>>();
1521
1522 included_files.sort_by(|(_, _, path_a, ranges_a), (_, _, path_b, ranges_b)| {
1523 (path_a, ranges_a.len()).cmp(&(path_b, ranges_b.len()))
1524 });
1525
1526 #[cfg(feature = "eval-support")]
1527 let eval_cache = self.eval_cache.clone();
1528
1529 let request_task = cx.background_spawn({
1530 let active_buffer = active_buffer.clone();
1531 async move {
1532 #[cfg(feature = "eval-support")]
1533 parsed_fut.await;
1534
1535 let index_state = if let Some(index_state) = index_state {
1536 Some(index_state.lock_owned().await)
1537 } else {
1538 None
1539 };
1540
1541 let cursor_offset = position.to_offset(&active_snapshot);
1542 let cursor_point = cursor_offset.to_point(&active_snapshot);
1543
1544 let before_retrieval = Instant::now();
1545
1546 let (diagnostic_groups, diagnostic_groups_truncated) =
1547 Self::gather_nearby_diagnostics(
1548 cursor_offset,
1549 &diagnostics,
1550 &active_snapshot,
1551 options.max_diagnostic_bytes,
1552 );
1553
1554 let cloud_request = match options.context {
1555 ContextMode::Agentic(context_options) => {
1556 let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
1557 cursor_point,
1558 &active_snapshot,
1559 &context_options.excerpt,
1560 index_state.as_deref(),
1561 ) else {
1562 return Ok((None, None));
1563 };
1564
1565 let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start)
1566 ..active_snapshot.anchor_before(excerpt.range.end);
1567
1568 if let Some(buffer_ix) =
1569 included_files.iter().position(|(_, snapshot, _, _)| {
1570 snapshot.remote_id() == active_snapshot.remote_id()
1571 })
1572 {
1573 let (_, buffer, _, ranges) = &mut included_files[buffer_ix];
1574 ranges.push(excerpt_anchor_range);
1575 retrieval_search::merge_anchor_ranges(ranges, buffer);
1576 let last_ix = included_files.len() - 1;
1577 included_files.swap(buffer_ix, last_ix);
1578 } else {
1579 included_files.push((
1580 active_buffer.clone(),
1581 active_snapshot.clone(),
1582 excerpt_path.clone(),
1583 vec![excerpt_anchor_range],
1584 ));
1585 }
1586
1587 let included_files = included_files
1588 .iter()
1589 .map(|(_, snapshot, path, ranges)| {
1590 let ranges = ranges
1591 .iter()
1592 .map(|range| {
1593 let point_range = range.to_point(&snapshot);
1594 Line(point_range.start.row)..Line(point_range.end.row)
1595 })
1596 .collect::<Vec<_>>();
1597 let excerpts = assemble_excerpts(&snapshot, ranges);
1598 predict_edits_v3::IncludedFile {
1599 path: path.clone(),
1600 max_row: Line(snapshot.max_point().row),
1601 excerpts,
1602 }
1603 })
1604 .collect::<Vec<_>>();
1605
1606 predict_edits_v3::PredictEditsRequest {
1607 excerpt_path,
1608 excerpt: String::new(),
1609 excerpt_line_range: Line(0)..Line(0),
1610 excerpt_range: 0..0,
1611 cursor_point: predict_edits_v3::Point {
1612 line: predict_edits_v3::Line(cursor_point.row),
1613 column: cursor_point.column,
1614 },
1615 included_files,
1616 referenced_declarations: vec![],
1617 events,
1618 can_collect_data,
1619 diagnostic_groups,
1620 diagnostic_groups_truncated,
1621 debug_info: debug_tx.is_some(),
1622 prompt_max_bytes: Some(options.max_prompt_bytes),
1623 prompt_format: options.prompt_format,
1624 // TODO [zeta2]
1625 signatures: vec![],
1626 excerpt_parent: None,
1627 git_info: None,
1628 }
1629 }
1630 ContextMode::Syntax(context_options) => {
1631 let Some(context) = EditPredictionContext::gather_context(
1632 cursor_point,
1633 &active_snapshot,
1634 parent_abs_path.as_deref(),
1635 &context_options,
1636 index_state.as_deref(),
1637 ) else {
1638 return Ok((None, None));
1639 };
1640
1641 make_syntax_context_cloud_request(
1642 excerpt_path,
1643 context,
1644 events,
1645 can_collect_data,
1646 diagnostic_groups,
1647 diagnostic_groups_truncated,
1648 None,
1649 debug_tx.is_some(),
1650 &worktree_snapshots,
1651 index_state.as_deref(),
1652 Some(options.max_prompt_bytes),
1653 options.prompt_format,
1654 )
1655 }
1656 };
1657
1658 let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request);
1659
1660 let inputs = EditPredictionInputs {
1661 included_files: cloud_request.included_files,
1662 events: cloud_request.events,
1663 cursor_point: cloud_request.cursor_point,
1664 cursor_path: cloud_request.excerpt_path,
1665 };
1666
1667 let retrieval_time = Instant::now() - before_retrieval;
1668
1669 let debug_response_tx = if let Some(debug_tx) = &debug_tx {
1670 let (response_tx, response_rx) = oneshot::channel();
1671
1672 debug_tx
1673 .unbounded_send(ZetaDebugInfo::EditPredictionRequested(
1674 ZetaEditPredictionDebugInfo {
1675 inputs: inputs.clone(),
1676 retrieval_time,
1677 buffer: active_buffer.downgrade(),
1678 local_prompt: match prompt_result.as_ref() {
1679 Ok((prompt, _)) => Ok(prompt.clone()),
1680 Err(err) => Err(err.to_string()),
1681 },
1682 position,
1683 response_rx,
1684 },
1685 ))
1686 .ok();
1687 Some(response_tx)
1688 } else {
1689 None
1690 };
1691
1692 if cfg!(debug_assertions) && env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
1693 if let Some(debug_response_tx) = debug_response_tx {
1694 debug_response_tx
1695 .send((Err("Request skipped".to_string()), Duration::ZERO))
1696 .ok();
1697 }
1698 anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
1699 }
1700
1701 let (prompt, _) = prompt_result?;
1702 let generation_params =
1703 cloud_zeta2_prompt::generation_params(cloud_request.prompt_format);
1704 let request = open_ai::Request {
1705 model: EDIT_PREDICTIONS_MODEL_ID.clone(),
1706 messages: vec![open_ai::RequestMessage::User {
1707 content: open_ai::MessageContent::Plain(prompt),
1708 }],
1709 stream: false,
1710 max_completion_tokens: None,
1711 stop: generation_params.stop.unwrap_or_default(),
1712 temperature: generation_params.temperature.unwrap_or(0.7),
1713 tool_choice: None,
1714 parallel_tool_calls: None,
1715 tools: vec![],
1716 prompt_cache_key: None,
1717 reasoning_effort: None,
1718 };
1719
1720 log::trace!("Sending edit prediction request");
1721
1722 let before_request = Instant::now();
1723 let response = Self::send_raw_llm_request(
1724 request,
1725 client,
1726 llm_token,
1727 app_version,
1728 #[cfg(feature = "eval-support")]
1729 eval_cache,
1730 #[cfg(feature = "eval-support")]
1731 EvalCacheEntryKind::Prediction,
1732 )
1733 .await;
1734 let received_response_at = Instant::now();
1735 let request_time = received_response_at - before_request;
1736
1737 log::trace!("Got edit prediction response");
1738
1739 if let Some(debug_response_tx) = debug_response_tx {
1740 debug_response_tx
1741 .send((
1742 response
1743 .as_ref()
1744 .map_err(|err| err.to_string())
1745 .map(|response| response.0.clone()),
1746 request_time,
1747 ))
1748 .ok();
1749 }
1750
1751 let (res, usage) = response?;
1752 let request_id = EditPredictionId(res.id.clone().into());
1753 let Some(mut output_text) = text_from_response(res) else {
1754 return Ok((Some((request_id, None)), usage));
1755 };
1756
1757 if output_text.contains(CURSOR_MARKER) {
1758 log::trace!("Stripping out {CURSOR_MARKER} from response");
1759 output_text = output_text.replace(CURSOR_MARKER, "");
1760 }
1761
1762 let get_buffer_from_context = |path: &Path| {
1763 included_files
1764 .iter()
1765 .find_map(|(_, buffer, probe_path, ranges)| {
1766 if probe_path.as_ref() == path {
1767 Some((buffer, ranges.as_slice()))
1768 } else {
1769 None
1770 }
1771 })
1772 };
1773
1774 let (edited_buffer_snapshot, edits) = match options.prompt_format {
1775 PromptFormat::NumLinesUniDiff => {
1776 // TODO: Implement parsing of multi-file diffs
1777 crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
1778 }
1779 PromptFormat::Minimal
1780 | PromptFormat::MinimalQwen
1781 | PromptFormat::SeedCoder1120 => {
1782 if output_text.contains("--- a/\n+++ b/\nNo edits") {
1783 let edits = vec![];
1784 (&active_snapshot, edits)
1785 } else {
1786 crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
1787 }
1788 }
1789 PromptFormat::OldTextNewText => {
1790 crate::xml_edits::parse_xml_edits(&output_text, get_buffer_from_context)
1791 .await?
1792 }
1793 _ => {
1794 bail!("unsupported prompt format {}", options.prompt_format)
1795 }
1796 };
1797
1798 let edited_buffer = included_files
1799 .iter()
1800 .find_map(|(buffer, snapshot, _, _)| {
1801 if snapshot.remote_id() == edited_buffer_snapshot.remote_id() {
1802 Some(buffer.clone())
1803 } else {
1804 None
1805 }
1806 })
1807 .context("Failed to find buffer in included_buffers")?;
1808
1809 anyhow::Ok((
1810 Some((
1811 request_id,
1812 Some((
1813 inputs,
1814 edited_buffer,
1815 edited_buffer_snapshot.clone(),
1816 edits,
1817 received_response_at,
1818 )),
1819 )),
1820 usage,
1821 ))
1822 }
1823 });
1824
1825 cx.spawn({
1826 async move |this, cx| {
1827 let Some((id, prediction)) =
1828 Self::handle_api_response(&this, request_task.await, cx)?
1829 else {
1830 return Ok(None);
1831 };
1832
1833 let Some((
1834 inputs,
1835 edited_buffer,
1836 edited_buffer_snapshot,
1837 edits,
1838 received_response_at,
1839 )) = prediction
1840 else {
1841 return Ok(Some(EditPredictionResult {
1842 id,
1843 prediction: Err(EditPredictionRejectReason::Empty),
1844 }));
1845 };
1846
1847 // TODO telemetry: duration, etc
1848 Ok(Some(
1849 EditPredictionResult::new(
1850 id,
1851 &edited_buffer,
1852 &edited_buffer_snapshot,
1853 edits.into(),
1854 buffer_snapshotted_at,
1855 received_response_at,
1856 inputs,
1857 cx,
1858 )
1859 .await,
1860 ))
1861 }
1862 })
1863 }
1864
1865 async fn send_raw_llm_request(
1866 request: open_ai::Request,
1867 client: Arc<Client>,
1868 llm_token: LlmApiToken,
1869 app_version: Version,
1870 #[cfg(feature = "eval-support")] eval_cache: Option<Arc<dyn EvalCache>>,
1871 #[cfg(feature = "eval-support")] eval_cache_kind: EvalCacheEntryKind,
1872 ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
1873 let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
1874 http_client::Url::parse(&predict_edits_url)?
1875 } else {
1876 client
1877 .http_client()
1878 .build_zed_llm_url("/predict_edits/raw", &[])?
1879 };
1880
1881 #[cfg(feature = "eval-support")]
1882 let cache_key = if let Some(cache) = eval_cache {
1883 use collections::FxHasher;
1884 use std::hash::{Hash, Hasher};
1885
1886 let mut hasher = FxHasher::default();
1887 url.hash(&mut hasher);
1888 let request_str = serde_json::to_string_pretty(&request)?;
1889 request_str.hash(&mut hasher);
1890 let hash = hasher.finish();
1891
1892 let key = (eval_cache_kind, hash);
1893 if let Some(response_str) = cache.read(key) {
1894 return Ok((serde_json::from_str(&response_str)?, None));
1895 }
1896
1897 Some((cache, request_str, key))
1898 } else {
1899 None
1900 };
1901
1902 let (response, usage) = Self::send_api_request(
1903 |builder| {
1904 let req = builder
1905 .uri(url.as_ref())
1906 .body(serde_json::to_string(&request)?.into());
1907 Ok(req?)
1908 },
1909 client,
1910 llm_token,
1911 app_version,
1912 )
1913 .await?;
1914
1915 #[cfg(feature = "eval-support")]
1916 if let Some((cache, request, key)) = cache_key {
1917 cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
1918 }
1919
1920 Ok((response, usage))
1921 }
1922
1923 fn handle_api_response<T>(
1924 this: &WeakEntity<Self>,
1925 response: Result<(T, Option<EditPredictionUsage>)>,
1926 cx: &mut gpui::AsyncApp,
1927 ) -> Result<T> {
1928 match response {
1929 Ok((data, usage)) => {
1930 if let Some(usage) = usage {
1931 this.update(cx, |this, cx| {
1932 this.user_store.update(cx, |user_store, cx| {
1933 user_store.update_edit_prediction_usage(usage, cx);
1934 });
1935 })
1936 .ok();
1937 }
1938 Ok(data)
1939 }
1940 Err(err) => {
1941 if err.is::<ZedUpdateRequiredError>() {
1942 cx.update(|cx| {
1943 this.update(cx, |this, _cx| {
1944 this.update_required = true;
1945 })
1946 .ok();
1947
1948 let error_message: SharedString = err.to_string().into();
1949 show_app_notification(
1950 NotificationId::unique::<ZedUpdateRequiredError>(),
1951 cx,
1952 move |cx| {
1953 cx.new(|cx| {
1954 ErrorMessagePrompt::new(error_message.clone(), cx)
1955 .with_link_button("Update Zed", "https://zed.dev/releases")
1956 })
1957 },
1958 );
1959 })
1960 .ok();
1961 }
1962 Err(err)
1963 }
1964 }
1965 }
1966
1967 async fn send_api_request<Res>(
1968 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1969 client: Arc<Client>,
1970 llm_token: LlmApiToken,
1971 app_version: Version,
1972 ) -> Result<(Res, Option<EditPredictionUsage>)>
1973 where
1974 Res: DeserializeOwned,
1975 {
1976 let http_client = client.http_client();
1977 let mut token = llm_token.acquire(&client).await?;
1978 let mut did_retry = false;
1979
1980 loop {
1981 let request_builder = http_client::Request::builder().method(Method::POST);
1982
1983 let request = build(
1984 request_builder
1985 .header("Content-Type", "application/json")
1986 .header("Authorization", format!("Bearer {}", token))
1987 .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
1988 )?;
1989
1990 let mut response = http_client.send(request).await?;
1991
1992 if let Some(minimum_required_version) = response
1993 .headers()
1994 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
1995 .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
1996 {
1997 anyhow::ensure!(
1998 app_version >= minimum_required_version,
1999 ZedUpdateRequiredError {
2000 minimum_version: minimum_required_version
2001 }
2002 );
2003 }
2004
2005 if response.status().is_success() {
2006 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2007
2008 let mut body = Vec::new();
2009 response.body_mut().read_to_end(&mut body).await?;
2010 return Ok((serde_json::from_slice(&body)?, usage));
2011 } else if !did_retry
2012 && response
2013 .headers()
2014 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
2015 .is_some()
2016 {
2017 did_retry = true;
2018 token = llm_token.refresh(&client).await?;
2019 } else {
2020 let mut body = String::new();
2021 response.body_mut().read_to_string(&mut body).await?;
2022 anyhow::bail!(
2023 "Request failed with status: {:?}\nBody: {}",
2024 response.status(),
2025 body
2026 );
2027 }
2028 }
2029 }
2030
2031 pub const CONTEXT_RETRIEVAL_IDLE_DURATION: Duration = Duration::from_secs(10);
2032 pub const CONTEXT_RETRIEVAL_DEBOUNCE_DURATION: Duration = Duration::from_secs(3);
2033
2034 // Refresh the related excerpts when the user just beguns editing after
2035 // an idle period, and after they pause editing.
2036 fn refresh_context_if_needed(
2037 &mut self,
2038 project: &Entity<Project>,
2039 buffer: &Entity<language::Buffer>,
2040 cursor_position: language::Anchor,
2041 cx: &mut Context<Self>,
2042 ) {
2043 if !matches!(self.edit_prediction_model, ZetaEditPredictionModel::Zeta2) {
2044 return;
2045 }
2046
2047 if !matches!(&self.options().context, ContextMode::Agentic { .. }) {
2048 return;
2049 }
2050
2051 let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
2052 return;
2053 };
2054
2055 let now = Instant::now();
2056 let was_idle = zeta_project
2057 .refresh_context_timestamp
2058 .map_or(true, |timestamp| {
2059 now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION
2060 });
2061 zeta_project.refresh_context_timestamp = Some(now);
2062 zeta_project.refresh_context_debounce_task = Some(cx.spawn({
2063 let buffer = buffer.clone();
2064 let project = project.clone();
2065 async move |this, cx| {
2066 if was_idle {
2067 log::debug!("refetching edit prediction context after idle");
2068 } else {
2069 cx.background_executor()
2070 .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION)
2071 .await;
2072 log::debug!("refetching edit prediction context after pause");
2073 }
2074 this.update(cx, |this, cx| {
2075 let task = this.refresh_context(project.clone(), buffer, cursor_position, cx);
2076
2077 if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
2078 zeta_project.refresh_context_task = Some(task.log_err());
2079 };
2080 })
2081 .ok()
2082 }
2083 }));
2084 }
2085
2086 // Refresh the related excerpts asynchronously. Ensure the task runs to completion,
2087 // and avoid spawning more than one concurrent task.
2088 pub fn refresh_context(
2089 &mut self,
2090 project: Entity<Project>,
2091 buffer: Entity<language::Buffer>,
2092 cursor_position: language::Anchor,
2093 cx: &mut Context<Self>,
2094 ) -> Task<Result<()>> {
2095 let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
2096 return Task::ready(anyhow::Ok(()));
2097 };
2098
2099 let ContextMode::Agentic(options) = &self.options().context else {
2100 return Task::ready(anyhow::Ok(()));
2101 };
2102
2103 let snapshot = buffer.read(cx).snapshot();
2104 let cursor_point = cursor_position.to_point(&snapshot);
2105 let Some(cursor_excerpt) = EditPredictionExcerpt::select_from_buffer(
2106 cursor_point,
2107 &snapshot,
2108 &options.excerpt,
2109 None,
2110 ) else {
2111 return Task::ready(Ok(()));
2112 };
2113
2114 let app_version = AppVersion::global(cx);
2115 let client = self.client.clone();
2116 let llm_token = self.llm_token.clone();
2117 let debug_tx = self.debug_tx.clone();
2118 let current_file_path: Arc<Path> = snapshot
2119 .file()
2120 .map(|f| f.full_path(cx).into())
2121 .unwrap_or_else(|| Path::new("untitled").into());
2122
2123 let prompt = match cloud_zeta2_prompt::retrieval_prompt::build_prompt(
2124 predict_edits_v3::PlanContextRetrievalRequest {
2125 excerpt: cursor_excerpt.text(&snapshot).body,
2126 excerpt_path: current_file_path,
2127 excerpt_line_range: cursor_excerpt.line_range,
2128 cursor_file_max_row: Line(snapshot.max_point().row),
2129 events: zeta_project.events(cx),
2130 },
2131 ) {
2132 Ok(prompt) => prompt,
2133 Err(err) => {
2134 return Task::ready(Err(err));
2135 }
2136 };
2137
2138 if let Some(debug_tx) = &debug_tx {
2139 debug_tx
2140 .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted(
2141 ZetaContextRetrievalStartedDebugInfo {
2142 project: project.clone(),
2143 timestamp: Instant::now(),
2144 search_prompt: prompt.clone(),
2145 },
2146 ))
2147 .ok();
2148 }
2149
2150 pub static TOOL_SCHEMA: LazyLock<(serde_json::Value, String)> = LazyLock::new(|| {
2151 let schema = language_model::tool_schema::root_schema_for::<SearchToolInput>(
2152 language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset,
2153 );
2154
2155 let description = schema
2156 .get("description")
2157 .and_then(|description| description.as_str())
2158 .unwrap()
2159 .to_string();
2160
2161 (schema.into(), description)
2162 });
2163
2164 let (tool_schema, tool_description) = TOOL_SCHEMA.clone();
2165
2166 let request = open_ai::Request {
2167 model: CONTEXT_RETRIEVAL_MODEL_ID.clone(),
2168 messages: vec![open_ai::RequestMessage::User {
2169 content: open_ai::MessageContent::Plain(prompt),
2170 }],
2171 stream: false,
2172 max_completion_tokens: None,
2173 stop: Default::default(),
2174 temperature: 0.7,
2175 tool_choice: None,
2176 parallel_tool_calls: None,
2177 tools: vec![open_ai::ToolDefinition::Function {
2178 function: FunctionDefinition {
2179 name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME.to_string(),
2180 description: Some(tool_description),
2181 parameters: Some(tool_schema),
2182 },
2183 }],
2184 prompt_cache_key: None,
2185 reasoning_effort: None,
2186 };
2187
2188 #[cfg(feature = "eval-support")]
2189 let eval_cache = self.eval_cache.clone();
2190
2191 cx.spawn(async move |this, cx| {
2192 log::trace!("Sending search planning request");
2193 let response = Self::send_raw_llm_request(
2194 request,
2195 client,
2196 llm_token,
2197 app_version,
2198 #[cfg(feature = "eval-support")]
2199 eval_cache.clone(),
2200 #[cfg(feature = "eval-support")]
2201 EvalCacheEntryKind::Context,
2202 )
2203 .await;
2204 let mut response = Self::handle_api_response(&this, response, cx)?;
2205 log::trace!("Got search planning response");
2206
2207 let choice = response
2208 .choices
2209 .pop()
2210 .context("No choices in retrieval response")?;
2211 let open_ai::RequestMessage::Assistant {
2212 content: _,
2213 tool_calls,
2214 } = choice.message
2215 else {
2216 anyhow::bail!("Retrieval response didn't include an assistant message");
2217 };
2218
2219 let mut queries: Vec<SearchToolQuery> = Vec::new();
2220 for tool_call in tool_calls {
2221 let open_ai::ToolCallContent::Function { function } = tool_call.content;
2222 if function.name != cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME {
2223 log::warn!(
2224 "Context retrieval response tried to call an unknown tool: {}",
2225 function.name
2226 );
2227
2228 continue;
2229 }
2230
2231 let input: SearchToolInput = serde_json::from_str(&function.arguments)
2232 .with_context(|| format!("invalid search json {}", &function.arguments))?;
2233 queries.extend(input.queries);
2234 }
2235
2236 if let Some(debug_tx) = &debug_tx {
2237 debug_tx
2238 .unbounded_send(ZetaDebugInfo::SearchQueriesGenerated(
2239 ZetaSearchQueryDebugInfo {
2240 project: project.clone(),
2241 timestamp: Instant::now(),
2242 search_queries: queries.clone(),
2243 },
2244 ))
2245 .ok();
2246 }
2247
2248 log::trace!("Running retrieval search: {queries:#?}");
2249
2250 let related_excerpts_result = retrieval_search::run_retrieval_searches(
2251 queries,
2252 project.clone(),
2253 #[cfg(feature = "eval-support")]
2254 eval_cache,
2255 cx,
2256 )
2257 .await;
2258
2259 log::trace!("Search queries executed");
2260
2261 if let Some(debug_tx) = &debug_tx {
2262 debug_tx
2263 .unbounded_send(ZetaDebugInfo::SearchQueriesExecuted(
2264 ZetaContextRetrievalDebugInfo {
2265 project: project.clone(),
2266 timestamp: Instant::now(),
2267 },
2268 ))
2269 .ok();
2270 }
2271
2272 this.update(cx, |this, _cx| {
2273 let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
2274 return Ok(());
2275 };
2276 zeta_project.refresh_context_task.take();
2277 if let Some(debug_tx) = &this.debug_tx {
2278 debug_tx
2279 .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished(
2280 ZetaContextRetrievalDebugInfo {
2281 project,
2282 timestamp: Instant::now(),
2283 },
2284 ))
2285 .ok();
2286 }
2287 match related_excerpts_result {
2288 Ok(excerpts) => {
2289 zeta_project.context = Some(excerpts);
2290 Ok(())
2291 }
2292 Err(error) => Err(error),
2293 }
2294 })?
2295 })
2296 }
2297
2298 pub fn set_context(
2299 &mut self,
2300 project: Entity<Project>,
2301 context: HashMap<Entity<Buffer>, Vec<Range<Anchor>>>,
2302 ) {
2303 if let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) {
2304 zeta_project.context = Some(context);
2305 }
2306 }
2307
2308 fn gather_nearby_diagnostics(
2309 cursor_offset: usize,
2310 diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
2311 snapshot: &BufferSnapshot,
2312 max_diagnostics_bytes: usize,
2313 ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
2314 // TODO: Could make this more efficient
2315 let mut diagnostic_groups = Vec::new();
2316 for (language_server_id, diagnostics) in diagnostic_sets {
2317 let mut groups = Vec::new();
2318 diagnostics.groups(*language_server_id, &mut groups, &snapshot);
2319 diagnostic_groups.extend(
2320 groups
2321 .into_iter()
2322 .map(|(_, group)| group.resolve::<usize>(&snapshot)),
2323 );
2324 }
2325
2326 // sort by proximity to cursor
2327 diagnostic_groups.sort_by_key(|group| {
2328 let range = &group.entries[group.primary_ix].range;
2329 if range.start >= cursor_offset {
2330 range.start - cursor_offset
2331 } else if cursor_offset >= range.end {
2332 cursor_offset - range.end
2333 } else {
2334 (cursor_offset - range.start).min(range.end - cursor_offset)
2335 }
2336 });
2337
2338 let mut results = Vec::new();
2339 let mut diagnostic_groups_truncated = false;
2340 let mut diagnostics_byte_count = 0;
2341 for group in diagnostic_groups {
2342 let raw_value = serde_json::value::to_raw_value(&group).unwrap();
2343 diagnostics_byte_count += raw_value.get().len();
2344 if diagnostics_byte_count > max_diagnostics_bytes {
2345 diagnostic_groups_truncated = true;
2346 break;
2347 }
2348 results.push(predict_edits_v3::DiagnosticGroup(raw_value));
2349 }
2350
2351 (results, diagnostic_groups_truncated)
2352 }
2353
2354 // TODO: Dedupe with similar code in request_prediction?
2355 pub fn cloud_request_for_zeta_cli(
2356 &mut self,
2357 project: &Entity<Project>,
2358 buffer: &Entity<Buffer>,
2359 position: language::Anchor,
2360 cx: &mut Context<Self>,
2361 ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
2362 let project_state = self.projects.get(&project.entity_id());
2363
2364 let index_state = project_state.and_then(|state| {
2365 state
2366 .syntax_index
2367 .as_ref()
2368 .map(|index| index.read_with(cx, |index, _cx| index.state().clone()))
2369 });
2370 let options = self.options.clone();
2371 let snapshot = buffer.read(cx).snapshot();
2372 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
2373 return Task::ready(Err(anyhow!("No file path for excerpt")));
2374 };
2375 let worktree_snapshots = project
2376 .read(cx)
2377 .worktrees(cx)
2378 .map(|worktree| worktree.read(cx).snapshot())
2379 .collect::<Vec<_>>();
2380
2381 let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
2382 let mut path = f.worktree.read(cx).absolutize(&f.path);
2383 if path.pop() { Some(path) } else { None }
2384 });
2385
2386 cx.background_spawn(async move {
2387 let index_state = if let Some(index_state) = index_state {
2388 Some(index_state.lock_owned().await)
2389 } else {
2390 None
2391 };
2392
2393 let cursor_point = position.to_point(&snapshot);
2394
2395 let debug_info = true;
2396 EditPredictionContext::gather_context(
2397 cursor_point,
2398 &snapshot,
2399 parent_abs_path.as_deref(),
2400 match &options.context {
2401 ContextMode::Agentic(_) => {
2402 // TODO
2403 panic!("Llm mode not supported in zeta cli yet");
2404 }
2405 ContextMode::Syntax(edit_prediction_context_options) => {
2406 edit_prediction_context_options
2407 }
2408 },
2409 index_state.as_deref(),
2410 )
2411 .context("Failed to select excerpt")
2412 .map(|context| {
2413 make_syntax_context_cloud_request(
2414 excerpt_path.into(),
2415 context,
2416 // TODO pass everything
2417 Vec::new(),
2418 false,
2419 Vec::new(),
2420 false,
2421 None,
2422 debug_info,
2423 &worktree_snapshots,
2424 index_state.as_deref(),
2425 Some(options.max_prompt_bytes),
2426 options.prompt_format,
2427 )
2428 })
2429 })
2430 }
2431
2432 pub fn wait_for_initial_indexing(
2433 &mut self,
2434 project: &Entity<Project>,
2435 cx: &mut Context<Self>,
2436 ) -> Task<Result<()>> {
2437 let zeta_project = self.get_or_init_zeta_project(project, cx);
2438 if let Some(syntax_index) = &zeta_project.syntax_index {
2439 syntax_index.read(cx).wait_for_initial_file_indexing(cx)
2440 } else {
2441 Task::ready(Ok(()))
2442 }
2443 }
2444
2445 fn is_file_open_source(
2446 &self,
2447 project: &Entity<Project>,
2448 file: &Arc<dyn File>,
2449 cx: &App,
2450 ) -> bool {
2451 if !file.is_local() || file.is_private() {
2452 return false;
2453 }
2454 let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
2455 return false;
2456 };
2457 zeta_project
2458 .license_detection_watchers
2459 .get(&file.worktree_id(cx))
2460 .as_ref()
2461 .is_some_and(|watcher| watcher.is_project_open_source())
2462 }
2463
2464 fn can_collect_file(&self, project: &Entity<Project>, file: &Arc<dyn File>, cx: &App) -> bool {
2465 self.data_collection_choice.is_enabled() && self.is_file_open_source(project, file, cx)
2466 }
2467
2468 fn can_collect_events(&self, events: &[Arc<Event>]) -> bool {
2469 if !self.data_collection_choice.is_enabled() {
2470 return false;
2471 }
2472 events.iter().all(|event| {
2473 matches!(
2474 event.as_ref(),
2475 Event::BufferChange {
2476 in_open_source_repo: true,
2477 ..
2478 }
2479 )
2480 })
2481 }
2482
2483 fn load_data_collection_choice() -> DataCollectionChoice {
2484 let choice = KEY_VALUE_STORE
2485 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2486 .log_err()
2487 .flatten();
2488
2489 match choice.as_deref() {
2490 Some("true") => DataCollectionChoice::Enabled,
2491 Some("false") => DataCollectionChoice::Disabled,
2492 Some(_) => {
2493 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2494 DataCollectionChoice::NotAnswered
2495 }
2496 None => DataCollectionChoice::NotAnswered,
2497 }
2498 }
2499
2500 pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2501 self.shown_predictions.iter()
2502 }
2503
2504 pub fn shown_completions_len(&self) -> usize {
2505 self.shown_predictions.len()
2506 }
2507
2508 pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2509 self.rated_predictions.contains(id)
2510 }
2511
2512 pub fn rate_prediction(
2513 &mut self,
2514 prediction: &EditPrediction,
2515 rating: EditPredictionRating,
2516 feedback: String,
2517 cx: &mut Context<Self>,
2518 ) {
2519 self.rated_predictions.insert(prediction.id.clone());
2520 telemetry::event!(
2521 "Edit Prediction Rated",
2522 rating,
2523 inputs = prediction.inputs,
2524 output = prediction.edit_preview.as_unified_diff(&prediction.edits),
2525 feedback
2526 );
2527 self.client.telemetry().flush_events().detach();
2528 cx.notify();
2529 }
2530}
2531
2532pub fn text_from_response(mut res: open_ai::Response) -> Option<String> {
2533 let choice = res.choices.pop()?;
2534 let output_text = match choice.message {
2535 open_ai::RequestMessage::Assistant {
2536 content: Some(open_ai::MessageContent::Plain(content)),
2537 ..
2538 } => content,
2539 open_ai::RequestMessage::Assistant {
2540 content: Some(open_ai::MessageContent::Multipart(mut content)),
2541 ..
2542 } => {
2543 if content.is_empty() {
2544 log::error!("No output from Baseten completion response");
2545 return None;
2546 }
2547
2548 match content.remove(0) {
2549 open_ai::MessagePart::Text { text } => text,
2550 open_ai::MessagePart::Image { .. } => {
2551 log::error!("Expected text, got an image");
2552 return None;
2553 }
2554 }
2555 }
2556 _ => {
2557 log::error!("Invalid response message: {:?}", choice.message);
2558 return None;
2559 }
2560 };
2561 Some(output_text)
2562}
2563
2564#[derive(Error, Debug)]
2565#[error(
2566 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2567)]
2568pub struct ZedUpdateRequiredError {
2569 minimum_version: Version,
2570}
2571
2572fn make_syntax_context_cloud_request(
2573 excerpt_path: Arc<Path>,
2574 context: EditPredictionContext,
2575 events: Vec<Arc<predict_edits_v3::Event>>,
2576 can_collect_data: bool,
2577 diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
2578 diagnostic_groups_truncated: bool,
2579 git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
2580 debug_info: bool,
2581 worktrees: &Vec<worktree::Snapshot>,
2582 index_state: Option<&SyntaxIndexState>,
2583 prompt_max_bytes: Option<usize>,
2584 prompt_format: PromptFormat,
2585) -> predict_edits_v3::PredictEditsRequest {
2586 let mut signatures = Vec::new();
2587 let mut declaration_to_signature_index = HashMap::default();
2588 let mut referenced_declarations = Vec::new();
2589
2590 for snippet in context.declarations {
2591 let project_entry_id = snippet.declaration.project_entry_id();
2592 let Some(path) = worktrees.iter().find_map(|worktree| {
2593 worktree.entry_for_id(project_entry_id).map(|entry| {
2594 let mut full_path = RelPathBuf::new();
2595 full_path.push(worktree.root_name());
2596 full_path.push(&entry.path);
2597 full_path
2598 })
2599 }) else {
2600 continue;
2601 };
2602
2603 let parent_index = index_state.and_then(|index_state| {
2604 snippet.declaration.parent().and_then(|parent| {
2605 add_signature(
2606 parent,
2607 &mut declaration_to_signature_index,
2608 &mut signatures,
2609 index_state,
2610 )
2611 })
2612 });
2613
2614 let (text, text_is_truncated) = snippet.declaration.item_text();
2615 referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
2616 path: path.as_std_path().into(),
2617 text: text.into(),
2618 range: snippet.declaration.item_line_range(),
2619 text_is_truncated,
2620 signature_range: snippet.declaration.signature_range_in_item_text(),
2621 parent_index,
2622 signature_score: snippet.score(DeclarationStyle::Signature),
2623 declaration_score: snippet.score(DeclarationStyle::Declaration),
2624 score_components: snippet.components,
2625 });
2626 }
2627
2628 let excerpt_parent = index_state.and_then(|index_state| {
2629 context
2630 .excerpt
2631 .parent_declarations
2632 .last()
2633 .and_then(|(parent, _)| {
2634 add_signature(
2635 *parent,
2636 &mut declaration_to_signature_index,
2637 &mut signatures,
2638 index_state,
2639 )
2640 })
2641 });
2642
2643 predict_edits_v3::PredictEditsRequest {
2644 excerpt_path,
2645 excerpt: context.excerpt_text.body,
2646 excerpt_line_range: context.excerpt.line_range,
2647 excerpt_range: context.excerpt.range,
2648 cursor_point: predict_edits_v3::Point {
2649 line: predict_edits_v3::Line(context.cursor_point.row),
2650 column: context.cursor_point.column,
2651 },
2652 referenced_declarations,
2653 included_files: vec![],
2654 signatures,
2655 excerpt_parent,
2656 events,
2657 can_collect_data,
2658 diagnostic_groups,
2659 diagnostic_groups_truncated,
2660 git_info,
2661 debug_info,
2662 prompt_max_bytes,
2663 prompt_format,
2664 }
2665}
2666
2667fn add_signature(
2668 declaration_id: DeclarationId,
2669 declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
2670 signatures: &mut Vec<Signature>,
2671 index: &SyntaxIndexState,
2672) -> Option<usize> {
2673 if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
2674 return Some(*signature_index);
2675 }
2676 let Some(parent_declaration) = index.declaration(declaration_id) else {
2677 log::error!("bug: missing parent declaration");
2678 return None;
2679 };
2680 let parent_index = parent_declaration.parent().and_then(|parent| {
2681 add_signature(parent, declaration_to_signature_index, signatures, index)
2682 });
2683 let (text, text_is_truncated) = parent_declaration.signature_text();
2684 let signature_index = signatures.len();
2685 signatures.push(Signature {
2686 text: text.into(),
2687 text_is_truncated,
2688 parent_index,
2689 range: parent_declaration.signature_line_range(),
2690 });
2691 declaration_to_signature_index.insert(declaration_id, signature_index);
2692 Some(signature_index)
2693}
2694
2695#[cfg(feature = "eval-support")]
2696pub type EvalCacheKey = (EvalCacheEntryKind, u64);
2697
2698#[cfg(feature = "eval-support")]
2699#[derive(Debug, Clone, Copy, PartialEq)]
2700pub enum EvalCacheEntryKind {
2701 Context,
2702 Search,
2703 Prediction,
2704}
2705
2706#[cfg(feature = "eval-support")]
2707impl std::fmt::Display for EvalCacheEntryKind {
2708 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2709 match self {
2710 EvalCacheEntryKind::Search => write!(f, "search"),
2711 EvalCacheEntryKind::Context => write!(f, "context"),
2712 EvalCacheEntryKind::Prediction => write!(f, "prediction"),
2713 }
2714 }
2715}
2716
2717#[cfg(feature = "eval-support")]
2718pub trait EvalCache: Send + Sync {
2719 fn read(&self, key: EvalCacheKey) -> Option<String>;
2720 fn write(&self, key: EvalCacheKey, input: &str, value: &str);
2721}
2722
2723#[derive(Debug, Clone, Copy)]
2724pub enum DataCollectionChoice {
2725 NotAnswered,
2726 Enabled,
2727 Disabled,
2728}
2729
2730impl DataCollectionChoice {
2731 pub fn is_enabled(self) -> bool {
2732 match self {
2733 Self::Enabled => true,
2734 Self::NotAnswered | Self::Disabled => false,
2735 }
2736 }
2737
2738 pub fn is_answered(self) -> bool {
2739 match self {
2740 Self::Enabled | Self::Disabled => true,
2741 Self::NotAnswered => false,
2742 }
2743 }
2744
2745 #[must_use]
2746 pub fn toggle(&self) -> DataCollectionChoice {
2747 match self {
2748 Self::Enabled => Self::Disabled,
2749 Self::Disabled => Self::Enabled,
2750 Self::NotAnswered => Self::Enabled,
2751 }
2752 }
2753}
2754
2755impl From<bool> for DataCollectionChoice {
2756 fn from(value: bool) -> Self {
2757 match value {
2758 true => DataCollectionChoice::Enabled,
2759 false => DataCollectionChoice::Disabled,
2760 }
2761 }
2762}
2763
2764struct ZedPredictUpsell;
2765
2766impl Dismissable for ZedPredictUpsell {
2767 const KEY: &'static str = "dismissed-edit-predict-upsell";
2768
2769 fn dismissed() -> bool {
2770 // To make this backwards compatible with older versions of Zed, we
2771 // check if the user has seen the previous Edit Prediction Onboarding
2772 // before, by checking the data collection choice which was written to
2773 // the database once the user clicked on "Accept and Enable"
2774 if KEY_VALUE_STORE
2775 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2776 .log_err()
2777 .is_some_and(|s| s.is_some())
2778 {
2779 return true;
2780 }
2781
2782 KEY_VALUE_STORE
2783 .read_kvp(Self::KEY)
2784 .log_err()
2785 .is_some_and(|s| s.is_some())
2786 }
2787}
2788
2789pub fn should_show_upsell_modal() -> bool {
2790 !ZedPredictUpsell::dismissed()
2791}
2792
2793pub fn init(cx: &mut App) {
2794 feature_gate_predict_edits_actions(cx);
2795
2796 cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2797 workspace.register_action(|workspace, _: &RateCompletions, window, cx| {
2798 if cx.has_flag::<PredictEditsRateCompletionsFeatureFlag>() {
2799 RatePredictionsModal::toggle(workspace, window, cx);
2800 }
2801 });
2802
2803 workspace.register_action(
2804 move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2805 ZedPredictModal::toggle(
2806 workspace,
2807 workspace.user_store().clone(),
2808 workspace.client().clone(),
2809 window,
2810 cx,
2811 )
2812 },
2813 );
2814
2815 workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2816 update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2817 settings
2818 .project
2819 .all_languages
2820 .features
2821 .get_or_insert_default()
2822 .edit_prediction_provider = Some(EditPredictionProvider::None)
2823 });
2824 });
2825 })
2826 .detach();
2827}
2828
2829fn feature_gate_predict_edits_actions(cx: &mut App) {
2830 let rate_completion_action_types = [TypeId::of::<RateCompletions>()];
2831 let reset_onboarding_action_types = [TypeId::of::<ResetOnboarding>()];
2832 let zeta_all_action_types = [
2833 TypeId::of::<RateCompletions>(),
2834 TypeId::of::<ResetOnboarding>(),
2835 zed_actions::OpenZedPredictOnboarding.type_id(),
2836 TypeId::of::<ClearHistory>(),
2837 TypeId::of::<ThumbsUpActivePrediction>(),
2838 TypeId::of::<ThumbsDownActivePrediction>(),
2839 TypeId::of::<NextEdit>(),
2840 TypeId::of::<PreviousEdit>(),
2841 ];
2842
2843 CommandPaletteFilter::update_global(cx, |filter, _cx| {
2844 filter.hide_action_types(&rate_completion_action_types);
2845 filter.hide_action_types(&reset_onboarding_action_types);
2846 filter.hide_action_types(&[zed_actions::OpenZedPredictOnboarding.type_id()]);
2847 });
2848
2849 cx.observe_global::<SettingsStore>(move |cx| {
2850 let is_ai_disabled = DisableAiSettings::get_global(cx).disable_ai;
2851 let has_feature_flag = cx.has_flag::<PredictEditsRateCompletionsFeatureFlag>();
2852
2853 CommandPaletteFilter::update_global(cx, |filter, _cx| {
2854 if is_ai_disabled {
2855 filter.hide_action_types(&zeta_all_action_types);
2856 } else if has_feature_flag {
2857 filter.show_action_types(&rate_completion_action_types);
2858 } else {
2859 filter.hide_action_types(&rate_completion_action_types);
2860 }
2861 });
2862 })
2863 .detach();
2864
2865 cx.observe_flag::<PredictEditsRateCompletionsFeatureFlag, _>(move |is_enabled, cx| {
2866 if !DisableAiSettings::get_global(cx).disable_ai {
2867 if is_enabled {
2868 CommandPaletteFilter::update_global(cx, |filter, _cx| {
2869 filter.show_action_types(&rate_completion_action_types);
2870 });
2871 } else {
2872 CommandPaletteFilter::update_global(cx, |filter, _cx| {
2873 filter.hide_action_types(&rate_completion_action_types);
2874 });
2875 }
2876 }
2877 })
2878 .detach();
2879}
2880
2881#[cfg(test)]
2882mod tests {
2883 use std::{path::Path, sync::Arc, time::Duration};
2884
2885 use client::UserStore;
2886 use clock::FakeSystemClock;
2887 use cloud_llm_client::{
2888 EditPredictionRejectReason, EditPredictionRejection, RejectEditPredictionsBody,
2889 };
2890 use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
2891 use futures::{
2892 AsyncReadExt, StreamExt,
2893 channel::{mpsc, oneshot},
2894 };
2895 use gpui::{
2896 Entity, TestAppContext,
2897 http_client::{FakeHttpClient, Response},
2898 prelude::*,
2899 };
2900 use indoc::indoc;
2901 use language::OffsetRangeExt as _;
2902 use open_ai::Usage;
2903 use pretty_assertions::{assert_eq, assert_matches};
2904 use project::{FakeFs, Project};
2905 use serde_json::json;
2906 use settings::SettingsStore;
2907 use util::path;
2908 use uuid::Uuid;
2909
2910 use crate::{BufferEditPrediction, EditPredictionId, REJECT_REQUEST_DEBOUNCE, Zeta};
2911
2912 #[gpui::test]
2913 async fn test_current_state(cx: &mut TestAppContext) {
2914 let (zeta, mut requests) = init_test(cx);
2915 let fs = FakeFs::new(cx.executor());
2916 fs.insert_tree(
2917 "/root",
2918 json!({
2919 "1.txt": "Hello!\nHow\nBye\n",
2920 "2.txt": "Hola!\nComo\nAdios\n"
2921 }),
2922 )
2923 .await;
2924 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2925
2926 zeta.update(cx, |zeta, cx| {
2927 zeta.register_project(&project, cx);
2928 });
2929
2930 let buffer1 = project
2931 .update(cx, |project, cx| {
2932 let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
2933 project.open_buffer(path, cx)
2934 })
2935 .await
2936 .unwrap();
2937 let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
2938 let position = snapshot1.anchor_before(language::Point::new(1, 3));
2939
2940 // Prediction for current file
2941
2942 zeta.update(cx, |zeta, cx| {
2943 zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
2944 });
2945 let (_request, respond_tx) = requests.predict.next().await.unwrap();
2946
2947 respond_tx
2948 .send(model_response(indoc! {r"
2949 --- a/root/1.txt
2950 +++ b/root/1.txt
2951 @@ ... @@
2952 Hello!
2953 -How
2954 +How are you?
2955 Bye
2956 "}))
2957 .unwrap();
2958
2959 cx.run_until_parked();
2960
2961 zeta.read_with(cx, |zeta, cx| {
2962 let prediction = zeta
2963 .current_prediction_for_buffer(&buffer1, &project, cx)
2964 .unwrap();
2965 assert_matches!(prediction, BufferEditPrediction::Local { .. });
2966 });
2967
2968 // Context refresh
2969 let refresh_task = zeta.update(cx, |zeta, cx| {
2970 zeta.refresh_context(project.clone(), buffer1.clone(), position, cx)
2971 });
2972 let (_request, respond_tx) = requests.predict.next().await.unwrap();
2973 respond_tx
2974 .send(open_ai::Response {
2975 id: Uuid::new_v4().to_string(),
2976 object: "response".into(),
2977 created: 0,
2978 model: "model".into(),
2979 choices: vec![open_ai::Choice {
2980 index: 0,
2981 message: open_ai::RequestMessage::Assistant {
2982 content: None,
2983 tool_calls: vec![open_ai::ToolCall {
2984 id: "search".into(),
2985 content: open_ai::ToolCallContent::Function {
2986 function: open_ai::FunctionContent {
2987 name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME
2988 .to_string(),
2989 arguments: serde_json::to_string(&SearchToolInput {
2990 queries: Box::new([SearchToolQuery {
2991 glob: "root/2.txt".to_string(),
2992 syntax_node: vec![],
2993 content: Some(".".into()),
2994 }]),
2995 })
2996 .unwrap(),
2997 },
2998 },
2999 }],
3000 },
3001 finish_reason: None,
3002 }],
3003 usage: Usage {
3004 prompt_tokens: 0,
3005 completion_tokens: 0,
3006 total_tokens: 0,
3007 },
3008 })
3009 .unwrap();
3010 refresh_task.await.unwrap();
3011
3012 zeta.update(cx, |zeta, _cx| {
3013 zeta.reject_current_prediction(EditPredictionRejectReason::Discarded, &project);
3014 });
3015
3016 // Prediction for another file
3017 zeta.update(cx, |zeta, cx| {
3018 zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
3019 });
3020 let (_request, respond_tx) = requests.predict.next().await.unwrap();
3021 respond_tx
3022 .send(model_response(indoc! {r#"
3023 --- a/root/2.txt
3024 +++ b/root/2.txt
3025 Hola!
3026 -Como
3027 +Como estas?
3028 Adios
3029 "#}))
3030 .unwrap();
3031 cx.run_until_parked();
3032
3033 zeta.read_with(cx, |zeta, cx| {
3034 let prediction = zeta
3035 .current_prediction_for_buffer(&buffer1, &project, cx)
3036 .unwrap();
3037 assert_matches!(
3038 prediction,
3039 BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
3040 );
3041 });
3042
3043 let buffer2 = project
3044 .update(cx, |project, cx| {
3045 let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
3046 project.open_buffer(path, cx)
3047 })
3048 .await
3049 .unwrap();
3050
3051 zeta.read_with(cx, |zeta, cx| {
3052 let prediction = zeta
3053 .current_prediction_for_buffer(&buffer2, &project, cx)
3054 .unwrap();
3055 assert_matches!(prediction, BufferEditPrediction::Local { .. });
3056 });
3057 }
3058
3059 #[gpui::test]
3060 async fn test_simple_request(cx: &mut TestAppContext) {
3061 let (zeta, mut requests) = init_test(cx);
3062 let fs = FakeFs::new(cx.executor());
3063 fs.insert_tree(
3064 "/root",
3065 json!({
3066 "foo.md": "Hello!\nHow\nBye\n"
3067 }),
3068 )
3069 .await;
3070 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3071
3072 let buffer = project
3073 .update(cx, |project, cx| {
3074 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3075 project.open_buffer(path, cx)
3076 })
3077 .await
3078 .unwrap();
3079 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3080 let position = snapshot.anchor_before(language::Point::new(1, 3));
3081
3082 let prediction_task = zeta.update(cx, |zeta, cx| {
3083 zeta.request_prediction(&project, &buffer, position, cx)
3084 });
3085
3086 let (_, respond_tx) = requests.predict.next().await.unwrap();
3087
3088 // TODO Put back when we have a structured request again
3089 // assert_eq!(
3090 // request.excerpt_path.as_ref(),
3091 // Path::new(path!("root/foo.md"))
3092 // );
3093 // assert_eq!(
3094 // request.cursor_point,
3095 // Point {
3096 // line: Line(1),
3097 // column: 3
3098 // }
3099 // );
3100
3101 respond_tx
3102 .send(model_response(indoc! { r"
3103 --- a/root/foo.md
3104 +++ b/root/foo.md
3105 @@ ... @@
3106 Hello!
3107 -How
3108 +How are you?
3109 Bye
3110 "}))
3111 .unwrap();
3112
3113 let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
3114
3115 assert_eq!(prediction.edits.len(), 1);
3116 assert_eq!(
3117 prediction.edits[0].0.to_point(&snapshot).start,
3118 language::Point::new(1, 3)
3119 );
3120 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
3121 }
3122
3123 #[gpui::test]
3124 async fn test_request_events(cx: &mut TestAppContext) {
3125 let (zeta, mut requests) = init_test(cx);
3126 let fs = FakeFs::new(cx.executor());
3127 fs.insert_tree(
3128 "/root",
3129 json!({
3130 "foo.md": "Hello!\n\nBye\n"
3131 }),
3132 )
3133 .await;
3134 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3135
3136 let buffer = project
3137 .update(cx, |project, cx| {
3138 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3139 project.open_buffer(path, cx)
3140 })
3141 .await
3142 .unwrap();
3143
3144 zeta.update(cx, |zeta, cx| {
3145 zeta.register_buffer(&buffer, &project, cx);
3146 });
3147
3148 buffer.update(cx, |buffer, cx| {
3149 buffer.edit(vec![(7..7, "How")], None, cx);
3150 });
3151
3152 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3153 let position = snapshot.anchor_before(language::Point::new(1, 3));
3154
3155 let prediction_task = zeta.update(cx, |zeta, cx| {
3156 zeta.request_prediction(&project, &buffer, position, cx)
3157 });
3158
3159 let (request, respond_tx) = requests.predict.next().await.unwrap();
3160
3161 let prompt = prompt_from_request(&request);
3162 assert!(
3163 prompt.contains(indoc! {"
3164 --- a/root/foo.md
3165 +++ b/root/foo.md
3166 @@ -1,3 +1,3 @@
3167 Hello!
3168 -
3169 +How
3170 Bye
3171 "}),
3172 "{prompt}"
3173 );
3174
3175 respond_tx
3176 .send(model_response(indoc! {r#"
3177 --- a/root/foo.md
3178 +++ b/root/foo.md
3179 @@ ... @@
3180 Hello!
3181 -How
3182 +How are you?
3183 Bye
3184 "#}))
3185 .unwrap();
3186
3187 let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
3188
3189 assert_eq!(prediction.edits.len(), 1);
3190 assert_eq!(
3191 prediction.edits[0].0.to_point(&snapshot).start,
3192 language::Point::new(1, 3)
3193 );
3194 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
3195 }
3196
3197 #[gpui::test]
3198 async fn test_empty_prediction(cx: &mut TestAppContext) {
3199 let (zeta, mut requests) = init_test(cx);
3200 let fs = FakeFs::new(cx.executor());
3201 fs.insert_tree(
3202 "/root",
3203 json!({
3204 "foo.md": "Hello!\nHow\nBye\n"
3205 }),
3206 )
3207 .await;
3208 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3209
3210 let buffer = project
3211 .update(cx, |project, cx| {
3212 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3213 project.open_buffer(path, cx)
3214 })
3215 .await
3216 .unwrap();
3217 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3218 let position = snapshot.anchor_before(language::Point::new(1, 3));
3219
3220 zeta.update(cx, |zeta, cx| {
3221 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3222 });
3223
3224 const NO_OP_DIFF: &str = indoc! { r"
3225 --- a/root/foo.md
3226 +++ b/root/foo.md
3227 @@ ... @@
3228 Hello!
3229 -How
3230 +How
3231 Bye
3232 "};
3233
3234 let (_, respond_tx) = requests.predict.next().await.unwrap();
3235 let response = model_response(NO_OP_DIFF);
3236 let id = response.id.clone();
3237 respond_tx.send(response).unwrap();
3238
3239 cx.run_until_parked();
3240
3241 zeta.read_with(cx, |zeta, cx| {
3242 assert!(
3243 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3244 .is_none()
3245 );
3246 });
3247
3248 // prediction is reported as rejected
3249 let (reject_request, _) = requests.reject.next().await.unwrap();
3250
3251 assert_eq!(
3252 &reject_request.rejections,
3253 &[EditPredictionRejection {
3254 request_id: id,
3255 reason: EditPredictionRejectReason::Empty,
3256 was_shown: false
3257 }]
3258 );
3259 }
3260
3261 #[gpui::test]
3262 async fn test_interpolated_empty(cx: &mut TestAppContext) {
3263 let (zeta, mut requests) = init_test(cx);
3264 let fs = FakeFs::new(cx.executor());
3265 fs.insert_tree(
3266 "/root",
3267 json!({
3268 "foo.md": "Hello!\nHow\nBye\n"
3269 }),
3270 )
3271 .await;
3272 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3273
3274 let buffer = project
3275 .update(cx, |project, cx| {
3276 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3277 project.open_buffer(path, cx)
3278 })
3279 .await
3280 .unwrap();
3281 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3282 let position = snapshot.anchor_before(language::Point::new(1, 3));
3283
3284 zeta.update(cx, |zeta, cx| {
3285 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3286 });
3287
3288 let (_, respond_tx) = requests.predict.next().await.unwrap();
3289
3290 buffer.update(cx, |buffer, cx| {
3291 buffer.set_text("Hello!\nHow are you?\nBye", cx);
3292 });
3293
3294 let response = model_response(SIMPLE_DIFF);
3295 let id = response.id.clone();
3296 respond_tx.send(response).unwrap();
3297
3298 cx.run_until_parked();
3299
3300 zeta.read_with(cx, |zeta, cx| {
3301 assert!(
3302 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3303 .is_none()
3304 );
3305 });
3306
3307 // prediction is reported as rejected
3308 let (reject_request, _) = requests.reject.next().await.unwrap();
3309
3310 assert_eq!(
3311 &reject_request.rejections,
3312 &[EditPredictionRejection {
3313 request_id: id,
3314 reason: EditPredictionRejectReason::InterpolatedEmpty,
3315 was_shown: false
3316 }]
3317 );
3318 }
3319
3320 const SIMPLE_DIFF: &str = indoc! { r"
3321 --- a/root/foo.md
3322 +++ b/root/foo.md
3323 @@ ... @@
3324 Hello!
3325 -How
3326 +How are you?
3327 Bye
3328 "};
3329
3330 #[gpui::test]
3331 async fn test_replace_current(cx: &mut TestAppContext) {
3332 let (zeta, mut requests) = init_test(cx);
3333 let fs = FakeFs::new(cx.executor());
3334 fs.insert_tree(
3335 "/root",
3336 json!({
3337 "foo.md": "Hello!\nHow\nBye\n"
3338 }),
3339 )
3340 .await;
3341 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3342
3343 let buffer = project
3344 .update(cx, |project, cx| {
3345 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3346 project.open_buffer(path, cx)
3347 })
3348 .await
3349 .unwrap();
3350 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3351 let position = snapshot.anchor_before(language::Point::new(1, 3));
3352
3353 zeta.update(cx, |zeta, cx| {
3354 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3355 });
3356
3357 let (_, respond_tx) = requests.predict.next().await.unwrap();
3358 let first_response = model_response(SIMPLE_DIFF);
3359 let first_id = first_response.id.clone();
3360 respond_tx.send(first_response).unwrap();
3361
3362 cx.run_until_parked();
3363
3364 zeta.read_with(cx, |zeta, cx| {
3365 assert_eq!(
3366 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3367 .unwrap()
3368 .id
3369 .0,
3370 first_id
3371 );
3372 });
3373
3374 // a second request is triggered
3375 zeta.update(cx, |zeta, cx| {
3376 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3377 });
3378
3379 let (_, respond_tx) = requests.predict.next().await.unwrap();
3380 let second_response = model_response(SIMPLE_DIFF);
3381 let second_id = second_response.id.clone();
3382 respond_tx.send(second_response).unwrap();
3383
3384 cx.run_until_parked();
3385
3386 zeta.read_with(cx, |zeta, cx| {
3387 // second replaces first
3388 assert_eq!(
3389 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3390 .unwrap()
3391 .id
3392 .0,
3393 second_id
3394 );
3395 });
3396
3397 // first is reported as replaced
3398 let (reject_request, _) = requests.reject.next().await.unwrap();
3399
3400 assert_eq!(
3401 &reject_request.rejections,
3402 &[EditPredictionRejection {
3403 request_id: first_id,
3404 reason: EditPredictionRejectReason::Replaced,
3405 was_shown: false
3406 }]
3407 );
3408 }
3409
3410 #[gpui::test]
3411 async fn test_current_preferred(cx: &mut TestAppContext) {
3412 let (zeta, mut requests) = init_test(cx);
3413 let fs = FakeFs::new(cx.executor());
3414 fs.insert_tree(
3415 "/root",
3416 json!({
3417 "foo.md": "Hello!\nHow\nBye\n"
3418 }),
3419 )
3420 .await;
3421 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3422
3423 let buffer = project
3424 .update(cx, |project, cx| {
3425 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3426 project.open_buffer(path, cx)
3427 })
3428 .await
3429 .unwrap();
3430 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3431 let position = snapshot.anchor_before(language::Point::new(1, 3));
3432
3433 zeta.update(cx, |zeta, cx| {
3434 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3435 });
3436
3437 let (_, respond_tx) = requests.predict.next().await.unwrap();
3438 let first_response = model_response(SIMPLE_DIFF);
3439 let first_id = first_response.id.clone();
3440 respond_tx.send(first_response).unwrap();
3441
3442 cx.run_until_parked();
3443
3444 zeta.read_with(cx, |zeta, cx| {
3445 assert_eq!(
3446 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3447 .unwrap()
3448 .id
3449 .0,
3450 first_id
3451 );
3452 });
3453
3454 // a second request is triggered
3455 zeta.update(cx, |zeta, cx| {
3456 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3457 });
3458
3459 let (_, respond_tx) = requests.predict.next().await.unwrap();
3460 // worse than current prediction
3461 let second_response = model_response(indoc! { r"
3462 --- a/root/foo.md
3463 +++ b/root/foo.md
3464 @@ ... @@
3465 Hello!
3466 -How
3467 +How are
3468 Bye
3469 "});
3470 let second_id = second_response.id.clone();
3471 respond_tx.send(second_response).unwrap();
3472
3473 cx.run_until_parked();
3474
3475 zeta.read_with(cx, |zeta, cx| {
3476 // first is preferred over second
3477 assert_eq!(
3478 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3479 .unwrap()
3480 .id
3481 .0,
3482 first_id
3483 );
3484 });
3485
3486 // second is reported as rejected
3487 let (reject_request, _) = requests.reject.next().await.unwrap();
3488
3489 assert_eq!(
3490 &reject_request.rejections,
3491 &[EditPredictionRejection {
3492 request_id: second_id,
3493 reason: EditPredictionRejectReason::CurrentPreferred,
3494 was_shown: false
3495 }]
3496 );
3497 }
3498
3499 #[gpui::test]
3500 async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
3501 let (zeta, mut requests) = init_test(cx);
3502 let fs = FakeFs::new(cx.executor());
3503 fs.insert_tree(
3504 "/root",
3505 json!({
3506 "foo.md": "Hello!\nHow\nBye\n"
3507 }),
3508 )
3509 .await;
3510 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3511
3512 let buffer = project
3513 .update(cx, |project, cx| {
3514 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3515 project.open_buffer(path, cx)
3516 })
3517 .await
3518 .unwrap();
3519 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3520 let position = snapshot.anchor_before(language::Point::new(1, 3));
3521
3522 // start two refresh tasks
3523 zeta.update(cx, |zeta, cx| {
3524 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3525 });
3526
3527 let (_, respond_first) = requests.predict.next().await.unwrap();
3528
3529 zeta.update(cx, |zeta, cx| {
3530 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3531 });
3532
3533 let (_, respond_second) = requests.predict.next().await.unwrap();
3534
3535 // wait for throttle
3536 cx.run_until_parked();
3537
3538 // second responds first
3539 let second_response = model_response(SIMPLE_DIFF);
3540 let second_id = second_response.id.clone();
3541 respond_second.send(second_response).unwrap();
3542
3543 cx.run_until_parked();
3544
3545 zeta.read_with(cx, |zeta, cx| {
3546 // current prediction is second
3547 assert_eq!(
3548 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3549 .unwrap()
3550 .id
3551 .0,
3552 second_id
3553 );
3554 });
3555
3556 let first_response = model_response(SIMPLE_DIFF);
3557 let first_id = first_response.id.clone();
3558 respond_first.send(first_response).unwrap();
3559
3560 cx.run_until_parked();
3561
3562 zeta.read_with(cx, |zeta, cx| {
3563 // current prediction is still second, since first was cancelled
3564 assert_eq!(
3565 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3566 .unwrap()
3567 .id
3568 .0,
3569 second_id
3570 );
3571 });
3572
3573 // first is reported as rejected
3574 let (reject_request, _) = requests.reject.next().await.unwrap();
3575
3576 cx.run_until_parked();
3577
3578 assert_eq!(
3579 &reject_request.rejections,
3580 &[EditPredictionRejection {
3581 request_id: first_id,
3582 reason: EditPredictionRejectReason::Canceled,
3583 was_shown: false
3584 }]
3585 );
3586 }
3587
3588 #[gpui::test]
3589 async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
3590 let (zeta, mut requests) = init_test(cx);
3591 let fs = FakeFs::new(cx.executor());
3592 fs.insert_tree(
3593 "/root",
3594 json!({
3595 "foo.md": "Hello!\nHow\nBye\n"
3596 }),
3597 )
3598 .await;
3599 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3600
3601 let buffer = project
3602 .update(cx, |project, cx| {
3603 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3604 project.open_buffer(path, cx)
3605 })
3606 .await
3607 .unwrap();
3608 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3609 let position = snapshot.anchor_before(language::Point::new(1, 3));
3610
3611 // start two refresh tasks
3612 zeta.update(cx, |zeta, cx| {
3613 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3614 });
3615
3616 let (_, respond_first) = requests.predict.next().await.unwrap();
3617
3618 zeta.update(cx, |zeta, cx| {
3619 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3620 });
3621
3622 let (_, respond_second) = requests.predict.next().await.unwrap();
3623
3624 // wait for throttle, so requests are sent
3625 cx.run_until_parked();
3626
3627 zeta.update(cx, |zeta, cx| {
3628 // start a third request
3629 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3630
3631 // 2 are pending, so 2nd is cancelled
3632 assert_eq!(
3633 zeta.get_or_init_zeta_project(&project, cx)
3634 .cancelled_predictions
3635 .iter()
3636 .copied()
3637 .collect::<Vec<_>>(),
3638 [1]
3639 );
3640 });
3641
3642 // wait for throttle
3643 cx.run_until_parked();
3644
3645 let (_, respond_third) = requests.predict.next().await.unwrap();
3646
3647 let first_response = model_response(SIMPLE_DIFF);
3648 let first_id = first_response.id.clone();
3649 respond_first.send(first_response).unwrap();
3650
3651 cx.run_until_parked();
3652
3653 zeta.read_with(cx, |zeta, cx| {
3654 // current prediction is first
3655 assert_eq!(
3656 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3657 .unwrap()
3658 .id
3659 .0,
3660 first_id
3661 );
3662 });
3663
3664 let cancelled_response = model_response(SIMPLE_DIFF);
3665 let cancelled_id = cancelled_response.id.clone();
3666 respond_second.send(cancelled_response).unwrap();
3667
3668 cx.run_until_parked();
3669
3670 zeta.read_with(cx, |zeta, cx| {
3671 // current prediction is still first, since second was cancelled
3672 assert_eq!(
3673 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3674 .unwrap()
3675 .id
3676 .0,
3677 first_id
3678 );
3679 });
3680
3681 let third_response = model_response(SIMPLE_DIFF);
3682 let third_response_id = third_response.id.clone();
3683 respond_third.send(third_response).unwrap();
3684
3685 cx.run_until_parked();
3686
3687 zeta.read_with(cx, |zeta, cx| {
3688 // third completes and replaces first
3689 assert_eq!(
3690 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3691 .unwrap()
3692 .id
3693 .0,
3694 third_response_id
3695 );
3696 });
3697
3698 // second is reported as rejected
3699 let (reject_request, _) = requests.reject.next().await.unwrap();
3700
3701 cx.run_until_parked();
3702
3703 assert_eq!(
3704 &reject_request.rejections,
3705 &[
3706 EditPredictionRejection {
3707 request_id: cancelled_id,
3708 reason: EditPredictionRejectReason::Canceled,
3709 was_shown: false
3710 },
3711 EditPredictionRejection {
3712 request_id: first_id,
3713 reason: EditPredictionRejectReason::Replaced,
3714 was_shown: false
3715 }
3716 ]
3717 );
3718 }
3719
3720 #[gpui::test]
3721 async fn test_rejections_flushing(cx: &mut TestAppContext) {
3722 let (zeta, mut requests) = init_test(cx);
3723
3724 zeta.update(cx, |zeta, _cx| {
3725 zeta.reject_prediction(
3726 EditPredictionId("test-1".into()),
3727 EditPredictionRejectReason::Discarded,
3728 false,
3729 );
3730 zeta.reject_prediction(
3731 EditPredictionId("test-2".into()),
3732 EditPredictionRejectReason::Canceled,
3733 true,
3734 );
3735 });
3736
3737 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
3738 cx.run_until_parked();
3739
3740 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
3741 respond_tx.send(()).unwrap();
3742
3743 // batched
3744 assert_eq!(reject_request.rejections.len(), 2);
3745 assert_eq!(
3746 reject_request.rejections[0],
3747 EditPredictionRejection {
3748 request_id: "test-1".to_string(),
3749 reason: EditPredictionRejectReason::Discarded,
3750 was_shown: false
3751 }
3752 );
3753 assert_eq!(
3754 reject_request.rejections[1],
3755 EditPredictionRejection {
3756 request_id: "test-2".to_string(),
3757 reason: EditPredictionRejectReason::Canceled,
3758 was_shown: true
3759 }
3760 );
3761
3762 // Reaching batch size limit sends without debounce
3763 zeta.update(cx, |zeta, _cx| {
3764 for i in 0..70 {
3765 zeta.reject_prediction(
3766 EditPredictionId(format!("batch-{}", i).into()),
3767 EditPredictionRejectReason::Discarded,
3768 false,
3769 );
3770 }
3771 });
3772
3773 // First MAX/2 items are sent immediately
3774 cx.run_until_parked();
3775 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
3776 respond_tx.send(()).unwrap();
3777
3778 assert_eq!(reject_request.rejections.len(), 50);
3779 assert_eq!(reject_request.rejections[0].request_id, "batch-0");
3780 assert_eq!(reject_request.rejections[49].request_id, "batch-49");
3781
3782 // Remaining items are debounced with the next batch
3783 cx.executor().advance_clock(Duration::from_secs(15));
3784 cx.run_until_parked();
3785
3786 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
3787 respond_tx.send(()).unwrap();
3788
3789 assert_eq!(reject_request.rejections.len(), 20);
3790 assert_eq!(reject_request.rejections[0].request_id, "batch-50");
3791 assert_eq!(reject_request.rejections[19].request_id, "batch-69");
3792
3793 // Request failure
3794 zeta.update(cx, |zeta, _cx| {
3795 zeta.reject_prediction(
3796 EditPredictionId("retry-1".into()),
3797 EditPredictionRejectReason::Discarded,
3798 false,
3799 );
3800 });
3801
3802 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
3803 cx.run_until_parked();
3804
3805 let (reject_request, _respond_tx) = requests.reject.next().await.unwrap();
3806 assert_eq!(reject_request.rejections.len(), 1);
3807 assert_eq!(reject_request.rejections[0].request_id, "retry-1");
3808 // Simulate failure
3809 drop(_respond_tx);
3810
3811 // Add another rejection
3812 zeta.update(cx, |zeta, _cx| {
3813 zeta.reject_prediction(
3814 EditPredictionId("retry-2".into()),
3815 EditPredictionRejectReason::Discarded,
3816 false,
3817 );
3818 });
3819
3820 cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
3821 cx.run_until_parked();
3822
3823 // Retry should include both the failed item and the new one
3824 let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
3825 respond_tx.send(()).unwrap();
3826
3827 assert_eq!(reject_request.rejections.len(), 2);
3828 assert_eq!(reject_request.rejections[0].request_id, "retry-1");
3829 assert_eq!(reject_request.rejections[1].request_id, "retry-2");
3830 }
3831
3832 // Skipped until we start including diagnostics in prompt
3833 // #[gpui::test]
3834 // async fn test_request_diagnostics(cx: &mut TestAppContext) {
3835 // let (zeta, mut req_rx) = init_test(cx);
3836 // let fs = FakeFs::new(cx.executor());
3837 // fs.insert_tree(
3838 // "/root",
3839 // json!({
3840 // "foo.md": "Hello!\nBye"
3841 // }),
3842 // )
3843 // .await;
3844 // let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3845
3846 // let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
3847 // let diagnostic = lsp::Diagnostic {
3848 // range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
3849 // severity: Some(lsp::DiagnosticSeverity::ERROR),
3850 // message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
3851 // ..Default::default()
3852 // };
3853
3854 // project.update(cx, |project, cx| {
3855 // project.lsp_store().update(cx, |lsp_store, cx| {
3856 // // Create some diagnostics
3857 // lsp_store
3858 // .update_diagnostics(
3859 // LanguageServerId(0),
3860 // lsp::PublishDiagnosticsParams {
3861 // uri: path_to_buffer_uri.clone(),
3862 // diagnostics: vec![diagnostic],
3863 // version: None,
3864 // },
3865 // None,
3866 // language::DiagnosticSourceKind::Pushed,
3867 // &[],
3868 // cx,
3869 // )
3870 // .unwrap();
3871 // });
3872 // });
3873
3874 // let buffer = project
3875 // .update(cx, |project, cx| {
3876 // let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3877 // project.open_buffer(path, cx)
3878 // })
3879 // .await
3880 // .unwrap();
3881
3882 // let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3883 // let position = snapshot.anchor_before(language::Point::new(0, 0));
3884
3885 // let _prediction_task = zeta.update(cx, |zeta, cx| {
3886 // zeta.request_prediction(&project, &buffer, position, cx)
3887 // });
3888
3889 // let (request, _respond_tx) = req_rx.next().await.unwrap();
3890
3891 // assert_eq!(request.diagnostic_groups.len(), 1);
3892 // let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
3893 // .unwrap();
3894 // // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
3895 // assert_eq!(
3896 // value,
3897 // json!({
3898 // "entries": [{
3899 // "range": {
3900 // "start": 8,
3901 // "end": 10
3902 // },
3903 // "diagnostic": {
3904 // "source": null,
3905 // "code": null,
3906 // "code_description": null,
3907 // "severity": 1,
3908 // "message": "\"Hello\" deprecated. Use \"Hi\" instead",
3909 // "markdown": null,
3910 // "group_id": 0,
3911 // "is_primary": true,
3912 // "is_disk_based": false,
3913 // "is_unnecessary": false,
3914 // "source_kind": "Pushed",
3915 // "data": null,
3916 // "underline": true
3917 // }
3918 // }],
3919 // "primary_ix": 0
3920 // })
3921 // );
3922 // }
3923
3924 fn model_response(text: &str) -> open_ai::Response {
3925 open_ai::Response {
3926 id: Uuid::new_v4().to_string(),
3927 object: "response".into(),
3928 created: 0,
3929 model: "model".into(),
3930 choices: vec![open_ai::Choice {
3931 index: 0,
3932 message: open_ai::RequestMessage::Assistant {
3933 content: Some(open_ai::MessageContent::Plain(text.to_string())),
3934 tool_calls: vec![],
3935 },
3936 finish_reason: None,
3937 }],
3938 usage: Usage {
3939 prompt_tokens: 0,
3940 completion_tokens: 0,
3941 total_tokens: 0,
3942 },
3943 }
3944 }
3945
3946 fn prompt_from_request(request: &open_ai::Request) -> &str {
3947 assert_eq!(request.messages.len(), 1);
3948 let open_ai::RequestMessage::User {
3949 content: open_ai::MessageContent::Plain(content),
3950 ..
3951 } = &request.messages[0]
3952 else {
3953 panic!(
3954 "Request does not have single user message of type Plain. {:#?}",
3955 request
3956 );
3957 };
3958 content
3959 }
3960
3961 struct RequestChannels {
3962 predict: mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender<open_ai::Response>)>,
3963 reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>,
3964 }
3965
3966 fn init_test(cx: &mut TestAppContext) -> (Entity<Zeta>, RequestChannels) {
3967 cx.update(move |cx| {
3968 let settings_store = SettingsStore::test(cx);
3969 cx.set_global(settings_store);
3970 zlog::init_test();
3971
3972 let (predict_req_tx, predict_req_rx) = mpsc::unbounded();
3973 let (reject_req_tx, reject_req_rx) = mpsc::unbounded();
3974
3975 let http_client = FakeHttpClient::create({
3976 move |req| {
3977 let uri = req.uri().path().to_string();
3978 let mut body = req.into_body();
3979 let predict_req_tx = predict_req_tx.clone();
3980 let reject_req_tx = reject_req_tx.clone();
3981 async move {
3982 let resp = match uri.as_str() {
3983 "/client/llm_tokens" => serde_json::to_string(&json!({
3984 "token": "test"
3985 }))
3986 .unwrap(),
3987 "/predict_edits/raw" => {
3988 let mut buf = Vec::new();
3989 body.read_to_end(&mut buf).await.ok();
3990 let req = serde_json::from_slice(&buf).unwrap();
3991
3992 let (res_tx, res_rx) = oneshot::channel();
3993 predict_req_tx.unbounded_send((req, res_tx)).unwrap();
3994 serde_json::to_string(&res_rx.await?).unwrap()
3995 }
3996 "/predict_edits/reject" => {
3997 let mut buf = Vec::new();
3998 body.read_to_end(&mut buf).await.ok();
3999 let req = serde_json::from_slice(&buf).unwrap();
4000
4001 let (res_tx, res_rx) = oneshot::channel();
4002 reject_req_tx.unbounded_send((req, res_tx)).unwrap();
4003 serde_json::to_string(&res_rx.await?).unwrap()
4004 }
4005 _ => {
4006 panic!("Unexpected path: {}", uri)
4007 }
4008 };
4009
4010 Ok(Response::builder().body(resp.into()).unwrap())
4011 }
4012 }
4013 });
4014
4015 let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
4016 client.cloud_client().set_credentials(1, "test".into());
4017
4018 language_model::init(client.clone(), cx);
4019
4020 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
4021 let zeta = Zeta::global(&client, &user_store, cx);
4022
4023 (
4024 zeta,
4025 RequestChannels {
4026 predict: predict_req_rx,
4027 reject: reject_req_rx,
4028 },
4029 )
4030 })
4031 }
4032}