1use anyhow::{Context as _, Result, anyhow, bail};
2use arrayvec::ArrayVec;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat, Signature};
5use cloud_llm_client::{
6 AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason,
7 EditPredictionRejection, MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST,
8 MINIMUM_REQUIRED_VERSION_HEADER_NAME, PredictEditsRequestTrigger, RejectEditPredictionsBody,
9 ZED_VERSION_HEADER_NAME,
10};
11use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
12use cloud_zeta2_prompt::{CURSOR_MARKER, DEFAULT_MAX_PROMPT_BYTES};
13use collections::{HashMap, HashSet};
14use command_palette_hooks::CommandPaletteFilter;
15use db::kvp::{Dismissable, KEY_VALUE_STORE};
16use edit_prediction_context::{
17 DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
18 EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionScoreOptions, Line,
19 SyntaxIndex, SyntaxIndexState,
20};
21use feature_flags::{FeatureFlag, FeatureFlagAppExt as _, PredictEditsRateCompletionsFeatureFlag};
22use futures::channel::{mpsc, oneshot};
23use futures::{AsyncReadExt as _, FutureExt as _, StreamExt as _};
24use gpui::{
25 App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
26 http_client::{self, AsyncBody, Method},
27 prelude::*,
28};
29use language::{
30 Anchor, Buffer, DiagnosticSet, File, LanguageServerId, Point, ToOffset as _, ToPoint,
31};
32use language::{BufferSnapshot, OffsetRangeExt};
33use language_model::{LlmApiToken, RefreshLlmTokenListener};
34use open_ai::FunctionDefinition;
35use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
36use release_channel::AppVersion;
37use semver::Version;
38use serde::de::DeserializeOwned;
39use settings::{EditPredictionProvider, Settings as _, SettingsStore, update_settings_file};
40use std::any::{Any as _, TypeId};
41use std::collections::{VecDeque, hash_map};
42use telemetry_events::EditPredictionRating;
43use workspace::Workspace;
44
45use std::ops::Range;
46use std::path::Path;
47use std::rc::Rc;
48use std::str::FromStr as _;
49use std::sync::{Arc, LazyLock};
50use std::time::{Duration, Instant};
51use std::{env, mem};
52use thiserror::Error;
53use util::rel_path::RelPathBuf;
54use util::{LogErrorFuture, RangeExt as _, ResultExt as _, TryFutureExt};
55use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
56
57pub mod assemble_excerpts;
58mod license_detection;
59mod onboarding_modal;
60mod prediction;
61mod provider;
62mod rate_prediction_modal;
63pub mod retrieval_search;
64pub mod sweep_ai;
65pub mod udiff;
66mod xml_edits;
67pub mod zeta1;
68
69#[cfg(test)]
70mod zeta_tests;
71
72use crate::assemble_excerpts::assemble_excerpts;
73use crate::license_detection::LicenseDetectionWatcher;
74use crate::onboarding_modal::ZedPredictModal;
75pub use crate::prediction::EditPrediction;
76pub use crate::prediction::EditPredictionId;
77pub use crate::prediction::EditPredictionInputs;
78use crate::prediction::EditPredictionResult;
79use crate::rate_prediction_modal::{
80 NextEdit, PreviousEdit, RatePredictionsModal, ThumbsDownActivePrediction,
81 ThumbsUpActivePrediction,
82};
83pub use crate::sweep_ai::SweepAi;
84use crate::zeta1::request_prediction_with_zeta1;
85pub use provider::ZetaEditPredictionProvider;
86
87actions!(
88 edit_prediction,
89 [
90 /// Resets the edit prediction onboarding state.
91 ResetOnboarding,
92 /// Opens the rate completions modal.
93 RateCompletions,
94 /// Clears the edit prediction history.
95 ClearHistory,
96 ]
97);
98
99/// Maximum number of events to track.
100const EVENT_COUNT_MAX: usize = 6;
101const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
102const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
103
104pub struct SweepFeatureFlag;
105
106impl FeatureFlag for SweepFeatureFlag {
107 const NAME: &str = "sweep-ai";
108}
109pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
110 max_bytes: 512,
111 min_bytes: 128,
112 target_before_cursor_over_total_bytes: 0.5,
113};
114
115pub const DEFAULT_CONTEXT_OPTIONS: ContextMode =
116 ContextMode::Agentic(DEFAULT_AGENTIC_CONTEXT_OPTIONS);
117
118pub const DEFAULT_AGENTIC_CONTEXT_OPTIONS: AgenticContextOptions = AgenticContextOptions {
119 excerpt: DEFAULT_EXCERPT_OPTIONS,
120};
121
122pub const DEFAULT_SYNTAX_CONTEXT_OPTIONS: EditPredictionContextOptions =
123 EditPredictionContextOptions {
124 use_imports: true,
125 max_retrieved_declarations: 0,
126 excerpt: DEFAULT_EXCERPT_OPTIONS,
127 score: EditPredictionScoreOptions {
128 omit_excerpt_overlaps: true,
129 },
130 };
131
132pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
133 context: DEFAULT_CONTEXT_OPTIONS,
134 max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
135 max_diagnostic_bytes: 2048,
136 prompt_format: PromptFormat::DEFAULT,
137 file_indexing_parallelism: 1,
138 buffer_change_grouping_interval: Duration::from_secs(1),
139};
140
141static USE_OLLAMA: LazyLock<bool> =
142 LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
143static CONTEXT_RETRIEVAL_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
144 env::var("ZED_ZETA2_CONTEXT_MODEL").unwrap_or(if *USE_OLLAMA {
145 "qwen3-coder:30b".to_string()
146 } else {
147 "yqvev8r3".to_string()
148 })
149});
150static EDIT_PREDICTIONS_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
151 match env::var("ZED_ZETA2_MODEL").as_deref() {
152 Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten
153 Ok(model) => model,
154 Err(_) if *USE_OLLAMA => "qwen3-coder:30b",
155 Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten
156 }
157 .to_string()
158});
159static PREDICT_EDITS_URL: LazyLock<Option<String>> = LazyLock::new(|| {
160 env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| {
161 if *USE_OLLAMA {
162 Some("http://localhost:11434/v1/chat/completions".into())
163 } else {
164 None
165 }
166 })
167});
168
169pub struct Zeta2FeatureFlag;
170
171impl FeatureFlag for Zeta2FeatureFlag {
172 const NAME: &'static str = "zeta2";
173
174 fn enabled_for_staff() -> bool {
175 true
176 }
177}
178
179#[derive(Clone)]
180struct ZetaGlobal(Entity<Zeta>);
181
182impl Global for ZetaGlobal {}
183
184pub struct Zeta {
185 client: Arc<Client>,
186 user_store: Entity<UserStore>,
187 llm_token: LlmApiToken,
188 _llm_token_subscription: Subscription,
189 projects: HashMap<EntityId, ZetaProject>,
190 options: ZetaOptions,
191 update_required: bool,
192 debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
193 #[cfg(feature = "eval-support")]
194 eval_cache: Option<Arc<dyn EvalCache>>,
195 edit_prediction_model: ZetaEditPredictionModel,
196 pub sweep_ai: SweepAi,
197 data_collection_choice: DataCollectionChoice,
198 rejected_predictions: Vec<EditPredictionRejection>,
199 reject_predictions_tx: mpsc::UnboundedSender<()>,
200 reject_predictions_debounce_task: Option<Task<()>>,
201 shown_predictions: VecDeque<EditPrediction>,
202 rated_predictions: HashSet<EditPredictionId>,
203}
204
205#[derive(Copy, Clone, Default, PartialEq, Eq)]
206pub enum ZetaEditPredictionModel {
207 #[default]
208 Zeta1,
209 Zeta2,
210 Sweep,
211}
212
213#[derive(Debug, Clone, PartialEq)]
214pub struct ZetaOptions {
215 pub context: ContextMode,
216 pub max_prompt_bytes: usize,
217 pub max_diagnostic_bytes: usize,
218 pub prompt_format: predict_edits_v3::PromptFormat,
219 pub file_indexing_parallelism: usize,
220 pub buffer_change_grouping_interval: Duration,
221}
222
223#[derive(Debug, Clone, PartialEq)]
224pub enum ContextMode {
225 Agentic(AgenticContextOptions),
226 Syntax(EditPredictionContextOptions),
227}
228
229#[derive(Debug, Clone, PartialEq)]
230pub struct AgenticContextOptions {
231 pub excerpt: EditPredictionExcerptOptions,
232}
233
234impl ContextMode {
235 pub fn excerpt(&self) -> &EditPredictionExcerptOptions {
236 match self {
237 ContextMode::Agentic(options) => &options.excerpt,
238 ContextMode::Syntax(options) => &options.excerpt,
239 }
240 }
241}
242
243#[derive(Debug)]
244pub enum ZetaDebugInfo {
245 ContextRetrievalStarted(ZetaContextRetrievalStartedDebugInfo),
246 SearchQueriesGenerated(ZetaSearchQueryDebugInfo),
247 SearchQueriesExecuted(ZetaContextRetrievalDebugInfo),
248 ContextRetrievalFinished(ZetaContextRetrievalDebugInfo),
249 EditPredictionRequested(ZetaEditPredictionDebugInfo),
250}
251
252#[derive(Debug)]
253pub struct ZetaContextRetrievalStartedDebugInfo {
254 pub project: Entity<Project>,
255 pub timestamp: Instant,
256 pub search_prompt: String,
257}
258
259#[derive(Debug)]
260pub struct ZetaContextRetrievalDebugInfo {
261 pub project: Entity<Project>,
262 pub timestamp: Instant,
263}
264
265#[derive(Debug)]
266pub struct ZetaEditPredictionDebugInfo {
267 pub inputs: EditPredictionInputs,
268 pub retrieval_time: Duration,
269 pub buffer: WeakEntity<Buffer>,
270 pub position: language::Anchor,
271 pub local_prompt: Result<String, String>,
272 pub response_rx: oneshot::Receiver<(Result<open_ai::Response, String>, Duration)>,
273}
274
275#[derive(Debug)]
276pub struct ZetaSearchQueryDebugInfo {
277 pub project: Entity<Project>,
278 pub timestamp: Instant,
279 pub search_queries: Vec<SearchToolQuery>,
280}
281
282pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
283
284struct ZetaProject {
285 syntax_index: Option<Entity<SyntaxIndex>>,
286 events: VecDeque<Arc<cloud_llm_client::predict_edits_v3::Event>>,
287 last_event: Option<LastEvent>,
288 recent_paths: VecDeque<ProjectPath>,
289 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
290 current_prediction: Option<CurrentEditPrediction>,
291 next_pending_prediction_id: usize,
292 pending_predictions: ArrayVec<PendingPrediction, 2>,
293 last_prediction_refresh: Option<(EntityId, Instant)>,
294 cancelled_predictions: HashSet<usize>,
295 context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
296 refresh_context_task: Option<LogErrorFuture<Task<Result<()>>>>,
297 refresh_context_debounce_task: Option<Task<Option<()>>>,
298 refresh_context_timestamp: Option<Instant>,
299 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
300 _subscription: gpui::Subscription,
301}
302
303impl ZetaProject {
304 pub fn events(&self, cx: &App) -> Vec<Arc<cloud_llm_client::predict_edits_v3::Event>> {
305 self.events
306 .iter()
307 .cloned()
308 .chain(
309 self.last_event
310 .as_ref()
311 .and_then(|event| event.finalize(&self.license_detection_watchers, cx)),
312 )
313 .collect()
314 }
315
316 fn cancel_pending_prediction(
317 &mut self,
318 pending_prediction: PendingPrediction,
319 cx: &mut Context<Zeta>,
320 ) {
321 self.cancelled_predictions.insert(pending_prediction.id);
322
323 cx.spawn(async move |this, cx| {
324 let Some(prediction_id) = pending_prediction.task.await else {
325 return;
326 };
327
328 this.update(cx, |this, cx| {
329 this.reject_prediction(
330 prediction_id,
331 EditPredictionRejectReason::Canceled,
332 false,
333 cx,
334 );
335 })
336 .ok();
337 })
338 .detach()
339 }
340}
341
342#[derive(Debug, Clone)]
343struct CurrentEditPrediction {
344 pub requested_by: PredictionRequestedBy,
345 pub prediction: EditPrediction,
346 pub was_shown: bool,
347}
348
349impl CurrentEditPrediction {
350 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
351 let Some(new_edits) = self
352 .prediction
353 .interpolate(&self.prediction.buffer.read(cx))
354 else {
355 return false;
356 };
357
358 if self.prediction.buffer != old_prediction.prediction.buffer {
359 return true;
360 }
361
362 let Some(old_edits) = old_prediction
363 .prediction
364 .interpolate(&old_prediction.prediction.buffer.read(cx))
365 else {
366 return true;
367 };
368
369 let requested_by_buffer_id = self.requested_by.buffer_id();
370
371 // This reduces the occurrence of UI thrash from replacing edits
372 //
373 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
374 if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
375 && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
376 && old_edits.len() == 1
377 && new_edits.len() == 1
378 {
379 let (old_range, old_text) = &old_edits[0];
380 let (new_range, new_text) = &new_edits[0];
381 new_range == old_range && new_text.starts_with(old_text.as_ref())
382 } else {
383 true
384 }
385 }
386}
387
388#[derive(Debug, Clone)]
389enum PredictionRequestedBy {
390 DiagnosticsUpdate,
391 Buffer(EntityId),
392}
393
394impl PredictionRequestedBy {
395 pub fn buffer_id(&self) -> Option<EntityId> {
396 match self {
397 PredictionRequestedBy::DiagnosticsUpdate => None,
398 PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
399 }
400 }
401}
402
403#[derive(Debug)]
404struct PendingPrediction {
405 id: usize,
406 task: Task<Option<EditPredictionId>>,
407}
408
409/// A prediction from the perspective of a buffer.
410#[derive(Debug)]
411enum BufferEditPrediction<'a> {
412 Local { prediction: &'a EditPrediction },
413 Jump { prediction: &'a EditPrediction },
414}
415
416#[cfg(test)]
417impl std::ops::Deref for BufferEditPrediction<'_> {
418 type Target = EditPrediction;
419
420 fn deref(&self) -> &Self::Target {
421 match self {
422 BufferEditPrediction::Local { prediction } => prediction,
423 BufferEditPrediction::Jump { prediction } => prediction,
424 }
425 }
426}
427
428struct RegisteredBuffer {
429 snapshot: BufferSnapshot,
430 _subscriptions: [gpui::Subscription; 2],
431}
432
433struct LastEvent {
434 old_snapshot: BufferSnapshot,
435 new_snapshot: BufferSnapshot,
436 end_edit_anchor: Option<Anchor>,
437}
438
439impl LastEvent {
440 pub fn finalize(
441 &self,
442 license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
443 cx: &App,
444 ) -> Option<Arc<predict_edits_v3::Event>> {
445 let path = buffer_path_with_id_fallback(&self.new_snapshot, cx);
446 let old_path = buffer_path_with_id_fallback(&self.old_snapshot, cx);
447
448 let file = self.new_snapshot.file();
449 let old_file = self.old_snapshot.file();
450
451 let in_open_source_repo = [file, old_file].iter().all(|file| {
452 file.is_some_and(|file| {
453 license_detection_watchers
454 .get(&file.worktree_id(cx))
455 .is_some_and(|watcher| watcher.is_project_open_source())
456 })
457 });
458
459 let diff = language::unified_diff(&self.old_snapshot.text(), &self.new_snapshot.text());
460
461 if path == old_path && diff.is_empty() {
462 None
463 } else {
464 Some(Arc::new(predict_edits_v3::Event::BufferChange {
465 old_path,
466 path,
467 diff,
468 in_open_source_repo,
469 // TODO: Actually detect if this edit was predicted or not
470 predicted: false,
471 }))
472 }
473 }
474}
475
476fn buffer_path_with_id_fallback(snapshot: &BufferSnapshot, cx: &App) -> Arc<Path> {
477 if let Some(file) = snapshot.file() {
478 file.full_path(cx).into()
479 } else {
480 Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
481 }
482}
483
484impl Zeta {
485 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
486 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
487 }
488
489 pub fn global(
490 client: &Arc<Client>,
491 user_store: &Entity<UserStore>,
492 cx: &mut App,
493 ) -> Entity<Self> {
494 cx.try_global::<ZetaGlobal>()
495 .map(|global| global.0.clone())
496 .unwrap_or_else(|| {
497 let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
498 cx.set_global(ZetaGlobal(zeta.clone()));
499 zeta
500 })
501 }
502
503 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
504 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
505 let data_collection_choice = Self::load_data_collection_choice();
506
507 let (reject_tx, mut reject_rx) = mpsc::unbounded();
508 cx.spawn(async move |this, cx| {
509 while let Some(()) = reject_rx.next().await {
510 this.update(cx, |this, cx| this.flush_rejected_predictions(cx))?
511 .await
512 .log_err();
513 }
514 anyhow::Ok(())
515 })
516 .detach();
517
518 Self {
519 projects: HashMap::default(),
520 client,
521 user_store,
522 options: DEFAULT_OPTIONS,
523 llm_token: LlmApiToken::default(),
524 _llm_token_subscription: cx.subscribe(
525 &refresh_llm_token_listener,
526 |this, _listener, _event, cx| {
527 let client = this.client.clone();
528 let llm_token = this.llm_token.clone();
529 cx.spawn(async move |_this, _cx| {
530 llm_token.refresh(&client).await?;
531 anyhow::Ok(())
532 })
533 .detach_and_log_err(cx);
534 },
535 ),
536 update_required: false,
537 debug_tx: None,
538 #[cfg(feature = "eval-support")]
539 eval_cache: None,
540 edit_prediction_model: ZetaEditPredictionModel::Zeta2,
541 sweep_ai: SweepAi::new(cx),
542 data_collection_choice,
543 rejected_predictions: Vec::new(),
544 reject_predictions_debounce_task: None,
545 reject_predictions_tx: reject_tx,
546 rated_predictions: Default::default(),
547 shown_predictions: Default::default(),
548 }
549 }
550
551 pub fn set_edit_prediction_model(&mut self, model: ZetaEditPredictionModel) {
552 self.edit_prediction_model = model;
553 }
554
555 pub fn has_sweep_api_token(&self) -> bool {
556 self.sweep_ai
557 .api_token
558 .clone()
559 .now_or_never()
560 .flatten()
561 .is_some()
562 }
563
564 #[cfg(feature = "eval-support")]
565 pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
566 self.eval_cache = Some(cache);
567 }
568
569 pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<ZetaDebugInfo> {
570 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
571 self.debug_tx = Some(debug_watch_tx);
572 debug_watch_rx
573 }
574
575 pub fn options(&self) -> &ZetaOptions {
576 &self.options
577 }
578
579 pub fn set_options(&mut self, options: ZetaOptions) {
580 self.options = options;
581 }
582
583 pub fn clear_history(&mut self) {
584 for zeta_project in self.projects.values_mut() {
585 zeta_project.events.clear();
586 }
587 }
588
589 pub fn context_for_project(
590 &self,
591 project: &Entity<Project>,
592 ) -> impl Iterator<Item = (Entity<Buffer>, &[Range<Anchor>])> {
593 self.projects
594 .get(&project.entity_id())
595 .and_then(|project| {
596 Some(
597 project
598 .context
599 .as_ref()?
600 .iter()
601 .map(|(buffer, ranges)| (buffer.clone(), ranges.as_slice())),
602 )
603 })
604 .into_iter()
605 .flatten()
606 }
607
608 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
609 if self.edit_prediction_model == ZetaEditPredictionModel::Zeta2 {
610 self.user_store.read(cx).edit_prediction_usage()
611 } else {
612 None
613 }
614 }
615
616 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
617 self.get_or_init_zeta_project(project, cx);
618 }
619
620 pub fn register_buffer(
621 &mut self,
622 buffer: &Entity<Buffer>,
623 project: &Entity<Project>,
624 cx: &mut Context<Self>,
625 ) {
626 let zeta_project = self.get_or_init_zeta_project(project, cx);
627 Self::register_buffer_impl(zeta_project, buffer, project, cx);
628 }
629
630 fn get_or_init_zeta_project(
631 &mut self,
632 project: &Entity<Project>,
633 cx: &mut Context<Self>,
634 ) -> &mut ZetaProject {
635 self.projects
636 .entry(project.entity_id())
637 .or_insert_with(|| ZetaProject {
638 syntax_index: if let ContextMode::Syntax(_) = &self.options.context {
639 Some(cx.new(|cx| {
640 SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
641 }))
642 } else {
643 None
644 },
645 events: VecDeque::new(),
646 last_event: None,
647 recent_paths: VecDeque::new(),
648 registered_buffers: HashMap::default(),
649 current_prediction: None,
650 cancelled_predictions: HashSet::default(),
651 pending_predictions: ArrayVec::new(),
652 next_pending_prediction_id: 0,
653 last_prediction_refresh: None,
654 context: None,
655 refresh_context_task: None,
656 refresh_context_debounce_task: None,
657 refresh_context_timestamp: None,
658 license_detection_watchers: HashMap::default(),
659 _subscription: cx.subscribe(&project, Self::handle_project_event),
660 })
661 }
662
663 fn handle_project_event(
664 &mut self,
665 project: Entity<Project>,
666 event: &project::Event,
667 cx: &mut Context<Self>,
668 ) {
669 // TODO [zeta2] init with recent paths
670 match event {
671 project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
672 let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
673 return;
674 };
675 let path = project.read(cx).path_for_entry(*active_entry_id, cx);
676 if let Some(path) = path {
677 if let Some(ix) = zeta_project
678 .recent_paths
679 .iter()
680 .position(|probe| probe == &path)
681 {
682 zeta_project.recent_paths.remove(ix);
683 }
684 zeta_project.recent_paths.push_front(path);
685 }
686 }
687 project::Event::DiagnosticsUpdated { .. } => {
688 if cx.has_flag::<Zeta2FeatureFlag>() {
689 self.refresh_prediction_from_diagnostics(project, cx);
690 }
691 }
692 _ => (),
693 }
694 }
695
696 fn register_buffer_impl<'a>(
697 zeta_project: &'a mut ZetaProject,
698 buffer: &Entity<Buffer>,
699 project: &Entity<Project>,
700 cx: &mut Context<Self>,
701 ) -> &'a mut RegisteredBuffer {
702 let buffer_id = buffer.entity_id();
703
704 if let Some(file) = buffer.read(cx).file() {
705 let worktree_id = file.worktree_id(cx);
706 if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
707 zeta_project
708 .license_detection_watchers
709 .entry(worktree_id)
710 .or_insert_with(|| {
711 let project_entity_id = project.entity_id();
712 cx.observe_release(&worktree, move |this, _worktree, _cx| {
713 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
714 else {
715 return;
716 };
717 zeta_project.license_detection_watchers.remove(&worktree_id);
718 })
719 .detach();
720 Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
721 });
722 }
723 }
724
725 match zeta_project.registered_buffers.entry(buffer_id) {
726 hash_map::Entry::Occupied(entry) => entry.into_mut(),
727 hash_map::Entry::Vacant(entry) => {
728 let snapshot = buffer.read(cx).snapshot();
729 let project_entity_id = project.entity_id();
730 entry.insert(RegisteredBuffer {
731 snapshot,
732 _subscriptions: [
733 cx.subscribe(buffer, {
734 let project = project.downgrade();
735 move |this, buffer, event, cx| {
736 if let language::BufferEvent::Edited = event
737 && let Some(project) = project.upgrade()
738 {
739 this.report_changes_for_buffer(&buffer, &project, cx);
740 }
741 }
742 }),
743 cx.observe_release(buffer, move |this, _buffer, _cx| {
744 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
745 else {
746 return;
747 };
748 zeta_project.registered_buffers.remove(&buffer_id);
749 }),
750 ],
751 })
752 }
753 }
754 }
755
756 fn report_changes_for_buffer(
757 &mut self,
758 buffer: &Entity<Buffer>,
759 project: &Entity<Project>,
760 cx: &mut Context<Self>,
761 ) {
762 let project_state = self.get_or_init_zeta_project(project, cx);
763 let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
764
765 let new_snapshot = buffer.read(cx).snapshot();
766 if new_snapshot.version == registered_buffer.snapshot.version {
767 return;
768 }
769
770 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
771 let end_edit_anchor = new_snapshot
772 .anchored_edits_since::<Point>(&old_snapshot.version)
773 .last()
774 .map(|(_, range)| range.end);
775 let events = &mut project_state.events;
776
777 if let Some(LastEvent {
778 new_snapshot: last_new_snapshot,
779 end_edit_anchor: last_end_edit_anchor,
780 ..
781 }) = project_state.last_event.as_mut()
782 {
783 let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
784 == last_new_snapshot.remote_id()
785 && old_snapshot.version == last_new_snapshot.version;
786
787 let should_coalesce = is_next_snapshot_of_same_buffer
788 && end_edit_anchor
789 .as_ref()
790 .zip(last_end_edit_anchor.as_ref())
791 .is_some_and(|(a, b)| {
792 let a = a.to_point(&new_snapshot);
793 let b = b.to_point(&new_snapshot);
794 a.row.abs_diff(b.row) <= CHANGE_GROUPING_LINE_SPAN
795 });
796
797 if should_coalesce {
798 *last_end_edit_anchor = end_edit_anchor;
799 *last_new_snapshot = new_snapshot;
800 return;
801 }
802 }
803
804 if events.len() + 1 >= EVENT_COUNT_MAX {
805 events.pop_front();
806 }
807
808 if let Some(event) = project_state.last_event.take() {
809 events.extend(event.finalize(&project_state.license_detection_watchers, cx));
810 }
811
812 project_state.last_event = Some(LastEvent {
813 old_snapshot,
814 new_snapshot,
815 end_edit_anchor,
816 });
817 }
818
819 fn current_prediction_for_buffer(
820 &self,
821 buffer: &Entity<Buffer>,
822 project: &Entity<Project>,
823 cx: &App,
824 ) -> Option<BufferEditPrediction<'_>> {
825 let project_state = self.projects.get(&project.entity_id())?;
826
827 let CurrentEditPrediction {
828 requested_by,
829 prediction,
830 ..
831 } = project_state.current_prediction.as_ref()?;
832
833 if prediction.targets_buffer(buffer.read(cx)) {
834 Some(BufferEditPrediction::Local { prediction })
835 } else {
836 let show_jump = match requested_by {
837 PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
838 requested_by_buffer_id == &buffer.entity_id()
839 }
840 PredictionRequestedBy::DiagnosticsUpdate => true,
841 };
842
843 if show_jump {
844 Some(BufferEditPrediction::Jump { prediction })
845 } else {
846 None
847 }
848 }
849 }
850
851 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
852 match self.edit_prediction_model {
853 ZetaEditPredictionModel::Zeta1 | ZetaEditPredictionModel::Zeta2 => {}
854 ZetaEditPredictionModel::Sweep => return,
855 }
856
857 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
858 return;
859 };
860
861 let Some(prediction) = project_state.current_prediction.take() else {
862 return;
863 };
864 let request_id = prediction.prediction.id.to_string();
865 for pending_prediction in mem::take(&mut project_state.pending_predictions) {
866 project_state.cancel_pending_prediction(pending_prediction, cx);
867 }
868
869 let client = self.client.clone();
870 let llm_token = self.llm_token.clone();
871 let app_version = AppVersion::global(cx);
872 cx.spawn(async move |this, cx| {
873 let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") {
874 http_client::Url::parse(&predict_edits_url)?
875 } else {
876 client
877 .http_client()
878 .build_zed_llm_url("/predict_edits/accept", &[])?
879 };
880
881 let response = cx
882 .background_spawn(Self::send_api_request::<()>(
883 move |builder| {
884 let req = builder.uri(url.as_ref()).body(
885 serde_json::to_string(&AcceptEditPredictionBody {
886 request_id: request_id.clone(),
887 })?
888 .into(),
889 );
890 Ok(req?)
891 },
892 client,
893 llm_token,
894 app_version,
895 ))
896 .await;
897
898 Self::handle_api_response(&this, response, cx)?;
899 anyhow::Ok(())
900 })
901 .detach_and_log_err(cx);
902 }
903
904 fn flush_rejected_predictions(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
905 match self.edit_prediction_model {
906 ZetaEditPredictionModel::Zeta1 | ZetaEditPredictionModel::Zeta2 => {}
907 ZetaEditPredictionModel::Sweep => return Task::ready(anyhow::Ok(())),
908 }
909
910 let client = self.client.clone();
911 let llm_token = self.llm_token.clone();
912 let app_version = AppVersion::global(cx);
913 let last_rejection = self.rejected_predictions.last().cloned();
914 let Some(last_rejection) = last_rejection else {
915 return Task::ready(anyhow::Ok(()));
916 };
917
918 let body = serde_json::to_string(&RejectEditPredictionsBody {
919 rejections: self.rejected_predictions.clone(),
920 })
921 .ok();
922
923 cx.spawn(async move |this, cx| {
924 let url = client
925 .http_client()
926 .build_zed_llm_url("/predict_edits/reject", &[])?;
927
928 cx.background_spawn(Self::send_api_request::<()>(
929 move |builder| {
930 let req = builder.uri(url.as_ref()).body(body.clone().into());
931 Ok(req?)
932 },
933 client,
934 llm_token,
935 app_version,
936 ))
937 .await
938 .context("Failed to reject edit predictions")?;
939
940 this.update(cx, |this, _| {
941 if let Some(ix) = this
942 .rejected_predictions
943 .iter()
944 .position(|rejection| rejection.request_id == last_rejection.request_id)
945 {
946 this.rejected_predictions.drain(..ix + 1);
947 }
948 })
949 })
950 }
951
952 fn reject_current_prediction(
953 &mut self,
954 reason: EditPredictionRejectReason,
955 project: &Entity<Project>,
956 cx: &mut Context<Self>,
957 ) {
958 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
959 project_state.pending_predictions.clear();
960 if let Some(prediction) = project_state.current_prediction.take() {
961 self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown, cx);
962 }
963 };
964 }
965
966 fn did_show_current_prediction(&mut self, project: &Entity<Project>, _cx: &mut Context<Self>) {
967 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
968 if let Some(current_prediction) = project_state.current_prediction.as_mut() {
969 if !current_prediction.was_shown {
970 current_prediction.was_shown = true;
971 self.shown_predictions
972 .push_front(current_prediction.prediction.clone());
973 if self.shown_predictions.len() > 50 {
974 let completion = self.shown_predictions.pop_back().unwrap();
975 self.rated_predictions.remove(&completion.id);
976 }
977 }
978 }
979 }
980 }
981
982 fn reject_prediction(
983 &mut self,
984 prediction_id: EditPredictionId,
985 reason: EditPredictionRejectReason,
986 was_shown: bool,
987 cx: &mut Context<Self>,
988 ) {
989 self.rejected_predictions.push(EditPredictionRejection {
990 request_id: prediction_id.to_string(),
991 reason,
992 was_shown,
993 });
994
995 let reached_request_limit =
996 self.rejected_predictions.len() >= MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2;
997 let reject_tx = self.reject_predictions_tx.clone();
998 self.reject_predictions_debounce_task = Some(cx.spawn(async move |_this, cx| {
999 const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
1000 if !reached_request_limit {
1001 cx.background_executor()
1002 .timer(REJECT_REQUEST_DEBOUNCE)
1003 .await;
1004 }
1005 reject_tx.unbounded_send(()).log_err();
1006 }));
1007 }
1008
1009 fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1010 self.projects
1011 .get(&project.entity_id())
1012 .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1013 }
1014
1015 pub fn refresh_prediction_from_buffer(
1016 &mut self,
1017 project: Entity<Project>,
1018 buffer: Entity<Buffer>,
1019 position: language::Anchor,
1020 cx: &mut Context<Self>,
1021 ) {
1022 self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
1023 let Some(request_task) = this
1024 .update(cx, |this, cx| {
1025 this.request_prediction(
1026 &project,
1027 &buffer,
1028 position,
1029 PredictEditsRequestTrigger::Other,
1030 cx,
1031 )
1032 })
1033 .log_err()
1034 else {
1035 return Task::ready(anyhow::Ok(None));
1036 };
1037
1038 cx.spawn(async move |_cx| {
1039 request_task.await.map(|prediction_result| {
1040 prediction_result.map(|prediction_result| {
1041 (
1042 prediction_result,
1043 PredictionRequestedBy::Buffer(buffer.entity_id()),
1044 )
1045 })
1046 })
1047 })
1048 })
1049 }
1050
1051 pub fn refresh_prediction_from_diagnostics(
1052 &mut self,
1053 project: Entity<Project>,
1054 cx: &mut Context<Self>,
1055 ) {
1056 let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1057 return;
1058 };
1059
1060 // Prefer predictions from buffer
1061 if zeta_project.current_prediction.is_some() {
1062 return;
1063 };
1064
1065 self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
1066 let Some(open_buffer_task) = project
1067 .update(cx, |project, cx| {
1068 project
1069 .active_entry()
1070 .and_then(|entry| project.path_for_entry(entry, cx))
1071 .map(|path| project.open_buffer(path, cx))
1072 })
1073 .log_err()
1074 .flatten()
1075 else {
1076 return Task::ready(anyhow::Ok(None));
1077 };
1078
1079 cx.spawn(async move |cx| {
1080 let active_buffer = open_buffer_task.await?;
1081 let snapshot = active_buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
1082
1083 let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1084 active_buffer,
1085 &snapshot,
1086 Default::default(),
1087 Default::default(),
1088 &project,
1089 cx,
1090 )
1091 .await?
1092 else {
1093 return anyhow::Ok(None);
1094 };
1095
1096 let Some(prediction_result) = this
1097 .update(cx, |this, cx| {
1098 this.request_prediction(
1099 &project,
1100 &jump_buffer,
1101 jump_position,
1102 PredictEditsRequestTrigger::Diagnostics,
1103 cx,
1104 )
1105 })?
1106 .await?
1107 else {
1108 return anyhow::Ok(None);
1109 };
1110
1111 this.update(cx, |this, cx| {
1112 Some((
1113 if this
1114 .get_or_init_zeta_project(&project, cx)
1115 .current_prediction
1116 .is_none()
1117 {
1118 prediction_result
1119 } else {
1120 EditPredictionResult {
1121 id: prediction_result.id,
1122 prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1123 }
1124 },
1125 PredictionRequestedBy::DiagnosticsUpdate,
1126 ))
1127 })
1128 })
1129 });
1130 }
1131
1132 #[cfg(not(test))]
1133 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1134 #[cfg(test)]
1135 pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
1136
1137 fn queue_prediction_refresh(
1138 &mut self,
1139 project: Entity<Project>,
1140 throttle_entity: EntityId,
1141 cx: &mut Context<Self>,
1142 do_refresh: impl FnOnce(
1143 WeakEntity<Self>,
1144 &mut AsyncApp,
1145 )
1146 -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1147 + 'static,
1148 ) {
1149 let zeta_project = self.get_or_init_zeta_project(&project, cx);
1150 let pending_prediction_id = zeta_project.next_pending_prediction_id;
1151 zeta_project.next_pending_prediction_id += 1;
1152 let last_request = zeta_project.last_prediction_refresh;
1153
1154 let task = cx.spawn(async move |this, cx| {
1155 if let Some((last_entity, last_timestamp)) = last_request
1156 && throttle_entity == last_entity
1157 && let Some(timeout) =
1158 (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
1159 {
1160 cx.background_executor().timer(timeout).await;
1161 }
1162
1163 // If this task was cancelled before the throttle timeout expired,
1164 // do not perform a request.
1165 let mut is_cancelled = true;
1166 this.update(cx, |this, cx| {
1167 let project_state = this.get_or_init_zeta_project(&project, cx);
1168 if !project_state
1169 .cancelled_predictions
1170 .remove(&pending_prediction_id)
1171 {
1172 project_state.last_prediction_refresh = Some((throttle_entity, Instant::now()));
1173 is_cancelled = false;
1174 }
1175 })
1176 .ok();
1177 if is_cancelled {
1178 return None;
1179 }
1180
1181 let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
1182 let new_prediction_id = new_prediction_result
1183 .as_ref()
1184 .map(|(prediction, _)| prediction.id.clone());
1185
1186 // When a prediction completes, remove it from the pending list, and cancel
1187 // any pending predictions that were enqueued before it.
1188 this.update(cx, |this, cx| {
1189 let zeta_project = this.get_or_init_zeta_project(&project, cx);
1190
1191 let is_cancelled = zeta_project
1192 .cancelled_predictions
1193 .remove(&pending_prediction_id);
1194
1195 let new_current_prediction = if !is_cancelled
1196 && let Some((prediction_result, requested_by)) = new_prediction_result
1197 {
1198 match prediction_result.prediction {
1199 Ok(prediction) => {
1200 let new_prediction = CurrentEditPrediction {
1201 requested_by,
1202 prediction,
1203 was_shown: false,
1204 };
1205
1206 if let Some(current_prediction) =
1207 zeta_project.current_prediction.as_ref()
1208 {
1209 if new_prediction.should_replace_prediction(¤t_prediction, cx)
1210 {
1211 this.reject_current_prediction(
1212 EditPredictionRejectReason::Replaced,
1213 &project,
1214 cx,
1215 );
1216
1217 Some(new_prediction)
1218 } else {
1219 this.reject_prediction(
1220 new_prediction.prediction.id,
1221 EditPredictionRejectReason::CurrentPreferred,
1222 false,
1223 cx,
1224 );
1225 None
1226 }
1227 } else {
1228 Some(new_prediction)
1229 }
1230 }
1231 Err(reject_reason) => {
1232 this.reject_prediction(prediction_result.id, reject_reason, false, cx);
1233 None
1234 }
1235 }
1236 } else {
1237 None
1238 };
1239
1240 let zeta_project = this.get_or_init_zeta_project(&project, cx);
1241
1242 if let Some(new_prediction) = new_current_prediction {
1243 zeta_project.current_prediction = Some(new_prediction);
1244 }
1245
1246 let mut pending_predictions = mem::take(&mut zeta_project.pending_predictions);
1247 for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1248 if pending_prediction.id == pending_prediction_id {
1249 pending_predictions.remove(ix);
1250 for pending_prediction in pending_predictions.drain(0..ix) {
1251 zeta_project.cancel_pending_prediction(pending_prediction, cx)
1252 }
1253 break;
1254 }
1255 }
1256 this.get_or_init_zeta_project(&project, cx)
1257 .pending_predictions = pending_predictions;
1258 cx.notify();
1259 })
1260 .ok();
1261
1262 new_prediction_id
1263 });
1264
1265 if zeta_project.pending_predictions.len() <= 1 {
1266 zeta_project.pending_predictions.push(PendingPrediction {
1267 id: pending_prediction_id,
1268 task,
1269 });
1270 } else if zeta_project.pending_predictions.len() == 2 {
1271 let pending_prediction = zeta_project.pending_predictions.pop().unwrap();
1272 zeta_project.pending_predictions.push(PendingPrediction {
1273 id: pending_prediction_id,
1274 task,
1275 });
1276 zeta_project.cancel_pending_prediction(pending_prediction, cx);
1277 }
1278 }
1279
1280 pub fn request_prediction(
1281 &mut self,
1282 project: &Entity<Project>,
1283 active_buffer: &Entity<Buffer>,
1284 position: language::Anchor,
1285 trigger: PredictEditsRequestTrigger,
1286 cx: &mut Context<Self>,
1287 ) -> Task<Result<Option<EditPredictionResult>>> {
1288 self.request_prediction_internal(
1289 project.clone(),
1290 active_buffer.clone(),
1291 position,
1292 trigger,
1293 cx.has_flag::<Zeta2FeatureFlag>(),
1294 cx,
1295 )
1296 }
1297
1298 fn request_prediction_internal(
1299 &mut self,
1300 project: Entity<Project>,
1301 active_buffer: Entity<Buffer>,
1302 position: language::Anchor,
1303 trigger: PredictEditsRequestTrigger,
1304 allow_jump: bool,
1305 cx: &mut Context<Self>,
1306 ) -> Task<Result<Option<EditPredictionResult>>> {
1307 const DIAGNOSTIC_LINES_RANGE: u32 = 20;
1308
1309 self.get_or_init_zeta_project(&project, cx);
1310 let zeta_project = self.projects.get(&project.entity_id()).unwrap();
1311 let events = zeta_project.events(cx);
1312 let has_events = !events.is_empty();
1313
1314 let snapshot = active_buffer.read(cx).snapshot();
1315 let cursor_point = position.to_point(&snapshot);
1316 let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1317 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1318 let diagnostic_search_range =
1319 Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1320
1321 let task = match self.edit_prediction_model {
1322 ZetaEditPredictionModel::Zeta1 => request_prediction_with_zeta1(
1323 self,
1324 &project,
1325 &active_buffer,
1326 snapshot.clone(),
1327 position,
1328 events,
1329 trigger,
1330 cx,
1331 ),
1332 ZetaEditPredictionModel::Zeta2 => self.request_prediction_with_zeta2(
1333 &project,
1334 &active_buffer,
1335 snapshot.clone(),
1336 position,
1337 events,
1338 trigger,
1339 cx,
1340 ),
1341 ZetaEditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(
1342 &project,
1343 &active_buffer,
1344 snapshot.clone(),
1345 position,
1346 events,
1347 &zeta_project.recent_paths,
1348 diagnostic_search_range.clone(),
1349 cx,
1350 ),
1351 };
1352
1353 cx.spawn(async move |this, cx| {
1354 let prediction = task.await?;
1355
1356 if prediction.is_none() && allow_jump {
1357 let cursor_point = position.to_point(&snapshot);
1358 if has_events
1359 && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1360 active_buffer.clone(),
1361 &snapshot,
1362 diagnostic_search_range,
1363 cursor_point,
1364 &project,
1365 cx,
1366 )
1367 .await?
1368 {
1369 return this
1370 .update(cx, |this, cx| {
1371 this.request_prediction_internal(
1372 project,
1373 jump_buffer,
1374 jump_position,
1375 trigger,
1376 false,
1377 cx,
1378 )
1379 })?
1380 .await;
1381 }
1382
1383 return anyhow::Ok(None);
1384 }
1385
1386 Ok(prediction)
1387 })
1388 }
1389
1390 async fn next_diagnostic_location(
1391 active_buffer: Entity<Buffer>,
1392 active_buffer_snapshot: &BufferSnapshot,
1393 active_buffer_diagnostic_search_range: Range<Point>,
1394 active_buffer_cursor_point: Point,
1395 project: &Entity<Project>,
1396 cx: &mut AsyncApp,
1397 ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1398 // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
1399 let mut jump_location = active_buffer_snapshot
1400 .diagnostic_groups(None)
1401 .into_iter()
1402 .filter_map(|(_, group)| {
1403 let range = &group.entries[group.primary_ix]
1404 .range
1405 .to_point(&active_buffer_snapshot);
1406 if range.overlaps(&active_buffer_diagnostic_search_range) {
1407 None
1408 } else {
1409 Some(range.start)
1410 }
1411 })
1412 .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
1413 .map(|position| {
1414 (
1415 active_buffer.clone(),
1416 active_buffer_snapshot.anchor_before(position),
1417 )
1418 });
1419
1420 if jump_location.is_none() {
1421 let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
1422 let file = buffer.file()?;
1423
1424 Some(ProjectPath {
1425 worktree_id: file.worktree_id(cx),
1426 path: file.path().clone(),
1427 })
1428 })?;
1429
1430 let buffer_task = project.update(cx, |project, cx| {
1431 let (path, _, _) = project
1432 .diagnostic_summaries(false, cx)
1433 .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
1434 .max_by_key(|(path, _, _)| {
1435 // find the buffer with errors that shares most parent directories
1436 path.path
1437 .components()
1438 .zip(
1439 active_buffer_path
1440 .as_ref()
1441 .map(|p| p.path.components())
1442 .unwrap_or_default(),
1443 )
1444 .take_while(|(a, b)| a == b)
1445 .count()
1446 })?;
1447
1448 Some(project.open_buffer(path, cx))
1449 })?;
1450
1451 if let Some(buffer_task) = buffer_task {
1452 let closest_buffer = buffer_task.await?;
1453
1454 jump_location = closest_buffer
1455 .read_with(cx, |buffer, _cx| {
1456 buffer
1457 .buffer_diagnostics(None)
1458 .into_iter()
1459 .min_by_key(|entry| entry.diagnostic.severity)
1460 .map(|entry| entry.range.start)
1461 })?
1462 .map(|position| (closest_buffer, position));
1463 }
1464 }
1465
1466 anyhow::Ok(jump_location)
1467 }
1468
1469 fn request_prediction_with_zeta2(
1470 &mut self,
1471 project: &Entity<Project>,
1472 active_buffer: &Entity<Buffer>,
1473 active_snapshot: BufferSnapshot,
1474 position: language::Anchor,
1475 events: Vec<Arc<Event>>,
1476 trigger: PredictEditsRequestTrigger,
1477 cx: &mut Context<Self>,
1478 ) -> Task<Result<Option<EditPredictionResult>>> {
1479 let project_state = self.projects.get(&project.entity_id());
1480
1481 let index_state = project_state.and_then(|state| {
1482 state
1483 .syntax_index
1484 .as_ref()
1485 .map(|syntax_index| syntax_index.read_with(cx, |index, _cx| index.state().clone()))
1486 });
1487 let options = self.options.clone();
1488 let buffer_snapshotted_at = Instant::now();
1489 let Some(excerpt_path) = active_snapshot
1490 .file()
1491 .map(|path| -> Arc<Path> { path.full_path(cx).into() })
1492 else {
1493 return Task::ready(Err(anyhow!("No file path for excerpt")));
1494 };
1495 let client = self.client.clone();
1496 let llm_token = self.llm_token.clone();
1497 let app_version = AppVersion::global(cx);
1498 let worktree_snapshots = project
1499 .read(cx)
1500 .worktrees(cx)
1501 .map(|worktree| worktree.read(cx).snapshot())
1502 .collect::<Vec<_>>();
1503 let debug_tx = self.debug_tx.clone();
1504
1505 let diagnostics = active_snapshot.diagnostic_sets().clone();
1506
1507 let file = active_buffer.read(cx).file();
1508 let parent_abs_path = project::File::from_dyn(file).and_then(|f| {
1509 let mut path = f.worktree.read(cx).absolutize(&f.path);
1510 if path.pop() { Some(path) } else { None }
1511 });
1512
1513 // TODO data collection
1514 let can_collect_data = file
1515 .as_ref()
1516 .map_or(false, |file| self.can_collect_file(project, file, cx));
1517
1518 let empty_context_files = HashMap::default();
1519 let context_files = project_state
1520 .and_then(|project_state| project_state.context.as_ref())
1521 .unwrap_or(&empty_context_files);
1522
1523 #[cfg(feature = "eval-support")]
1524 let parsed_fut = futures::future::join_all(
1525 context_files
1526 .keys()
1527 .map(|buffer| buffer.read(cx).parsing_idle()),
1528 );
1529
1530 let mut included_files = context_files
1531 .iter()
1532 .filter_map(|(buffer_entity, ranges)| {
1533 let buffer = buffer_entity.read(cx);
1534 Some((
1535 buffer_entity.clone(),
1536 buffer.snapshot(),
1537 buffer.file()?.full_path(cx).into(),
1538 ranges.clone(),
1539 ))
1540 })
1541 .collect::<Vec<_>>();
1542
1543 included_files.sort_by(|(_, _, path_a, ranges_a), (_, _, path_b, ranges_b)| {
1544 (path_a, ranges_a.len()).cmp(&(path_b, ranges_b.len()))
1545 });
1546
1547 #[cfg(feature = "eval-support")]
1548 let eval_cache = self.eval_cache.clone();
1549
1550 let request_task = cx.background_spawn({
1551 let active_buffer = active_buffer.clone();
1552 async move {
1553 #[cfg(feature = "eval-support")]
1554 parsed_fut.await;
1555
1556 let index_state = if let Some(index_state) = index_state {
1557 Some(index_state.lock_owned().await)
1558 } else {
1559 None
1560 };
1561
1562 let cursor_offset = position.to_offset(&active_snapshot);
1563 let cursor_point = cursor_offset.to_point(&active_snapshot);
1564
1565 let before_retrieval = Instant::now();
1566
1567 let (diagnostic_groups, diagnostic_groups_truncated) =
1568 Self::gather_nearby_diagnostics(
1569 cursor_offset,
1570 &diagnostics,
1571 &active_snapshot,
1572 options.max_diagnostic_bytes,
1573 );
1574
1575 let cloud_request = match options.context {
1576 ContextMode::Agentic(context_options) => {
1577 let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
1578 cursor_point,
1579 &active_snapshot,
1580 &context_options.excerpt,
1581 index_state.as_deref(),
1582 ) else {
1583 return Ok((None, None));
1584 };
1585
1586 let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start)
1587 ..active_snapshot.anchor_before(excerpt.range.end);
1588
1589 if let Some(buffer_ix) =
1590 included_files.iter().position(|(_, snapshot, _, _)| {
1591 snapshot.remote_id() == active_snapshot.remote_id()
1592 })
1593 {
1594 let (_, buffer, _, ranges) = &mut included_files[buffer_ix];
1595 ranges.push(excerpt_anchor_range);
1596 retrieval_search::merge_anchor_ranges(ranges, buffer);
1597 let last_ix = included_files.len() - 1;
1598 included_files.swap(buffer_ix, last_ix);
1599 } else {
1600 included_files.push((
1601 active_buffer.clone(),
1602 active_snapshot.clone(),
1603 excerpt_path.clone(),
1604 vec![excerpt_anchor_range],
1605 ));
1606 }
1607
1608 let included_files = included_files
1609 .iter()
1610 .map(|(_, snapshot, path, ranges)| {
1611 let ranges = ranges
1612 .iter()
1613 .map(|range| {
1614 let point_range = range.to_point(&snapshot);
1615 Line(point_range.start.row)..Line(point_range.end.row)
1616 })
1617 .collect::<Vec<_>>();
1618 let excerpts = assemble_excerpts(&snapshot, ranges);
1619 predict_edits_v3::IncludedFile {
1620 path: path.clone(),
1621 max_row: Line(snapshot.max_point().row),
1622 excerpts,
1623 }
1624 })
1625 .collect::<Vec<_>>();
1626
1627 predict_edits_v3::PredictEditsRequest {
1628 excerpt_path,
1629 excerpt: String::new(),
1630 excerpt_line_range: Line(0)..Line(0),
1631 excerpt_range: 0..0,
1632 cursor_point: predict_edits_v3::Point {
1633 line: predict_edits_v3::Line(cursor_point.row),
1634 column: cursor_point.column,
1635 },
1636 included_files,
1637 referenced_declarations: vec![],
1638 events,
1639 can_collect_data,
1640 diagnostic_groups,
1641 diagnostic_groups_truncated,
1642 debug_info: debug_tx.is_some(),
1643 prompt_max_bytes: Some(options.max_prompt_bytes),
1644 prompt_format: options.prompt_format,
1645 // TODO [zeta2]
1646 signatures: vec![],
1647 excerpt_parent: None,
1648 git_info: None,
1649 trigger,
1650 }
1651 }
1652 ContextMode::Syntax(context_options) => {
1653 let Some(context) = EditPredictionContext::gather_context(
1654 cursor_point,
1655 &active_snapshot,
1656 parent_abs_path.as_deref(),
1657 &context_options,
1658 index_state.as_deref(),
1659 ) else {
1660 return Ok((None, None));
1661 };
1662
1663 make_syntax_context_cloud_request(
1664 excerpt_path,
1665 context,
1666 events,
1667 can_collect_data,
1668 diagnostic_groups,
1669 diagnostic_groups_truncated,
1670 None,
1671 debug_tx.is_some(),
1672 &worktree_snapshots,
1673 index_state.as_deref(),
1674 Some(options.max_prompt_bytes),
1675 options.prompt_format,
1676 trigger,
1677 )
1678 }
1679 };
1680
1681 let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request);
1682
1683 let inputs = EditPredictionInputs {
1684 included_files: cloud_request.included_files,
1685 events: cloud_request.events,
1686 cursor_point: cloud_request.cursor_point,
1687 cursor_path: cloud_request.excerpt_path,
1688 };
1689
1690 let retrieval_time = Instant::now() - before_retrieval;
1691
1692 let debug_response_tx = if let Some(debug_tx) = &debug_tx {
1693 let (response_tx, response_rx) = oneshot::channel();
1694
1695 debug_tx
1696 .unbounded_send(ZetaDebugInfo::EditPredictionRequested(
1697 ZetaEditPredictionDebugInfo {
1698 inputs: inputs.clone(),
1699 retrieval_time,
1700 buffer: active_buffer.downgrade(),
1701 local_prompt: match prompt_result.as_ref() {
1702 Ok((prompt, _)) => Ok(prompt.clone()),
1703 Err(err) => Err(err.to_string()),
1704 },
1705 position,
1706 response_rx,
1707 },
1708 ))
1709 .ok();
1710 Some(response_tx)
1711 } else {
1712 None
1713 };
1714
1715 if cfg!(debug_assertions) && env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
1716 if let Some(debug_response_tx) = debug_response_tx {
1717 debug_response_tx
1718 .send((Err("Request skipped".to_string()), Duration::ZERO))
1719 .ok();
1720 }
1721 anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
1722 }
1723
1724 let (prompt, _) = prompt_result?;
1725 let generation_params =
1726 cloud_zeta2_prompt::generation_params(cloud_request.prompt_format);
1727 let request = open_ai::Request {
1728 model: EDIT_PREDICTIONS_MODEL_ID.clone(),
1729 messages: vec![open_ai::RequestMessage::User {
1730 content: open_ai::MessageContent::Plain(prompt),
1731 }],
1732 stream: false,
1733 max_completion_tokens: None,
1734 stop: generation_params.stop.unwrap_or_default(),
1735 temperature: generation_params.temperature.unwrap_or(0.7),
1736 tool_choice: None,
1737 parallel_tool_calls: None,
1738 tools: vec![],
1739 prompt_cache_key: None,
1740 reasoning_effort: None,
1741 };
1742
1743 log::trace!("Sending edit prediction request");
1744
1745 let before_request = Instant::now();
1746 let response = Self::send_raw_llm_request(
1747 request,
1748 client,
1749 llm_token,
1750 app_version,
1751 #[cfg(feature = "eval-support")]
1752 eval_cache,
1753 #[cfg(feature = "eval-support")]
1754 EvalCacheEntryKind::Prediction,
1755 )
1756 .await;
1757 let received_response_at = Instant::now();
1758 let request_time = received_response_at - before_request;
1759
1760 log::trace!("Got edit prediction response");
1761
1762 if let Some(debug_response_tx) = debug_response_tx {
1763 debug_response_tx
1764 .send((
1765 response
1766 .as_ref()
1767 .map_err(|err| err.to_string())
1768 .map(|response| response.0.clone()),
1769 request_time,
1770 ))
1771 .ok();
1772 }
1773
1774 let (res, usage) = response?;
1775 let request_id = EditPredictionId(res.id.clone().into());
1776 let Some(mut output_text) = text_from_response(res) else {
1777 return Ok((Some((request_id, None)), usage));
1778 };
1779
1780 if output_text.contains(CURSOR_MARKER) {
1781 log::trace!("Stripping out {CURSOR_MARKER} from response");
1782 output_text = output_text.replace(CURSOR_MARKER, "");
1783 }
1784
1785 let get_buffer_from_context = |path: &Path| {
1786 included_files
1787 .iter()
1788 .find_map(|(_, buffer, probe_path, ranges)| {
1789 if probe_path.as_ref() == path {
1790 Some((buffer, ranges.as_slice()))
1791 } else {
1792 None
1793 }
1794 })
1795 };
1796
1797 let (edited_buffer_snapshot, edits) = match options.prompt_format {
1798 PromptFormat::NumLinesUniDiff => {
1799 // TODO: Implement parsing of multi-file diffs
1800 crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
1801 }
1802 PromptFormat::Minimal
1803 | PromptFormat::MinimalQwen
1804 | PromptFormat::SeedCoder1120 => {
1805 if output_text.contains("--- a/\n+++ b/\nNo edits") {
1806 let edits = vec![];
1807 (&active_snapshot, edits)
1808 } else {
1809 crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
1810 }
1811 }
1812 PromptFormat::OldTextNewText => {
1813 crate::xml_edits::parse_xml_edits(&output_text, get_buffer_from_context)
1814 .await?
1815 }
1816 _ => {
1817 bail!("unsupported prompt format {}", options.prompt_format)
1818 }
1819 };
1820
1821 let edited_buffer = included_files
1822 .iter()
1823 .find_map(|(buffer, snapshot, _, _)| {
1824 if snapshot.remote_id() == edited_buffer_snapshot.remote_id() {
1825 Some(buffer.clone())
1826 } else {
1827 None
1828 }
1829 })
1830 .context("Failed to find buffer in included_buffers")?;
1831
1832 anyhow::Ok((
1833 Some((
1834 request_id,
1835 Some((
1836 inputs,
1837 edited_buffer,
1838 edited_buffer_snapshot.clone(),
1839 edits,
1840 received_response_at,
1841 )),
1842 )),
1843 usage,
1844 ))
1845 }
1846 });
1847
1848 cx.spawn({
1849 async move |this, cx| {
1850 let Some((id, prediction)) =
1851 Self::handle_api_response(&this, request_task.await, cx)?
1852 else {
1853 return Ok(None);
1854 };
1855
1856 let Some((
1857 inputs,
1858 edited_buffer,
1859 edited_buffer_snapshot,
1860 edits,
1861 received_response_at,
1862 )) = prediction
1863 else {
1864 return Ok(Some(EditPredictionResult {
1865 id,
1866 prediction: Err(EditPredictionRejectReason::Empty),
1867 }));
1868 };
1869
1870 // TODO telemetry: duration, etc
1871 Ok(Some(
1872 EditPredictionResult::new(
1873 id,
1874 &edited_buffer,
1875 &edited_buffer_snapshot,
1876 edits.into(),
1877 buffer_snapshotted_at,
1878 received_response_at,
1879 inputs,
1880 cx,
1881 )
1882 .await,
1883 ))
1884 }
1885 })
1886 }
1887
1888 async fn send_raw_llm_request(
1889 request: open_ai::Request,
1890 client: Arc<Client>,
1891 llm_token: LlmApiToken,
1892 app_version: Version,
1893 #[cfg(feature = "eval-support")] eval_cache: Option<Arc<dyn EvalCache>>,
1894 #[cfg(feature = "eval-support")] eval_cache_kind: EvalCacheEntryKind,
1895 ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
1896 let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
1897 http_client::Url::parse(&predict_edits_url)?
1898 } else {
1899 client
1900 .http_client()
1901 .build_zed_llm_url("/predict_edits/raw", &[])?
1902 };
1903
1904 #[cfg(feature = "eval-support")]
1905 let cache_key = if let Some(cache) = eval_cache {
1906 use collections::FxHasher;
1907 use std::hash::{Hash, Hasher};
1908
1909 let mut hasher = FxHasher::default();
1910 url.hash(&mut hasher);
1911 let request_str = serde_json::to_string_pretty(&request)?;
1912 request_str.hash(&mut hasher);
1913 let hash = hasher.finish();
1914
1915 let key = (eval_cache_kind, hash);
1916 if let Some(response_str) = cache.read(key) {
1917 return Ok((serde_json::from_str(&response_str)?, None));
1918 }
1919
1920 Some((cache, request_str, key))
1921 } else {
1922 None
1923 };
1924
1925 let (response, usage) = Self::send_api_request(
1926 |builder| {
1927 let req = builder
1928 .uri(url.as_ref())
1929 .body(serde_json::to_string(&request)?.into());
1930 Ok(req?)
1931 },
1932 client,
1933 llm_token,
1934 app_version,
1935 )
1936 .await?;
1937
1938 #[cfg(feature = "eval-support")]
1939 if let Some((cache, request, key)) = cache_key {
1940 cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
1941 }
1942
1943 Ok((response, usage))
1944 }
1945
1946 fn handle_api_response<T>(
1947 this: &WeakEntity<Self>,
1948 response: Result<(T, Option<EditPredictionUsage>)>,
1949 cx: &mut gpui::AsyncApp,
1950 ) -> Result<T> {
1951 match response {
1952 Ok((data, usage)) => {
1953 if let Some(usage) = usage {
1954 this.update(cx, |this, cx| {
1955 this.user_store.update(cx, |user_store, cx| {
1956 user_store.update_edit_prediction_usage(usage, cx);
1957 });
1958 })
1959 .ok();
1960 }
1961 Ok(data)
1962 }
1963 Err(err) => {
1964 if err.is::<ZedUpdateRequiredError>() {
1965 cx.update(|cx| {
1966 this.update(cx, |this, _cx| {
1967 this.update_required = true;
1968 })
1969 .ok();
1970
1971 let error_message: SharedString = err.to_string().into();
1972 show_app_notification(
1973 NotificationId::unique::<ZedUpdateRequiredError>(),
1974 cx,
1975 move |cx| {
1976 cx.new(|cx| {
1977 ErrorMessagePrompt::new(error_message.clone(), cx)
1978 .with_link_button("Update Zed", "https://zed.dev/releases")
1979 })
1980 },
1981 );
1982 })
1983 .ok();
1984 }
1985 Err(err)
1986 }
1987 }
1988 }
1989
1990 async fn send_api_request<Res>(
1991 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1992 client: Arc<Client>,
1993 llm_token: LlmApiToken,
1994 app_version: Version,
1995 ) -> Result<(Res, Option<EditPredictionUsage>)>
1996 where
1997 Res: DeserializeOwned,
1998 {
1999 let http_client = client.http_client();
2000 let mut token = llm_token.acquire(&client).await?;
2001 let mut did_retry = false;
2002
2003 loop {
2004 let request_builder = http_client::Request::builder().method(Method::POST);
2005
2006 let request = build(
2007 request_builder
2008 .header("Content-Type", "application/json")
2009 .header("Authorization", format!("Bearer {}", token))
2010 .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
2011 )?;
2012
2013 let mut response = http_client.send(request).await?;
2014
2015 if let Some(minimum_required_version) = response
2016 .headers()
2017 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2018 .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2019 {
2020 anyhow::ensure!(
2021 app_version >= minimum_required_version,
2022 ZedUpdateRequiredError {
2023 minimum_version: minimum_required_version
2024 }
2025 );
2026 }
2027
2028 if response.status().is_success() {
2029 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2030
2031 let mut body = Vec::new();
2032 response.body_mut().read_to_end(&mut body).await?;
2033 return Ok((serde_json::from_slice(&body)?, usage));
2034 } else if !did_retry
2035 && response
2036 .headers()
2037 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
2038 .is_some()
2039 {
2040 did_retry = true;
2041 token = llm_token.refresh(&client).await?;
2042 } else {
2043 let mut body = String::new();
2044 response.body_mut().read_to_string(&mut body).await?;
2045 anyhow::bail!(
2046 "Request failed with status: {:?}\nBody: {}",
2047 response.status(),
2048 body
2049 );
2050 }
2051 }
2052 }
2053
2054 pub const CONTEXT_RETRIEVAL_IDLE_DURATION: Duration = Duration::from_secs(10);
2055 pub const CONTEXT_RETRIEVAL_DEBOUNCE_DURATION: Duration = Duration::from_secs(3);
2056
2057 // Refresh the related excerpts when the user just beguns editing after
2058 // an idle period, and after they pause editing.
2059 fn refresh_context_if_needed(
2060 &mut self,
2061 project: &Entity<Project>,
2062 buffer: &Entity<language::Buffer>,
2063 cursor_position: language::Anchor,
2064 cx: &mut Context<Self>,
2065 ) {
2066 if !matches!(self.edit_prediction_model, ZetaEditPredictionModel::Zeta2) {
2067 return;
2068 }
2069
2070 if !matches!(&self.options().context, ContextMode::Agentic { .. }) {
2071 return;
2072 }
2073
2074 let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
2075 return;
2076 };
2077
2078 let now = Instant::now();
2079 let was_idle = zeta_project
2080 .refresh_context_timestamp
2081 .map_or(true, |timestamp| {
2082 now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION
2083 });
2084 zeta_project.refresh_context_timestamp = Some(now);
2085 zeta_project.refresh_context_debounce_task = Some(cx.spawn({
2086 let buffer = buffer.clone();
2087 let project = project.clone();
2088 async move |this, cx| {
2089 if was_idle {
2090 log::debug!("refetching edit prediction context after idle");
2091 } else {
2092 cx.background_executor()
2093 .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION)
2094 .await;
2095 log::debug!("refetching edit prediction context after pause");
2096 }
2097 this.update(cx, |this, cx| {
2098 let task = this.refresh_context(project.clone(), buffer, cursor_position, cx);
2099
2100 if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
2101 zeta_project.refresh_context_task = Some(task.log_err());
2102 };
2103 })
2104 .ok()
2105 }
2106 }));
2107 }
2108
2109 // Refresh the related excerpts asynchronously. Ensure the task runs to completion,
2110 // and avoid spawning more than one concurrent task.
2111 pub fn refresh_context(
2112 &mut self,
2113 project: Entity<Project>,
2114 buffer: Entity<language::Buffer>,
2115 cursor_position: language::Anchor,
2116 cx: &mut Context<Self>,
2117 ) -> Task<Result<()>> {
2118 let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
2119 return Task::ready(anyhow::Ok(()));
2120 };
2121
2122 let ContextMode::Agentic(options) = &self.options().context else {
2123 return Task::ready(anyhow::Ok(()));
2124 };
2125
2126 let snapshot = buffer.read(cx).snapshot();
2127 let cursor_point = cursor_position.to_point(&snapshot);
2128 let Some(cursor_excerpt) = EditPredictionExcerpt::select_from_buffer(
2129 cursor_point,
2130 &snapshot,
2131 &options.excerpt,
2132 None,
2133 ) else {
2134 return Task::ready(Ok(()));
2135 };
2136
2137 let app_version = AppVersion::global(cx);
2138 let client = self.client.clone();
2139 let llm_token = self.llm_token.clone();
2140 let debug_tx = self.debug_tx.clone();
2141 let current_file_path: Arc<Path> = snapshot
2142 .file()
2143 .map(|f| f.full_path(cx).into())
2144 .unwrap_or_else(|| Path::new("untitled").into());
2145
2146 let prompt = match cloud_zeta2_prompt::retrieval_prompt::build_prompt(
2147 predict_edits_v3::PlanContextRetrievalRequest {
2148 excerpt: cursor_excerpt.text(&snapshot).body,
2149 excerpt_path: current_file_path,
2150 excerpt_line_range: cursor_excerpt.line_range,
2151 cursor_file_max_row: Line(snapshot.max_point().row),
2152 events: zeta_project.events(cx),
2153 },
2154 ) {
2155 Ok(prompt) => prompt,
2156 Err(err) => {
2157 return Task::ready(Err(err));
2158 }
2159 };
2160
2161 if let Some(debug_tx) = &debug_tx {
2162 debug_tx
2163 .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted(
2164 ZetaContextRetrievalStartedDebugInfo {
2165 project: project.clone(),
2166 timestamp: Instant::now(),
2167 search_prompt: prompt.clone(),
2168 },
2169 ))
2170 .ok();
2171 }
2172
2173 pub static TOOL_SCHEMA: LazyLock<(serde_json::Value, String)> = LazyLock::new(|| {
2174 let schema = language_model::tool_schema::root_schema_for::<SearchToolInput>(
2175 language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset,
2176 );
2177
2178 let description = schema
2179 .get("description")
2180 .and_then(|description| description.as_str())
2181 .unwrap()
2182 .to_string();
2183
2184 (schema.into(), description)
2185 });
2186
2187 let (tool_schema, tool_description) = TOOL_SCHEMA.clone();
2188
2189 let request = open_ai::Request {
2190 model: CONTEXT_RETRIEVAL_MODEL_ID.clone(),
2191 messages: vec![open_ai::RequestMessage::User {
2192 content: open_ai::MessageContent::Plain(prompt),
2193 }],
2194 stream: false,
2195 max_completion_tokens: None,
2196 stop: Default::default(),
2197 temperature: 0.7,
2198 tool_choice: None,
2199 parallel_tool_calls: None,
2200 tools: vec![open_ai::ToolDefinition::Function {
2201 function: FunctionDefinition {
2202 name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME.to_string(),
2203 description: Some(tool_description),
2204 parameters: Some(tool_schema),
2205 },
2206 }],
2207 prompt_cache_key: None,
2208 reasoning_effort: None,
2209 };
2210
2211 #[cfg(feature = "eval-support")]
2212 let eval_cache = self.eval_cache.clone();
2213
2214 cx.spawn(async move |this, cx| {
2215 log::trace!("Sending search planning request");
2216 let response = Self::send_raw_llm_request(
2217 request,
2218 client,
2219 llm_token,
2220 app_version,
2221 #[cfg(feature = "eval-support")]
2222 eval_cache.clone(),
2223 #[cfg(feature = "eval-support")]
2224 EvalCacheEntryKind::Context,
2225 )
2226 .await;
2227 let mut response = Self::handle_api_response(&this, response, cx)?;
2228 log::trace!("Got search planning response");
2229
2230 let choice = response
2231 .choices
2232 .pop()
2233 .context("No choices in retrieval response")?;
2234 let open_ai::RequestMessage::Assistant {
2235 content: _,
2236 tool_calls,
2237 } = choice.message
2238 else {
2239 anyhow::bail!("Retrieval response didn't include an assistant message");
2240 };
2241
2242 let mut queries: Vec<SearchToolQuery> = Vec::new();
2243 for tool_call in tool_calls {
2244 let open_ai::ToolCallContent::Function { function } = tool_call.content;
2245 if function.name != cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME {
2246 log::warn!(
2247 "Context retrieval response tried to call an unknown tool: {}",
2248 function.name
2249 );
2250
2251 continue;
2252 }
2253
2254 let input: SearchToolInput = serde_json::from_str(&function.arguments)
2255 .with_context(|| format!("invalid search json {}", &function.arguments))?;
2256 queries.extend(input.queries);
2257 }
2258
2259 if let Some(debug_tx) = &debug_tx {
2260 debug_tx
2261 .unbounded_send(ZetaDebugInfo::SearchQueriesGenerated(
2262 ZetaSearchQueryDebugInfo {
2263 project: project.clone(),
2264 timestamp: Instant::now(),
2265 search_queries: queries.clone(),
2266 },
2267 ))
2268 .ok();
2269 }
2270
2271 log::trace!("Running retrieval search: {queries:#?}");
2272
2273 let related_excerpts_result = retrieval_search::run_retrieval_searches(
2274 queries,
2275 project.clone(),
2276 #[cfg(feature = "eval-support")]
2277 eval_cache,
2278 cx,
2279 )
2280 .await;
2281
2282 log::trace!("Search queries executed");
2283
2284 if let Some(debug_tx) = &debug_tx {
2285 debug_tx
2286 .unbounded_send(ZetaDebugInfo::SearchQueriesExecuted(
2287 ZetaContextRetrievalDebugInfo {
2288 project: project.clone(),
2289 timestamp: Instant::now(),
2290 },
2291 ))
2292 .ok();
2293 }
2294
2295 this.update(cx, |this, _cx| {
2296 let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
2297 return Ok(());
2298 };
2299 zeta_project.refresh_context_task.take();
2300 if let Some(debug_tx) = &this.debug_tx {
2301 debug_tx
2302 .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished(
2303 ZetaContextRetrievalDebugInfo {
2304 project,
2305 timestamp: Instant::now(),
2306 },
2307 ))
2308 .ok();
2309 }
2310 match related_excerpts_result {
2311 Ok(excerpts) => {
2312 zeta_project.context = Some(excerpts);
2313 Ok(())
2314 }
2315 Err(error) => Err(error),
2316 }
2317 })?
2318 })
2319 }
2320
2321 pub fn set_context(
2322 &mut self,
2323 project: Entity<Project>,
2324 context: HashMap<Entity<Buffer>, Vec<Range<Anchor>>>,
2325 ) {
2326 if let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) {
2327 zeta_project.context = Some(context);
2328 }
2329 }
2330
2331 fn gather_nearby_diagnostics(
2332 cursor_offset: usize,
2333 diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
2334 snapshot: &BufferSnapshot,
2335 max_diagnostics_bytes: usize,
2336 ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
2337 // TODO: Could make this more efficient
2338 let mut diagnostic_groups = Vec::new();
2339 for (language_server_id, diagnostics) in diagnostic_sets {
2340 let mut groups = Vec::new();
2341 diagnostics.groups(*language_server_id, &mut groups, &snapshot);
2342 diagnostic_groups.extend(
2343 groups
2344 .into_iter()
2345 .map(|(_, group)| group.resolve::<usize>(&snapshot)),
2346 );
2347 }
2348
2349 // sort by proximity to cursor
2350 diagnostic_groups.sort_by_key(|group| {
2351 let range = &group.entries[group.primary_ix].range;
2352 if range.start >= cursor_offset {
2353 range.start - cursor_offset
2354 } else if cursor_offset >= range.end {
2355 cursor_offset - range.end
2356 } else {
2357 (cursor_offset - range.start).min(range.end - cursor_offset)
2358 }
2359 });
2360
2361 let mut results = Vec::new();
2362 let mut diagnostic_groups_truncated = false;
2363 let mut diagnostics_byte_count = 0;
2364 for group in diagnostic_groups {
2365 let raw_value = serde_json::value::to_raw_value(&group).unwrap();
2366 diagnostics_byte_count += raw_value.get().len();
2367 if diagnostics_byte_count > max_diagnostics_bytes {
2368 diagnostic_groups_truncated = true;
2369 break;
2370 }
2371 results.push(predict_edits_v3::DiagnosticGroup(raw_value));
2372 }
2373
2374 (results, diagnostic_groups_truncated)
2375 }
2376
2377 // TODO: Dedupe with similar code in request_prediction?
2378 pub fn cloud_request_for_zeta_cli(
2379 &mut self,
2380 project: &Entity<Project>,
2381 buffer: &Entity<Buffer>,
2382 position: language::Anchor,
2383 cx: &mut Context<Self>,
2384 ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
2385 let project_state = self.projects.get(&project.entity_id());
2386
2387 let index_state = project_state.and_then(|state| {
2388 state
2389 .syntax_index
2390 .as_ref()
2391 .map(|index| index.read_with(cx, |index, _cx| index.state().clone()))
2392 });
2393 let options = self.options.clone();
2394 let snapshot = buffer.read(cx).snapshot();
2395 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
2396 return Task::ready(Err(anyhow!("No file path for excerpt")));
2397 };
2398 let worktree_snapshots = project
2399 .read(cx)
2400 .worktrees(cx)
2401 .map(|worktree| worktree.read(cx).snapshot())
2402 .collect::<Vec<_>>();
2403
2404 let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
2405 let mut path = f.worktree.read(cx).absolutize(&f.path);
2406 if path.pop() { Some(path) } else { None }
2407 });
2408
2409 cx.background_spawn(async move {
2410 let index_state = if let Some(index_state) = index_state {
2411 Some(index_state.lock_owned().await)
2412 } else {
2413 None
2414 };
2415
2416 let cursor_point = position.to_point(&snapshot);
2417
2418 let debug_info = true;
2419 EditPredictionContext::gather_context(
2420 cursor_point,
2421 &snapshot,
2422 parent_abs_path.as_deref(),
2423 match &options.context {
2424 ContextMode::Agentic(_) => {
2425 // TODO
2426 panic!("Llm mode not supported in zeta cli yet");
2427 }
2428 ContextMode::Syntax(edit_prediction_context_options) => {
2429 edit_prediction_context_options
2430 }
2431 },
2432 index_state.as_deref(),
2433 )
2434 .context("Failed to select excerpt")
2435 .map(|context| {
2436 make_syntax_context_cloud_request(
2437 excerpt_path.into(),
2438 context,
2439 // TODO pass everything
2440 Vec::new(),
2441 false,
2442 Vec::new(),
2443 false,
2444 None,
2445 debug_info,
2446 &worktree_snapshots,
2447 index_state.as_deref(),
2448 Some(options.max_prompt_bytes),
2449 options.prompt_format,
2450 PredictEditsRequestTrigger::Other,
2451 )
2452 })
2453 })
2454 }
2455
2456 pub fn wait_for_initial_indexing(
2457 &mut self,
2458 project: &Entity<Project>,
2459 cx: &mut Context<Self>,
2460 ) -> Task<Result<()>> {
2461 let zeta_project = self.get_or_init_zeta_project(project, cx);
2462 if let Some(syntax_index) = &zeta_project.syntax_index {
2463 syntax_index.read(cx).wait_for_initial_file_indexing(cx)
2464 } else {
2465 Task::ready(Ok(()))
2466 }
2467 }
2468
2469 fn is_file_open_source(
2470 &self,
2471 project: &Entity<Project>,
2472 file: &Arc<dyn File>,
2473 cx: &App,
2474 ) -> bool {
2475 if !file.is_local() || file.is_private() {
2476 return false;
2477 }
2478 let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
2479 return false;
2480 };
2481 zeta_project
2482 .license_detection_watchers
2483 .get(&file.worktree_id(cx))
2484 .as_ref()
2485 .is_some_and(|watcher| watcher.is_project_open_source())
2486 }
2487
2488 fn can_collect_file(&self, project: &Entity<Project>, file: &Arc<dyn File>, cx: &App) -> bool {
2489 self.data_collection_choice.is_enabled() && self.is_file_open_source(project, file, cx)
2490 }
2491
2492 fn can_collect_events(&self, events: &[Arc<Event>]) -> bool {
2493 if !self.data_collection_choice.is_enabled() {
2494 return false;
2495 }
2496 events.iter().all(|event| {
2497 matches!(
2498 event.as_ref(),
2499 Event::BufferChange {
2500 in_open_source_repo: true,
2501 ..
2502 }
2503 )
2504 })
2505 }
2506
2507 fn load_data_collection_choice() -> DataCollectionChoice {
2508 let choice = KEY_VALUE_STORE
2509 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2510 .log_err()
2511 .flatten();
2512
2513 match choice.as_deref() {
2514 Some("true") => DataCollectionChoice::Enabled,
2515 Some("false") => DataCollectionChoice::Disabled,
2516 Some(_) => {
2517 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2518 DataCollectionChoice::NotAnswered
2519 }
2520 None => DataCollectionChoice::NotAnswered,
2521 }
2522 }
2523
2524 pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2525 self.shown_predictions.iter()
2526 }
2527
2528 pub fn shown_completions_len(&self) -> usize {
2529 self.shown_predictions.len()
2530 }
2531
2532 pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2533 self.rated_predictions.contains(id)
2534 }
2535
2536 pub fn rate_prediction(
2537 &mut self,
2538 prediction: &EditPrediction,
2539 rating: EditPredictionRating,
2540 feedback: String,
2541 cx: &mut Context<Self>,
2542 ) {
2543 self.rated_predictions.insert(prediction.id.clone());
2544 telemetry::event!(
2545 "Edit Prediction Rated",
2546 rating,
2547 inputs = prediction.inputs,
2548 output = prediction.edit_preview.as_unified_diff(&prediction.edits),
2549 feedback
2550 );
2551 self.client.telemetry().flush_events().detach();
2552 cx.notify();
2553 }
2554}
2555
2556pub fn text_from_response(mut res: open_ai::Response) -> Option<String> {
2557 let choice = res.choices.pop()?;
2558 let output_text = match choice.message {
2559 open_ai::RequestMessage::Assistant {
2560 content: Some(open_ai::MessageContent::Plain(content)),
2561 ..
2562 } => content,
2563 open_ai::RequestMessage::Assistant {
2564 content: Some(open_ai::MessageContent::Multipart(mut content)),
2565 ..
2566 } => {
2567 if content.is_empty() {
2568 log::error!("No output from Baseten completion response");
2569 return None;
2570 }
2571
2572 match content.remove(0) {
2573 open_ai::MessagePart::Text { text } => text,
2574 open_ai::MessagePart::Image { .. } => {
2575 log::error!("Expected text, got an image");
2576 return None;
2577 }
2578 }
2579 }
2580 _ => {
2581 log::error!("Invalid response message: {:?}", choice.message);
2582 return None;
2583 }
2584 };
2585 Some(output_text)
2586}
2587
2588#[derive(Error, Debug)]
2589#[error(
2590 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2591)]
2592pub struct ZedUpdateRequiredError {
2593 minimum_version: Version,
2594}
2595
2596fn make_syntax_context_cloud_request(
2597 excerpt_path: Arc<Path>,
2598 context: EditPredictionContext,
2599 events: Vec<Arc<predict_edits_v3::Event>>,
2600 can_collect_data: bool,
2601 diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
2602 diagnostic_groups_truncated: bool,
2603 git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
2604 debug_info: bool,
2605 worktrees: &Vec<worktree::Snapshot>,
2606 index_state: Option<&SyntaxIndexState>,
2607 prompt_max_bytes: Option<usize>,
2608 prompt_format: PromptFormat,
2609 trigger: PredictEditsRequestTrigger,
2610) -> predict_edits_v3::PredictEditsRequest {
2611 let mut signatures = Vec::new();
2612 let mut declaration_to_signature_index = HashMap::default();
2613 let mut referenced_declarations = Vec::new();
2614
2615 for snippet in context.declarations {
2616 let project_entry_id = snippet.declaration.project_entry_id();
2617 let Some(path) = worktrees.iter().find_map(|worktree| {
2618 worktree.entry_for_id(project_entry_id).map(|entry| {
2619 let mut full_path = RelPathBuf::new();
2620 full_path.push(worktree.root_name());
2621 full_path.push(&entry.path);
2622 full_path
2623 })
2624 }) else {
2625 continue;
2626 };
2627
2628 let parent_index = index_state.and_then(|index_state| {
2629 snippet.declaration.parent().and_then(|parent| {
2630 add_signature(
2631 parent,
2632 &mut declaration_to_signature_index,
2633 &mut signatures,
2634 index_state,
2635 )
2636 })
2637 });
2638
2639 let (text, text_is_truncated) = snippet.declaration.item_text();
2640 referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
2641 path: path.as_std_path().into(),
2642 text: text.into(),
2643 range: snippet.declaration.item_line_range(),
2644 text_is_truncated,
2645 signature_range: snippet.declaration.signature_range_in_item_text(),
2646 parent_index,
2647 signature_score: snippet.score(DeclarationStyle::Signature),
2648 declaration_score: snippet.score(DeclarationStyle::Declaration),
2649 score_components: snippet.components,
2650 });
2651 }
2652
2653 let excerpt_parent = index_state.and_then(|index_state| {
2654 context
2655 .excerpt
2656 .parent_declarations
2657 .last()
2658 .and_then(|(parent, _)| {
2659 add_signature(
2660 *parent,
2661 &mut declaration_to_signature_index,
2662 &mut signatures,
2663 index_state,
2664 )
2665 })
2666 });
2667
2668 predict_edits_v3::PredictEditsRequest {
2669 excerpt_path,
2670 excerpt: context.excerpt_text.body,
2671 excerpt_line_range: context.excerpt.line_range,
2672 excerpt_range: context.excerpt.range,
2673 cursor_point: predict_edits_v3::Point {
2674 line: predict_edits_v3::Line(context.cursor_point.row),
2675 column: context.cursor_point.column,
2676 },
2677 referenced_declarations,
2678 included_files: vec![],
2679 signatures,
2680 excerpt_parent,
2681 events,
2682 can_collect_data,
2683 diagnostic_groups,
2684 diagnostic_groups_truncated,
2685 git_info,
2686 debug_info,
2687 prompt_max_bytes,
2688 prompt_format,
2689 trigger,
2690 }
2691}
2692
2693fn add_signature(
2694 declaration_id: DeclarationId,
2695 declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
2696 signatures: &mut Vec<Signature>,
2697 index: &SyntaxIndexState,
2698) -> Option<usize> {
2699 if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
2700 return Some(*signature_index);
2701 }
2702 let Some(parent_declaration) = index.declaration(declaration_id) else {
2703 log::error!("bug: missing parent declaration");
2704 return None;
2705 };
2706 let parent_index = parent_declaration.parent().and_then(|parent| {
2707 add_signature(parent, declaration_to_signature_index, signatures, index)
2708 });
2709 let (text, text_is_truncated) = parent_declaration.signature_text();
2710 let signature_index = signatures.len();
2711 signatures.push(Signature {
2712 text: text.into(),
2713 text_is_truncated,
2714 parent_index,
2715 range: parent_declaration.signature_line_range(),
2716 });
2717 declaration_to_signature_index.insert(declaration_id, signature_index);
2718 Some(signature_index)
2719}
2720
2721#[cfg(feature = "eval-support")]
2722pub type EvalCacheKey = (EvalCacheEntryKind, u64);
2723
2724#[cfg(feature = "eval-support")]
2725#[derive(Debug, Clone, Copy, PartialEq)]
2726pub enum EvalCacheEntryKind {
2727 Context,
2728 Search,
2729 Prediction,
2730}
2731
2732#[cfg(feature = "eval-support")]
2733impl std::fmt::Display for EvalCacheEntryKind {
2734 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2735 match self {
2736 EvalCacheEntryKind::Search => write!(f, "search"),
2737 EvalCacheEntryKind::Context => write!(f, "context"),
2738 EvalCacheEntryKind::Prediction => write!(f, "prediction"),
2739 }
2740 }
2741}
2742
2743#[cfg(feature = "eval-support")]
2744pub trait EvalCache: Send + Sync {
2745 fn read(&self, key: EvalCacheKey) -> Option<String>;
2746 fn write(&self, key: EvalCacheKey, input: &str, value: &str);
2747}
2748
2749#[derive(Debug, Clone, Copy)]
2750pub enum DataCollectionChoice {
2751 NotAnswered,
2752 Enabled,
2753 Disabled,
2754}
2755
2756impl DataCollectionChoice {
2757 pub fn is_enabled(self) -> bool {
2758 match self {
2759 Self::Enabled => true,
2760 Self::NotAnswered | Self::Disabled => false,
2761 }
2762 }
2763
2764 pub fn is_answered(self) -> bool {
2765 match self {
2766 Self::Enabled | Self::Disabled => true,
2767 Self::NotAnswered => false,
2768 }
2769 }
2770
2771 #[must_use]
2772 pub fn toggle(&self) -> DataCollectionChoice {
2773 match self {
2774 Self::Enabled => Self::Disabled,
2775 Self::Disabled => Self::Enabled,
2776 Self::NotAnswered => Self::Enabled,
2777 }
2778 }
2779}
2780
2781impl From<bool> for DataCollectionChoice {
2782 fn from(value: bool) -> Self {
2783 match value {
2784 true => DataCollectionChoice::Enabled,
2785 false => DataCollectionChoice::Disabled,
2786 }
2787 }
2788}
2789
2790struct ZedPredictUpsell;
2791
2792impl Dismissable for ZedPredictUpsell {
2793 const KEY: &'static str = "dismissed-edit-predict-upsell";
2794
2795 fn dismissed() -> bool {
2796 // To make this backwards compatible with older versions of Zed, we
2797 // check if the user has seen the previous Edit Prediction Onboarding
2798 // before, by checking the data collection choice which was written to
2799 // the database once the user clicked on "Accept and Enable"
2800 if KEY_VALUE_STORE
2801 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2802 .log_err()
2803 .is_some_and(|s| s.is_some())
2804 {
2805 return true;
2806 }
2807
2808 KEY_VALUE_STORE
2809 .read_kvp(Self::KEY)
2810 .log_err()
2811 .is_some_and(|s| s.is_some())
2812 }
2813}
2814
2815pub fn should_show_upsell_modal() -> bool {
2816 !ZedPredictUpsell::dismissed()
2817}
2818
2819pub fn init(cx: &mut App) {
2820 feature_gate_predict_edits_actions(cx);
2821
2822 cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2823 workspace.register_action(|workspace, _: &RateCompletions, window, cx| {
2824 if cx.has_flag::<PredictEditsRateCompletionsFeatureFlag>() {
2825 RatePredictionsModal::toggle(workspace, window, cx);
2826 }
2827 });
2828
2829 workspace.register_action(
2830 move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2831 ZedPredictModal::toggle(
2832 workspace,
2833 workspace.user_store().clone(),
2834 workspace.client().clone(),
2835 window,
2836 cx,
2837 )
2838 },
2839 );
2840
2841 workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2842 update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2843 settings
2844 .project
2845 .all_languages
2846 .features
2847 .get_or_insert_default()
2848 .edit_prediction_provider = Some(EditPredictionProvider::None)
2849 });
2850 });
2851 })
2852 .detach();
2853}
2854
2855fn feature_gate_predict_edits_actions(cx: &mut App) {
2856 let rate_completion_action_types = [TypeId::of::<RateCompletions>()];
2857 let reset_onboarding_action_types = [TypeId::of::<ResetOnboarding>()];
2858 let zeta_all_action_types = [
2859 TypeId::of::<RateCompletions>(),
2860 TypeId::of::<ResetOnboarding>(),
2861 zed_actions::OpenZedPredictOnboarding.type_id(),
2862 TypeId::of::<ClearHistory>(),
2863 TypeId::of::<ThumbsUpActivePrediction>(),
2864 TypeId::of::<ThumbsDownActivePrediction>(),
2865 TypeId::of::<NextEdit>(),
2866 TypeId::of::<PreviousEdit>(),
2867 ];
2868
2869 CommandPaletteFilter::update_global(cx, |filter, _cx| {
2870 filter.hide_action_types(&rate_completion_action_types);
2871 filter.hide_action_types(&reset_onboarding_action_types);
2872 filter.hide_action_types(&[zed_actions::OpenZedPredictOnboarding.type_id()]);
2873 });
2874
2875 cx.observe_global::<SettingsStore>(move |cx| {
2876 let is_ai_disabled = DisableAiSettings::get_global(cx).disable_ai;
2877 let has_feature_flag = cx.has_flag::<PredictEditsRateCompletionsFeatureFlag>();
2878
2879 CommandPaletteFilter::update_global(cx, |filter, _cx| {
2880 if is_ai_disabled {
2881 filter.hide_action_types(&zeta_all_action_types);
2882 } else if has_feature_flag {
2883 filter.show_action_types(&rate_completion_action_types);
2884 } else {
2885 filter.hide_action_types(&rate_completion_action_types);
2886 }
2887 });
2888 })
2889 .detach();
2890
2891 cx.observe_flag::<PredictEditsRateCompletionsFeatureFlag, _>(move |is_enabled, cx| {
2892 if !DisableAiSettings::get_global(cx).disable_ai {
2893 if is_enabled {
2894 CommandPaletteFilter::update_global(cx, |filter, _cx| {
2895 filter.show_action_types(&rate_completion_action_types);
2896 });
2897 } else {
2898 CommandPaletteFilter::update_global(cx, |filter, _cx| {
2899 filter.hide_action_types(&rate_completion_action_types);
2900 });
2901 }
2902 }
2903 })
2904 .detach();
2905}
2906
2907#[cfg(test)]
2908mod tests {
2909 use std::{path::Path, sync::Arc};
2910
2911 use client::UserStore;
2912 use clock::FakeSystemClock;
2913 use cloud_llm_client::{
2914 EditPredictionRejectReason, EditPredictionRejection, RejectEditPredictionsBody,
2915 };
2916 use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
2917 use futures::{
2918 AsyncReadExt, StreamExt,
2919 channel::{mpsc, oneshot},
2920 };
2921 use gpui::{
2922 Entity, TestAppContext,
2923 http_client::{FakeHttpClient, Response},
2924 prelude::*,
2925 };
2926 use indoc::indoc;
2927 use language::OffsetRangeExt as _;
2928 use open_ai::Usage;
2929 use pretty_assertions::{assert_eq, assert_matches};
2930 use project::{FakeFs, Project};
2931 use serde_json::json;
2932 use settings::SettingsStore;
2933 use util::path;
2934 use uuid::Uuid;
2935
2936 use crate::{BufferEditPrediction, Zeta};
2937
2938 #[gpui::test]
2939 async fn test_current_state(cx: &mut TestAppContext) {
2940 let (zeta, mut requests) = init_test(cx);
2941 let fs = FakeFs::new(cx.executor());
2942 fs.insert_tree(
2943 "/root",
2944 json!({
2945 "1.txt": "Hello!\nHow\nBye\n",
2946 "2.txt": "Hola!\nComo\nAdios\n"
2947 }),
2948 )
2949 .await;
2950 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2951
2952 zeta.update(cx, |zeta, cx| {
2953 zeta.register_project(&project, cx);
2954 });
2955
2956 let buffer1 = project
2957 .update(cx, |project, cx| {
2958 let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
2959 project.open_buffer(path, cx)
2960 })
2961 .await
2962 .unwrap();
2963 let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
2964 let position = snapshot1.anchor_before(language::Point::new(1, 3));
2965
2966 // Prediction for current file
2967
2968 zeta.update(cx, |zeta, cx| {
2969 zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
2970 });
2971 let (_request, respond_tx) = requests.predict.next().await.unwrap();
2972
2973 respond_tx
2974 .send(model_response(indoc! {r"
2975 --- a/root/1.txt
2976 +++ b/root/1.txt
2977 @@ ... @@
2978 Hello!
2979 -How
2980 +How are you?
2981 Bye
2982 "}))
2983 .unwrap();
2984
2985 cx.run_until_parked();
2986
2987 zeta.read_with(cx, |zeta, cx| {
2988 let prediction = zeta
2989 .current_prediction_for_buffer(&buffer1, &project, cx)
2990 .unwrap();
2991 assert_matches!(prediction, BufferEditPrediction::Local { .. });
2992 });
2993
2994 // Context refresh
2995 let refresh_task = zeta.update(cx, |zeta, cx| {
2996 zeta.refresh_context(project.clone(), buffer1.clone(), position, cx)
2997 });
2998 let (_request, respond_tx) = requests.predict.next().await.unwrap();
2999 respond_tx
3000 .send(open_ai::Response {
3001 id: Uuid::new_v4().to_string(),
3002 object: "response".into(),
3003 created: 0,
3004 model: "model".into(),
3005 choices: vec![open_ai::Choice {
3006 index: 0,
3007 message: open_ai::RequestMessage::Assistant {
3008 content: None,
3009 tool_calls: vec![open_ai::ToolCall {
3010 id: "search".into(),
3011 content: open_ai::ToolCallContent::Function {
3012 function: open_ai::FunctionContent {
3013 name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME
3014 .to_string(),
3015 arguments: serde_json::to_string(&SearchToolInput {
3016 queries: Box::new([SearchToolQuery {
3017 glob: "root/2.txt".to_string(),
3018 syntax_node: vec![],
3019 content: Some(".".into()),
3020 }]),
3021 })
3022 .unwrap(),
3023 },
3024 },
3025 }],
3026 },
3027 finish_reason: None,
3028 }],
3029 usage: Usage {
3030 prompt_tokens: 0,
3031 completion_tokens: 0,
3032 total_tokens: 0,
3033 },
3034 })
3035 .unwrap();
3036 refresh_task.await.unwrap();
3037
3038 zeta.update(cx, |zeta, cx| {
3039 zeta.reject_current_prediction(EditPredictionRejectReason::Discarded, &project, cx);
3040 });
3041
3042 // Prediction for another file
3043 zeta.update(cx, |zeta, cx| {
3044 zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
3045 });
3046 let (_request, respond_tx) = requests.predict.next().await.unwrap();
3047 respond_tx
3048 .send(model_response(indoc! {r#"
3049 --- a/root/2.txt
3050 +++ b/root/2.txt
3051 Hola!
3052 -Como
3053 +Como estas?
3054 Adios
3055 "#}))
3056 .unwrap();
3057 cx.run_until_parked();
3058
3059 zeta.read_with(cx, |zeta, cx| {
3060 let prediction = zeta
3061 .current_prediction_for_buffer(&buffer1, &project, cx)
3062 .unwrap();
3063 assert_matches!(
3064 prediction,
3065 BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
3066 );
3067 });
3068
3069 let buffer2 = project
3070 .update(cx, |project, cx| {
3071 let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
3072 project.open_buffer(path, cx)
3073 })
3074 .await
3075 .unwrap();
3076
3077 zeta.read_with(cx, |zeta, cx| {
3078 let prediction = zeta
3079 .current_prediction_for_buffer(&buffer2, &project, cx)
3080 .unwrap();
3081 assert_matches!(prediction, BufferEditPrediction::Local { .. });
3082 });
3083 }
3084
3085 #[gpui::test]
3086 async fn test_simple_request(cx: &mut TestAppContext) {
3087 let (zeta, mut requests) = init_test(cx);
3088 let fs = FakeFs::new(cx.executor());
3089 fs.insert_tree(
3090 "/root",
3091 json!({
3092 "foo.md": "Hello!\nHow\nBye\n"
3093 }),
3094 )
3095 .await;
3096 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3097
3098 let buffer = project
3099 .update(cx, |project, cx| {
3100 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3101 project.open_buffer(path, cx)
3102 })
3103 .await
3104 .unwrap();
3105 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3106 let position = snapshot.anchor_before(language::Point::new(1, 3));
3107
3108 let prediction_task = zeta.update(cx, |zeta, cx| {
3109 zeta.request_prediction(&project, &buffer, position, Default::default(), cx)
3110 });
3111
3112 let (_, respond_tx) = requests.predict.next().await.unwrap();
3113
3114 // TODO Put back when we have a structured request again
3115 // assert_eq!(
3116 // request.excerpt_path.as_ref(),
3117 // Path::new(path!("root/foo.md"))
3118 // );
3119 // assert_eq!(
3120 // request.cursor_point,
3121 // Point {
3122 // line: Line(1),
3123 // column: 3
3124 // }
3125 // );
3126
3127 respond_tx
3128 .send(model_response(indoc! { r"
3129 --- a/root/foo.md
3130 +++ b/root/foo.md
3131 @@ ... @@
3132 Hello!
3133 -How
3134 +How are you?
3135 Bye
3136 "}))
3137 .unwrap();
3138
3139 let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
3140
3141 assert_eq!(prediction.edits.len(), 1);
3142 assert_eq!(
3143 prediction.edits[0].0.to_point(&snapshot).start,
3144 language::Point::new(1, 3)
3145 );
3146 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
3147 }
3148
3149 #[gpui::test]
3150 async fn test_request_events(cx: &mut TestAppContext) {
3151 let (zeta, mut requests) = init_test(cx);
3152 let fs = FakeFs::new(cx.executor());
3153 fs.insert_tree(
3154 "/root",
3155 json!({
3156 "foo.md": "Hello!\n\nBye\n"
3157 }),
3158 )
3159 .await;
3160 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3161
3162 let buffer = project
3163 .update(cx, |project, cx| {
3164 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3165 project.open_buffer(path, cx)
3166 })
3167 .await
3168 .unwrap();
3169
3170 zeta.update(cx, |zeta, cx| {
3171 zeta.register_buffer(&buffer, &project, cx);
3172 });
3173
3174 buffer.update(cx, |buffer, cx| {
3175 buffer.edit(vec![(7..7, "How")], None, cx);
3176 });
3177
3178 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3179 let position = snapshot.anchor_before(language::Point::new(1, 3));
3180
3181 let prediction_task = zeta.update(cx, |zeta, cx| {
3182 zeta.request_prediction(&project, &buffer, position, Default::default(), cx)
3183 });
3184
3185 let (request, respond_tx) = requests.predict.next().await.unwrap();
3186
3187 let prompt = prompt_from_request(&request);
3188 assert!(
3189 prompt.contains(indoc! {"
3190 --- a/root/foo.md
3191 +++ b/root/foo.md
3192 @@ -1,3 +1,3 @@
3193 Hello!
3194 -
3195 +How
3196 Bye
3197 "}),
3198 "{prompt}"
3199 );
3200
3201 respond_tx
3202 .send(model_response(indoc! {r#"
3203 --- a/root/foo.md
3204 +++ b/root/foo.md
3205 @@ ... @@
3206 Hello!
3207 -How
3208 +How are you?
3209 Bye
3210 "#}))
3211 .unwrap();
3212
3213 let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
3214
3215 assert_eq!(prediction.edits.len(), 1);
3216 assert_eq!(
3217 prediction.edits[0].0.to_point(&snapshot).start,
3218 language::Point::new(1, 3)
3219 );
3220 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
3221 }
3222
3223 #[gpui::test]
3224 async fn test_empty_prediction(cx: &mut TestAppContext) {
3225 let (zeta, mut requests) = init_test(cx);
3226 let fs = FakeFs::new(cx.executor());
3227 fs.insert_tree(
3228 "/root",
3229 json!({
3230 "foo.md": "Hello!\nHow\nBye\n"
3231 }),
3232 )
3233 .await;
3234 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3235
3236 let buffer = project
3237 .update(cx, |project, cx| {
3238 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3239 project.open_buffer(path, cx)
3240 })
3241 .await
3242 .unwrap();
3243 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3244 let position = snapshot.anchor_before(language::Point::new(1, 3));
3245
3246 zeta.update(cx, |zeta, cx| {
3247 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3248 });
3249
3250 const NO_OP_DIFF: &str = indoc! { r"
3251 --- a/root/foo.md
3252 +++ b/root/foo.md
3253 @@ ... @@
3254 Hello!
3255 -How
3256 +How
3257 Bye
3258 "};
3259
3260 let (_, respond_tx) = requests.predict.next().await.unwrap();
3261 let response = model_response(NO_OP_DIFF);
3262 let id = response.id.clone();
3263 respond_tx.send(response).unwrap();
3264
3265 cx.run_until_parked();
3266
3267 zeta.read_with(cx, |zeta, cx| {
3268 assert!(
3269 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3270 .is_none()
3271 );
3272 });
3273
3274 // prediction is reported as rejected
3275 let (reject_request, _) = requests.reject.next().await.unwrap();
3276
3277 assert_eq!(
3278 &reject_request.rejections,
3279 &[EditPredictionRejection {
3280 request_id: id,
3281 reason: EditPredictionRejectReason::Empty,
3282 was_shown: false
3283 }]
3284 );
3285 }
3286
3287 #[gpui::test]
3288 async fn test_interpolated_empty(cx: &mut TestAppContext) {
3289 let (zeta, mut requests) = init_test(cx);
3290 let fs = FakeFs::new(cx.executor());
3291 fs.insert_tree(
3292 "/root",
3293 json!({
3294 "foo.md": "Hello!\nHow\nBye\n"
3295 }),
3296 )
3297 .await;
3298 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3299
3300 let buffer = project
3301 .update(cx, |project, cx| {
3302 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3303 project.open_buffer(path, cx)
3304 })
3305 .await
3306 .unwrap();
3307 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3308 let position = snapshot.anchor_before(language::Point::new(1, 3));
3309
3310 zeta.update(cx, |zeta, cx| {
3311 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3312 });
3313
3314 let (_, respond_tx) = requests.predict.next().await.unwrap();
3315
3316 buffer.update(cx, |buffer, cx| {
3317 buffer.set_text("Hello!\nHow are you?\nBye", cx);
3318 });
3319
3320 let response = model_response(SIMPLE_DIFF);
3321 let id = response.id.clone();
3322 respond_tx.send(response).unwrap();
3323
3324 cx.run_until_parked();
3325
3326 zeta.read_with(cx, |zeta, cx| {
3327 assert!(
3328 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3329 .is_none()
3330 );
3331 });
3332
3333 // prediction is reported as rejected
3334 let (reject_request, _) = requests.reject.next().await.unwrap();
3335
3336 assert_eq!(
3337 &reject_request.rejections,
3338 &[EditPredictionRejection {
3339 request_id: id,
3340 reason: EditPredictionRejectReason::InterpolatedEmpty,
3341 was_shown: false
3342 }]
3343 );
3344 }
3345
3346 const SIMPLE_DIFF: &str = indoc! { r"
3347 --- a/root/foo.md
3348 +++ b/root/foo.md
3349 @@ ... @@
3350 Hello!
3351 -How
3352 +How are you?
3353 Bye
3354 "};
3355
3356 #[gpui::test]
3357 async fn test_replace_current(cx: &mut TestAppContext) {
3358 let (zeta, mut requests) = init_test(cx);
3359 let fs = FakeFs::new(cx.executor());
3360 fs.insert_tree(
3361 "/root",
3362 json!({
3363 "foo.md": "Hello!\nHow\nBye\n"
3364 }),
3365 )
3366 .await;
3367 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3368
3369 let buffer = project
3370 .update(cx, |project, cx| {
3371 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3372 project.open_buffer(path, cx)
3373 })
3374 .await
3375 .unwrap();
3376 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3377 let position = snapshot.anchor_before(language::Point::new(1, 3));
3378
3379 zeta.update(cx, |zeta, cx| {
3380 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3381 });
3382
3383 let (_, respond_tx) = requests.predict.next().await.unwrap();
3384 let first_response = model_response(SIMPLE_DIFF);
3385 let first_id = first_response.id.clone();
3386 respond_tx.send(first_response).unwrap();
3387
3388 cx.run_until_parked();
3389
3390 zeta.read_with(cx, |zeta, cx| {
3391 assert_eq!(
3392 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3393 .unwrap()
3394 .id
3395 .0,
3396 first_id
3397 );
3398 });
3399
3400 // a second request is triggered
3401 zeta.update(cx, |zeta, cx| {
3402 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3403 });
3404
3405 let (_, respond_tx) = requests.predict.next().await.unwrap();
3406 let second_response = model_response(SIMPLE_DIFF);
3407 let second_id = second_response.id.clone();
3408 respond_tx.send(second_response).unwrap();
3409
3410 cx.run_until_parked();
3411
3412 zeta.read_with(cx, |zeta, cx| {
3413 // second replaces first
3414 assert_eq!(
3415 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3416 .unwrap()
3417 .id
3418 .0,
3419 second_id
3420 );
3421 });
3422
3423 // first is reported as replaced
3424 let (reject_request, _) = requests.reject.next().await.unwrap();
3425
3426 assert_eq!(
3427 &reject_request.rejections,
3428 &[EditPredictionRejection {
3429 request_id: first_id,
3430 reason: EditPredictionRejectReason::Replaced,
3431 was_shown: false
3432 }]
3433 );
3434 }
3435
3436 #[gpui::test]
3437 async fn test_current_preferred(cx: &mut TestAppContext) {
3438 let (zeta, mut requests) = init_test(cx);
3439 let fs = FakeFs::new(cx.executor());
3440 fs.insert_tree(
3441 "/root",
3442 json!({
3443 "foo.md": "Hello!\nHow\nBye\n"
3444 }),
3445 )
3446 .await;
3447 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3448
3449 let buffer = project
3450 .update(cx, |project, cx| {
3451 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3452 project.open_buffer(path, cx)
3453 })
3454 .await
3455 .unwrap();
3456 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3457 let position = snapshot.anchor_before(language::Point::new(1, 3));
3458
3459 zeta.update(cx, |zeta, cx| {
3460 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3461 });
3462
3463 let (_, respond_tx) = requests.predict.next().await.unwrap();
3464 let first_response = model_response(SIMPLE_DIFF);
3465 let first_id = first_response.id.clone();
3466 respond_tx.send(first_response).unwrap();
3467
3468 cx.run_until_parked();
3469
3470 zeta.read_with(cx, |zeta, cx| {
3471 assert_eq!(
3472 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3473 .unwrap()
3474 .id
3475 .0,
3476 first_id
3477 );
3478 });
3479
3480 // a second request is triggered
3481 zeta.update(cx, |zeta, cx| {
3482 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3483 });
3484
3485 let (_, respond_tx) = requests.predict.next().await.unwrap();
3486 // worse than current prediction
3487 let second_response = model_response(indoc! { r"
3488 --- a/root/foo.md
3489 +++ b/root/foo.md
3490 @@ ... @@
3491 Hello!
3492 -How
3493 +How are
3494 Bye
3495 "});
3496 let second_id = second_response.id.clone();
3497 respond_tx.send(second_response).unwrap();
3498
3499 cx.run_until_parked();
3500
3501 zeta.read_with(cx, |zeta, cx| {
3502 // first is preferred over second
3503 assert_eq!(
3504 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3505 .unwrap()
3506 .id
3507 .0,
3508 first_id
3509 );
3510 });
3511
3512 // second is reported as rejected
3513 let (reject_request, _) = requests.reject.next().await.unwrap();
3514
3515 assert_eq!(
3516 &reject_request.rejections,
3517 &[EditPredictionRejection {
3518 request_id: second_id,
3519 reason: EditPredictionRejectReason::CurrentPreferred,
3520 was_shown: false
3521 }]
3522 );
3523 }
3524
3525 #[gpui::test]
3526 async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
3527 let (zeta, mut requests) = init_test(cx);
3528 let fs = FakeFs::new(cx.executor());
3529 fs.insert_tree(
3530 "/root",
3531 json!({
3532 "foo.md": "Hello!\nHow\nBye\n"
3533 }),
3534 )
3535 .await;
3536 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3537
3538 let buffer = project
3539 .update(cx, |project, cx| {
3540 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3541 project.open_buffer(path, cx)
3542 })
3543 .await
3544 .unwrap();
3545 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3546 let position = snapshot.anchor_before(language::Point::new(1, 3));
3547
3548 zeta.update(cx, |zeta, cx| {
3549 // start two refresh tasks
3550 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3551
3552 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3553 });
3554
3555 let (_, respond_first) = requests.predict.next().await.unwrap();
3556 let (_, respond_second) = requests.predict.next().await.unwrap();
3557
3558 // wait for throttle
3559 cx.run_until_parked();
3560
3561 // second responds first
3562 let second_response = model_response(SIMPLE_DIFF);
3563 let second_id = second_response.id.clone();
3564 respond_second.send(second_response).unwrap();
3565
3566 cx.run_until_parked();
3567
3568 zeta.read_with(cx, |zeta, cx| {
3569 // current prediction is second
3570 assert_eq!(
3571 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3572 .unwrap()
3573 .id
3574 .0,
3575 second_id
3576 );
3577 });
3578
3579 let first_response = model_response(SIMPLE_DIFF);
3580 let first_id = first_response.id.clone();
3581 respond_first.send(first_response).unwrap();
3582
3583 cx.run_until_parked();
3584
3585 zeta.read_with(cx, |zeta, cx| {
3586 // current prediction is still second, since first was cancelled
3587 assert_eq!(
3588 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3589 .unwrap()
3590 .id
3591 .0,
3592 second_id
3593 );
3594 });
3595
3596 // first is reported as rejected
3597 let (reject_request, _) = requests.reject.next().await.unwrap();
3598
3599 cx.run_until_parked();
3600
3601 assert_eq!(
3602 &reject_request.rejections,
3603 &[EditPredictionRejection {
3604 request_id: first_id,
3605 reason: EditPredictionRejectReason::Canceled,
3606 was_shown: false
3607 }]
3608 );
3609 }
3610
3611 #[gpui::test]
3612 async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
3613 let (zeta, mut requests) = init_test(cx);
3614 let fs = FakeFs::new(cx.executor());
3615 fs.insert_tree(
3616 "/root",
3617 json!({
3618 "foo.md": "Hello!\nHow\nBye\n"
3619 }),
3620 )
3621 .await;
3622 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3623
3624 let buffer = project
3625 .update(cx, |project, cx| {
3626 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3627 project.open_buffer(path, cx)
3628 })
3629 .await
3630 .unwrap();
3631 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3632 let position = snapshot.anchor_before(language::Point::new(1, 3));
3633
3634 zeta.update(cx, |zeta, cx| {
3635 // start two refresh tasks
3636 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3637 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3638 });
3639
3640 // wait for throttle, so requests are sent
3641 cx.run_until_parked();
3642
3643 let (_, respond_first) = requests.predict.next().await.unwrap();
3644 let (_, respond_second) = requests.predict.next().await.unwrap();
3645
3646 zeta.update(cx, |zeta, cx| {
3647 // start a third request
3648 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3649
3650 // 2 are pending, so 2nd is cancelled
3651 assert_eq!(
3652 zeta.get_or_init_zeta_project(&project, cx)
3653 .cancelled_predictions
3654 .iter()
3655 .copied()
3656 .collect::<Vec<_>>(),
3657 [1]
3658 );
3659 });
3660
3661 // wait for throttle
3662 cx.run_until_parked();
3663
3664 let (_, respond_third) = requests.predict.next().await.unwrap();
3665
3666 let first_response = model_response(SIMPLE_DIFF);
3667 let first_id = first_response.id.clone();
3668 respond_first.send(first_response).unwrap();
3669
3670 cx.run_until_parked();
3671
3672 zeta.read_with(cx, |zeta, cx| {
3673 // current prediction is first
3674 assert_eq!(
3675 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3676 .unwrap()
3677 .id
3678 .0,
3679 first_id
3680 );
3681 });
3682
3683 let cancelled_response = model_response(SIMPLE_DIFF);
3684 let cancelled_id = cancelled_response.id.clone();
3685 respond_second.send(cancelled_response).unwrap();
3686
3687 cx.run_until_parked();
3688
3689 zeta.read_with(cx, |zeta, cx| {
3690 // current prediction is still first, since second was cancelled
3691 assert_eq!(
3692 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3693 .unwrap()
3694 .id
3695 .0,
3696 first_id
3697 );
3698 });
3699
3700 let third_response = model_response(SIMPLE_DIFF);
3701 let third_response_id = third_response.id.clone();
3702 respond_third.send(third_response).unwrap();
3703
3704 cx.run_until_parked();
3705
3706 zeta.read_with(cx, |zeta, cx| {
3707 // third completes and replaces first
3708 assert_eq!(
3709 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3710 .unwrap()
3711 .id
3712 .0,
3713 third_response_id
3714 );
3715 });
3716
3717 // second is reported as rejected
3718 let (reject_request, _) = requests.reject.next().await.unwrap();
3719
3720 cx.run_until_parked();
3721
3722 assert_eq!(
3723 &reject_request.rejections,
3724 &[
3725 EditPredictionRejection {
3726 request_id: cancelled_id,
3727 reason: EditPredictionRejectReason::Canceled,
3728 was_shown: false
3729 },
3730 EditPredictionRejection {
3731 request_id: first_id,
3732 reason: EditPredictionRejectReason::Replaced,
3733 was_shown: false
3734 }
3735 ]
3736 );
3737 }
3738
3739 // Skipped until we start including diagnostics in prompt
3740 // #[gpui::test]
3741 // async fn test_request_diagnostics(cx: &mut TestAppContext) {
3742 // let (zeta, mut req_rx) = init_test(cx);
3743 // let fs = FakeFs::new(cx.executor());
3744 // fs.insert_tree(
3745 // "/root",
3746 // json!({
3747 // "foo.md": "Hello!\nBye"
3748 // }),
3749 // )
3750 // .await;
3751 // let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3752
3753 // let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
3754 // let diagnostic = lsp::Diagnostic {
3755 // range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
3756 // severity: Some(lsp::DiagnosticSeverity::ERROR),
3757 // message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
3758 // ..Default::default()
3759 // };
3760
3761 // project.update(cx, |project, cx| {
3762 // project.lsp_store().update(cx, |lsp_store, cx| {
3763 // // Create some diagnostics
3764 // lsp_store
3765 // .update_diagnostics(
3766 // LanguageServerId(0),
3767 // lsp::PublishDiagnosticsParams {
3768 // uri: path_to_buffer_uri.clone(),
3769 // diagnostics: vec![diagnostic],
3770 // version: None,
3771 // },
3772 // None,
3773 // language::DiagnosticSourceKind::Pushed,
3774 // &[],
3775 // cx,
3776 // )
3777 // .unwrap();
3778 // });
3779 // });
3780
3781 // let buffer = project
3782 // .update(cx, |project, cx| {
3783 // let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3784 // project.open_buffer(path, cx)
3785 // })
3786 // .await
3787 // .unwrap();
3788
3789 // let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3790 // let position = snapshot.anchor_before(language::Point::new(0, 0));
3791
3792 // let _prediction_task = zeta.update(cx, |zeta, cx| {
3793 // zeta.request_prediction(&project, &buffer, position, cx)
3794 // });
3795
3796 // let (request, _respond_tx) = req_rx.next().await.unwrap();
3797
3798 // assert_eq!(request.diagnostic_groups.len(), 1);
3799 // let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
3800 // .unwrap();
3801 // // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
3802 // assert_eq!(
3803 // value,
3804 // json!({
3805 // "entries": [{
3806 // "range": {
3807 // "start": 8,
3808 // "end": 10
3809 // },
3810 // "diagnostic": {
3811 // "source": null,
3812 // "code": null,
3813 // "code_description": null,
3814 // "severity": 1,
3815 // "message": "\"Hello\" deprecated. Use \"Hi\" instead",
3816 // "markdown": null,
3817 // "group_id": 0,
3818 // "is_primary": true,
3819 // "is_disk_based": false,
3820 // "is_unnecessary": false,
3821 // "source_kind": "Pushed",
3822 // "data": null,
3823 // "underline": true
3824 // }
3825 // }],
3826 // "primary_ix": 0
3827 // })
3828 // );
3829 // }
3830
3831 fn model_response(text: &str) -> open_ai::Response {
3832 open_ai::Response {
3833 id: Uuid::new_v4().to_string(),
3834 object: "response".into(),
3835 created: 0,
3836 model: "model".into(),
3837 choices: vec![open_ai::Choice {
3838 index: 0,
3839 message: open_ai::RequestMessage::Assistant {
3840 content: Some(open_ai::MessageContent::Plain(text.to_string())),
3841 tool_calls: vec![],
3842 },
3843 finish_reason: None,
3844 }],
3845 usage: Usage {
3846 prompt_tokens: 0,
3847 completion_tokens: 0,
3848 total_tokens: 0,
3849 },
3850 }
3851 }
3852
3853 fn prompt_from_request(request: &open_ai::Request) -> &str {
3854 assert_eq!(request.messages.len(), 1);
3855 let open_ai::RequestMessage::User {
3856 content: open_ai::MessageContent::Plain(content),
3857 ..
3858 } = &request.messages[0]
3859 else {
3860 panic!(
3861 "Request does not have single user message of type Plain. {:#?}",
3862 request
3863 );
3864 };
3865 content
3866 }
3867
3868 struct RequestChannels {
3869 predict: mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender<open_ai::Response>)>,
3870 reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>,
3871 }
3872
3873 fn init_test(cx: &mut TestAppContext) -> (Entity<Zeta>, RequestChannels) {
3874 cx.update(move |cx| {
3875 let settings_store = SettingsStore::test(cx);
3876 cx.set_global(settings_store);
3877 zlog::init_test();
3878
3879 let (predict_req_tx, predict_req_rx) = mpsc::unbounded();
3880 let (reject_req_tx, reject_req_rx) = mpsc::unbounded();
3881
3882 let http_client = FakeHttpClient::create({
3883 move |req| {
3884 let uri = req.uri().path().to_string();
3885 let mut body = req.into_body();
3886 let predict_req_tx = predict_req_tx.clone();
3887 let reject_req_tx = reject_req_tx.clone();
3888 async move {
3889 let resp = match uri.as_str() {
3890 "/client/llm_tokens" => serde_json::to_string(&json!({
3891 "token": "test"
3892 }))
3893 .unwrap(),
3894 "/predict_edits/raw" => {
3895 let mut buf = Vec::new();
3896 body.read_to_end(&mut buf).await.ok();
3897 let req = serde_json::from_slice(&buf).unwrap();
3898
3899 let (res_tx, res_rx) = oneshot::channel();
3900 predict_req_tx.unbounded_send((req, res_tx)).unwrap();
3901 serde_json::to_string(&res_rx.await?).unwrap()
3902 }
3903 "/predict_edits/reject" => {
3904 let mut buf = Vec::new();
3905 body.read_to_end(&mut buf).await.ok();
3906 let req = serde_json::from_slice(&buf).unwrap();
3907
3908 let (res_tx, res_rx) = oneshot::channel();
3909 reject_req_tx.unbounded_send((req, res_tx)).unwrap();
3910 serde_json::to_string(&res_rx.await?).unwrap()
3911 }
3912 _ => {
3913 panic!("Unexpected path: {}", uri)
3914 }
3915 };
3916
3917 Ok(Response::builder().body(resp.into()).unwrap())
3918 }
3919 }
3920 });
3921
3922 let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
3923 client.cloud_client().set_credentials(1, "test".into());
3924
3925 language_model::init(client.clone(), cx);
3926
3927 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
3928 let zeta = Zeta::global(&client, &user_store, cx);
3929
3930 (
3931 zeta,
3932 RequestChannels {
3933 predict: predict_req_rx,
3934 reject: reject_req_rx,
3935 },
3936 )
3937 })
3938 }
3939}