1use anyhow::{Context as _, Result, anyhow, bail};
2use arrayvec::ArrayVec;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat, Signature};
5use cloud_llm_client::{
6 AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason,
7 EditPredictionRejection, MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST,
8 MINIMUM_REQUIRED_VERSION_HEADER_NAME, RejectEditPredictionsBody, ZED_VERSION_HEADER_NAME,
9};
10use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
11use cloud_zeta2_prompt::{CURSOR_MARKER, DEFAULT_MAX_PROMPT_BYTES};
12use collections::{HashMap, HashSet};
13use command_palette_hooks::CommandPaletteFilter;
14use db::kvp::{Dismissable, KEY_VALUE_STORE};
15use edit_prediction_context::{
16 DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
17 EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionScoreOptions, Line,
18 SyntaxIndex, SyntaxIndexState,
19};
20use feature_flags::{FeatureFlag, FeatureFlagAppExt as _, PredictEditsRateCompletionsFeatureFlag};
21use futures::channel::{mpsc, oneshot};
22use futures::{AsyncReadExt as _, StreamExt as _};
23use gpui::{
24 App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
25 http_client::{self, AsyncBody, Method},
26 prelude::*,
27};
28use language::{
29 Anchor, Buffer, DiagnosticSet, File, LanguageServerId, Point, ToOffset as _, ToPoint,
30};
31use language::{BufferSnapshot, OffsetRangeExt};
32use language_model::{LlmApiToken, RefreshLlmTokenListener};
33use open_ai::FunctionDefinition;
34use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
35use release_channel::AppVersion;
36use semver::Version;
37use serde::de::DeserializeOwned;
38use settings::{EditPredictionProvider, Settings as _, SettingsStore, update_settings_file};
39use std::any::{Any as _, TypeId};
40use std::collections::{VecDeque, hash_map};
41use telemetry_events::EditPredictionRating;
42use workspace::Workspace;
43
44use std::ops::Range;
45use std::path::Path;
46use std::rc::Rc;
47use std::str::FromStr as _;
48use std::sync::{Arc, LazyLock};
49use std::time::{Duration, Instant};
50use std::{env, mem};
51use thiserror::Error;
52use util::rel_path::RelPathBuf;
53use util::{LogErrorFuture, RangeExt as _, ResultExt as _, TryFutureExt};
54use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
55
56pub mod assemble_excerpts;
57mod license_detection;
58mod onboarding_modal;
59mod prediction;
60mod provider;
61mod rate_prediction_modal;
62pub mod retrieval_search;
63mod sweep_ai;
64pub mod udiff;
65mod xml_edits;
66pub mod zeta1;
67
68#[cfg(test)]
69mod zeta_tests;
70
71use crate::assemble_excerpts::assemble_excerpts;
72use crate::license_detection::LicenseDetectionWatcher;
73use crate::onboarding_modal::ZedPredictModal;
74pub use crate::prediction::EditPrediction;
75pub use crate::prediction::EditPredictionId;
76pub use crate::prediction::EditPredictionInputs;
77use crate::prediction::EditPredictionResult;
78use crate::rate_prediction_modal::{
79 NextEdit, PreviousEdit, RatePredictionsModal, ThumbsDownActivePrediction,
80 ThumbsUpActivePrediction,
81};
82use crate::sweep_ai::SweepAi;
83use crate::zeta1::request_prediction_with_zeta1;
84pub use provider::ZetaEditPredictionProvider;
85
86actions!(
87 edit_prediction,
88 [
89 /// Resets the edit prediction onboarding state.
90 ResetOnboarding,
91 /// Opens the rate completions modal.
92 RateCompletions,
93 /// Clears the edit prediction history.
94 ClearHistory,
95 ]
96);
97
98/// Maximum number of events to track.
99const EVENT_COUNT_MAX: usize = 6;
100const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
101const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
102
103pub struct SweepFeatureFlag;
104
105impl FeatureFlag for SweepFeatureFlag {
106 const NAME: &str = "sweep-ai";
107}
108pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
109 max_bytes: 512,
110 min_bytes: 128,
111 target_before_cursor_over_total_bytes: 0.5,
112};
113
114pub const DEFAULT_CONTEXT_OPTIONS: ContextMode =
115 ContextMode::Agentic(DEFAULT_AGENTIC_CONTEXT_OPTIONS);
116
117pub const DEFAULT_AGENTIC_CONTEXT_OPTIONS: AgenticContextOptions = AgenticContextOptions {
118 excerpt: DEFAULT_EXCERPT_OPTIONS,
119};
120
121pub const DEFAULT_SYNTAX_CONTEXT_OPTIONS: EditPredictionContextOptions =
122 EditPredictionContextOptions {
123 use_imports: true,
124 max_retrieved_declarations: 0,
125 excerpt: DEFAULT_EXCERPT_OPTIONS,
126 score: EditPredictionScoreOptions {
127 omit_excerpt_overlaps: true,
128 },
129 };
130
131pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
132 context: DEFAULT_CONTEXT_OPTIONS,
133 max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
134 max_diagnostic_bytes: 2048,
135 prompt_format: PromptFormat::DEFAULT,
136 file_indexing_parallelism: 1,
137 buffer_change_grouping_interval: Duration::from_secs(1),
138};
139
140static USE_OLLAMA: LazyLock<bool> =
141 LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
142static CONTEXT_RETRIEVAL_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
143 env::var("ZED_ZETA2_CONTEXT_MODEL").unwrap_or(if *USE_OLLAMA {
144 "qwen3-coder:30b".to_string()
145 } else {
146 "yqvev8r3".to_string()
147 })
148});
149static EDIT_PREDICTIONS_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
150 match env::var("ZED_ZETA2_MODEL").as_deref() {
151 Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten
152 Ok(model) => model,
153 Err(_) if *USE_OLLAMA => "qwen3-coder:30b",
154 Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten
155 }
156 .to_string()
157});
158static PREDICT_EDITS_URL: LazyLock<Option<String>> = LazyLock::new(|| {
159 env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| {
160 if *USE_OLLAMA {
161 Some("http://localhost:11434/v1/chat/completions".into())
162 } else {
163 None
164 }
165 })
166});
167
168pub struct Zeta2FeatureFlag;
169
170impl FeatureFlag for Zeta2FeatureFlag {
171 const NAME: &'static str = "zeta2";
172
173 fn enabled_for_staff() -> bool {
174 true
175 }
176}
177
178#[derive(Clone)]
179struct ZetaGlobal(Entity<Zeta>);
180
181impl Global for ZetaGlobal {}
182
183pub struct Zeta {
184 client: Arc<Client>,
185 user_store: Entity<UserStore>,
186 llm_token: LlmApiToken,
187 _llm_token_subscription: Subscription,
188 projects: HashMap<EntityId, ZetaProject>,
189 options: ZetaOptions,
190 update_required: bool,
191 debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
192 #[cfg(feature = "eval-support")]
193 eval_cache: Option<Arc<dyn EvalCache>>,
194 edit_prediction_model: ZetaEditPredictionModel,
195 sweep_ai: SweepAi,
196 data_collection_choice: DataCollectionChoice,
197 rejected_predictions: Vec<EditPredictionRejection>,
198 reject_predictions_tx: mpsc::UnboundedSender<()>,
199 reject_predictions_debounce_task: Option<Task<()>>,
200 shown_predictions: VecDeque<EditPrediction>,
201 rated_predictions: HashSet<EditPredictionId>,
202}
203
204#[derive(Copy, Clone, Default, PartialEq, Eq)]
205pub enum ZetaEditPredictionModel {
206 #[default]
207 Zeta1,
208 Zeta2,
209 Sweep,
210}
211
212#[derive(Debug, Clone, PartialEq)]
213pub struct ZetaOptions {
214 pub context: ContextMode,
215 pub max_prompt_bytes: usize,
216 pub max_diagnostic_bytes: usize,
217 pub prompt_format: predict_edits_v3::PromptFormat,
218 pub file_indexing_parallelism: usize,
219 pub buffer_change_grouping_interval: Duration,
220}
221
222#[derive(Debug, Clone, PartialEq)]
223pub enum ContextMode {
224 Agentic(AgenticContextOptions),
225 Syntax(EditPredictionContextOptions),
226}
227
228#[derive(Debug, Clone, PartialEq)]
229pub struct AgenticContextOptions {
230 pub excerpt: EditPredictionExcerptOptions,
231}
232
233impl ContextMode {
234 pub fn excerpt(&self) -> &EditPredictionExcerptOptions {
235 match self {
236 ContextMode::Agentic(options) => &options.excerpt,
237 ContextMode::Syntax(options) => &options.excerpt,
238 }
239 }
240}
241
242#[derive(Debug)]
243pub enum ZetaDebugInfo {
244 ContextRetrievalStarted(ZetaContextRetrievalStartedDebugInfo),
245 SearchQueriesGenerated(ZetaSearchQueryDebugInfo),
246 SearchQueriesExecuted(ZetaContextRetrievalDebugInfo),
247 ContextRetrievalFinished(ZetaContextRetrievalDebugInfo),
248 EditPredictionRequested(ZetaEditPredictionDebugInfo),
249}
250
251#[derive(Debug)]
252pub struct ZetaContextRetrievalStartedDebugInfo {
253 pub project: Entity<Project>,
254 pub timestamp: Instant,
255 pub search_prompt: String,
256}
257
258#[derive(Debug)]
259pub struct ZetaContextRetrievalDebugInfo {
260 pub project: Entity<Project>,
261 pub timestamp: Instant,
262}
263
264#[derive(Debug)]
265pub struct ZetaEditPredictionDebugInfo {
266 pub inputs: EditPredictionInputs,
267 pub retrieval_time: Duration,
268 pub buffer: WeakEntity<Buffer>,
269 pub position: language::Anchor,
270 pub local_prompt: Result<String, String>,
271 pub response_rx: oneshot::Receiver<(Result<open_ai::Response, String>, Duration)>,
272}
273
274#[derive(Debug)]
275pub struct ZetaSearchQueryDebugInfo {
276 pub project: Entity<Project>,
277 pub timestamp: Instant,
278 pub search_queries: Vec<SearchToolQuery>,
279}
280
281pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
282
283struct ZetaProject {
284 syntax_index: Option<Entity<SyntaxIndex>>,
285 events: VecDeque<Arc<cloud_llm_client::predict_edits_v3::Event>>,
286 last_event: Option<LastEvent>,
287 recent_paths: VecDeque<ProjectPath>,
288 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
289 current_prediction: Option<CurrentEditPrediction>,
290 next_pending_prediction_id: usize,
291 pending_predictions: ArrayVec<PendingPrediction, 2>,
292 last_prediction_refresh: Option<(EntityId, Instant)>,
293 cancelled_predictions: HashSet<usize>,
294 context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
295 refresh_context_task: Option<LogErrorFuture<Task<Result<()>>>>,
296 refresh_context_debounce_task: Option<Task<Option<()>>>,
297 refresh_context_timestamp: Option<Instant>,
298 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
299 _subscription: gpui::Subscription,
300}
301
302impl ZetaProject {
303 pub fn events(&self, cx: &App) -> Vec<Arc<cloud_llm_client::predict_edits_v3::Event>> {
304 self.events
305 .iter()
306 .cloned()
307 .chain(
308 self.last_event
309 .as_ref()
310 .and_then(|event| event.finalize(&self.license_detection_watchers, cx)),
311 )
312 .collect()
313 }
314
315 fn cancel_pending_prediction(
316 &mut self,
317 pending_prediction: PendingPrediction,
318 cx: &mut Context<Zeta>,
319 ) {
320 self.cancelled_predictions.insert(pending_prediction.id);
321
322 cx.spawn(async move |this, cx| {
323 let Some(prediction_id) = pending_prediction.task.await else {
324 return;
325 };
326
327 this.update(cx, |this, cx| {
328 this.reject_prediction(
329 prediction_id,
330 EditPredictionRejectReason::Canceled,
331 false,
332 cx,
333 );
334 })
335 .ok();
336 })
337 .detach()
338 }
339}
340
341#[derive(Debug, Clone)]
342struct CurrentEditPrediction {
343 pub requested_by: PredictionRequestedBy,
344 pub prediction: EditPrediction,
345 pub was_shown: bool,
346}
347
348impl CurrentEditPrediction {
349 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
350 let Some(new_edits) = self
351 .prediction
352 .interpolate(&self.prediction.buffer.read(cx))
353 else {
354 return false;
355 };
356
357 if self.prediction.buffer != old_prediction.prediction.buffer {
358 return true;
359 }
360
361 let Some(old_edits) = old_prediction
362 .prediction
363 .interpolate(&old_prediction.prediction.buffer.read(cx))
364 else {
365 return true;
366 };
367
368 let requested_by_buffer_id = self.requested_by.buffer_id();
369
370 // This reduces the occurrence of UI thrash from replacing edits
371 //
372 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
373 if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
374 && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
375 && old_edits.len() == 1
376 && new_edits.len() == 1
377 {
378 let (old_range, old_text) = &old_edits[0];
379 let (new_range, new_text) = &new_edits[0];
380 new_range == old_range && new_text.starts_with(old_text.as_ref())
381 } else {
382 true
383 }
384 }
385}
386
387#[derive(Debug, Clone)]
388enum PredictionRequestedBy {
389 DiagnosticsUpdate,
390 Buffer(EntityId),
391}
392
393impl PredictionRequestedBy {
394 pub fn buffer_id(&self) -> Option<EntityId> {
395 match self {
396 PredictionRequestedBy::DiagnosticsUpdate => None,
397 PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
398 }
399 }
400}
401
402#[derive(Debug)]
403struct PendingPrediction {
404 id: usize,
405 task: Task<Option<EditPredictionId>>,
406}
407
408/// A prediction from the perspective of a buffer.
409#[derive(Debug)]
410enum BufferEditPrediction<'a> {
411 Local { prediction: &'a EditPrediction },
412 Jump { prediction: &'a EditPrediction },
413}
414
415#[cfg(test)]
416impl std::ops::Deref for BufferEditPrediction<'_> {
417 type Target = EditPrediction;
418
419 fn deref(&self) -> &Self::Target {
420 match self {
421 BufferEditPrediction::Local { prediction } => prediction,
422 BufferEditPrediction::Jump { prediction } => prediction,
423 }
424 }
425}
426
427struct RegisteredBuffer {
428 snapshot: BufferSnapshot,
429 _subscriptions: [gpui::Subscription; 2],
430}
431
432struct LastEvent {
433 old_snapshot: BufferSnapshot,
434 new_snapshot: BufferSnapshot,
435 end_edit_anchor: Option<Anchor>,
436}
437
438impl LastEvent {
439 pub fn finalize(
440 &self,
441 license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
442 cx: &App,
443 ) -> Option<Arc<predict_edits_v3::Event>> {
444 let path = buffer_path_with_id_fallback(&self.new_snapshot, cx);
445 let old_path = buffer_path_with_id_fallback(&self.old_snapshot, cx);
446
447 let file = self.new_snapshot.file();
448 let old_file = self.old_snapshot.file();
449
450 let in_open_source_repo = [file, old_file].iter().all(|file| {
451 file.is_some_and(|file| {
452 license_detection_watchers
453 .get(&file.worktree_id(cx))
454 .is_some_and(|watcher| watcher.is_project_open_source())
455 })
456 });
457
458 let diff = language::unified_diff(&self.old_snapshot.text(), &self.new_snapshot.text());
459
460 if path == old_path && diff.is_empty() {
461 None
462 } else {
463 Some(Arc::new(predict_edits_v3::Event::BufferChange {
464 old_path,
465 path,
466 diff,
467 in_open_source_repo,
468 // TODO: Actually detect if this edit was predicted or not
469 predicted: false,
470 }))
471 }
472 }
473}
474
475fn buffer_path_with_id_fallback(snapshot: &BufferSnapshot, cx: &App) -> Arc<Path> {
476 if let Some(file) = snapshot.file() {
477 file.full_path(cx).into()
478 } else {
479 Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
480 }
481}
482
483impl Zeta {
484 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
485 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
486 }
487
488 pub fn global(
489 client: &Arc<Client>,
490 user_store: &Entity<UserStore>,
491 cx: &mut App,
492 ) -> Entity<Self> {
493 cx.try_global::<ZetaGlobal>()
494 .map(|global| global.0.clone())
495 .unwrap_or_else(|| {
496 let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
497 cx.set_global(ZetaGlobal(zeta.clone()));
498 zeta
499 })
500 }
501
502 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
503 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
504 let data_collection_choice = Self::load_data_collection_choice();
505
506 let (reject_tx, mut reject_rx) = mpsc::unbounded();
507 cx.spawn(async move |this, cx| {
508 while let Some(()) = reject_rx.next().await {
509 this.update(cx, |this, cx| this.flush_rejected_predictions(cx))?
510 .await
511 .log_err();
512 }
513 anyhow::Ok(())
514 })
515 .detach();
516
517 Self {
518 projects: HashMap::default(),
519 client,
520 user_store,
521 options: DEFAULT_OPTIONS,
522 llm_token: LlmApiToken::default(),
523 _llm_token_subscription: cx.subscribe(
524 &refresh_llm_token_listener,
525 |this, _listener, _event, cx| {
526 let client = this.client.clone();
527 let llm_token = this.llm_token.clone();
528 cx.spawn(async move |_this, _cx| {
529 llm_token.refresh(&client).await?;
530 anyhow::Ok(())
531 })
532 .detach_and_log_err(cx);
533 },
534 ),
535 update_required: false,
536 debug_tx: None,
537 #[cfg(feature = "eval-support")]
538 eval_cache: None,
539 edit_prediction_model: ZetaEditPredictionModel::Zeta2,
540 sweep_ai: SweepAi::new(cx),
541 data_collection_choice,
542 rejected_predictions: Vec::new(),
543 reject_predictions_debounce_task: None,
544 reject_predictions_tx: reject_tx,
545 rated_predictions: Default::default(),
546 shown_predictions: Default::default(),
547 }
548 }
549
550 pub fn set_edit_prediction_model(&mut self, model: ZetaEditPredictionModel) {
551 self.edit_prediction_model = model;
552 }
553
554 pub fn has_sweep_api_token(&self) -> bool {
555 self.sweep_ai.api_token.is_some()
556 }
557
558 #[cfg(feature = "eval-support")]
559 pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
560 self.eval_cache = Some(cache);
561 }
562
563 pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<ZetaDebugInfo> {
564 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
565 self.debug_tx = Some(debug_watch_tx);
566 debug_watch_rx
567 }
568
569 pub fn options(&self) -> &ZetaOptions {
570 &self.options
571 }
572
573 pub fn set_options(&mut self, options: ZetaOptions) {
574 self.options = options;
575 }
576
577 pub fn clear_history(&mut self) {
578 for zeta_project in self.projects.values_mut() {
579 zeta_project.events.clear();
580 }
581 }
582
583 pub fn context_for_project(
584 &self,
585 project: &Entity<Project>,
586 ) -> impl Iterator<Item = (Entity<Buffer>, &[Range<Anchor>])> {
587 self.projects
588 .get(&project.entity_id())
589 .and_then(|project| {
590 Some(
591 project
592 .context
593 .as_ref()?
594 .iter()
595 .map(|(buffer, ranges)| (buffer.clone(), ranges.as_slice())),
596 )
597 })
598 .into_iter()
599 .flatten()
600 }
601
602 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
603 if self.edit_prediction_model == ZetaEditPredictionModel::Zeta2 {
604 self.user_store.read(cx).edit_prediction_usage()
605 } else {
606 None
607 }
608 }
609
610 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
611 self.get_or_init_zeta_project(project, cx);
612 }
613
614 pub fn register_buffer(
615 &mut self,
616 buffer: &Entity<Buffer>,
617 project: &Entity<Project>,
618 cx: &mut Context<Self>,
619 ) {
620 let zeta_project = self.get_or_init_zeta_project(project, cx);
621 Self::register_buffer_impl(zeta_project, buffer, project, cx);
622 }
623
624 fn get_or_init_zeta_project(
625 &mut self,
626 project: &Entity<Project>,
627 cx: &mut Context<Self>,
628 ) -> &mut ZetaProject {
629 self.projects
630 .entry(project.entity_id())
631 .or_insert_with(|| ZetaProject {
632 syntax_index: if let ContextMode::Syntax(_) = &self.options.context {
633 Some(cx.new(|cx| {
634 SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
635 }))
636 } else {
637 None
638 },
639 events: VecDeque::new(),
640 last_event: None,
641 recent_paths: VecDeque::new(),
642 registered_buffers: HashMap::default(),
643 current_prediction: None,
644 cancelled_predictions: HashSet::default(),
645 pending_predictions: ArrayVec::new(),
646 next_pending_prediction_id: 0,
647 last_prediction_refresh: None,
648 context: None,
649 refresh_context_task: None,
650 refresh_context_debounce_task: None,
651 refresh_context_timestamp: None,
652 license_detection_watchers: HashMap::default(),
653 _subscription: cx.subscribe(&project, Self::handle_project_event),
654 })
655 }
656
657 fn handle_project_event(
658 &mut self,
659 project: Entity<Project>,
660 event: &project::Event,
661 cx: &mut Context<Self>,
662 ) {
663 // TODO [zeta2] init with recent paths
664 match event {
665 project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
666 let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
667 return;
668 };
669 let path = project.read(cx).path_for_entry(*active_entry_id, cx);
670 if let Some(path) = path {
671 if let Some(ix) = zeta_project
672 .recent_paths
673 .iter()
674 .position(|probe| probe == &path)
675 {
676 zeta_project.recent_paths.remove(ix);
677 }
678 zeta_project.recent_paths.push_front(path);
679 }
680 }
681 project::Event::DiagnosticsUpdated { .. } => {
682 if cx.has_flag::<Zeta2FeatureFlag>() {
683 self.refresh_prediction_from_diagnostics(project, cx);
684 }
685 }
686 _ => (),
687 }
688 }
689
690 fn register_buffer_impl<'a>(
691 zeta_project: &'a mut ZetaProject,
692 buffer: &Entity<Buffer>,
693 project: &Entity<Project>,
694 cx: &mut Context<Self>,
695 ) -> &'a mut RegisteredBuffer {
696 let buffer_id = buffer.entity_id();
697
698 if let Some(file) = buffer.read(cx).file() {
699 let worktree_id = file.worktree_id(cx);
700 if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
701 zeta_project
702 .license_detection_watchers
703 .entry(worktree_id)
704 .or_insert_with(|| {
705 let project_entity_id = project.entity_id();
706 cx.observe_release(&worktree, move |this, _worktree, _cx| {
707 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
708 else {
709 return;
710 };
711 zeta_project.license_detection_watchers.remove(&worktree_id);
712 })
713 .detach();
714 Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
715 });
716 }
717 }
718
719 match zeta_project.registered_buffers.entry(buffer_id) {
720 hash_map::Entry::Occupied(entry) => entry.into_mut(),
721 hash_map::Entry::Vacant(entry) => {
722 let snapshot = buffer.read(cx).snapshot();
723 let project_entity_id = project.entity_id();
724 entry.insert(RegisteredBuffer {
725 snapshot,
726 _subscriptions: [
727 cx.subscribe(buffer, {
728 let project = project.downgrade();
729 move |this, buffer, event, cx| {
730 if let language::BufferEvent::Edited = event
731 && let Some(project) = project.upgrade()
732 {
733 this.report_changes_for_buffer(&buffer, &project, cx);
734 }
735 }
736 }),
737 cx.observe_release(buffer, move |this, _buffer, _cx| {
738 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
739 else {
740 return;
741 };
742 zeta_project.registered_buffers.remove(&buffer_id);
743 }),
744 ],
745 })
746 }
747 }
748 }
749
750 fn report_changes_for_buffer(
751 &mut self,
752 buffer: &Entity<Buffer>,
753 project: &Entity<Project>,
754 cx: &mut Context<Self>,
755 ) {
756 let project_state = self.get_or_init_zeta_project(project, cx);
757 let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
758
759 let new_snapshot = buffer.read(cx).snapshot();
760 if new_snapshot.version == registered_buffer.snapshot.version {
761 return;
762 }
763
764 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
765 let end_edit_anchor = new_snapshot
766 .anchored_edits_since::<Point>(&old_snapshot.version)
767 .last()
768 .map(|(_, range)| range.end);
769 let events = &mut project_state.events;
770
771 if let Some(LastEvent {
772 new_snapshot: last_new_snapshot,
773 end_edit_anchor: last_end_edit_anchor,
774 ..
775 }) = project_state.last_event.as_mut()
776 {
777 let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
778 == last_new_snapshot.remote_id()
779 && old_snapshot.version == last_new_snapshot.version;
780
781 let should_coalesce = is_next_snapshot_of_same_buffer
782 && end_edit_anchor
783 .as_ref()
784 .zip(last_end_edit_anchor.as_ref())
785 .is_some_and(|(a, b)| {
786 let a = a.to_point(&new_snapshot);
787 let b = b.to_point(&new_snapshot);
788 a.row.abs_diff(b.row) <= CHANGE_GROUPING_LINE_SPAN
789 });
790
791 if should_coalesce {
792 *last_end_edit_anchor = end_edit_anchor;
793 *last_new_snapshot = new_snapshot;
794 return;
795 }
796 }
797
798 if events.len() + 1 >= EVENT_COUNT_MAX {
799 events.pop_front();
800 }
801
802 if let Some(event) = project_state.last_event.take() {
803 events.extend(event.finalize(&project_state.license_detection_watchers, cx));
804 }
805
806 project_state.last_event = Some(LastEvent {
807 old_snapshot,
808 new_snapshot,
809 end_edit_anchor,
810 });
811 }
812
813 fn current_prediction_for_buffer(
814 &self,
815 buffer: &Entity<Buffer>,
816 project: &Entity<Project>,
817 cx: &App,
818 ) -> Option<BufferEditPrediction<'_>> {
819 let project_state = self.projects.get(&project.entity_id())?;
820
821 let CurrentEditPrediction {
822 requested_by,
823 prediction,
824 ..
825 } = project_state.current_prediction.as_ref()?;
826
827 if prediction.targets_buffer(buffer.read(cx)) {
828 Some(BufferEditPrediction::Local { prediction })
829 } else {
830 let show_jump = match requested_by {
831 PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
832 requested_by_buffer_id == &buffer.entity_id()
833 }
834 PredictionRequestedBy::DiagnosticsUpdate => true,
835 };
836
837 if show_jump {
838 Some(BufferEditPrediction::Jump { prediction })
839 } else {
840 None
841 }
842 }
843 }
844
845 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
846 match self.edit_prediction_model {
847 ZetaEditPredictionModel::Zeta1 | ZetaEditPredictionModel::Zeta2 => {}
848 ZetaEditPredictionModel::Sweep => return,
849 }
850
851 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
852 return;
853 };
854
855 let Some(prediction) = project_state.current_prediction.take() else {
856 return;
857 };
858 let request_id = prediction.prediction.id.to_string();
859 for pending_prediction in mem::take(&mut project_state.pending_predictions) {
860 project_state.cancel_pending_prediction(pending_prediction, cx);
861 }
862
863 let client = self.client.clone();
864 let llm_token = self.llm_token.clone();
865 let app_version = AppVersion::global(cx);
866 cx.spawn(async move |this, cx| {
867 let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") {
868 http_client::Url::parse(&predict_edits_url)?
869 } else {
870 client
871 .http_client()
872 .build_zed_llm_url("/predict_edits/accept", &[])?
873 };
874
875 let response = cx
876 .background_spawn(Self::send_api_request::<()>(
877 move |builder| {
878 let req = builder.uri(url.as_ref()).body(
879 serde_json::to_string(&AcceptEditPredictionBody {
880 request_id: request_id.clone(),
881 })?
882 .into(),
883 );
884 Ok(req?)
885 },
886 client,
887 llm_token,
888 app_version,
889 ))
890 .await;
891
892 Self::handle_api_response(&this, response, cx)?;
893 anyhow::Ok(())
894 })
895 .detach_and_log_err(cx);
896 }
897
898 fn flush_rejected_predictions(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
899 match self.edit_prediction_model {
900 ZetaEditPredictionModel::Zeta1 | ZetaEditPredictionModel::Zeta2 => {}
901 ZetaEditPredictionModel::Sweep => return Task::ready(anyhow::Ok(())),
902 }
903
904 let client = self.client.clone();
905 let llm_token = self.llm_token.clone();
906 let app_version = AppVersion::global(cx);
907 let last_rejection = self.rejected_predictions.last().cloned();
908 let Some(last_rejection) = last_rejection else {
909 return Task::ready(anyhow::Ok(()));
910 };
911
912 let body = serde_json::to_string(&RejectEditPredictionsBody {
913 rejections: self.rejected_predictions.clone(),
914 })
915 .ok();
916
917 cx.spawn(async move |this, cx| {
918 let url = client
919 .http_client()
920 .build_zed_llm_url("/predict_edits/reject", &[])?;
921
922 cx.background_spawn(Self::send_api_request::<()>(
923 move |builder| {
924 let req = builder.uri(url.as_ref()).body(body.clone().into());
925 Ok(req?)
926 },
927 client,
928 llm_token,
929 app_version,
930 ))
931 .await
932 .context("Failed to reject edit predictions")?;
933
934 this.update(cx, |this, _| {
935 if let Some(ix) = this
936 .rejected_predictions
937 .iter()
938 .position(|rejection| rejection.request_id == last_rejection.request_id)
939 {
940 this.rejected_predictions.drain(..ix + 1);
941 }
942 })
943 })
944 }
945
946 fn reject_current_prediction(
947 &mut self,
948 reason: EditPredictionRejectReason,
949 project: &Entity<Project>,
950 cx: &mut Context<Self>,
951 ) {
952 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
953 project_state.pending_predictions.clear();
954 if let Some(prediction) = project_state.current_prediction.take() {
955 self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown, cx);
956 }
957 };
958 }
959
960 fn did_show_current_prediction(&mut self, project: &Entity<Project>, _cx: &mut Context<Self>) {
961 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
962 if let Some(current_prediction) = project_state.current_prediction.as_mut() {
963 if !current_prediction.was_shown {
964 current_prediction.was_shown = true;
965 self.shown_predictions
966 .push_front(current_prediction.prediction.clone());
967 if self.shown_predictions.len() > 50 {
968 let completion = self.shown_predictions.pop_back().unwrap();
969 self.rated_predictions.remove(&completion.id);
970 }
971 }
972 }
973 }
974 }
975
976 fn reject_prediction(
977 &mut self,
978 prediction_id: EditPredictionId,
979 reason: EditPredictionRejectReason,
980 was_shown: bool,
981 cx: &mut Context<Self>,
982 ) {
983 self.rejected_predictions.push(EditPredictionRejection {
984 request_id: prediction_id.to_string(),
985 reason,
986 was_shown,
987 });
988
989 let reached_request_limit =
990 self.rejected_predictions.len() >= MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST;
991 let reject_tx = self.reject_predictions_tx.clone();
992 self.reject_predictions_debounce_task = Some(cx.spawn(async move |_this, cx| {
993 const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
994 if !reached_request_limit {
995 cx.background_executor()
996 .timer(REJECT_REQUEST_DEBOUNCE)
997 .await;
998 }
999 reject_tx.unbounded_send(()).log_err();
1000 }));
1001 }
1002
1003 fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1004 self.projects
1005 .get(&project.entity_id())
1006 .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1007 }
1008
1009 pub fn refresh_prediction_from_buffer(
1010 &mut self,
1011 project: Entity<Project>,
1012 buffer: Entity<Buffer>,
1013 position: language::Anchor,
1014 cx: &mut Context<Self>,
1015 ) {
1016 self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
1017 let Some(request_task) = this
1018 .update(cx, |this, cx| {
1019 this.request_prediction(&project, &buffer, position, cx)
1020 })
1021 .log_err()
1022 else {
1023 return Task::ready(anyhow::Ok(None));
1024 };
1025
1026 cx.spawn(async move |_cx| {
1027 request_task.await.map(|prediction_result| {
1028 prediction_result.map(|prediction_result| {
1029 (
1030 prediction_result,
1031 PredictionRequestedBy::Buffer(buffer.entity_id()),
1032 )
1033 })
1034 })
1035 })
1036 })
1037 }
1038
1039 pub fn refresh_prediction_from_diagnostics(
1040 &mut self,
1041 project: Entity<Project>,
1042 cx: &mut Context<Self>,
1043 ) {
1044 let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1045 return;
1046 };
1047
1048 // Prefer predictions from buffer
1049 if zeta_project.current_prediction.is_some() {
1050 return;
1051 };
1052
1053 self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
1054 let Some(open_buffer_task) = project
1055 .update(cx, |project, cx| {
1056 project
1057 .active_entry()
1058 .and_then(|entry| project.path_for_entry(entry, cx))
1059 .map(|path| project.open_buffer(path, cx))
1060 })
1061 .log_err()
1062 .flatten()
1063 else {
1064 return Task::ready(anyhow::Ok(None));
1065 };
1066
1067 cx.spawn(async move |cx| {
1068 let active_buffer = open_buffer_task.await?;
1069 let snapshot = active_buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
1070
1071 let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1072 active_buffer,
1073 &snapshot,
1074 Default::default(),
1075 Default::default(),
1076 &project,
1077 cx,
1078 )
1079 .await?
1080 else {
1081 return anyhow::Ok(None);
1082 };
1083
1084 let Some(prediction_result) = this
1085 .update(cx, |this, cx| {
1086 this.request_prediction(&project, &jump_buffer, jump_position, cx)
1087 })?
1088 .await?
1089 else {
1090 return anyhow::Ok(None);
1091 };
1092
1093 this.update(cx, |this, cx| {
1094 Some((
1095 if this
1096 .get_or_init_zeta_project(&project, cx)
1097 .current_prediction
1098 .is_none()
1099 {
1100 prediction_result
1101 } else {
1102 EditPredictionResult {
1103 id: prediction_result.id,
1104 prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1105 }
1106 },
1107 PredictionRequestedBy::DiagnosticsUpdate,
1108 ))
1109 })
1110 })
1111 });
1112 }
1113
1114 #[cfg(not(test))]
1115 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1116 #[cfg(test)]
1117 pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
1118
1119 fn queue_prediction_refresh(
1120 &mut self,
1121 project: Entity<Project>,
1122 throttle_entity: EntityId,
1123 cx: &mut Context<Self>,
1124 do_refresh: impl FnOnce(
1125 WeakEntity<Self>,
1126 &mut AsyncApp,
1127 )
1128 -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1129 + 'static,
1130 ) {
1131 let zeta_project = self.get_or_init_zeta_project(&project, cx);
1132 let pending_prediction_id = zeta_project.next_pending_prediction_id;
1133 zeta_project.next_pending_prediction_id += 1;
1134 let last_request = zeta_project.last_prediction_refresh;
1135
1136 let task = cx.spawn(async move |this, cx| {
1137 if let Some((last_entity, last_timestamp)) = last_request
1138 && throttle_entity == last_entity
1139 && let Some(timeout) =
1140 (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
1141 {
1142 cx.background_executor().timer(timeout).await;
1143 }
1144
1145 // If this task was cancelled before the throttle timeout expired,
1146 // do not perform a request.
1147 let mut is_cancelled = true;
1148 this.update(cx, |this, cx| {
1149 let project_state = this.get_or_init_zeta_project(&project, cx);
1150 if !project_state
1151 .cancelled_predictions
1152 .remove(&pending_prediction_id)
1153 {
1154 project_state.last_prediction_refresh = Some((throttle_entity, Instant::now()));
1155 is_cancelled = false;
1156 }
1157 })
1158 .ok();
1159 if is_cancelled {
1160 return None;
1161 }
1162
1163 let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
1164 let new_prediction_id = new_prediction_result
1165 .as_ref()
1166 .map(|(prediction, _)| prediction.id.clone());
1167
1168 // When a prediction completes, remove it from the pending list, and cancel
1169 // any pending predictions that were enqueued before it.
1170 this.update(cx, |this, cx| {
1171 let zeta_project = this.get_or_init_zeta_project(&project, cx);
1172
1173 let is_cancelled = zeta_project
1174 .cancelled_predictions
1175 .remove(&pending_prediction_id);
1176
1177 let new_current_prediction = if !is_cancelled
1178 && let Some((prediction_result, requested_by)) = new_prediction_result
1179 {
1180 match prediction_result.prediction {
1181 Ok(prediction) => {
1182 let new_prediction = CurrentEditPrediction {
1183 requested_by,
1184 prediction,
1185 was_shown: false,
1186 };
1187
1188 if let Some(current_prediction) =
1189 zeta_project.current_prediction.as_ref()
1190 {
1191 if new_prediction.should_replace_prediction(¤t_prediction, cx)
1192 {
1193 this.reject_current_prediction(
1194 EditPredictionRejectReason::Replaced,
1195 &project,
1196 cx,
1197 );
1198
1199 Some(new_prediction)
1200 } else {
1201 this.reject_prediction(
1202 new_prediction.prediction.id,
1203 EditPredictionRejectReason::CurrentPreferred,
1204 false,
1205 cx,
1206 );
1207 None
1208 }
1209 } else {
1210 Some(new_prediction)
1211 }
1212 }
1213 Err(reject_reason) => {
1214 this.reject_prediction(prediction_result.id, reject_reason, false, cx);
1215 None
1216 }
1217 }
1218 } else {
1219 None
1220 };
1221
1222 let zeta_project = this.get_or_init_zeta_project(&project, cx);
1223
1224 if let Some(new_prediction) = new_current_prediction {
1225 zeta_project.current_prediction = Some(new_prediction);
1226 }
1227
1228 let mut pending_predictions = mem::take(&mut zeta_project.pending_predictions);
1229 for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1230 if pending_prediction.id == pending_prediction_id {
1231 pending_predictions.remove(ix);
1232 for pending_prediction in pending_predictions.drain(0..ix) {
1233 zeta_project.cancel_pending_prediction(pending_prediction, cx)
1234 }
1235 break;
1236 }
1237 }
1238 this.get_or_init_zeta_project(&project, cx)
1239 .pending_predictions = pending_predictions;
1240 cx.notify();
1241 })
1242 .ok();
1243
1244 new_prediction_id
1245 });
1246
1247 if zeta_project.pending_predictions.len() <= 1 {
1248 zeta_project.pending_predictions.push(PendingPrediction {
1249 id: pending_prediction_id,
1250 task,
1251 });
1252 } else if zeta_project.pending_predictions.len() == 2 {
1253 let pending_prediction = zeta_project.pending_predictions.pop().unwrap();
1254 zeta_project.pending_predictions.push(PendingPrediction {
1255 id: pending_prediction_id,
1256 task,
1257 });
1258 zeta_project.cancel_pending_prediction(pending_prediction, cx);
1259 }
1260 }
1261
1262 pub fn request_prediction(
1263 &mut self,
1264 project: &Entity<Project>,
1265 active_buffer: &Entity<Buffer>,
1266 position: language::Anchor,
1267 cx: &mut Context<Self>,
1268 ) -> Task<Result<Option<EditPredictionResult>>> {
1269 self.request_prediction_internal(
1270 project.clone(),
1271 active_buffer.clone(),
1272 position,
1273 cx.has_flag::<Zeta2FeatureFlag>(),
1274 cx,
1275 )
1276 }
1277
1278 fn request_prediction_internal(
1279 &mut self,
1280 project: Entity<Project>,
1281 active_buffer: Entity<Buffer>,
1282 position: language::Anchor,
1283 allow_jump: bool,
1284 cx: &mut Context<Self>,
1285 ) -> Task<Result<Option<EditPredictionResult>>> {
1286 const DIAGNOSTIC_LINES_RANGE: u32 = 20;
1287
1288 self.get_or_init_zeta_project(&project, cx);
1289 let zeta_project = self.projects.get(&project.entity_id()).unwrap();
1290 let events = zeta_project.events(cx);
1291 let has_events = !events.is_empty();
1292
1293 let snapshot = active_buffer.read(cx).snapshot();
1294 let cursor_point = position.to_point(&snapshot);
1295 let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1296 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1297 let diagnostic_search_range =
1298 Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1299
1300 let task = match self.edit_prediction_model {
1301 ZetaEditPredictionModel::Zeta1 => request_prediction_with_zeta1(
1302 self,
1303 &project,
1304 &active_buffer,
1305 snapshot.clone(),
1306 position,
1307 events,
1308 cx,
1309 ),
1310 ZetaEditPredictionModel::Zeta2 => self.request_prediction_with_zeta2(
1311 &project,
1312 &active_buffer,
1313 snapshot.clone(),
1314 position,
1315 events,
1316 cx,
1317 ),
1318 ZetaEditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(
1319 &project,
1320 &active_buffer,
1321 snapshot.clone(),
1322 position,
1323 events,
1324 &zeta_project.recent_paths,
1325 diagnostic_search_range.clone(),
1326 cx,
1327 ),
1328 };
1329
1330 cx.spawn(async move |this, cx| {
1331 let prediction = task.await?;
1332
1333 if prediction.is_none() && allow_jump {
1334 let cursor_point = position.to_point(&snapshot);
1335 if has_events
1336 && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1337 active_buffer.clone(),
1338 &snapshot,
1339 diagnostic_search_range,
1340 cursor_point,
1341 &project,
1342 cx,
1343 )
1344 .await?
1345 {
1346 return this
1347 .update(cx, |this, cx| {
1348 this.request_prediction_internal(
1349 project,
1350 jump_buffer,
1351 jump_position,
1352 false,
1353 cx,
1354 )
1355 })?
1356 .await;
1357 }
1358
1359 return anyhow::Ok(None);
1360 }
1361
1362 Ok(prediction)
1363 })
1364 }
1365
1366 async fn next_diagnostic_location(
1367 active_buffer: Entity<Buffer>,
1368 active_buffer_snapshot: &BufferSnapshot,
1369 active_buffer_diagnostic_search_range: Range<Point>,
1370 active_buffer_cursor_point: Point,
1371 project: &Entity<Project>,
1372 cx: &mut AsyncApp,
1373 ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1374 // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
1375 let mut jump_location = active_buffer_snapshot
1376 .diagnostic_groups(None)
1377 .into_iter()
1378 .filter_map(|(_, group)| {
1379 let range = &group.entries[group.primary_ix]
1380 .range
1381 .to_point(&active_buffer_snapshot);
1382 if range.overlaps(&active_buffer_diagnostic_search_range) {
1383 None
1384 } else {
1385 Some(range.start)
1386 }
1387 })
1388 .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
1389 .map(|position| {
1390 (
1391 active_buffer.clone(),
1392 active_buffer_snapshot.anchor_before(position),
1393 )
1394 });
1395
1396 if jump_location.is_none() {
1397 let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
1398 let file = buffer.file()?;
1399
1400 Some(ProjectPath {
1401 worktree_id: file.worktree_id(cx),
1402 path: file.path().clone(),
1403 })
1404 })?;
1405
1406 let buffer_task = project.update(cx, |project, cx| {
1407 let (path, _, _) = project
1408 .diagnostic_summaries(false, cx)
1409 .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
1410 .max_by_key(|(path, _, _)| {
1411 // find the buffer with errors that shares most parent directories
1412 path.path
1413 .components()
1414 .zip(
1415 active_buffer_path
1416 .as_ref()
1417 .map(|p| p.path.components())
1418 .unwrap_or_default(),
1419 )
1420 .take_while(|(a, b)| a == b)
1421 .count()
1422 })?;
1423
1424 Some(project.open_buffer(path, cx))
1425 })?;
1426
1427 if let Some(buffer_task) = buffer_task {
1428 let closest_buffer = buffer_task.await?;
1429
1430 jump_location = closest_buffer
1431 .read_with(cx, |buffer, _cx| {
1432 buffer
1433 .buffer_diagnostics(None)
1434 .into_iter()
1435 .min_by_key(|entry| entry.diagnostic.severity)
1436 .map(|entry| entry.range.start)
1437 })?
1438 .map(|position| (closest_buffer, position));
1439 }
1440 }
1441
1442 anyhow::Ok(jump_location)
1443 }
1444
1445 fn request_prediction_with_zeta2(
1446 &mut self,
1447 project: &Entity<Project>,
1448 active_buffer: &Entity<Buffer>,
1449 active_snapshot: BufferSnapshot,
1450 position: language::Anchor,
1451 events: Vec<Arc<Event>>,
1452 cx: &mut Context<Self>,
1453 ) -> Task<Result<Option<EditPredictionResult>>> {
1454 let project_state = self.projects.get(&project.entity_id());
1455
1456 let index_state = project_state.and_then(|state| {
1457 state
1458 .syntax_index
1459 .as_ref()
1460 .map(|syntax_index| syntax_index.read_with(cx, |index, _cx| index.state().clone()))
1461 });
1462 let options = self.options.clone();
1463 let buffer_snapshotted_at = Instant::now();
1464 let Some(excerpt_path) = active_snapshot
1465 .file()
1466 .map(|path| -> Arc<Path> { path.full_path(cx).into() })
1467 else {
1468 return Task::ready(Err(anyhow!("No file path for excerpt")));
1469 };
1470 let client = self.client.clone();
1471 let llm_token = self.llm_token.clone();
1472 let app_version = AppVersion::global(cx);
1473 let worktree_snapshots = project
1474 .read(cx)
1475 .worktrees(cx)
1476 .map(|worktree| worktree.read(cx).snapshot())
1477 .collect::<Vec<_>>();
1478 let debug_tx = self.debug_tx.clone();
1479
1480 let diagnostics = active_snapshot.diagnostic_sets().clone();
1481
1482 let file = active_buffer.read(cx).file();
1483 let parent_abs_path = project::File::from_dyn(file).and_then(|f| {
1484 let mut path = f.worktree.read(cx).absolutize(&f.path);
1485 if path.pop() { Some(path) } else { None }
1486 });
1487
1488 // TODO data collection
1489 let can_collect_data = file
1490 .as_ref()
1491 .map_or(false, |file| self.can_collect_file(project, file, cx));
1492
1493 let empty_context_files = HashMap::default();
1494 let context_files = project_state
1495 .and_then(|project_state| project_state.context.as_ref())
1496 .unwrap_or(&empty_context_files);
1497
1498 #[cfg(feature = "eval-support")]
1499 let parsed_fut = futures::future::join_all(
1500 context_files
1501 .keys()
1502 .map(|buffer| buffer.read(cx).parsing_idle()),
1503 );
1504
1505 let mut included_files = context_files
1506 .iter()
1507 .filter_map(|(buffer_entity, ranges)| {
1508 let buffer = buffer_entity.read(cx);
1509 Some((
1510 buffer_entity.clone(),
1511 buffer.snapshot(),
1512 buffer.file()?.full_path(cx).into(),
1513 ranges.clone(),
1514 ))
1515 })
1516 .collect::<Vec<_>>();
1517
1518 included_files.sort_by(|(_, _, path_a, ranges_a), (_, _, path_b, ranges_b)| {
1519 (path_a, ranges_a.len()).cmp(&(path_b, ranges_b.len()))
1520 });
1521
1522 #[cfg(feature = "eval-support")]
1523 let eval_cache = self.eval_cache.clone();
1524
1525 let request_task = cx.background_spawn({
1526 let active_buffer = active_buffer.clone();
1527 async move {
1528 #[cfg(feature = "eval-support")]
1529 parsed_fut.await;
1530
1531 let index_state = if let Some(index_state) = index_state {
1532 Some(index_state.lock_owned().await)
1533 } else {
1534 None
1535 };
1536
1537 let cursor_offset = position.to_offset(&active_snapshot);
1538 let cursor_point = cursor_offset.to_point(&active_snapshot);
1539
1540 let before_retrieval = Instant::now();
1541
1542 let (diagnostic_groups, diagnostic_groups_truncated) =
1543 Self::gather_nearby_diagnostics(
1544 cursor_offset,
1545 &diagnostics,
1546 &active_snapshot,
1547 options.max_diagnostic_bytes,
1548 );
1549
1550 let cloud_request = match options.context {
1551 ContextMode::Agentic(context_options) => {
1552 let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
1553 cursor_point,
1554 &active_snapshot,
1555 &context_options.excerpt,
1556 index_state.as_deref(),
1557 ) else {
1558 return Ok((None, None));
1559 };
1560
1561 let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start)
1562 ..active_snapshot.anchor_before(excerpt.range.end);
1563
1564 if let Some(buffer_ix) =
1565 included_files.iter().position(|(_, snapshot, _, _)| {
1566 snapshot.remote_id() == active_snapshot.remote_id()
1567 })
1568 {
1569 let (_, buffer, _, ranges) = &mut included_files[buffer_ix];
1570 ranges.push(excerpt_anchor_range);
1571 retrieval_search::merge_anchor_ranges(ranges, buffer);
1572 let last_ix = included_files.len() - 1;
1573 included_files.swap(buffer_ix, last_ix);
1574 } else {
1575 included_files.push((
1576 active_buffer.clone(),
1577 active_snapshot.clone(),
1578 excerpt_path.clone(),
1579 vec![excerpt_anchor_range],
1580 ));
1581 }
1582
1583 let included_files = included_files
1584 .iter()
1585 .map(|(_, snapshot, path, ranges)| {
1586 let ranges = ranges
1587 .iter()
1588 .map(|range| {
1589 let point_range = range.to_point(&snapshot);
1590 Line(point_range.start.row)..Line(point_range.end.row)
1591 })
1592 .collect::<Vec<_>>();
1593 let excerpts = assemble_excerpts(&snapshot, ranges);
1594 predict_edits_v3::IncludedFile {
1595 path: path.clone(),
1596 max_row: Line(snapshot.max_point().row),
1597 excerpts,
1598 }
1599 })
1600 .collect::<Vec<_>>();
1601
1602 predict_edits_v3::PredictEditsRequest {
1603 excerpt_path,
1604 excerpt: String::new(),
1605 excerpt_line_range: Line(0)..Line(0),
1606 excerpt_range: 0..0,
1607 cursor_point: predict_edits_v3::Point {
1608 line: predict_edits_v3::Line(cursor_point.row),
1609 column: cursor_point.column,
1610 },
1611 included_files,
1612 referenced_declarations: vec![],
1613 events,
1614 can_collect_data,
1615 diagnostic_groups,
1616 diagnostic_groups_truncated,
1617 debug_info: debug_tx.is_some(),
1618 prompt_max_bytes: Some(options.max_prompt_bytes),
1619 prompt_format: options.prompt_format,
1620 // TODO [zeta2]
1621 signatures: vec![],
1622 excerpt_parent: None,
1623 git_info: None,
1624 }
1625 }
1626 ContextMode::Syntax(context_options) => {
1627 let Some(context) = EditPredictionContext::gather_context(
1628 cursor_point,
1629 &active_snapshot,
1630 parent_abs_path.as_deref(),
1631 &context_options,
1632 index_state.as_deref(),
1633 ) else {
1634 return Ok((None, None));
1635 };
1636
1637 make_syntax_context_cloud_request(
1638 excerpt_path,
1639 context,
1640 events,
1641 can_collect_data,
1642 diagnostic_groups,
1643 diagnostic_groups_truncated,
1644 None,
1645 debug_tx.is_some(),
1646 &worktree_snapshots,
1647 index_state.as_deref(),
1648 Some(options.max_prompt_bytes),
1649 options.prompt_format,
1650 )
1651 }
1652 };
1653
1654 let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request);
1655
1656 let inputs = EditPredictionInputs {
1657 included_files: cloud_request.included_files,
1658 events: cloud_request.events,
1659 cursor_point: cloud_request.cursor_point,
1660 cursor_path: cloud_request.excerpt_path,
1661 };
1662
1663 let retrieval_time = Instant::now() - before_retrieval;
1664
1665 let debug_response_tx = if let Some(debug_tx) = &debug_tx {
1666 let (response_tx, response_rx) = oneshot::channel();
1667
1668 debug_tx
1669 .unbounded_send(ZetaDebugInfo::EditPredictionRequested(
1670 ZetaEditPredictionDebugInfo {
1671 inputs: inputs.clone(),
1672 retrieval_time,
1673 buffer: active_buffer.downgrade(),
1674 local_prompt: match prompt_result.as_ref() {
1675 Ok((prompt, _)) => Ok(prompt.clone()),
1676 Err(err) => Err(err.to_string()),
1677 },
1678 position,
1679 response_rx,
1680 },
1681 ))
1682 .ok();
1683 Some(response_tx)
1684 } else {
1685 None
1686 };
1687
1688 if cfg!(debug_assertions) && env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
1689 if let Some(debug_response_tx) = debug_response_tx {
1690 debug_response_tx
1691 .send((Err("Request skipped".to_string()), Duration::ZERO))
1692 .ok();
1693 }
1694 anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
1695 }
1696
1697 let (prompt, _) = prompt_result?;
1698 let generation_params =
1699 cloud_zeta2_prompt::generation_params(cloud_request.prompt_format);
1700 let request = open_ai::Request {
1701 model: EDIT_PREDICTIONS_MODEL_ID.clone(),
1702 messages: vec![open_ai::RequestMessage::User {
1703 content: open_ai::MessageContent::Plain(prompt),
1704 }],
1705 stream: false,
1706 max_completion_tokens: None,
1707 stop: generation_params.stop.unwrap_or_default(),
1708 temperature: generation_params.temperature.unwrap_or(0.7),
1709 tool_choice: None,
1710 parallel_tool_calls: None,
1711 tools: vec![],
1712 prompt_cache_key: None,
1713 reasoning_effort: None,
1714 };
1715
1716 log::trace!("Sending edit prediction request");
1717
1718 let before_request = Instant::now();
1719 let response = Self::send_raw_llm_request(
1720 request,
1721 client,
1722 llm_token,
1723 app_version,
1724 #[cfg(feature = "eval-support")]
1725 eval_cache,
1726 #[cfg(feature = "eval-support")]
1727 EvalCacheEntryKind::Prediction,
1728 )
1729 .await;
1730 let received_response_at = Instant::now();
1731 let request_time = received_response_at - before_request;
1732
1733 log::trace!("Got edit prediction response");
1734
1735 if let Some(debug_response_tx) = debug_response_tx {
1736 debug_response_tx
1737 .send((
1738 response
1739 .as_ref()
1740 .map_err(|err| err.to_string())
1741 .map(|response| response.0.clone()),
1742 request_time,
1743 ))
1744 .ok();
1745 }
1746
1747 let (res, usage) = response?;
1748 let request_id = EditPredictionId(res.id.clone().into());
1749 let Some(mut output_text) = text_from_response(res) else {
1750 return Ok((Some((request_id, None)), usage));
1751 };
1752
1753 if output_text.contains(CURSOR_MARKER) {
1754 log::trace!("Stripping out {CURSOR_MARKER} from response");
1755 output_text = output_text.replace(CURSOR_MARKER, "");
1756 }
1757
1758 let get_buffer_from_context = |path: &Path| {
1759 included_files
1760 .iter()
1761 .find_map(|(_, buffer, probe_path, ranges)| {
1762 if probe_path.as_ref() == path {
1763 Some((buffer, ranges.as_slice()))
1764 } else {
1765 None
1766 }
1767 })
1768 };
1769
1770 let (edited_buffer_snapshot, edits) = match options.prompt_format {
1771 PromptFormat::NumLinesUniDiff => {
1772 // TODO: Implement parsing of multi-file diffs
1773 crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
1774 }
1775 PromptFormat::Minimal
1776 | PromptFormat::MinimalQwen
1777 | PromptFormat::SeedCoder1120 => {
1778 if output_text.contains("--- a/\n+++ b/\nNo edits") {
1779 let edits = vec![];
1780 (&active_snapshot, edits)
1781 } else {
1782 crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
1783 }
1784 }
1785 PromptFormat::OldTextNewText => {
1786 crate::xml_edits::parse_xml_edits(&output_text, get_buffer_from_context)
1787 .await?
1788 }
1789 _ => {
1790 bail!("unsupported prompt format {}", options.prompt_format)
1791 }
1792 };
1793
1794 let edited_buffer = included_files
1795 .iter()
1796 .find_map(|(buffer, snapshot, _, _)| {
1797 if snapshot.remote_id() == edited_buffer_snapshot.remote_id() {
1798 Some(buffer.clone())
1799 } else {
1800 None
1801 }
1802 })
1803 .context("Failed to find buffer in included_buffers")?;
1804
1805 anyhow::Ok((
1806 Some((
1807 request_id,
1808 Some((
1809 inputs,
1810 edited_buffer,
1811 edited_buffer_snapshot.clone(),
1812 edits,
1813 received_response_at,
1814 )),
1815 )),
1816 usage,
1817 ))
1818 }
1819 });
1820
1821 cx.spawn({
1822 async move |this, cx| {
1823 let Some((id, prediction)) =
1824 Self::handle_api_response(&this, request_task.await, cx)?
1825 else {
1826 return Ok(None);
1827 };
1828
1829 let Some((
1830 inputs,
1831 edited_buffer,
1832 edited_buffer_snapshot,
1833 edits,
1834 received_response_at,
1835 )) = prediction
1836 else {
1837 return Ok(Some(EditPredictionResult {
1838 id,
1839 prediction: Err(EditPredictionRejectReason::Empty),
1840 }));
1841 };
1842
1843 // TODO telemetry: duration, etc
1844 Ok(Some(
1845 EditPredictionResult::new(
1846 id,
1847 &edited_buffer,
1848 &edited_buffer_snapshot,
1849 edits.into(),
1850 buffer_snapshotted_at,
1851 received_response_at,
1852 inputs,
1853 cx,
1854 )
1855 .await,
1856 ))
1857 }
1858 })
1859 }
1860
1861 async fn send_raw_llm_request(
1862 request: open_ai::Request,
1863 client: Arc<Client>,
1864 llm_token: LlmApiToken,
1865 app_version: Version,
1866 #[cfg(feature = "eval-support")] eval_cache: Option<Arc<dyn EvalCache>>,
1867 #[cfg(feature = "eval-support")] eval_cache_kind: EvalCacheEntryKind,
1868 ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
1869 let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
1870 http_client::Url::parse(&predict_edits_url)?
1871 } else {
1872 client
1873 .http_client()
1874 .build_zed_llm_url("/predict_edits/raw", &[])?
1875 };
1876
1877 #[cfg(feature = "eval-support")]
1878 let cache_key = if let Some(cache) = eval_cache {
1879 use collections::FxHasher;
1880 use std::hash::{Hash, Hasher};
1881
1882 let mut hasher = FxHasher::default();
1883 url.hash(&mut hasher);
1884 let request_str = serde_json::to_string_pretty(&request)?;
1885 request_str.hash(&mut hasher);
1886 let hash = hasher.finish();
1887
1888 let key = (eval_cache_kind, hash);
1889 if let Some(response_str) = cache.read(key) {
1890 return Ok((serde_json::from_str(&response_str)?, None));
1891 }
1892
1893 Some((cache, request_str, key))
1894 } else {
1895 None
1896 };
1897
1898 let (response, usage) = Self::send_api_request(
1899 |builder| {
1900 let req = builder
1901 .uri(url.as_ref())
1902 .body(serde_json::to_string(&request)?.into());
1903 Ok(req?)
1904 },
1905 client,
1906 llm_token,
1907 app_version,
1908 )
1909 .await?;
1910
1911 #[cfg(feature = "eval-support")]
1912 if let Some((cache, request, key)) = cache_key {
1913 cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
1914 }
1915
1916 Ok((response, usage))
1917 }
1918
1919 fn handle_api_response<T>(
1920 this: &WeakEntity<Self>,
1921 response: Result<(T, Option<EditPredictionUsage>)>,
1922 cx: &mut gpui::AsyncApp,
1923 ) -> Result<T> {
1924 match response {
1925 Ok((data, usage)) => {
1926 if let Some(usage) = usage {
1927 this.update(cx, |this, cx| {
1928 this.user_store.update(cx, |user_store, cx| {
1929 user_store.update_edit_prediction_usage(usage, cx);
1930 });
1931 })
1932 .ok();
1933 }
1934 Ok(data)
1935 }
1936 Err(err) => {
1937 if err.is::<ZedUpdateRequiredError>() {
1938 cx.update(|cx| {
1939 this.update(cx, |this, _cx| {
1940 this.update_required = true;
1941 })
1942 .ok();
1943
1944 let error_message: SharedString = err.to_string().into();
1945 show_app_notification(
1946 NotificationId::unique::<ZedUpdateRequiredError>(),
1947 cx,
1948 move |cx| {
1949 cx.new(|cx| {
1950 ErrorMessagePrompt::new(error_message.clone(), cx)
1951 .with_link_button("Update Zed", "https://zed.dev/releases")
1952 })
1953 },
1954 );
1955 })
1956 .ok();
1957 }
1958 Err(err)
1959 }
1960 }
1961 }
1962
1963 async fn send_api_request<Res>(
1964 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1965 client: Arc<Client>,
1966 llm_token: LlmApiToken,
1967 app_version: Version,
1968 ) -> Result<(Res, Option<EditPredictionUsage>)>
1969 where
1970 Res: DeserializeOwned,
1971 {
1972 let http_client = client.http_client();
1973 let mut token = llm_token.acquire(&client).await?;
1974 let mut did_retry = false;
1975
1976 loop {
1977 let request_builder = http_client::Request::builder().method(Method::POST);
1978
1979 let request = build(
1980 request_builder
1981 .header("Content-Type", "application/json")
1982 .header("Authorization", format!("Bearer {}", token))
1983 .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
1984 )?;
1985
1986 let mut response = http_client.send(request).await?;
1987
1988 if let Some(minimum_required_version) = response
1989 .headers()
1990 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
1991 .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
1992 {
1993 anyhow::ensure!(
1994 app_version >= minimum_required_version,
1995 ZedUpdateRequiredError {
1996 minimum_version: minimum_required_version
1997 }
1998 );
1999 }
2000
2001 if response.status().is_success() {
2002 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2003
2004 let mut body = Vec::new();
2005 response.body_mut().read_to_end(&mut body).await?;
2006 return Ok((serde_json::from_slice(&body)?, usage));
2007 } else if !did_retry
2008 && response
2009 .headers()
2010 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
2011 .is_some()
2012 {
2013 did_retry = true;
2014 token = llm_token.refresh(&client).await?;
2015 } else {
2016 let mut body = String::new();
2017 response.body_mut().read_to_string(&mut body).await?;
2018 anyhow::bail!(
2019 "Request failed with status: {:?}\nBody: {}",
2020 response.status(),
2021 body
2022 );
2023 }
2024 }
2025 }
2026
2027 pub const CONTEXT_RETRIEVAL_IDLE_DURATION: Duration = Duration::from_secs(10);
2028 pub const CONTEXT_RETRIEVAL_DEBOUNCE_DURATION: Duration = Duration::from_secs(3);
2029
2030 // Refresh the related excerpts when the user just beguns editing after
2031 // an idle period, and after they pause editing.
2032 fn refresh_context_if_needed(
2033 &mut self,
2034 project: &Entity<Project>,
2035 buffer: &Entity<language::Buffer>,
2036 cursor_position: language::Anchor,
2037 cx: &mut Context<Self>,
2038 ) {
2039 if !matches!(&self.options().context, ContextMode::Agentic { .. }) {
2040 return;
2041 }
2042
2043 let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
2044 return;
2045 };
2046
2047 let now = Instant::now();
2048 let was_idle = zeta_project
2049 .refresh_context_timestamp
2050 .map_or(true, |timestamp| {
2051 now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION
2052 });
2053 zeta_project.refresh_context_timestamp = Some(now);
2054 zeta_project.refresh_context_debounce_task = Some(cx.spawn({
2055 let buffer = buffer.clone();
2056 let project = project.clone();
2057 async move |this, cx| {
2058 if was_idle {
2059 log::debug!("refetching edit prediction context after idle");
2060 } else {
2061 cx.background_executor()
2062 .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION)
2063 .await;
2064 log::debug!("refetching edit prediction context after pause");
2065 }
2066 this.update(cx, |this, cx| {
2067 let task = this.refresh_context(project.clone(), buffer, cursor_position, cx);
2068
2069 if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
2070 zeta_project.refresh_context_task = Some(task.log_err());
2071 };
2072 })
2073 .ok()
2074 }
2075 }));
2076 }
2077
2078 // Refresh the related excerpts asynchronously. Ensure the task runs to completion,
2079 // and avoid spawning more than one concurrent task.
2080 pub fn refresh_context(
2081 &mut self,
2082 project: Entity<Project>,
2083 buffer: Entity<language::Buffer>,
2084 cursor_position: language::Anchor,
2085 cx: &mut Context<Self>,
2086 ) -> Task<Result<()>> {
2087 let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
2088 return Task::ready(anyhow::Ok(()));
2089 };
2090
2091 let ContextMode::Agentic(options) = &self.options().context else {
2092 return Task::ready(anyhow::Ok(()));
2093 };
2094
2095 let snapshot = buffer.read(cx).snapshot();
2096 let cursor_point = cursor_position.to_point(&snapshot);
2097 let Some(cursor_excerpt) = EditPredictionExcerpt::select_from_buffer(
2098 cursor_point,
2099 &snapshot,
2100 &options.excerpt,
2101 None,
2102 ) else {
2103 return Task::ready(Ok(()));
2104 };
2105
2106 let app_version = AppVersion::global(cx);
2107 let client = self.client.clone();
2108 let llm_token = self.llm_token.clone();
2109 let debug_tx = self.debug_tx.clone();
2110 let current_file_path: Arc<Path> = snapshot
2111 .file()
2112 .map(|f| f.full_path(cx).into())
2113 .unwrap_or_else(|| Path::new("untitled").into());
2114
2115 let prompt = match cloud_zeta2_prompt::retrieval_prompt::build_prompt(
2116 predict_edits_v3::PlanContextRetrievalRequest {
2117 excerpt: cursor_excerpt.text(&snapshot).body,
2118 excerpt_path: current_file_path,
2119 excerpt_line_range: cursor_excerpt.line_range,
2120 cursor_file_max_row: Line(snapshot.max_point().row),
2121 events: zeta_project.events(cx),
2122 },
2123 ) {
2124 Ok(prompt) => prompt,
2125 Err(err) => {
2126 return Task::ready(Err(err));
2127 }
2128 };
2129
2130 if let Some(debug_tx) = &debug_tx {
2131 debug_tx
2132 .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted(
2133 ZetaContextRetrievalStartedDebugInfo {
2134 project: project.clone(),
2135 timestamp: Instant::now(),
2136 search_prompt: prompt.clone(),
2137 },
2138 ))
2139 .ok();
2140 }
2141
2142 pub static TOOL_SCHEMA: LazyLock<(serde_json::Value, String)> = LazyLock::new(|| {
2143 let schema = language_model::tool_schema::root_schema_for::<SearchToolInput>(
2144 language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset,
2145 );
2146
2147 let description = schema
2148 .get("description")
2149 .and_then(|description| description.as_str())
2150 .unwrap()
2151 .to_string();
2152
2153 (schema.into(), description)
2154 });
2155
2156 let (tool_schema, tool_description) = TOOL_SCHEMA.clone();
2157
2158 let request = open_ai::Request {
2159 model: CONTEXT_RETRIEVAL_MODEL_ID.clone(),
2160 messages: vec![open_ai::RequestMessage::User {
2161 content: open_ai::MessageContent::Plain(prompt),
2162 }],
2163 stream: false,
2164 max_completion_tokens: None,
2165 stop: Default::default(),
2166 temperature: 0.7,
2167 tool_choice: None,
2168 parallel_tool_calls: None,
2169 tools: vec![open_ai::ToolDefinition::Function {
2170 function: FunctionDefinition {
2171 name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME.to_string(),
2172 description: Some(tool_description),
2173 parameters: Some(tool_schema),
2174 },
2175 }],
2176 prompt_cache_key: None,
2177 reasoning_effort: None,
2178 };
2179
2180 #[cfg(feature = "eval-support")]
2181 let eval_cache = self.eval_cache.clone();
2182
2183 cx.spawn(async move |this, cx| {
2184 log::trace!("Sending search planning request");
2185 let response = Self::send_raw_llm_request(
2186 request,
2187 client,
2188 llm_token,
2189 app_version,
2190 #[cfg(feature = "eval-support")]
2191 eval_cache.clone(),
2192 #[cfg(feature = "eval-support")]
2193 EvalCacheEntryKind::Context,
2194 )
2195 .await;
2196 let mut response = Self::handle_api_response(&this, response, cx)?;
2197 log::trace!("Got search planning response");
2198
2199 let choice = response
2200 .choices
2201 .pop()
2202 .context("No choices in retrieval response")?;
2203 let open_ai::RequestMessage::Assistant {
2204 content: _,
2205 tool_calls,
2206 } = choice.message
2207 else {
2208 anyhow::bail!("Retrieval response didn't include an assistant message");
2209 };
2210
2211 let mut queries: Vec<SearchToolQuery> = Vec::new();
2212 for tool_call in tool_calls {
2213 let open_ai::ToolCallContent::Function { function } = tool_call.content;
2214 if function.name != cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME {
2215 log::warn!(
2216 "Context retrieval response tried to call an unknown tool: {}",
2217 function.name
2218 );
2219
2220 continue;
2221 }
2222
2223 let input: SearchToolInput = serde_json::from_str(&function.arguments)
2224 .with_context(|| format!("invalid search json {}", &function.arguments))?;
2225 queries.extend(input.queries);
2226 }
2227
2228 if let Some(debug_tx) = &debug_tx {
2229 debug_tx
2230 .unbounded_send(ZetaDebugInfo::SearchQueriesGenerated(
2231 ZetaSearchQueryDebugInfo {
2232 project: project.clone(),
2233 timestamp: Instant::now(),
2234 search_queries: queries.clone(),
2235 },
2236 ))
2237 .ok();
2238 }
2239
2240 log::trace!("Running retrieval search: {queries:#?}");
2241
2242 let related_excerpts_result = retrieval_search::run_retrieval_searches(
2243 queries,
2244 project.clone(),
2245 #[cfg(feature = "eval-support")]
2246 eval_cache,
2247 cx,
2248 )
2249 .await;
2250
2251 log::trace!("Search queries executed");
2252
2253 if let Some(debug_tx) = &debug_tx {
2254 debug_tx
2255 .unbounded_send(ZetaDebugInfo::SearchQueriesExecuted(
2256 ZetaContextRetrievalDebugInfo {
2257 project: project.clone(),
2258 timestamp: Instant::now(),
2259 },
2260 ))
2261 .ok();
2262 }
2263
2264 this.update(cx, |this, _cx| {
2265 let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
2266 return Ok(());
2267 };
2268 zeta_project.refresh_context_task.take();
2269 if let Some(debug_tx) = &this.debug_tx {
2270 debug_tx
2271 .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished(
2272 ZetaContextRetrievalDebugInfo {
2273 project,
2274 timestamp: Instant::now(),
2275 },
2276 ))
2277 .ok();
2278 }
2279 match related_excerpts_result {
2280 Ok(excerpts) => {
2281 zeta_project.context = Some(excerpts);
2282 Ok(())
2283 }
2284 Err(error) => Err(error),
2285 }
2286 })?
2287 })
2288 }
2289
2290 pub fn set_context(
2291 &mut self,
2292 project: Entity<Project>,
2293 context: HashMap<Entity<Buffer>, Vec<Range<Anchor>>>,
2294 ) {
2295 if let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) {
2296 zeta_project.context = Some(context);
2297 }
2298 }
2299
2300 fn gather_nearby_diagnostics(
2301 cursor_offset: usize,
2302 diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
2303 snapshot: &BufferSnapshot,
2304 max_diagnostics_bytes: usize,
2305 ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
2306 // TODO: Could make this more efficient
2307 let mut diagnostic_groups = Vec::new();
2308 for (language_server_id, diagnostics) in diagnostic_sets {
2309 let mut groups = Vec::new();
2310 diagnostics.groups(*language_server_id, &mut groups, &snapshot);
2311 diagnostic_groups.extend(
2312 groups
2313 .into_iter()
2314 .map(|(_, group)| group.resolve::<usize>(&snapshot)),
2315 );
2316 }
2317
2318 // sort by proximity to cursor
2319 diagnostic_groups.sort_by_key(|group| {
2320 let range = &group.entries[group.primary_ix].range;
2321 if range.start >= cursor_offset {
2322 range.start - cursor_offset
2323 } else if cursor_offset >= range.end {
2324 cursor_offset - range.end
2325 } else {
2326 (cursor_offset - range.start).min(range.end - cursor_offset)
2327 }
2328 });
2329
2330 let mut results = Vec::new();
2331 let mut diagnostic_groups_truncated = false;
2332 let mut diagnostics_byte_count = 0;
2333 for group in diagnostic_groups {
2334 let raw_value = serde_json::value::to_raw_value(&group).unwrap();
2335 diagnostics_byte_count += raw_value.get().len();
2336 if diagnostics_byte_count > max_diagnostics_bytes {
2337 diagnostic_groups_truncated = true;
2338 break;
2339 }
2340 results.push(predict_edits_v3::DiagnosticGroup(raw_value));
2341 }
2342
2343 (results, diagnostic_groups_truncated)
2344 }
2345
2346 // TODO: Dedupe with similar code in request_prediction?
2347 pub fn cloud_request_for_zeta_cli(
2348 &mut self,
2349 project: &Entity<Project>,
2350 buffer: &Entity<Buffer>,
2351 position: language::Anchor,
2352 cx: &mut Context<Self>,
2353 ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
2354 let project_state = self.projects.get(&project.entity_id());
2355
2356 let index_state = project_state.and_then(|state| {
2357 state
2358 .syntax_index
2359 .as_ref()
2360 .map(|index| index.read_with(cx, |index, _cx| index.state().clone()))
2361 });
2362 let options = self.options.clone();
2363 let snapshot = buffer.read(cx).snapshot();
2364 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
2365 return Task::ready(Err(anyhow!("No file path for excerpt")));
2366 };
2367 let worktree_snapshots = project
2368 .read(cx)
2369 .worktrees(cx)
2370 .map(|worktree| worktree.read(cx).snapshot())
2371 .collect::<Vec<_>>();
2372
2373 let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
2374 let mut path = f.worktree.read(cx).absolutize(&f.path);
2375 if path.pop() { Some(path) } else { None }
2376 });
2377
2378 cx.background_spawn(async move {
2379 let index_state = if let Some(index_state) = index_state {
2380 Some(index_state.lock_owned().await)
2381 } else {
2382 None
2383 };
2384
2385 let cursor_point = position.to_point(&snapshot);
2386
2387 let debug_info = true;
2388 EditPredictionContext::gather_context(
2389 cursor_point,
2390 &snapshot,
2391 parent_abs_path.as_deref(),
2392 match &options.context {
2393 ContextMode::Agentic(_) => {
2394 // TODO
2395 panic!("Llm mode not supported in zeta cli yet");
2396 }
2397 ContextMode::Syntax(edit_prediction_context_options) => {
2398 edit_prediction_context_options
2399 }
2400 },
2401 index_state.as_deref(),
2402 )
2403 .context("Failed to select excerpt")
2404 .map(|context| {
2405 make_syntax_context_cloud_request(
2406 excerpt_path.into(),
2407 context,
2408 // TODO pass everything
2409 Vec::new(),
2410 false,
2411 Vec::new(),
2412 false,
2413 None,
2414 debug_info,
2415 &worktree_snapshots,
2416 index_state.as_deref(),
2417 Some(options.max_prompt_bytes),
2418 options.prompt_format,
2419 )
2420 })
2421 })
2422 }
2423
2424 pub fn wait_for_initial_indexing(
2425 &mut self,
2426 project: &Entity<Project>,
2427 cx: &mut Context<Self>,
2428 ) -> Task<Result<()>> {
2429 let zeta_project = self.get_or_init_zeta_project(project, cx);
2430 if let Some(syntax_index) = &zeta_project.syntax_index {
2431 syntax_index.read(cx).wait_for_initial_file_indexing(cx)
2432 } else {
2433 Task::ready(Ok(()))
2434 }
2435 }
2436
2437 fn is_file_open_source(
2438 &self,
2439 project: &Entity<Project>,
2440 file: &Arc<dyn File>,
2441 cx: &App,
2442 ) -> bool {
2443 if !file.is_local() || file.is_private() {
2444 return false;
2445 }
2446 let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
2447 return false;
2448 };
2449 zeta_project
2450 .license_detection_watchers
2451 .get(&file.worktree_id(cx))
2452 .as_ref()
2453 .is_some_and(|watcher| watcher.is_project_open_source())
2454 }
2455
2456 fn can_collect_file(&self, project: &Entity<Project>, file: &Arc<dyn File>, cx: &App) -> bool {
2457 self.data_collection_choice.is_enabled() && self.is_file_open_source(project, file, cx)
2458 }
2459
2460 fn can_collect_events(&self, events: &[Arc<Event>]) -> bool {
2461 if !self.data_collection_choice.is_enabled() {
2462 return false;
2463 }
2464 events.iter().all(|event| {
2465 matches!(
2466 event.as_ref(),
2467 Event::BufferChange {
2468 in_open_source_repo: true,
2469 ..
2470 }
2471 )
2472 })
2473 }
2474
2475 fn load_data_collection_choice() -> DataCollectionChoice {
2476 let choice = KEY_VALUE_STORE
2477 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2478 .log_err()
2479 .flatten();
2480
2481 match choice.as_deref() {
2482 Some("true") => DataCollectionChoice::Enabled,
2483 Some("false") => DataCollectionChoice::Disabled,
2484 Some(_) => {
2485 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2486 DataCollectionChoice::NotAnswered
2487 }
2488 None => DataCollectionChoice::NotAnswered,
2489 }
2490 }
2491
2492 pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2493 self.shown_predictions.iter()
2494 }
2495
2496 pub fn shown_completions_len(&self) -> usize {
2497 self.shown_predictions.len()
2498 }
2499
2500 pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2501 self.rated_predictions.contains(id)
2502 }
2503
2504 pub fn rate_prediction(
2505 &mut self,
2506 prediction: &EditPrediction,
2507 rating: EditPredictionRating,
2508 feedback: String,
2509 cx: &mut Context<Self>,
2510 ) {
2511 self.rated_predictions.insert(prediction.id.clone());
2512 telemetry::event!(
2513 "Edit Prediction Rated",
2514 rating,
2515 inputs = prediction.inputs,
2516 output = prediction.edit_preview.as_unified_diff(&prediction.edits),
2517 feedback
2518 );
2519 self.client.telemetry().flush_events().detach();
2520 cx.notify();
2521 }
2522}
2523
2524pub fn text_from_response(mut res: open_ai::Response) -> Option<String> {
2525 let choice = res.choices.pop()?;
2526 let output_text = match choice.message {
2527 open_ai::RequestMessage::Assistant {
2528 content: Some(open_ai::MessageContent::Plain(content)),
2529 ..
2530 } => content,
2531 open_ai::RequestMessage::Assistant {
2532 content: Some(open_ai::MessageContent::Multipart(mut content)),
2533 ..
2534 } => {
2535 if content.is_empty() {
2536 log::error!("No output from Baseten completion response");
2537 return None;
2538 }
2539
2540 match content.remove(0) {
2541 open_ai::MessagePart::Text { text } => text,
2542 open_ai::MessagePart::Image { .. } => {
2543 log::error!("Expected text, got an image");
2544 return None;
2545 }
2546 }
2547 }
2548 _ => {
2549 log::error!("Invalid response message: {:?}", choice.message);
2550 return None;
2551 }
2552 };
2553 Some(output_text)
2554}
2555
2556#[derive(Error, Debug)]
2557#[error(
2558 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2559)]
2560pub struct ZedUpdateRequiredError {
2561 minimum_version: Version,
2562}
2563
2564fn make_syntax_context_cloud_request(
2565 excerpt_path: Arc<Path>,
2566 context: EditPredictionContext,
2567 events: Vec<Arc<predict_edits_v3::Event>>,
2568 can_collect_data: bool,
2569 diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
2570 diagnostic_groups_truncated: bool,
2571 git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
2572 debug_info: bool,
2573 worktrees: &Vec<worktree::Snapshot>,
2574 index_state: Option<&SyntaxIndexState>,
2575 prompt_max_bytes: Option<usize>,
2576 prompt_format: PromptFormat,
2577) -> predict_edits_v3::PredictEditsRequest {
2578 let mut signatures = Vec::new();
2579 let mut declaration_to_signature_index = HashMap::default();
2580 let mut referenced_declarations = Vec::new();
2581
2582 for snippet in context.declarations {
2583 let project_entry_id = snippet.declaration.project_entry_id();
2584 let Some(path) = worktrees.iter().find_map(|worktree| {
2585 worktree.entry_for_id(project_entry_id).map(|entry| {
2586 let mut full_path = RelPathBuf::new();
2587 full_path.push(worktree.root_name());
2588 full_path.push(&entry.path);
2589 full_path
2590 })
2591 }) else {
2592 continue;
2593 };
2594
2595 let parent_index = index_state.and_then(|index_state| {
2596 snippet.declaration.parent().and_then(|parent| {
2597 add_signature(
2598 parent,
2599 &mut declaration_to_signature_index,
2600 &mut signatures,
2601 index_state,
2602 )
2603 })
2604 });
2605
2606 let (text, text_is_truncated) = snippet.declaration.item_text();
2607 referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
2608 path: path.as_std_path().into(),
2609 text: text.into(),
2610 range: snippet.declaration.item_line_range(),
2611 text_is_truncated,
2612 signature_range: snippet.declaration.signature_range_in_item_text(),
2613 parent_index,
2614 signature_score: snippet.score(DeclarationStyle::Signature),
2615 declaration_score: snippet.score(DeclarationStyle::Declaration),
2616 score_components: snippet.components,
2617 });
2618 }
2619
2620 let excerpt_parent = index_state.and_then(|index_state| {
2621 context
2622 .excerpt
2623 .parent_declarations
2624 .last()
2625 .and_then(|(parent, _)| {
2626 add_signature(
2627 *parent,
2628 &mut declaration_to_signature_index,
2629 &mut signatures,
2630 index_state,
2631 )
2632 })
2633 });
2634
2635 predict_edits_v3::PredictEditsRequest {
2636 excerpt_path,
2637 excerpt: context.excerpt_text.body,
2638 excerpt_line_range: context.excerpt.line_range,
2639 excerpt_range: context.excerpt.range,
2640 cursor_point: predict_edits_v3::Point {
2641 line: predict_edits_v3::Line(context.cursor_point.row),
2642 column: context.cursor_point.column,
2643 },
2644 referenced_declarations,
2645 included_files: vec![],
2646 signatures,
2647 excerpt_parent,
2648 events,
2649 can_collect_data,
2650 diagnostic_groups,
2651 diagnostic_groups_truncated,
2652 git_info,
2653 debug_info,
2654 prompt_max_bytes,
2655 prompt_format,
2656 }
2657}
2658
2659fn add_signature(
2660 declaration_id: DeclarationId,
2661 declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
2662 signatures: &mut Vec<Signature>,
2663 index: &SyntaxIndexState,
2664) -> Option<usize> {
2665 if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
2666 return Some(*signature_index);
2667 }
2668 let Some(parent_declaration) = index.declaration(declaration_id) else {
2669 log::error!("bug: missing parent declaration");
2670 return None;
2671 };
2672 let parent_index = parent_declaration.parent().and_then(|parent| {
2673 add_signature(parent, declaration_to_signature_index, signatures, index)
2674 });
2675 let (text, text_is_truncated) = parent_declaration.signature_text();
2676 let signature_index = signatures.len();
2677 signatures.push(Signature {
2678 text: text.into(),
2679 text_is_truncated,
2680 parent_index,
2681 range: parent_declaration.signature_line_range(),
2682 });
2683 declaration_to_signature_index.insert(declaration_id, signature_index);
2684 Some(signature_index)
2685}
2686
2687#[cfg(feature = "eval-support")]
2688pub type EvalCacheKey = (EvalCacheEntryKind, u64);
2689
2690#[cfg(feature = "eval-support")]
2691#[derive(Debug, Clone, Copy, PartialEq)]
2692pub enum EvalCacheEntryKind {
2693 Context,
2694 Search,
2695 Prediction,
2696}
2697
2698#[cfg(feature = "eval-support")]
2699impl std::fmt::Display for EvalCacheEntryKind {
2700 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2701 match self {
2702 EvalCacheEntryKind::Search => write!(f, "search"),
2703 EvalCacheEntryKind::Context => write!(f, "context"),
2704 EvalCacheEntryKind::Prediction => write!(f, "prediction"),
2705 }
2706 }
2707}
2708
2709#[cfg(feature = "eval-support")]
2710pub trait EvalCache: Send + Sync {
2711 fn read(&self, key: EvalCacheKey) -> Option<String>;
2712 fn write(&self, key: EvalCacheKey, input: &str, value: &str);
2713}
2714
2715#[derive(Debug, Clone, Copy)]
2716pub enum DataCollectionChoice {
2717 NotAnswered,
2718 Enabled,
2719 Disabled,
2720}
2721
2722impl DataCollectionChoice {
2723 pub fn is_enabled(self) -> bool {
2724 match self {
2725 Self::Enabled => true,
2726 Self::NotAnswered | Self::Disabled => false,
2727 }
2728 }
2729
2730 pub fn is_answered(self) -> bool {
2731 match self {
2732 Self::Enabled | Self::Disabled => true,
2733 Self::NotAnswered => false,
2734 }
2735 }
2736
2737 #[must_use]
2738 pub fn toggle(&self) -> DataCollectionChoice {
2739 match self {
2740 Self::Enabled => Self::Disabled,
2741 Self::Disabled => Self::Enabled,
2742 Self::NotAnswered => Self::Enabled,
2743 }
2744 }
2745}
2746
2747impl From<bool> for DataCollectionChoice {
2748 fn from(value: bool) -> Self {
2749 match value {
2750 true => DataCollectionChoice::Enabled,
2751 false => DataCollectionChoice::Disabled,
2752 }
2753 }
2754}
2755
2756struct ZedPredictUpsell;
2757
2758impl Dismissable for ZedPredictUpsell {
2759 const KEY: &'static str = "dismissed-edit-predict-upsell";
2760
2761 fn dismissed() -> bool {
2762 // To make this backwards compatible with older versions of Zed, we
2763 // check if the user has seen the previous Edit Prediction Onboarding
2764 // before, by checking the data collection choice which was written to
2765 // the database once the user clicked on "Accept and Enable"
2766 if KEY_VALUE_STORE
2767 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2768 .log_err()
2769 .is_some_and(|s| s.is_some())
2770 {
2771 return true;
2772 }
2773
2774 KEY_VALUE_STORE
2775 .read_kvp(Self::KEY)
2776 .log_err()
2777 .is_some_and(|s| s.is_some())
2778 }
2779}
2780
2781pub fn should_show_upsell_modal() -> bool {
2782 !ZedPredictUpsell::dismissed()
2783}
2784
2785pub fn init(cx: &mut App) {
2786 feature_gate_predict_edits_actions(cx);
2787
2788 cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2789 workspace.register_action(|workspace, _: &RateCompletions, window, cx| {
2790 if cx.has_flag::<PredictEditsRateCompletionsFeatureFlag>() {
2791 RatePredictionsModal::toggle(workspace, window, cx);
2792 }
2793 });
2794
2795 workspace.register_action(
2796 move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2797 ZedPredictModal::toggle(
2798 workspace,
2799 workspace.user_store().clone(),
2800 workspace.client().clone(),
2801 window,
2802 cx,
2803 )
2804 },
2805 );
2806
2807 workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2808 update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2809 settings
2810 .project
2811 .all_languages
2812 .features
2813 .get_or_insert_default()
2814 .edit_prediction_provider = Some(EditPredictionProvider::None)
2815 });
2816 });
2817 })
2818 .detach();
2819}
2820
2821fn feature_gate_predict_edits_actions(cx: &mut App) {
2822 let rate_completion_action_types = [TypeId::of::<RateCompletions>()];
2823 let reset_onboarding_action_types = [TypeId::of::<ResetOnboarding>()];
2824 let zeta_all_action_types = [
2825 TypeId::of::<RateCompletions>(),
2826 TypeId::of::<ResetOnboarding>(),
2827 zed_actions::OpenZedPredictOnboarding.type_id(),
2828 TypeId::of::<ClearHistory>(),
2829 TypeId::of::<ThumbsUpActivePrediction>(),
2830 TypeId::of::<ThumbsDownActivePrediction>(),
2831 TypeId::of::<NextEdit>(),
2832 TypeId::of::<PreviousEdit>(),
2833 ];
2834
2835 CommandPaletteFilter::update_global(cx, |filter, _cx| {
2836 filter.hide_action_types(&rate_completion_action_types);
2837 filter.hide_action_types(&reset_onboarding_action_types);
2838 filter.hide_action_types(&[zed_actions::OpenZedPredictOnboarding.type_id()]);
2839 });
2840
2841 cx.observe_global::<SettingsStore>(move |cx| {
2842 let is_ai_disabled = DisableAiSettings::get_global(cx).disable_ai;
2843 let has_feature_flag = cx.has_flag::<PredictEditsRateCompletionsFeatureFlag>();
2844
2845 CommandPaletteFilter::update_global(cx, |filter, _cx| {
2846 if is_ai_disabled {
2847 filter.hide_action_types(&zeta_all_action_types);
2848 } else if has_feature_flag {
2849 filter.show_action_types(&rate_completion_action_types);
2850 } else {
2851 filter.hide_action_types(&rate_completion_action_types);
2852 }
2853 });
2854 })
2855 .detach();
2856
2857 cx.observe_flag::<PredictEditsRateCompletionsFeatureFlag, _>(move |is_enabled, cx| {
2858 if !DisableAiSettings::get_global(cx).disable_ai {
2859 if is_enabled {
2860 CommandPaletteFilter::update_global(cx, |filter, _cx| {
2861 filter.show_action_types(&rate_completion_action_types);
2862 });
2863 } else {
2864 CommandPaletteFilter::update_global(cx, |filter, _cx| {
2865 filter.hide_action_types(&rate_completion_action_types);
2866 });
2867 }
2868 }
2869 })
2870 .detach();
2871}
2872
2873#[cfg(test)]
2874mod tests {
2875 use std::{path::Path, sync::Arc};
2876
2877 use client::UserStore;
2878 use clock::FakeSystemClock;
2879 use cloud_llm_client::{
2880 EditPredictionRejectReason, EditPredictionRejection, RejectEditPredictionsBody,
2881 };
2882 use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
2883 use futures::{
2884 AsyncReadExt, StreamExt,
2885 channel::{mpsc, oneshot},
2886 };
2887 use gpui::{
2888 Entity, TestAppContext,
2889 http_client::{FakeHttpClient, Response},
2890 prelude::*,
2891 };
2892 use indoc::indoc;
2893 use language::OffsetRangeExt as _;
2894 use open_ai::Usage;
2895 use pretty_assertions::{assert_eq, assert_matches};
2896 use project::{FakeFs, Project};
2897 use serde_json::json;
2898 use settings::SettingsStore;
2899 use util::path;
2900 use uuid::Uuid;
2901
2902 use crate::{BufferEditPrediction, Zeta};
2903
2904 #[gpui::test]
2905 async fn test_current_state(cx: &mut TestAppContext) {
2906 let (zeta, mut requests) = init_test(cx);
2907 let fs = FakeFs::new(cx.executor());
2908 fs.insert_tree(
2909 "/root",
2910 json!({
2911 "1.txt": "Hello!\nHow\nBye\n",
2912 "2.txt": "Hola!\nComo\nAdios\n"
2913 }),
2914 )
2915 .await;
2916 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2917
2918 zeta.update(cx, |zeta, cx| {
2919 zeta.register_project(&project, cx);
2920 });
2921
2922 let buffer1 = project
2923 .update(cx, |project, cx| {
2924 let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
2925 project.open_buffer(path, cx)
2926 })
2927 .await
2928 .unwrap();
2929 let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
2930 let position = snapshot1.anchor_before(language::Point::new(1, 3));
2931
2932 // Prediction for current file
2933
2934 zeta.update(cx, |zeta, cx| {
2935 zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
2936 });
2937 let (_request, respond_tx) = requests.predict.next().await.unwrap();
2938
2939 respond_tx
2940 .send(model_response(indoc! {r"
2941 --- a/root/1.txt
2942 +++ b/root/1.txt
2943 @@ ... @@
2944 Hello!
2945 -How
2946 +How are you?
2947 Bye
2948 "}))
2949 .unwrap();
2950
2951 cx.run_until_parked();
2952
2953 zeta.read_with(cx, |zeta, cx| {
2954 let prediction = zeta
2955 .current_prediction_for_buffer(&buffer1, &project, cx)
2956 .unwrap();
2957 assert_matches!(prediction, BufferEditPrediction::Local { .. });
2958 });
2959
2960 // Context refresh
2961 let refresh_task = zeta.update(cx, |zeta, cx| {
2962 zeta.refresh_context(project.clone(), buffer1.clone(), position, cx)
2963 });
2964 let (_request, respond_tx) = requests.predict.next().await.unwrap();
2965 respond_tx
2966 .send(open_ai::Response {
2967 id: Uuid::new_v4().to_string(),
2968 object: "response".into(),
2969 created: 0,
2970 model: "model".into(),
2971 choices: vec![open_ai::Choice {
2972 index: 0,
2973 message: open_ai::RequestMessage::Assistant {
2974 content: None,
2975 tool_calls: vec![open_ai::ToolCall {
2976 id: "search".into(),
2977 content: open_ai::ToolCallContent::Function {
2978 function: open_ai::FunctionContent {
2979 name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME
2980 .to_string(),
2981 arguments: serde_json::to_string(&SearchToolInput {
2982 queries: Box::new([SearchToolQuery {
2983 glob: "root/2.txt".to_string(),
2984 syntax_node: vec![],
2985 content: Some(".".into()),
2986 }]),
2987 })
2988 .unwrap(),
2989 },
2990 },
2991 }],
2992 },
2993 finish_reason: None,
2994 }],
2995 usage: Usage {
2996 prompt_tokens: 0,
2997 completion_tokens: 0,
2998 total_tokens: 0,
2999 },
3000 })
3001 .unwrap();
3002 refresh_task.await.unwrap();
3003
3004 zeta.update(cx, |zeta, cx| {
3005 zeta.reject_current_prediction(EditPredictionRejectReason::Discarded, &project, cx);
3006 });
3007
3008 // Prediction for another file
3009 zeta.update(cx, |zeta, cx| {
3010 zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
3011 });
3012 let (_request, respond_tx) = requests.predict.next().await.unwrap();
3013 respond_tx
3014 .send(model_response(indoc! {r#"
3015 --- a/root/2.txt
3016 +++ b/root/2.txt
3017 Hola!
3018 -Como
3019 +Como estas?
3020 Adios
3021 "#}))
3022 .unwrap();
3023 cx.run_until_parked();
3024
3025 zeta.read_with(cx, |zeta, cx| {
3026 let prediction = zeta
3027 .current_prediction_for_buffer(&buffer1, &project, cx)
3028 .unwrap();
3029 assert_matches!(
3030 prediction,
3031 BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
3032 );
3033 });
3034
3035 let buffer2 = project
3036 .update(cx, |project, cx| {
3037 let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
3038 project.open_buffer(path, cx)
3039 })
3040 .await
3041 .unwrap();
3042
3043 zeta.read_with(cx, |zeta, cx| {
3044 let prediction = zeta
3045 .current_prediction_for_buffer(&buffer2, &project, cx)
3046 .unwrap();
3047 assert_matches!(prediction, BufferEditPrediction::Local { .. });
3048 });
3049 }
3050
3051 #[gpui::test]
3052 async fn test_simple_request(cx: &mut TestAppContext) {
3053 let (zeta, mut requests) = init_test(cx);
3054 let fs = FakeFs::new(cx.executor());
3055 fs.insert_tree(
3056 "/root",
3057 json!({
3058 "foo.md": "Hello!\nHow\nBye\n"
3059 }),
3060 )
3061 .await;
3062 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3063
3064 let buffer = project
3065 .update(cx, |project, cx| {
3066 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3067 project.open_buffer(path, cx)
3068 })
3069 .await
3070 .unwrap();
3071 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3072 let position = snapshot.anchor_before(language::Point::new(1, 3));
3073
3074 let prediction_task = zeta.update(cx, |zeta, cx| {
3075 zeta.request_prediction(&project, &buffer, position, cx)
3076 });
3077
3078 let (_, respond_tx) = requests.predict.next().await.unwrap();
3079
3080 // TODO Put back when we have a structured request again
3081 // assert_eq!(
3082 // request.excerpt_path.as_ref(),
3083 // Path::new(path!("root/foo.md"))
3084 // );
3085 // assert_eq!(
3086 // request.cursor_point,
3087 // Point {
3088 // line: Line(1),
3089 // column: 3
3090 // }
3091 // );
3092
3093 respond_tx
3094 .send(model_response(indoc! { r"
3095 --- a/root/foo.md
3096 +++ b/root/foo.md
3097 @@ ... @@
3098 Hello!
3099 -How
3100 +How are you?
3101 Bye
3102 "}))
3103 .unwrap();
3104
3105 let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
3106
3107 assert_eq!(prediction.edits.len(), 1);
3108 assert_eq!(
3109 prediction.edits[0].0.to_point(&snapshot).start,
3110 language::Point::new(1, 3)
3111 );
3112 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
3113 }
3114
3115 #[gpui::test]
3116 async fn test_request_events(cx: &mut TestAppContext) {
3117 let (zeta, mut requests) = init_test(cx);
3118 let fs = FakeFs::new(cx.executor());
3119 fs.insert_tree(
3120 "/root",
3121 json!({
3122 "foo.md": "Hello!\n\nBye\n"
3123 }),
3124 )
3125 .await;
3126 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3127
3128 let buffer = project
3129 .update(cx, |project, cx| {
3130 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3131 project.open_buffer(path, cx)
3132 })
3133 .await
3134 .unwrap();
3135
3136 zeta.update(cx, |zeta, cx| {
3137 zeta.register_buffer(&buffer, &project, cx);
3138 });
3139
3140 buffer.update(cx, |buffer, cx| {
3141 buffer.edit(vec![(7..7, "How")], None, cx);
3142 });
3143
3144 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3145 let position = snapshot.anchor_before(language::Point::new(1, 3));
3146
3147 let prediction_task = zeta.update(cx, |zeta, cx| {
3148 zeta.request_prediction(&project, &buffer, position, cx)
3149 });
3150
3151 let (request, respond_tx) = requests.predict.next().await.unwrap();
3152
3153 let prompt = prompt_from_request(&request);
3154 assert!(
3155 prompt.contains(indoc! {"
3156 --- a/root/foo.md
3157 +++ b/root/foo.md
3158 @@ -1,3 +1,3 @@
3159 Hello!
3160 -
3161 +How
3162 Bye
3163 "}),
3164 "{prompt}"
3165 );
3166
3167 respond_tx
3168 .send(model_response(indoc! {r#"
3169 --- a/root/foo.md
3170 +++ b/root/foo.md
3171 @@ ... @@
3172 Hello!
3173 -How
3174 +How are you?
3175 Bye
3176 "#}))
3177 .unwrap();
3178
3179 let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
3180
3181 assert_eq!(prediction.edits.len(), 1);
3182 assert_eq!(
3183 prediction.edits[0].0.to_point(&snapshot).start,
3184 language::Point::new(1, 3)
3185 );
3186 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
3187 }
3188
3189 #[gpui::test]
3190 async fn test_empty_prediction(cx: &mut TestAppContext) {
3191 let (zeta, mut requests) = init_test(cx);
3192 let fs = FakeFs::new(cx.executor());
3193 fs.insert_tree(
3194 "/root",
3195 json!({
3196 "foo.md": "Hello!\nHow\nBye\n"
3197 }),
3198 )
3199 .await;
3200 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3201
3202 let buffer = project
3203 .update(cx, |project, cx| {
3204 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3205 project.open_buffer(path, cx)
3206 })
3207 .await
3208 .unwrap();
3209 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3210 let position = snapshot.anchor_before(language::Point::new(1, 3));
3211
3212 zeta.update(cx, |zeta, cx| {
3213 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3214 });
3215
3216 const NO_OP_DIFF: &str = indoc! { r"
3217 --- a/root/foo.md
3218 +++ b/root/foo.md
3219 @@ ... @@
3220 Hello!
3221 -How
3222 +How
3223 Bye
3224 "};
3225
3226 let (_, respond_tx) = requests.predict.next().await.unwrap();
3227 let response = model_response(NO_OP_DIFF);
3228 let id = response.id.clone();
3229 respond_tx.send(response).unwrap();
3230
3231 cx.run_until_parked();
3232
3233 zeta.read_with(cx, |zeta, cx| {
3234 assert!(
3235 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3236 .is_none()
3237 );
3238 });
3239
3240 // prediction is reported as rejected
3241 let (reject_request, _) = requests.reject.next().await.unwrap();
3242
3243 assert_eq!(
3244 &reject_request.rejections,
3245 &[EditPredictionRejection {
3246 request_id: id,
3247 reason: EditPredictionRejectReason::Empty,
3248 was_shown: false
3249 }]
3250 );
3251 }
3252
3253 #[gpui::test]
3254 async fn test_interpolated_empty(cx: &mut TestAppContext) {
3255 let (zeta, mut requests) = init_test(cx);
3256 let fs = FakeFs::new(cx.executor());
3257 fs.insert_tree(
3258 "/root",
3259 json!({
3260 "foo.md": "Hello!\nHow\nBye\n"
3261 }),
3262 )
3263 .await;
3264 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3265
3266 let buffer = project
3267 .update(cx, |project, cx| {
3268 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3269 project.open_buffer(path, cx)
3270 })
3271 .await
3272 .unwrap();
3273 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3274 let position = snapshot.anchor_before(language::Point::new(1, 3));
3275
3276 zeta.update(cx, |zeta, cx| {
3277 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3278 });
3279
3280 let (_, respond_tx) = requests.predict.next().await.unwrap();
3281
3282 buffer.update(cx, |buffer, cx| {
3283 buffer.set_text("Hello!\nHow are you?\nBye", cx);
3284 });
3285
3286 let response = model_response(SIMPLE_DIFF);
3287 let id = response.id.clone();
3288 respond_tx.send(response).unwrap();
3289
3290 cx.run_until_parked();
3291
3292 zeta.read_with(cx, |zeta, cx| {
3293 assert!(
3294 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3295 .is_none()
3296 );
3297 });
3298
3299 // prediction is reported as rejected
3300 let (reject_request, _) = requests.reject.next().await.unwrap();
3301
3302 assert_eq!(
3303 &reject_request.rejections,
3304 &[EditPredictionRejection {
3305 request_id: id,
3306 reason: EditPredictionRejectReason::InterpolatedEmpty,
3307 was_shown: false
3308 }]
3309 );
3310 }
3311
3312 const SIMPLE_DIFF: &str = indoc! { r"
3313 --- a/root/foo.md
3314 +++ b/root/foo.md
3315 @@ ... @@
3316 Hello!
3317 -How
3318 +How are you?
3319 Bye
3320 "};
3321
3322 #[gpui::test]
3323 async fn test_replace_current(cx: &mut TestAppContext) {
3324 let (zeta, mut requests) = init_test(cx);
3325 let fs = FakeFs::new(cx.executor());
3326 fs.insert_tree(
3327 "/root",
3328 json!({
3329 "foo.md": "Hello!\nHow\nBye\n"
3330 }),
3331 )
3332 .await;
3333 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3334
3335 let buffer = project
3336 .update(cx, |project, cx| {
3337 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3338 project.open_buffer(path, cx)
3339 })
3340 .await
3341 .unwrap();
3342 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3343 let position = snapshot.anchor_before(language::Point::new(1, 3));
3344
3345 zeta.update(cx, |zeta, cx| {
3346 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3347 });
3348
3349 let (_, respond_tx) = requests.predict.next().await.unwrap();
3350 let first_response = model_response(SIMPLE_DIFF);
3351 let first_id = first_response.id.clone();
3352 respond_tx.send(first_response).unwrap();
3353
3354 cx.run_until_parked();
3355
3356 zeta.read_with(cx, |zeta, cx| {
3357 assert_eq!(
3358 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3359 .unwrap()
3360 .id
3361 .0,
3362 first_id
3363 );
3364 });
3365
3366 // a second request is triggered
3367 zeta.update(cx, |zeta, cx| {
3368 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3369 });
3370
3371 let (_, respond_tx) = requests.predict.next().await.unwrap();
3372 let second_response = model_response(SIMPLE_DIFF);
3373 let second_id = second_response.id.clone();
3374 respond_tx.send(second_response).unwrap();
3375
3376 cx.run_until_parked();
3377
3378 zeta.read_with(cx, |zeta, cx| {
3379 // second replaces first
3380 assert_eq!(
3381 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3382 .unwrap()
3383 .id
3384 .0,
3385 second_id
3386 );
3387 });
3388
3389 // first is reported as replaced
3390 let (reject_request, _) = requests.reject.next().await.unwrap();
3391
3392 assert_eq!(
3393 &reject_request.rejections,
3394 &[EditPredictionRejection {
3395 request_id: first_id,
3396 reason: EditPredictionRejectReason::Replaced,
3397 was_shown: false
3398 }]
3399 );
3400 }
3401
3402 #[gpui::test]
3403 async fn test_current_preferred(cx: &mut TestAppContext) {
3404 let (zeta, mut requests) = init_test(cx);
3405 let fs = FakeFs::new(cx.executor());
3406 fs.insert_tree(
3407 "/root",
3408 json!({
3409 "foo.md": "Hello!\nHow\nBye\n"
3410 }),
3411 )
3412 .await;
3413 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3414
3415 let buffer = project
3416 .update(cx, |project, cx| {
3417 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3418 project.open_buffer(path, cx)
3419 })
3420 .await
3421 .unwrap();
3422 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3423 let position = snapshot.anchor_before(language::Point::new(1, 3));
3424
3425 zeta.update(cx, |zeta, cx| {
3426 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3427 });
3428
3429 let (_, respond_tx) = requests.predict.next().await.unwrap();
3430 let first_response = model_response(SIMPLE_DIFF);
3431 let first_id = first_response.id.clone();
3432 respond_tx.send(first_response).unwrap();
3433
3434 cx.run_until_parked();
3435
3436 zeta.read_with(cx, |zeta, cx| {
3437 assert_eq!(
3438 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3439 .unwrap()
3440 .id
3441 .0,
3442 first_id
3443 );
3444 });
3445
3446 // a second request is triggered
3447 zeta.update(cx, |zeta, cx| {
3448 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3449 });
3450
3451 let (_, respond_tx) = requests.predict.next().await.unwrap();
3452 // worse than current prediction
3453 let second_response = model_response(indoc! { r"
3454 --- a/root/foo.md
3455 +++ b/root/foo.md
3456 @@ ... @@
3457 Hello!
3458 -How
3459 +How are
3460 Bye
3461 "});
3462 let second_id = second_response.id.clone();
3463 respond_tx.send(second_response).unwrap();
3464
3465 cx.run_until_parked();
3466
3467 zeta.read_with(cx, |zeta, cx| {
3468 // first is preferred over second
3469 assert_eq!(
3470 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3471 .unwrap()
3472 .id
3473 .0,
3474 first_id
3475 );
3476 });
3477
3478 // second is reported as rejected
3479 let (reject_request, _) = requests.reject.next().await.unwrap();
3480
3481 assert_eq!(
3482 &reject_request.rejections,
3483 &[EditPredictionRejection {
3484 request_id: second_id,
3485 reason: EditPredictionRejectReason::CurrentPreferred,
3486 was_shown: false
3487 }]
3488 );
3489 }
3490
3491 #[gpui::test]
3492 async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
3493 let (zeta, mut requests) = init_test(cx);
3494 let fs = FakeFs::new(cx.executor());
3495 fs.insert_tree(
3496 "/root",
3497 json!({
3498 "foo.md": "Hello!\nHow\nBye\n"
3499 }),
3500 )
3501 .await;
3502 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3503
3504 let buffer = project
3505 .update(cx, |project, cx| {
3506 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3507 project.open_buffer(path, cx)
3508 })
3509 .await
3510 .unwrap();
3511 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3512 let position = snapshot.anchor_before(language::Point::new(1, 3));
3513
3514 zeta.update(cx, |zeta, cx| {
3515 // start two refresh tasks
3516 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3517
3518 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3519 });
3520
3521 let (_, respond_first) = requests.predict.next().await.unwrap();
3522 let (_, respond_second) = requests.predict.next().await.unwrap();
3523
3524 // wait for throttle
3525 cx.run_until_parked();
3526
3527 // second responds first
3528 let second_response = model_response(SIMPLE_DIFF);
3529 let second_id = second_response.id.clone();
3530 respond_second.send(second_response).unwrap();
3531
3532 cx.run_until_parked();
3533
3534 zeta.read_with(cx, |zeta, cx| {
3535 // current prediction is second
3536 assert_eq!(
3537 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3538 .unwrap()
3539 .id
3540 .0,
3541 second_id
3542 );
3543 });
3544
3545 let first_response = model_response(SIMPLE_DIFF);
3546 let first_id = first_response.id.clone();
3547 respond_first.send(first_response).unwrap();
3548
3549 cx.run_until_parked();
3550
3551 zeta.read_with(cx, |zeta, cx| {
3552 // current prediction is still second, since first was cancelled
3553 assert_eq!(
3554 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3555 .unwrap()
3556 .id
3557 .0,
3558 second_id
3559 );
3560 });
3561
3562 // first is reported as rejected
3563 let (reject_request, _) = requests.reject.next().await.unwrap();
3564
3565 cx.run_until_parked();
3566
3567 assert_eq!(
3568 &reject_request.rejections,
3569 &[EditPredictionRejection {
3570 request_id: first_id,
3571 reason: EditPredictionRejectReason::Canceled,
3572 was_shown: false
3573 }]
3574 );
3575 }
3576
3577 #[gpui::test]
3578 async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
3579 let (zeta, mut requests) = init_test(cx);
3580 let fs = FakeFs::new(cx.executor());
3581 fs.insert_tree(
3582 "/root",
3583 json!({
3584 "foo.md": "Hello!\nHow\nBye\n"
3585 }),
3586 )
3587 .await;
3588 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3589
3590 let buffer = project
3591 .update(cx, |project, cx| {
3592 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3593 project.open_buffer(path, cx)
3594 })
3595 .await
3596 .unwrap();
3597 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3598 let position = snapshot.anchor_before(language::Point::new(1, 3));
3599
3600 zeta.update(cx, |zeta, cx| {
3601 // start two refresh tasks
3602 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3603 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3604 });
3605
3606 // wait for throttle, so requests are sent
3607 cx.run_until_parked();
3608
3609 let (_, respond_first) = requests.predict.next().await.unwrap();
3610 let (_, respond_second) = requests.predict.next().await.unwrap();
3611
3612 zeta.update(cx, |zeta, cx| {
3613 // start a third request
3614 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3615
3616 // 2 are pending, so 2nd is cancelled
3617 assert_eq!(
3618 zeta.get_or_init_zeta_project(&project, cx)
3619 .cancelled_predictions
3620 .iter()
3621 .copied()
3622 .collect::<Vec<_>>(),
3623 [1]
3624 );
3625 });
3626
3627 // wait for throttle
3628 cx.run_until_parked();
3629
3630 let (_, respond_third) = requests.predict.next().await.unwrap();
3631
3632 let first_response = model_response(SIMPLE_DIFF);
3633 let first_id = first_response.id.clone();
3634 respond_first.send(first_response).unwrap();
3635
3636 cx.run_until_parked();
3637
3638 zeta.read_with(cx, |zeta, cx| {
3639 // current prediction is first
3640 assert_eq!(
3641 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3642 .unwrap()
3643 .id
3644 .0,
3645 first_id
3646 );
3647 });
3648
3649 let cancelled_response = model_response(SIMPLE_DIFF);
3650 let cancelled_id = cancelled_response.id.clone();
3651 respond_second.send(cancelled_response).unwrap();
3652
3653 cx.run_until_parked();
3654
3655 zeta.read_with(cx, |zeta, cx| {
3656 // current prediction is still first, since second was cancelled
3657 assert_eq!(
3658 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3659 .unwrap()
3660 .id
3661 .0,
3662 first_id
3663 );
3664 });
3665
3666 let third_response = model_response(SIMPLE_DIFF);
3667 let third_response_id = third_response.id.clone();
3668 respond_third.send(third_response).unwrap();
3669
3670 cx.run_until_parked();
3671
3672 zeta.read_with(cx, |zeta, cx| {
3673 // third completes and replaces first
3674 assert_eq!(
3675 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3676 .unwrap()
3677 .id
3678 .0,
3679 third_response_id
3680 );
3681 });
3682
3683 // second is reported as rejected
3684 let (reject_request, _) = requests.reject.next().await.unwrap();
3685
3686 cx.run_until_parked();
3687
3688 assert_eq!(
3689 &reject_request.rejections,
3690 &[
3691 EditPredictionRejection {
3692 request_id: cancelled_id,
3693 reason: EditPredictionRejectReason::Canceled,
3694 was_shown: false
3695 },
3696 EditPredictionRejection {
3697 request_id: first_id,
3698 reason: EditPredictionRejectReason::Replaced,
3699 was_shown: false
3700 }
3701 ]
3702 );
3703 }
3704
3705 // Skipped until we start including diagnostics in prompt
3706 // #[gpui::test]
3707 // async fn test_request_diagnostics(cx: &mut TestAppContext) {
3708 // let (zeta, mut req_rx) = init_test(cx);
3709 // let fs = FakeFs::new(cx.executor());
3710 // fs.insert_tree(
3711 // "/root",
3712 // json!({
3713 // "foo.md": "Hello!\nBye"
3714 // }),
3715 // )
3716 // .await;
3717 // let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3718
3719 // let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
3720 // let diagnostic = lsp::Diagnostic {
3721 // range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
3722 // severity: Some(lsp::DiagnosticSeverity::ERROR),
3723 // message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
3724 // ..Default::default()
3725 // };
3726
3727 // project.update(cx, |project, cx| {
3728 // project.lsp_store().update(cx, |lsp_store, cx| {
3729 // // Create some diagnostics
3730 // lsp_store
3731 // .update_diagnostics(
3732 // LanguageServerId(0),
3733 // lsp::PublishDiagnosticsParams {
3734 // uri: path_to_buffer_uri.clone(),
3735 // diagnostics: vec![diagnostic],
3736 // version: None,
3737 // },
3738 // None,
3739 // language::DiagnosticSourceKind::Pushed,
3740 // &[],
3741 // cx,
3742 // )
3743 // .unwrap();
3744 // });
3745 // });
3746
3747 // let buffer = project
3748 // .update(cx, |project, cx| {
3749 // let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3750 // project.open_buffer(path, cx)
3751 // })
3752 // .await
3753 // .unwrap();
3754
3755 // let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3756 // let position = snapshot.anchor_before(language::Point::new(0, 0));
3757
3758 // let _prediction_task = zeta.update(cx, |zeta, cx| {
3759 // zeta.request_prediction(&project, &buffer, position, cx)
3760 // });
3761
3762 // let (request, _respond_tx) = req_rx.next().await.unwrap();
3763
3764 // assert_eq!(request.diagnostic_groups.len(), 1);
3765 // let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
3766 // .unwrap();
3767 // // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
3768 // assert_eq!(
3769 // value,
3770 // json!({
3771 // "entries": [{
3772 // "range": {
3773 // "start": 8,
3774 // "end": 10
3775 // },
3776 // "diagnostic": {
3777 // "source": null,
3778 // "code": null,
3779 // "code_description": null,
3780 // "severity": 1,
3781 // "message": "\"Hello\" deprecated. Use \"Hi\" instead",
3782 // "markdown": null,
3783 // "group_id": 0,
3784 // "is_primary": true,
3785 // "is_disk_based": false,
3786 // "is_unnecessary": false,
3787 // "source_kind": "Pushed",
3788 // "data": null,
3789 // "underline": true
3790 // }
3791 // }],
3792 // "primary_ix": 0
3793 // })
3794 // );
3795 // }
3796
3797 fn model_response(text: &str) -> open_ai::Response {
3798 open_ai::Response {
3799 id: Uuid::new_v4().to_string(),
3800 object: "response".into(),
3801 created: 0,
3802 model: "model".into(),
3803 choices: vec![open_ai::Choice {
3804 index: 0,
3805 message: open_ai::RequestMessage::Assistant {
3806 content: Some(open_ai::MessageContent::Plain(text.to_string())),
3807 tool_calls: vec![],
3808 },
3809 finish_reason: None,
3810 }],
3811 usage: Usage {
3812 prompt_tokens: 0,
3813 completion_tokens: 0,
3814 total_tokens: 0,
3815 },
3816 }
3817 }
3818
3819 fn prompt_from_request(request: &open_ai::Request) -> &str {
3820 assert_eq!(request.messages.len(), 1);
3821 let open_ai::RequestMessage::User {
3822 content: open_ai::MessageContent::Plain(content),
3823 ..
3824 } = &request.messages[0]
3825 else {
3826 panic!(
3827 "Request does not have single user message of type Plain. {:#?}",
3828 request
3829 );
3830 };
3831 content
3832 }
3833
3834 struct RequestChannels {
3835 predict: mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender<open_ai::Response>)>,
3836 reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>,
3837 }
3838
3839 fn init_test(cx: &mut TestAppContext) -> (Entity<Zeta>, RequestChannels) {
3840 cx.update(move |cx| {
3841 let settings_store = SettingsStore::test(cx);
3842 cx.set_global(settings_store);
3843 zlog::init_test();
3844
3845 let (predict_req_tx, predict_req_rx) = mpsc::unbounded();
3846 let (reject_req_tx, reject_req_rx) = mpsc::unbounded();
3847
3848 let http_client = FakeHttpClient::create({
3849 move |req| {
3850 let uri = req.uri().path().to_string();
3851 let mut body = req.into_body();
3852 let predict_req_tx = predict_req_tx.clone();
3853 let reject_req_tx = reject_req_tx.clone();
3854 async move {
3855 let resp = match uri.as_str() {
3856 "/client/llm_tokens" => serde_json::to_string(&json!({
3857 "token": "test"
3858 }))
3859 .unwrap(),
3860 "/predict_edits/raw" => {
3861 let mut buf = Vec::new();
3862 body.read_to_end(&mut buf).await.ok();
3863 let req = serde_json::from_slice(&buf).unwrap();
3864
3865 let (res_tx, res_rx) = oneshot::channel();
3866 predict_req_tx.unbounded_send((req, res_tx)).unwrap();
3867 serde_json::to_string(&res_rx.await?).unwrap()
3868 }
3869 "/predict_edits/reject" => {
3870 let mut buf = Vec::new();
3871 body.read_to_end(&mut buf).await.ok();
3872 let req = serde_json::from_slice(&buf).unwrap();
3873
3874 let (res_tx, res_rx) = oneshot::channel();
3875 reject_req_tx.unbounded_send((req, res_tx)).unwrap();
3876 serde_json::to_string(&res_rx.await?).unwrap()
3877 }
3878 _ => {
3879 panic!("Unexpected path: {}", uri)
3880 }
3881 };
3882
3883 Ok(Response::builder().body(resp.into()).unwrap())
3884 }
3885 }
3886 });
3887
3888 let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
3889 client.cloud_client().set_credentials(1, "test".into());
3890
3891 language_model::init(client.clone(), cx);
3892
3893 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
3894 let zeta = Zeta::global(&client, &user_store, cx);
3895
3896 (
3897 zeta,
3898 RequestChannels {
3899 predict: predict_req_rx,
3900 reject: reject_req_rx,
3901 },
3902 )
3903 })
3904 }
3905}