1use anyhow::{Context as _, Result, anyhow, bail};
2use arrayvec::ArrayVec;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat, Signature};
5use cloud_llm_client::{
6 AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason,
7 EditPredictionRejection, MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST,
8 MINIMUM_REQUIRED_VERSION_HEADER_NAME, PredictEditsRequestTrigger, RejectEditPredictionsBody,
9 ZED_VERSION_HEADER_NAME,
10};
11use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
12use cloud_zeta2_prompt::{CURSOR_MARKER, DEFAULT_MAX_PROMPT_BYTES};
13use collections::{HashMap, HashSet};
14use command_palette_hooks::CommandPaletteFilter;
15use db::kvp::{Dismissable, KEY_VALUE_STORE};
16use edit_prediction_context::{
17 DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
18 EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionScoreOptions, Line,
19 SyntaxIndex, SyntaxIndexState,
20};
21use feature_flags::{FeatureFlag, FeatureFlagAppExt as _, PredictEditsRateCompletionsFeatureFlag};
22use futures::channel::{mpsc, oneshot};
23use futures::{AsyncReadExt as _, FutureExt as _, StreamExt as _};
24use gpui::{
25 App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
26 http_client::{self, AsyncBody, Method},
27 prelude::*,
28};
29use language::{
30 Anchor, Buffer, DiagnosticSet, File, LanguageServerId, Point, ToOffset as _, ToPoint,
31};
32use language::{BufferSnapshot, OffsetRangeExt};
33use language_model::{LlmApiToken, RefreshLlmTokenListener};
34use open_ai::FunctionDefinition;
35use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
36use release_channel::AppVersion;
37use semver::Version;
38use serde::de::DeserializeOwned;
39use settings::{EditPredictionProvider, Settings as _, SettingsStore, update_settings_file};
40use std::any::{Any as _, TypeId};
41use std::collections::{VecDeque, hash_map};
42use telemetry_events::EditPredictionRating;
43use workspace::Workspace;
44
45use std::ops::Range;
46use std::path::Path;
47use std::rc::Rc;
48use std::str::FromStr as _;
49use std::sync::{Arc, LazyLock};
50use std::time::{Duration, Instant};
51use std::{env, mem};
52use thiserror::Error;
53use util::rel_path::RelPathBuf;
54use util::{LogErrorFuture, RangeExt as _, ResultExt as _, TryFutureExt};
55use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
56
57pub mod assemble_excerpts;
58mod license_detection;
59mod onboarding_modal;
60mod prediction;
61mod provider;
62mod rate_prediction_modal;
63pub mod retrieval_search;
64pub mod sweep_ai;
65pub mod udiff;
66mod xml_edits;
67pub mod zeta1;
68
69#[cfg(test)]
70mod zeta_tests;
71
72use crate::assemble_excerpts::assemble_excerpts;
73use crate::license_detection::LicenseDetectionWatcher;
74use crate::onboarding_modal::ZedPredictModal;
75pub use crate::prediction::EditPrediction;
76pub use crate::prediction::EditPredictionId;
77pub use crate::prediction::EditPredictionInputs;
78use crate::prediction::EditPredictionResult;
79use crate::rate_prediction_modal::{
80 NextEdit, PreviousEdit, RatePredictionsModal, ThumbsDownActivePrediction,
81 ThumbsUpActivePrediction,
82};
83pub use crate::sweep_ai::SweepAi;
84use crate::zeta1::request_prediction_with_zeta1;
85pub use provider::ZetaEditPredictionProvider;
86
87actions!(
88 edit_prediction,
89 [
90 /// Resets the edit prediction onboarding state.
91 ResetOnboarding,
92 /// Opens the rate completions modal.
93 RateCompletions,
94 /// Clears the edit prediction history.
95 ClearHistory,
96 ]
97);
98
99/// Maximum number of events to track.
100const EVENT_COUNT_MAX: usize = 6;
101const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
102const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
103
104pub struct SweepFeatureFlag;
105
106impl FeatureFlag for SweepFeatureFlag {
107 const NAME: &str = "sweep-ai";
108}
109pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
110 max_bytes: 512,
111 min_bytes: 128,
112 target_before_cursor_over_total_bytes: 0.5,
113};
114
115pub const DEFAULT_CONTEXT_OPTIONS: ContextMode =
116 ContextMode::Agentic(DEFAULT_AGENTIC_CONTEXT_OPTIONS);
117
118pub const DEFAULT_AGENTIC_CONTEXT_OPTIONS: AgenticContextOptions = AgenticContextOptions {
119 excerpt: DEFAULT_EXCERPT_OPTIONS,
120};
121
122pub const DEFAULT_SYNTAX_CONTEXT_OPTIONS: EditPredictionContextOptions =
123 EditPredictionContextOptions {
124 use_imports: true,
125 max_retrieved_declarations: 0,
126 excerpt: DEFAULT_EXCERPT_OPTIONS,
127 score: EditPredictionScoreOptions {
128 omit_excerpt_overlaps: true,
129 },
130 };
131
132pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
133 context: DEFAULT_CONTEXT_OPTIONS,
134 max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
135 max_diagnostic_bytes: 2048,
136 prompt_format: PromptFormat::DEFAULT,
137 file_indexing_parallelism: 1,
138 buffer_change_grouping_interval: Duration::from_secs(1),
139};
140
141static USE_OLLAMA: LazyLock<bool> =
142 LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
143static CONTEXT_RETRIEVAL_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
144 env::var("ZED_ZETA2_CONTEXT_MODEL").unwrap_or(if *USE_OLLAMA {
145 "qwen3-coder:30b".to_string()
146 } else {
147 "yqvev8r3".to_string()
148 })
149});
150static EDIT_PREDICTIONS_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
151 match env::var("ZED_ZETA2_MODEL").as_deref() {
152 Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten
153 Ok(model) => model,
154 Err(_) if *USE_OLLAMA => "qwen3-coder:30b",
155 Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten
156 }
157 .to_string()
158});
159static PREDICT_EDITS_URL: LazyLock<Option<String>> = LazyLock::new(|| {
160 env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| {
161 if *USE_OLLAMA {
162 Some("http://localhost:11434/v1/chat/completions".into())
163 } else {
164 None
165 }
166 })
167});
168
169pub struct Zeta2FeatureFlag;
170
171impl FeatureFlag for Zeta2FeatureFlag {
172 const NAME: &'static str = "zeta2";
173
174 fn enabled_for_staff() -> bool {
175 true
176 }
177}
178
179#[derive(Clone)]
180struct ZetaGlobal(Entity<Zeta>);
181
182impl Global for ZetaGlobal {}
183
184pub struct Zeta {
185 client: Arc<Client>,
186 user_store: Entity<UserStore>,
187 llm_token: LlmApiToken,
188 _llm_token_subscription: Subscription,
189 projects: HashMap<EntityId, ZetaProject>,
190 options: ZetaOptions,
191 update_required: bool,
192 debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
193 #[cfg(feature = "eval-support")]
194 eval_cache: Option<Arc<dyn EvalCache>>,
195 edit_prediction_model: ZetaEditPredictionModel,
196 pub sweep_ai: SweepAi,
197 data_collection_choice: DataCollectionChoice,
198 rejected_predictions: Vec<EditPredictionRejection>,
199 reject_predictions_tx: mpsc::UnboundedSender<()>,
200 reject_predictions_debounce_task: Option<Task<()>>,
201 shown_predictions: VecDeque<EditPrediction>,
202 rated_predictions: HashSet<EditPredictionId>,
203}
204
205#[derive(Copy, Clone, Default, PartialEq, Eq)]
206pub enum ZetaEditPredictionModel {
207 #[default]
208 Zeta1,
209 Zeta2,
210 Sweep,
211}
212
213#[derive(Debug, Clone, PartialEq)]
214pub struct ZetaOptions {
215 pub context: ContextMode,
216 pub max_prompt_bytes: usize,
217 pub max_diagnostic_bytes: usize,
218 pub prompt_format: predict_edits_v3::PromptFormat,
219 pub file_indexing_parallelism: usize,
220 pub buffer_change_grouping_interval: Duration,
221}
222
223#[derive(Debug, Clone, PartialEq)]
224pub enum ContextMode {
225 Agentic(AgenticContextOptions),
226 Syntax(EditPredictionContextOptions),
227}
228
229#[derive(Debug, Clone, PartialEq)]
230pub struct AgenticContextOptions {
231 pub excerpt: EditPredictionExcerptOptions,
232}
233
234impl ContextMode {
235 pub fn excerpt(&self) -> &EditPredictionExcerptOptions {
236 match self {
237 ContextMode::Agentic(options) => &options.excerpt,
238 ContextMode::Syntax(options) => &options.excerpt,
239 }
240 }
241}
242
243#[derive(Debug)]
244pub enum ZetaDebugInfo {
245 ContextRetrievalStarted(ZetaContextRetrievalStartedDebugInfo),
246 SearchQueriesGenerated(ZetaSearchQueryDebugInfo),
247 SearchQueriesExecuted(ZetaContextRetrievalDebugInfo),
248 ContextRetrievalFinished(ZetaContextRetrievalDebugInfo),
249 EditPredictionRequested(ZetaEditPredictionDebugInfo),
250}
251
252#[derive(Debug)]
253pub struct ZetaContextRetrievalStartedDebugInfo {
254 pub project: Entity<Project>,
255 pub timestamp: Instant,
256 pub search_prompt: String,
257}
258
259#[derive(Debug)]
260pub struct ZetaContextRetrievalDebugInfo {
261 pub project: Entity<Project>,
262 pub timestamp: Instant,
263}
264
265#[derive(Debug)]
266pub struct ZetaEditPredictionDebugInfo {
267 pub inputs: EditPredictionInputs,
268 pub retrieval_time: Duration,
269 pub buffer: WeakEntity<Buffer>,
270 pub position: language::Anchor,
271 pub local_prompt: Result<String, String>,
272 pub response_rx: oneshot::Receiver<(Result<open_ai::Response, String>, Duration)>,
273}
274
275#[derive(Debug)]
276pub struct ZetaSearchQueryDebugInfo {
277 pub project: Entity<Project>,
278 pub timestamp: Instant,
279 pub search_queries: Vec<SearchToolQuery>,
280}
281
282pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
283
284struct ZetaProject {
285 syntax_index: Option<Entity<SyntaxIndex>>,
286 events: VecDeque<Arc<cloud_llm_client::predict_edits_v3::Event>>,
287 last_event: Option<LastEvent>,
288 recent_paths: VecDeque<ProjectPath>,
289 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
290 current_prediction: Option<CurrentEditPrediction>,
291 next_pending_prediction_id: usize,
292 pending_predictions: ArrayVec<PendingPrediction, 2>,
293 last_prediction_refresh: Option<(EntityId, Instant)>,
294 cancelled_predictions: HashSet<usize>,
295 context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
296 refresh_context_task: Option<LogErrorFuture<Task<Result<()>>>>,
297 refresh_context_debounce_task: Option<Task<Option<()>>>,
298 refresh_context_timestamp: Option<Instant>,
299 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
300 _subscription: gpui::Subscription,
301}
302
303impl ZetaProject {
304 pub fn events(&self, cx: &App) -> Vec<Arc<cloud_llm_client::predict_edits_v3::Event>> {
305 self.events
306 .iter()
307 .cloned()
308 .chain(
309 self.last_event
310 .as_ref()
311 .and_then(|event| event.finalize(&self.license_detection_watchers, cx)),
312 )
313 .collect()
314 }
315
316 fn cancel_pending_prediction(
317 &mut self,
318 pending_prediction: PendingPrediction,
319 cx: &mut Context<Zeta>,
320 ) {
321 self.cancelled_predictions.insert(pending_prediction.id);
322
323 cx.spawn(async move |this, cx| {
324 let Some(prediction_id) = pending_prediction.task.await else {
325 return;
326 };
327
328 this.update(cx, |this, cx| {
329 this.reject_prediction(
330 prediction_id,
331 EditPredictionRejectReason::Canceled,
332 false,
333 cx,
334 );
335 })
336 .ok();
337 })
338 .detach()
339 }
340}
341
342#[derive(Debug, Clone)]
343struct CurrentEditPrediction {
344 pub requested_by: PredictionRequestedBy,
345 pub prediction: EditPrediction,
346 pub was_shown: bool,
347}
348
349impl CurrentEditPrediction {
350 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
351 let Some(new_edits) = self
352 .prediction
353 .interpolate(&self.prediction.buffer.read(cx))
354 else {
355 return false;
356 };
357
358 if self.prediction.buffer != old_prediction.prediction.buffer {
359 return true;
360 }
361
362 let Some(old_edits) = old_prediction
363 .prediction
364 .interpolate(&old_prediction.prediction.buffer.read(cx))
365 else {
366 return true;
367 };
368
369 let requested_by_buffer_id = self.requested_by.buffer_id();
370
371 // This reduces the occurrence of UI thrash from replacing edits
372 //
373 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
374 if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
375 && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
376 && old_edits.len() == 1
377 && new_edits.len() == 1
378 {
379 let (old_range, old_text) = &old_edits[0];
380 let (new_range, new_text) = &new_edits[0];
381 new_range == old_range && new_text.starts_with(old_text.as_ref())
382 } else {
383 true
384 }
385 }
386}
387
388#[derive(Debug, Clone)]
389enum PredictionRequestedBy {
390 DiagnosticsUpdate,
391 Buffer(EntityId),
392}
393
394impl PredictionRequestedBy {
395 pub fn buffer_id(&self) -> Option<EntityId> {
396 match self {
397 PredictionRequestedBy::DiagnosticsUpdate => None,
398 PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
399 }
400 }
401}
402
403#[derive(Debug)]
404struct PendingPrediction {
405 id: usize,
406 task: Task<Option<EditPredictionId>>,
407}
408
409/// A prediction from the perspective of a buffer.
410#[derive(Debug)]
411enum BufferEditPrediction<'a> {
412 Local { prediction: &'a EditPrediction },
413 Jump { prediction: &'a EditPrediction },
414}
415
416#[cfg(test)]
417impl std::ops::Deref for BufferEditPrediction<'_> {
418 type Target = EditPrediction;
419
420 fn deref(&self) -> &Self::Target {
421 match self {
422 BufferEditPrediction::Local { prediction } => prediction,
423 BufferEditPrediction::Jump { prediction } => prediction,
424 }
425 }
426}
427
428struct RegisteredBuffer {
429 snapshot: BufferSnapshot,
430 _subscriptions: [gpui::Subscription; 2],
431}
432
433struct LastEvent {
434 old_snapshot: BufferSnapshot,
435 new_snapshot: BufferSnapshot,
436 end_edit_anchor: Option<Anchor>,
437}
438
439impl LastEvent {
440 pub fn finalize(
441 &self,
442 license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
443 cx: &App,
444 ) -> Option<Arc<predict_edits_v3::Event>> {
445 let path = buffer_path_with_id_fallback(&self.new_snapshot, cx);
446 let old_path = buffer_path_with_id_fallback(&self.old_snapshot, cx);
447
448 let file = self.new_snapshot.file();
449 let old_file = self.old_snapshot.file();
450
451 let in_open_source_repo = [file, old_file].iter().all(|file| {
452 file.is_some_and(|file| {
453 license_detection_watchers
454 .get(&file.worktree_id(cx))
455 .is_some_and(|watcher| watcher.is_project_open_source())
456 })
457 });
458
459 let diff = language::unified_diff(&self.old_snapshot.text(), &self.new_snapshot.text());
460
461 if path == old_path && diff.is_empty() {
462 None
463 } else {
464 Some(Arc::new(predict_edits_v3::Event::BufferChange {
465 old_path,
466 path,
467 diff,
468 in_open_source_repo,
469 // TODO: Actually detect if this edit was predicted or not
470 predicted: false,
471 }))
472 }
473 }
474}
475
476fn buffer_path_with_id_fallback(snapshot: &BufferSnapshot, cx: &App) -> Arc<Path> {
477 if let Some(file) = snapshot.file() {
478 file.full_path(cx).into()
479 } else {
480 Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
481 }
482}
483
484impl Zeta {
485 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
486 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
487 }
488
489 pub fn global(
490 client: &Arc<Client>,
491 user_store: &Entity<UserStore>,
492 cx: &mut App,
493 ) -> Entity<Self> {
494 cx.try_global::<ZetaGlobal>()
495 .map(|global| global.0.clone())
496 .unwrap_or_else(|| {
497 let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
498 cx.set_global(ZetaGlobal(zeta.clone()));
499 zeta
500 })
501 }
502
503 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
504 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
505 let data_collection_choice = Self::load_data_collection_choice();
506
507 let (reject_tx, mut reject_rx) = mpsc::unbounded();
508 cx.spawn(async move |this, cx| {
509 while let Some(()) = reject_rx.next().await {
510 this.update(cx, |this, cx| this.flush_rejected_predictions(cx))?
511 .await
512 .log_err();
513 }
514 anyhow::Ok(())
515 })
516 .detach();
517
518 Self {
519 projects: HashMap::default(),
520 client,
521 user_store,
522 options: DEFAULT_OPTIONS,
523 llm_token: LlmApiToken::default(),
524 _llm_token_subscription: cx.subscribe(
525 &refresh_llm_token_listener,
526 |this, _listener, _event, cx| {
527 let client = this.client.clone();
528 let llm_token = this.llm_token.clone();
529 cx.spawn(async move |_this, _cx| {
530 llm_token.refresh(&client).await?;
531 anyhow::Ok(())
532 })
533 .detach_and_log_err(cx);
534 },
535 ),
536 update_required: false,
537 debug_tx: None,
538 #[cfg(feature = "eval-support")]
539 eval_cache: None,
540 edit_prediction_model: ZetaEditPredictionModel::Zeta2,
541 sweep_ai: SweepAi::new(cx),
542 data_collection_choice,
543 rejected_predictions: Vec::new(),
544 reject_predictions_debounce_task: None,
545 reject_predictions_tx: reject_tx,
546 rated_predictions: Default::default(),
547 shown_predictions: Default::default(),
548 }
549 }
550
551 pub fn set_edit_prediction_model(&mut self, model: ZetaEditPredictionModel) {
552 self.edit_prediction_model = model;
553 }
554
555 pub fn has_sweep_api_token(&self) -> bool {
556 self.sweep_ai
557 .api_token
558 .clone()
559 .now_or_never()
560 .flatten()
561 .is_some()
562 }
563
564 #[cfg(feature = "eval-support")]
565 pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
566 self.eval_cache = Some(cache);
567 }
568
569 pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<ZetaDebugInfo> {
570 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
571 self.debug_tx = Some(debug_watch_tx);
572 debug_watch_rx
573 }
574
575 pub fn options(&self) -> &ZetaOptions {
576 &self.options
577 }
578
579 pub fn set_options(&mut self, options: ZetaOptions) {
580 self.options = options;
581 }
582
583 pub fn clear_history(&mut self) {
584 for zeta_project in self.projects.values_mut() {
585 zeta_project.events.clear();
586 }
587 }
588
589 pub fn context_for_project(
590 &self,
591 project: &Entity<Project>,
592 ) -> impl Iterator<Item = (Entity<Buffer>, &[Range<Anchor>])> {
593 self.projects
594 .get(&project.entity_id())
595 .and_then(|project| {
596 Some(
597 project
598 .context
599 .as_ref()?
600 .iter()
601 .map(|(buffer, ranges)| (buffer.clone(), ranges.as_slice())),
602 )
603 })
604 .into_iter()
605 .flatten()
606 }
607
608 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
609 if self.edit_prediction_model == ZetaEditPredictionModel::Zeta2 {
610 self.user_store.read(cx).edit_prediction_usage()
611 } else {
612 None
613 }
614 }
615
616 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
617 self.get_or_init_zeta_project(project, cx);
618 }
619
620 pub fn register_buffer(
621 &mut self,
622 buffer: &Entity<Buffer>,
623 project: &Entity<Project>,
624 cx: &mut Context<Self>,
625 ) {
626 let zeta_project = self.get_or_init_zeta_project(project, cx);
627 Self::register_buffer_impl(zeta_project, buffer, project, cx);
628 }
629
630 fn get_or_init_zeta_project(
631 &mut self,
632 project: &Entity<Project>,
633 cx: &mut Context<Self>,
634 ) -> &mut ZetaProject {
635 self.projects
636 .entry(project.entity_id())
637 .or_insert_with(|| ZetaProject {
638 syntax_index: if let ContextMode::Syntax(_) = &self.options.context {
639 Some(cx.new(|cx| {
640 SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
641 }))
642 } else {
643 None
644 },
645 events: VecDeque::new(),
646 last_event: None,
647 recent_paths: VecDeque::new(),
648 registered_buffers: HashMap::default(),
649 current_prediction: None,
650 cancelled_predictions: HashSet::default(),
651 pending_predictions: ArrayVec::new(),
652 next_pending_prediction_id: 0,
653 last_prediction_refresh: None,
654 context: None,
655 refresh_context_task: None,
656 refresh_context_debounce_task: None,
657 refresh_context_timestamp: None,
658 license_detection_watchers: HashMap::default(),
659 _subscription: cx.subscribe(&project, Self::handle_project_event),
660 })
661 }
662
663 fn handle_project_event(
664 &mut self,
665 project: Entity<Project>,
666 event: &project::Event,
667 cx: &mut Context<Self>,
668 ) {
669 // TODO [zeta2] init with recent paths
670 match event {
671 project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
672 let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
673 return;
674 };
675 let path = project.read(cx).path_for_entry(*active_entry_id, cx);
676 if let Some(path) = path {
677 if let Some(ix) = zeta_project
678 .recent_paths
679 .iter()
680 .position(|probe| probe == &path)
681 {
682 zeta_project.recent_paths.remove(ix);
683 }
684 zeta_project.recent_paths.push_front(path);
685 }
686 }
687 project::Event::DiagnosticsUpdated { .. } => {
688 if cx.has_flag::<Zeta2FeatureFlag>() {
689 self.refresh_prediction_from_diagnostics(project, cx);
690 }
691 }
692 _ => (),
693 }
694 }
695
696 fn register_buffer_impl<'a>(
697 zeta_project: &'a mut ZetaProject,
698 buffer: &Entity<Buffer>,
699 project: &Entity<Project>,
700 cx: &mut Context<Self>,
701 ) -> &'a mut RegisteredBuffer {
702 let buffer_id = buffer.entity_id();
703
704 if let Some(file) = buffer.read(cx).file() {
705 let worktree_id = file.worktree_id(cx);
706 if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
707 zeta_project
708 .license_detection_watchers
709 .entry(worktree_id)
710 .or_insert_with(|| {
711 let project_entity_id = project.entity_id();
712 cx.observe_release(&worktree, move |this, _worktree, _cx| {
713 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
714 else {
715 return;
716 };
717 zeta_project.license_detection_watchers.remove(&worktree_id);
718 })
719 .detach();
720 Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
721 });
722 }
723 }
724
725 match zeta_project.registered_buffers.entry(buffer_id) {
726 hash_map::Entry::Occupied(entry) => entry.into_mut(),
727 hash_map::Entry::Vacant(entry) => {
728 let snapshot = buffer.read(cx).snapshot();
729 let project_entity_id = project.entity_id();
730 entry.insert(RegisteredBuffer {
731 snapshot,
732 _subscriptions: [
733 cx.subscribe(buffer, {
734 let project = project.downgrade();
735 move |this, buffer, event, cx| {
736 if let language::BufferEvent::Edited = event
737 && let Some(project) = project.upgrade()
738 {
739 this.report_changes_for_buffer(&buffer, &project, cx);
740 }
741 }
742 }),
743 cx.observe_release(buffer, move |this, _buffer, _cx| {
744 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
745 else {
746 return;
747 };
748 zeta_project.registered_buffers.remove(&buffer_id);
749 }),
750 ],
751 })
752 }
753 }
754 }
755
756 fn report_changes_for_buffer(
757 &mut self,
758 buffer: &Entity<Buffer>,
759 project: &Entity<Project>,
760 cx: &mut Context<Self>,
761 ) {
762 let project_state = self.get_or_init_zeta_project(project, cx);
763 let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
764
765 let new_snapshot = buffer.read(cx).snapshot();
766 if new_snapshot.version == registered_buffer.snapshot.version {
767 return;
768 }
769
770 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
771 let end_edit_anchor = new_snapshot
772 .anchored_edits_since::<Point>(&old_snapshot.version)
773 .last()
774 .map(|(_, range)| range.end);
775 let events = &mut project_state.events;
776
777 if let Some(LastEvent {
778 new_snapshot: last_new_snapshot,
779 end_edit_anchor: last_end_edit_anchor,
780 ..
781 }) = project_state.last_event.as_mut()
782 {
783 let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
784 == last_new_snapshot.remote_id()
785 && old_snapshot.version == last_new_snapshot.version;
786
787 let should_coalesce = is_next_snapshot_of_same_buffer
788 && end_edit_anchor
789 .as_ref()
790 .zip(last_end_edit_anchor.as_ref())
791 .is_some_and(|(a, b)| {
792 let a = a.to_point(&new_snapshot);
793 let b = b.to_point(&new_snapshot);
794 a.row.abs_diff(b.row) <= CHANGE_GROUPING_LINE_SPAN
795 });
796
797 if should_coalesce {
798 *last_end_edit_anchor = end_edit_anchor;
799 *last_new_snapshot = new_snapshot;
800 return;
801 }
802 }
803
804 if events.len() + 1 >= EVENT_COUNT_MAX {
805 events.pop_front();
806 }
807
808 if let Some(event) = project_state.last_event.take() {
809 events.extend(event.finalize(&project_state.license_detection_watchers, cx));
810 }
811
812 project_state.last_event = Some(LastEvent {
813 old_snapshot,
814 new_snapshot,
815 end_edit_anchor,
816 });
817 }
818
819 fn current_prediction_for_buffer(
820 &self,
821 buffer: &Entity<Buffer>,
822 project: &Entity<Project>,
823 cx: &App,
824 ) -> Option<BufferEditPrediction<'_>> {
825 let project_state = self.projects.get(&project.entity_id())?;
826
827 let CurrentEditPrediction {
828 requested_by,
829 prediction,
830 ..
831 } = project_state.current_prediction.as_ref()?;
832
833 if prediction.targets_buffer(buffer.read(cx)) {
834 Some(BufferEditPrediction::Local { prediction })
835 } else {
836 let show_jump = match requested_by {
837 PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
838 requested_by_buffer_id == &buffer.entity_id()
839 }
840 PredictionRequestedBy::DiagnosticsUpdate => true,
841 };
842
843 if show_jump {
844 Some(BufferEditPrediction::Jump { prediction })
845 } else {
846 None
847 }
848 }
849 }
850
851 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
852 match self.edit_prediction_model {
853 ZetaEditPredictionModel::Zeta1 | ZetaEditPredictionModel::Zeta2 => {}
854 ZetaEditPredictionModel::Sweep => return,
855 }
856
857 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
858 return;
859 };
860
861 let Some(prediction) = project_state.current_prediction.take() else {
862 return;
863 };
864 let request_id = prediction.prediction.id.to_string();
865 for pending_prediction in mem::take(&mut project_state.pending_predictions) {
866 project_state.cancel_pending_prediction(pending_prediction, cx);
867 }
868
869 let client = self.client.clone();
870 let llm_token = self.llm_token.clone();
871 let app_version = AppVersion::global(cx);
872 cx.spawn(async move |this, cx| {
873 let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") {
874 http_client::Url::parse(&predict_edits_url)?
875 } else {
876 client
877 .http_client()
878 .build_zed_llm_url("/predict_edits/accept", &[])?
879 };
880
881 let response = cx
882 .background_spawn(Self::send_api_request::<()>(
883 move |builder| {
884 let req = builder.uri(url.as_ref()).body(
885 serde_json::to_string(&AcceptEditPredictionBody {
886 request_id: request_id.clone(),
887 })?
888 .into(),
889 );
890 Ok(req?)
891 },
892 client,
893 llm_token,
894 app_version,
895 ))
896 .await;
897
898 Self::handle_api_response(&this, response, cx)?;
899 anyhow::Ok(())
900 })
901 .detach_and_log_err(cx);
902 }
903
904 fn flush_rejected_predictions(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
905 match self.edit_prediction_model {
906 ZetaEditPredictionModel::Zeta1 | ZetaEditPredictionModel::Zeta2 => {}
907 ZetaEditPredictionModel::Sweep => return Task::ready(anyhow::Ok(())),
908 }
909
910 let client = self.client.clone();
911 let llm_token = self.llm_token.clone();
912 let app_version = AppVersion::global(cx);
913 let last_rejection = self.rejected_predictions.last().cloned();
914 let Some(last_rejection) = last_rejection else {
915 return Task::ready(anyhow::Ok(()));
916 };
917
918 let body = serde_json::to_string(&RejectEditPredictionsBody {
919 rejections: self.rejected_predictions.clone(),
920 })
921 .ok();
922
923 cx.spawn(async move |this, cx| {
924 let url = client
925 .http_client()
926 .build_zed_llm_url("/predict_edits/reject", &[])?;
927
928 cx.background_spawn(Self::send_api_request::<()>(
929 move |builder| {
930 let req = builder.uri(url.as_ref()).body(body.clone().into());
931 Ok(req?)
932 },
933 client,
934 llm_token,
935 app_version,
936 ))
937 .await
938 .context("Failed to reject edit predictions")?;
939
940 this.update(cx, |this, _| {
941 if let Some(ix) = this
942 .rejected_predictions
943 .iter()
944 .position(|rejection| rejection.request_id == last_rejection.request_id)
945 {
946 this.rejected_predictions.drain(..ix + 1);
947 }
948 })
949 })
950 }
951
952 fn reject_current_prediction(
953 &mut self,
954 reason: EditPredictionRejectReason,
955 project: &Entity<Project>,
956 cx: &mut Context<Self>,
957 ) {
958 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
959 project_state.pending_predictions.clear();
960 if let Some(prediction) = project_state.current_prediction.take() {
961 self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown, cx);
962 }
963 };
964 }
965
966 fn did_show_current_prediction(&mut self, project: &Entity<Project>, _cx: &mut Context<Self>) {
967 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
968 if let Some(current_prediction) = project_state.current_prediction.as_mut() {
969 if !current_prediction.was_shown {
970 current_prediction.was_shown = true;
971 self.shown_predictions
972 .push_front(current_prediction.prediction.clone());
973 if self.shown_predictions.len() > 50 {
974 let completion = self.shown_predictions.pop_back().unwrap();
975 self.rated_predictions.remove(&completion.id);
976 }
977 }
978 }
979 }
980 }
981
982 fn reject_prediction(
983 &mut self,
984 prediction_id: EditPredictionId,
985 reason: EditPredictionRejectReason,
986 was_shown: bool,
987 cx: &mut Context<Self>,
988 ) {
989 self.rejected_predictions.push(EditPredictionRejection {
990 request_id: prediction_id.to_string(),
991 reason,
992 was_shown,
993 });
994
995 let reached_request_limit =
996 self.rejected_predictions.len() >= MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2;
997 let reject_tx = self.reject_predictions_tx.clone();
998 self.reject_predictions_debounce_task = Some(cx.spawn(async move |_this, cx| {
999 const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
1000 if !reached_request_limit {
1001 cx.background_executor()
1002 .timer(REJECT_REQUEST_DEBOUNCE)
1003 .await;
1004 }
1005 reject_tx.unbounded_send(()).log_err();
1006 }));
1007 }
1008
1009 fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1010 self.projects
1011 .get(&project.entity_id())
1012 .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1013 }
1014
1015 pub fn refresh_prediction_from_buffer(
1016 &mut self,
1017 project: Entity<Project>,
1018 buffer: Entity<Buffer>,
1019 position: language::Anchor,
1020 cx: &mut Context<Self>,
1021 ) {
1022 self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
1023 let Some(request_task) = this
1024 .update(cx, |this, cx| {
1025 this.request_prediction(
1026 &project,
1027 &buffer,
1028 position,
1029 PredictEditsRequestTrigger::Other,
1030 cx,
1031 )
1032 })
1033 .log_err()
1034 else {
1035 return Task::ready(anyhow::Ok(None));
1036 };
1037
1038 cx.spawn(async move |_cx| {
1039 request_task.await.map(|prediction_result| {
1040 prediction_result.map(|prediction_result| {
1041 (
1042 prediction_result,
1043 PredictionRequestedBy::Buffer(buffer.entity_id()),
1044 )
1045 })
1046 })
1047 })
1048 })
1049 }
1050
1051 pub fn refresh_prediction_from_diagnostics(
1052 &mut self,
1053 project: Entity<Project>,
1054 cx: &mut Context<Self>,
1055 ) {
1056 let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1057 return;
1058 };
1059
1060 // Prefer predictions from buffer
1061 if zeta_project.current_prediction.is_some() {
1062 return;
1063 };
1064
1065 self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
1066 let Some(open_buffer_task) = project
1067 .update(cx, |project, cx| {
1068 project
1069 .active_entry()
1070 .and_then(|entry| project.path_for_entry(entry, cx))
1071 .map(|path| project.open_buffer(path, cx))
1072 })
1073 .log_err()
1074 .flatten()
1075 else {
1076 return Task::ready(anyhow::Ok(None));
1077 };
1078
1079 cx.spawn(async move |cx| {
1080 let active_buffer = open_buffer_task.await?;
1081 let snapshot = active_buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
1082
1083 let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1084 active_buffer,
1085 &snapshot,
1086 Default::default(),
1087 Default::default(),
1088 &project,
1089 cx,
1090 )
1091 .await?
1092 else {
1093 return anyhow::Ok(None);
1094 };
1095
1096 let Some(prediction_result) = this
1097 .update(cx, |this, cx| {
1098 this.request_prediction(
1099 &project,
1100 &jump_buffer,
1101 jump_position,
1102 PredictEditsRequestTrigger::Diagnostics,
1103 cx,
1104 )
1105 })?
1106 .await?
1107 else {
1108 return anyhow::Ok(None);
1109 };
1110
1111 this.update(cx, |this, cx| {
1112 Some((
1113 if this
1114 .get_or_init_zeta_project(&project, cx)
1115 .current_prediction
1116 .is_none()
1117 {
1118 prediction_result
1119 } else {
1120 EditPredictionResult {
1121 id: prediction_result.id,
1122 prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1123 }
1124 },
1125 PredictionRequestedBy::DiagnosticsUpdate,
1126 ))
1127 })
1128 })
1129 });
1130 }
1131
1132 #[cfg(not(test))]
1133 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1134 #[cfg(test)]
1135 pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
1136
1137 fn queue_prediction_refresh(
1138 &mut self,
1139 project: Entity<Project>,
1140 throttle_entity: EntityId,
1141 cx: &mut Context<Self>,
1142 do_refresh: impl FnOnce(
1143 WeakEntity<Self>,
1144 &mut AsyncApp,
1145 )
1146 -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1147 + 'static,
1148 ) {
1149 let zeta_project = self.get_or_init_zeta_project(&project, cx);
1150 let pending_prediction_id = zeta_project.next_pending_prediction_id;
1151 zeta_project.next_pending_prediction_id += 1;
1152 let last_request = zeta_project.last_prediction_refresh;
1153
1154 let task = cx.spawn(async move |this, cx| {
1155 if let Some((last_entity, last_timestamp)) = last_request
1156 && throttle_entity == last_entity
1157 && let Some(timeout) =
1158 (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
1159 {
1160 cx.background_executor().timer(timeout).await;
1161 }
1162
1163 // If this task was cancelled before the throttle timeout expired,
1164 // do not perform a request.
1165 let mut is_cancelled = true;
1166 this.update(cx, |this, cx| {
1167 let project_state = this.get_or_init_zeta_project(&project, cx);
1168 if !project_state
1169 .cancelled_predictions
1170 .remove(&pending_prediction_id)
1171 {
1172 project_state.last_prediction_refresh = Some((throttle_entity, Instant::now()));
1173 is_cancelled = false;
1174 }
1175 })
1176 .ok();
1177 if is_cancelled {
1178 return None;
1179 }
1180
1181 let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
1182 let new_prediction_id = new_prediction_result
1183 .as_ref()
1184 .map(|(prediction, _)| prediction.id.clone());
1185
1186 // When a prediction completes, remove it from the pending list, and cancel
1187 // any pending predictions that were enqueued before it.
1188 this.update(cx, |this, cx| {
1189 let zeta_project = this.get_or_init_zeta_project(&project, cx);
1190
1191 let is_cancelled = zeta_project
1192 .cancelled_predictions
1193 .remove(&pending_prediction_id);
1194
1195 let new_current_prediction = if !is_cancelled
1196 && let Some((prediction_result, requested_by)) = new_prediction_result
1197 {
1198 match prediction_result.prediction {
1199 Ok(prediction) => {
1200 let new_prediction = CurrentEditPrediction {
1201 requested_by,
1202 prediction,
1203 was_shown: false,
1204 };
1205
1206 if let Some(current_prediction) =
1207 zeta_project.current_prediction.as_ref()
1208 {
1209 if new_prediction.should_replace_prediction(¤t_prediction, cx)
1210 {
1211 this.reject_current_prediction(
1212 EditPredictionRejectReason::Replaced,
1213 &project,
1214 cx,
1215 );
1216
1217 Some(new_prediction)
1218 } else {
1219 this.reject_prediction(
1220 new_prediction.prediction.id,
1221 EditPredictionRejectReason::CurrentPreferred,
1222 false,
1223 cx,
1224 );
1225 None
1226 }
1227 } else {
1228 Some(new_prediction)
1229 }
1230 }
1231 Err(reject_reason) => {
1232 this.reject_prediction(prediction_result.id, reject_reason, false, cx);
1233 None
1234 }
1235 }
1236 } else {
1237 None
1238 };
1239
1240 let zeta_project = this.get_or_init_zeta_project(&project, cx);
1241
1242 if let Some(new_prediction) = new_current_prediction {
1243 zeta_project.current_prediction = Some(new_prediction);
1244 }
1245
1246 let mut pending_predictions = mem::take(&mut zeta_project.pending_predictions);
1247 for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1248 if pending_prediction.id == pending_prediction_id {
1249 pending_predictions.remove(ix);
1250 for pending_prediction in pending_predictions.drain(0..ix) {
1251 zeta_project.cancel_pending_prediction(pending_prediction, cx)
1252 }
1253 break;
1254 }
1255 }
1256 this.get_or_init_zeta_project(&project, cx)
1257 .pending_predictions = pending_predictions;
1258 cx.notify();
1259 })
1260 .ok();
1261
1262 new_prediction_id
1263 });
1264
1265 if zeta_project.pending_predictions.len() <= 1 {
1266 zeta_project.pending_predictions.push(PendingPrediction {
1267 id: pending_prediction_id,
1268 task,
1269 });
1270 } else if zeta_project.pending_predictions.len() == 2 {
1271 let pending_prediction = zeta_project.pending_predictions.pop().unwrap();
1272 zeta_project.pending_predictions.push(PendingPrediction {
1273 id: pending_prediction_id,
1274 task,
1275 });
1276 zeta_project.cancel_pending_prediction(pending_prediction, cx);
1277 }
1278 }
1279
1280 pub fn request_prediction(
1281 &mut self,
1282 project: &Entity<Project>,
1283 active_buffer: &Entity<Buffer>,
1284 position: language::Anchor,
1285 trigger: PredictEditsRequestTrigger,
1286 cx: &mut Context<Self>,
1287 ) -> Task<Result<Option<EditPredictionResult>>> {
1288 self.request_prediction_internal(
1289 project.clone(),
1290 active_buffer.clone(),
1291 position,
1292 trigger,
1293 cx.has_flag::<Zeta2FeatureFlag>(),
1294 cx,
1295 )
1296 }
1297
1298 fn request_prediction_internal(
1299 &mut self,
1300 project: Entity<Project>,
1301 active_buffer: Entity<Buffer>,
1302 position: language::Anchor,
1303 trigger: PredictEditsRequestTrigger,
1304 allow_jump: bool,
1305 cx: &mut Context<Self>,
1306 ) -> Task<Result<Option<EditPredictionResult>>> {
1307 const DIAGNOSTIC_LINES_RANGE: u32 = 20;
1308
1309 self.get_or_init_zeta_project(&project, cx);
1310 let zeta_project = self.projects.get(&project.entity_id()).unwrap();
1311 let events = zeta_project.events(cx);
1312 let has_events = !events.is_empty();
1313
1314 let snapshot = active_buffer.read(cx).snapshot();
1315 let cursor_point = position.to_point(&snapshot);
1316 let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1317 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1318 let diagnostic_search_range =
1319 Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1320
1321 let task = match self.edit_prediction_model {
1322 ZetaEditPredictionModel::Zeta1 => request_prediction_with_zeta1(
1323 self,
1324 &project,
1325 &active_buffer,
1326 snapshot.clone(),
1327 position,
1328 events,
1329 trigger,
1330 cx,
1331 ),
1332 ZetaEditPredictionModel::Zeta2 => self.request_prediction_with_zeta2(
1333 &project,
1334 &active_buffer,
1335 snapshot.clone(),
1336 position,
1337 events,
1338 trigger,
1339 cx,
1340 ),
1341 ZetaEditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(
1342 &project,
1343 &active_buffer,
1344 snapshot.clone(),
1345 position,
1346 events,
1347 &zeta_project.recent_paths,
1348 diagnostic_search_range.clone(),
1349 cx,
1350 ),
1351 };
1352
1353 cx.spawn(async move |this, cx| {
1354 let prediction = task.await?;
1355
1356 if prediction.is_none() && allow_jump {
1357 let cursor_point = position.to_point(&snapshot);
1358 if has_events
1359 && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1360 active_buffer.clone(),
1361 &snapshot,
1362 diagnostic_search_range,
1363 cursor_point,
1364 &project,
1365 cx,
1366 )
1367 .await?
1368 {
1369 return this
1370 .update(cx, |this, cx| {
1371 this.request_prediction_internal(
1372 project,
1373 jump_buffer,
1374 jump_position,
1375 trigger,
1376 false,
1377 cx,
1378 )
1379 })?
1380 .await;
1381 }
1382
1383 return anyhow::Ok(None);
1384 }
1385
1386 Ok(prediction)
1387 })
1388 }
1389
1390 async fn next_diagnostic_location(
1391 active_buffer: Entity<Buffer>,
1392 active_buffer_snapshot: &BufferSnapshot,
1393 active_buffer_diagnostic_search_range: Range<Point>,
1394 active_buffer_cursor_point: Point,
1395 project: &Entity<Project>,
1396 cx: &mut AsyncApp,
1397 ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1398 // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
1399 let mut jump_location = active_buffer_snapshot
1400 .diagnostic_groups(None)
1401 .into_iter()
1402 .filter_map(|(_, group)| {
1403 let range = &group.entries[group.primary_ix]
1404 .range
1405 .to_point(&active_buffer_snapshot);
1406 if range.overlaps(&active_buffer_diagnostic_search_range) {
1407 None
1408 } else {
1409 Some(range.start)
1410 }
1411 })
1412 .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
1413 .map(|position| {
1414 (
1415 active_buffer.clone(),
1416 active_buffer_snapshot.anchor_before(position),
1417 )
1418 });
1419
1420 if jump_location.is_none() {
1421 let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
1422 let file = buffer.file()?;
1423
1424 Some(ProjectPath {
1425 worktree_id: file.worktree_id(cx),
1426 path: file.path().clone(),
1427 })
1428 })?;
1429
1430 let buffer_task = project.update(cx, |project, cx| {
1431 let (path, _, _) = project
1432 .diagnostic_summaries(false, cx)
1433 .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
1434 .max_by_key(|(path, _, _)| {
1435 // find the buffer with errors that shares most parent directories
1436 path.path
1437 .components()
1438 .zip(
1439 active_buffer_path
1440 .as_ref()
1441 .map(|p| p.path.components())
1442 .unwrap_or_default(),
1443 )
1444 .take_while(|(a, b)| a == b)
1445 .count()
1446 })?;
1447
1448 Some(project.open_buffer(path, cx))
1449 })?;
1450
1451 if let Some(buffer_task) = buffer_task {
1452 let closest_buffer = buffer_task.await?;
1453
1454 jump_location = closest_buffer
1455 .read_with(cx, |buffer, _cx| {
1456 buffer
1457 .buffer_diagnostics(None)
1458 .into_iter()
1459 .min_by_key(|entry| entry.diagnostic.severity)
1460 .map(|entry| entry.range.start)
1461 })?
1462 .map(|position| (closest_buffer, position));
1463 }
1464 }
1465
1466 anyhow::Ok(jump_location)
1467 }
1468
1469 fn request_prediction_with_zeta2(
1470 &mut self,
1471 project: &Entity<Project>,
1472 active_buffer: &Entity<Buffer>,
1473 active_snapshot: BufferSnapshot,
1474 position: language::Anchor,
1475 events: Vec<Arc<Event>>,
1476 trigger: PredictEditsRequestTrigger,
1477 cx: &mut Context<Self>,
1478 ) -> Task<Result<Option<EditPredictionResult>>> {
1479 let project_state = self.projects.get(&project.entity_id());
1480
1481 let index_state = project_state.and_then(|state| {
1482 state
1483 .syntax_index
1484 .as_ref()
1485 .map(|syntax_index| syntax_index.read_with(cx, |index, _cx| index.state().clone()))
1486 });
1487 let options = self.options.clone();
1488 let buffer_snapshotted_at = Instant::now();
1489 let Some(excerpt_path) = active_snapshot
1490 .file()
1491 .map(|path| -> Arc<Path> { path.full_path(cx).into() })
1492 else {
1493 return Task::ready(Err(anyhow!("No file path for excerpt")));
1494 };
1495 let client = self.client.clone();
1496 let llm_token = self.llm_token.clone();
1497 let app_version = AppVersion::global(cx);
1498 let worktree_snapshots = project
1499 .read(cx)
1500 .worktrees(cx)
1501 .map(|worktree| worktree.read(cx).snapshot())
1502 .collect::<Vec<_>>();
1503 let debug_tx = self.debug_tx.clone();
1504
1505 let diagnostics = active_snapshot.diagnostic_sets().clone();
1506
1507 let file = active_buffer.read(cx).file();
1508 let parent_abs_path = project::File::from_dyn(file).and_then(|f| {
1509 let mut path = f.worktree.read(cx).absolutize(&f.path);
1510 if path.pop() { Some(path) } else { None }
1511 });
1512
1513 // TODO data collection
1514 let can_collect_data = file
1515 .as_ref()
1516 .map_or(false, |file| self.can_collect_file(project, file, cx));
1517
1518 let empty_context_files = HashMap::default();
1519 let context_files = project_state
1520 .and_then(|project_state| project_state.context.as_ref())
1521 .unwrap_or(&empty_context_files);
1522
1523 #[cfg(feature = "eval-support")]
1524 let parsed_fut = futures::future::join_all(
1525 context_files
1526 .keys()
1527 .map(|buffer| buffer.read(cx).parsing_idle()),
1528 );
1529
1530 let mut included_files = context_files
1531 .iter()
1532 .filter_map(|(buffer_entity, ranges)| {
1533 let buffer = buffer_entity.read(cx);
1534 Some((
1535 buffer_entity.clone(),
1536 buffer.snapshot(),
1537 buffer.file()?.full_path(cx).into(),
1538 ranges.clone(),
1539 ))
1540 })
1541 .collect::<Vec<_>>();
1542
1543 included_files.sort_by(|(_, _, path_a, ranges_a), (_, _, path_b, ranges_b)| {
1544 (path_a, ranges_a.len()).cmp(&(path_b, ranges_b.len()))
1545 });
1546
1547 #[cfg(feature = "eval-support")]
1548 let eval_cache = self.eval_cache.clone();
1549
1550 let request_task = cx.background_spawn({
1551 let active_buffer = active_buffer.clone();
1552 async move {
1553 #[cfg(feature = "eval-support")]
1554 parsed_fut.await;
1555
1556 let index_state = if let Some(index_state) = index_state {
1557 Some(index_state.lock_owned().await)
1558 } else {
1559 None
1560 };
1561
1562 let cursor_offset = position.to_offset(&active_snapshot);
1563 let cursor_point = cursor_offset.to_point(&active_snapshot);
1564
1565 let before_retrieval = Instant::now();
1566
1567 let (diagnostic_groups, diagnostic_groups_truncated) =
1568 Self::gather_nearby_diagnostics(
1569 cursor_offset,
1570 &diagnostics,
1571 &active_snapshot,
1572 options.max_diagnostic_bytes,
1573 );
1574
1575 let cloud_request = match options.context {
1576 ContextMode::Agentic(context_options) => {
1577 let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
1578 cursor_point,
1579 &active_snapshot,
1580 &context_options.excerpt,
1581 index_state.as_deref(),
1582 ) else {
1583 return Ok((None, None));
1584 };
1585
1586 let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start)
1587 ..active_snapshot.anchor_before(excerpt.range.end);
1588
1589 if let Some(buffer_ix) =
1590 included_files.iter().position(|(_, snapshot, _, _)| {
1591 snapshot.remote_id() == active_snapshot.remote_id()
1592 })
1593 {
1594 let (_, buffer, _, ranges) = &mut included_files[buffer_ix];
1595 ranges.push(excerpt_anchor_range);
1596 retrieval_search::merge_anchor_ranges(ranges, buffer);
1597 let last_ix = included_files.len() - 1;
1598 included_files.swap(buffer_ix, last_ix);
1599 } else {
1600 included_files.push((
1601 active_buffer.clone(),
1602 active_snapshot.clone(),
1603 excerpt_path.clone(),
1604 vec![excerpt_anchor_range],
1605 ));
1606 }
1607
1608 let included_files = included_files
1609 .iter()
1610 .map(|(_, snapshot, path, ranges)| {
1611 let ranges = ranges
1612 .iter()
1613 .map(|range| {
1614 let point_range = range.to_point(&snapshot);
1615 Line(point_range.start.row)..Line(point_range.end.row)
1616 })
1617 .collect::<Vec<_>>();
1618 let excerpts = assemble_excerpts(&snapshot, ranges);
1619 predict_edits_v3::IncludedFile {
1620 path: path.clone(),
1621 max_row: Line(snapshot.max_point().row),
1622 excerpts,
1623 }
1624 })
1625 .collect::<Vec<_>>();
1626
1627 predict_edits_v3::PredictEditsRequest {
1628 excerpt_path,
1629 excerpt: String::new(),
1630 excerpt_line_range: Line(0)..Line(0),
1631 excerpt_range: 0..0,
1632 cursor_point: predict_edits_v3::Point {
1633 line: predict_edits_v3::Line(cursor_point.row),
1634 column: cursor_point.column,
1635 },
1636 included_files,
1637 referenced_declarations: vec![],
1638 events,
1639 can_collect_data,
1640 diagnostic_groups,
1641 diagnostic_groups_truncated,
1642 debug_info: debug_tx.is_some(),
1643 prompt_max_bytes: Some(options.max_prompt_bytes),
1644 prompt_format: options.prompt_format,
1645 // TODO [zeta2]
1646 signatures: vec![],
1647 excerpt_parent: None,
1648 git_info: None,
1649 trigger,
1650 }
1651 }
1652 ContextMode::Syntax(context_options) => {
1653 let Some(context) = EditPredictionContext::gather_context(
1654 cursor_point,
1655 &active_snapshot,
1656 parent_abs_path.as_deref(),
1657 &context_options,
1658 index_state.as_deref(),
1659 ) else {
1660 return Ok((None, None));
1661 };
1662
1663 make_syntax_context_cloud_request(
1664 excerpt_path,
1665 context,
1666 events,
1667 can_collect_data,
1668 diagnostic_groups,
1669 diagnostic_groups_truncated,
1670 None,
1671 debug_tx.is_some(),
1672 &worktree_snapshots,
1673 index_state.as_deref(),
1674 Some(options.max_prompt_bytes),
1675 options.prompt_format,
1676 trigger,
1677 )
1678 }
1679 };
1680
1681 let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request);
1682
1683 let inputs = EditPredictionInputs {
1684 included_files: cloud_request.included_files,
1685 events: cloud_request.events,
1686 cursor_point: cloud_request.cursor_point,
1687 cursor_path: cloud_request.excerpt_path,
1688 };
1689
1690 let retrieval_time = Instant::now() - before_retrieval;
1691
1692 let debug_response_tx = if let Some(debug_tx) = &debug_tx {
1693 let (response_tx, response_rx) = oneshot::channel();
1694
1695 debug_tx
1696 .unbounded_send(ZetaDebugInfo::EditPredictionRequested(
1697 ZetaEditPredictionDebugInfo {
1698 inputs: inputs.clone(),
1699 retrieval_time,
1700 buffer: active_buffer.downgrade(),
1701 local_prompt: match prompt_result.as_ref() {
1702 Ok((prompt, _)) => Ok(prompt.clone()),
1703 Err(err) => Err(err.to_string()),
1704 },
1705 position,
1706 response_rx,
1707 },
1708 ))
1709 .ok();
1710 Some(response_tx)
1711 } else {
1712 None
1713 };
1714
1715 if cfg!(debug_assertions) && env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
1716 if let Some(debug_response_tx) = debug_response_tx {
1717 debug_response_tx
1718 .send((Err("Request skipped".to_string()), Duration::ZERO))
1719 .ok();
1720 }
1721 anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
1722 }
1723
1724 let (prompt, _) = prompt_result?;
1725 let generation_params =
1726 cloud_zeta2_prompt::generation_params(cloud_request.prompt_format);
1727 let request = open_ai::Request {
1728 model: EDIT_PREDICTIONS_MODEL_ID.clone(),
1729 messages: vec![open_ai::RequestMessage::User {
1730 content: open_ai::MessageContent::Plain(prompt),
1731 }],
1732 stream: false,
1733 max_completion_tokens: None,
1734 stop: generation_params.stop.unwrap_or_default(),
1735 temperature: generation_params.temperature.unwrap_or(0.7),
1736 tool_choice: None,
1737 parallel_tool_calls: None,
1738 tools: vec![],
1739 prompt_cache_key: None,
1740 reasoning_effort: None,
1741 };
1742
1743 log::trace!("Sending edit prediction request");
1744
1745 let before_request = Instant::now();
1746 let response = Self::send_raw_llm_request(
1747 request,
1748 client,
1749 llm_token,
1750 app_version,
1751 #[cfg(feature = "eval-support")]
1752 eval_cache,
1753 #[cfg(feature = "eval-support")]
1754 EvalCacheEntryKind::Prediction,
1755 )
1756 .await;
1757 let received_response_at = Instant::now();
1758 let request_time = received_response_at - before_request;
1759
1760 log::trace!("Got edit prediction response");
1761
1762 if let Some(debug_response_tx) = debug_response_tx {
1763 debug_response_tx
1764 .send((
1765 response
1766 .as_ref()
1767 .map_err(|err| err.to_string())
1768 .map(|response| response.0.clone()),
1769 request_time,
1770 ))
1771 .ok();
1772 }
1773
1774 let (res, usage) = response?;
1775 let request_id = EditPredictionId(res.id.clone().into());
1776 let Some(mut output_text) = text_from_response(res) else {
1777 return Ok((Some((request_id, None)), usage));
1778 };
1779
1780 if output_text.contains(CURSOR_MARKER) {
1781 log::trace!("Stripping out {CURSOR_MARKER} from response");
1782 output_text = output_text.replace(CURSOR_MARKER, "");
1783 }
1784
1785 let get_buffer_from_context = |path: &Path| {
1786 included_files
1787 .iter()
1788 .find_map(|(_, buffer, probe_path, ranges)| {
1789 if probe_path.as_ref() == path {
1790 Some((buffer, ranges.as_slice()))
1791 } else {
1792 None
1793 }
1794 })
1795 };
1796
1797 let (edited_buffer_snapshot, edits) = match options.prompt_format {
1798 PromptFormat::NumLinesUniDiff => {
1799 // TODO: Implement parsing of multi-file diffs
1800 crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
1801 }
1802 PromptFormat::Minimal
1803 | PromptFormat::MinimalQwen
1804 | PromptFormat::SeedCoder1120 => {
1805 if output_text.contains("--- a/\n+++ b/\nNo edits") {
1806 let edits = vec![];
1807 (&active_snapshot, edits)
1808 } else {
1809 crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
1810 }
1811 }
1812 PromptFormat::OldTextNewText => {
1813 crate::xml_edits::parse_xml_edits(&output_text, get_buffer_from_context)
1814 .await?
1815 }
1816 _ => {
1817 bail!("unsupported prompt format {}", options.prompt_format)
1818 }
1819 };
1820
1821 let edited_buffer = included_files
1822 .iter()
1823 .find_map(|(buffer, snapshot, _, _)| {
1824 if snapshot.remote_id() == edited_buffer_snapshot.remote_id() {
1825 Some(buffer.clone())
1826 } else {
1827 None
1828 }
1829 })
1830 .context("Failed to find buffer in included_buffers")?;
1831
1832 anyhow::Ok((
1833 Some((
1834 request_id,
1835 Some((
1836 inputs,
1837 edited_buffer,
1838 edited_buffer_snapshot.clone(),
1839 edits,
1840 received_response_at,
1841 )),
1842 )),
1843 usage,
1844 ))
1845 }
1846 });
1847
1848 cx.spawn({
1849 async move |this, cx| {
1850 let Some((id, prediction)) =
1851 Self::handle_api_response(&this, request_task.await, cx)?
1852 else {
1853 return Ok(None);
1854 };
1855
1856 let Some((
1857 inputs,
1858 edited_buffer,
1859 edited_buffer_snapshot,
1860 edits,
1861 received_response_at,
1862 )) = prediction
1863 else {
1864 return Ok(Some(EditPredictionResult {
1865 id,
1866 prediction: Err(EditPredictionRejectReason::Empty),
1867 }));
1868 };
1869
1870 // TODO telemetry: duration, etc
1871 Ok(Some(
1872 EditPredictionResult::new(
1873 id,
1874 &edited_buffer,
1875 &edited_buffer_snapshot,
1876 edits.into(),
1877 buffer_snapshotted_at,
1878 received_response_at,
1879 inputs,
1880 cx,
1881 )
1882 .await,
1883 ))
1884 }
1885 })
1886 }
1887
1888 async fn send_raw_llm_request(
1889 request: open_ai::Request,
1890 client: Arc<Client>,
1891 llm_token: LlmApiToken,
1892 app_version: Version,
1893 #[cfg(feature = "eval-support")] eval_cache: Option<Arc<dyn EvalCache>>,
1894 #[cfg(feature = "eval-support")] eval_cache_kind: EvalCacheEntryKind,
1895 ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
1896 let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
1897 http_client::Url::parse(&predict_edits_url)?
1898 } else {
1899 client
1900 .http_client()
1901 .build_zed_llm_url("/predict_edits/raw", &[])?
1902 };
1903
1904 #[cfg(feature = "eval-support")]
1905 let cache_key = if let Some(cache) = eval_cache {
1906 use collections::FxHasher;
1907 use std::hash::{Hash, Hasher};
1908
1909 let mut hasher = FxHasher::default();
1910 url.hash(&mut hasher);
1911 let request_str = serde_json::to_string_pretty(&request)?;
1912 request_str.hash(&mut hasher);
1913 let hash = hasher.finish();
1914
1915 let key = (eval_cache_kind, hash);
1916 if let Some(response_str) = cache.read(key) {
1917 return Ok((serde_json::from_str(&response_str)?, None));
1918 }
1919
1920 Some((cache, request_str, key))
1921 } else {
1922 None
1923 };
1924
1925 let (response, usage) = Self::send_api_request(
1926 |builder| {
1927 let req = builder
1928 .uri(url.as_ref())
1929 .body(serde_json::to_string(&request)?.into());
1930 Ok(req?)
1931 },
1932 client,
1933 llm_token,
1934 app_version,
1935 )
1936 .await?;
1937
1938 #[cfg(feature = "eval-support")]
1939 if let Some((cache, request, key)) = cache_key {
1940 cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
1941 }
1942
1943 Ok((response, usage))
1944 }
1945
1946 fn handle_api_response<T>(
1947 this: &WeakEntity<Self>,
1948 response: Result<(T, Option<EditPredictionUsage>)>,
1949 cx: &mut gpui::AsyncApp,
1950 ) -> Result<T> {
1951 match response {
1952 Ok((data, usage)) => {
1953 if let Some(usage) = usage {
1954 this.update(cx, |this, cx| {
1955 this.user_store.update(cx, |user_store, cx| {
1956 user_store.update_edit_prediction_usage(usage, cx);
1957 });
1958 })
1959 .ok();
1960 }
1961 Ok(data)
1962 }
1963 Err(err) => {
1964 if err.is::<ZedUpdateRequiredError>() {
1965 cx.update(|cx| {
1966 this.update(cx, |this, _cx| {
1967 this.update_required = true;
1968 })
1969 .ok();
1970
1971 let error_message: SharedString = err.to_string().into();
1972 show_app_notification(
1973 NotificationId::unique::<ZedUpdateRequiredError>(),
1974 cx,
1975 move |cx| {
1976 cx.new(|cx| {
1977 ErrorMessagePrompt::new(error_message.clone(), cx)
1978 .with_link_button("Update Zed", "https://zed.dev/releases")
1979 })
1980 },
1981 );
1982 })
1983 .ok();
1984 }
1985 Err(err)
1986 }
1987 }
1988 }
1989
1990 async fn send_api_request<Res>(
1991 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1992 client: Arc<Client>,
1993 llm_token: LlmApiToken,
1994 app_version: Version,
1995 ) -> Result<(Res, Option<EditPredictionUsage>)>
1996 where
1997 Res: DeserializeOwned,
1998 {
1999 let http_client = client.http_client();
2000 let mut token = llm_token.acquire(&client).await?;
2001 let mut did_retry = false;
2002
2003 loop {
2004 let request_builder = http_client::Request::builder().method(Method::POST);
2005
2006 let request = build(
2007 request_builder
2008 .header("Content-Type", "application/json")
2009 .header("Authorization", format!("Bearer {}", token))
2010 .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
2011 )?;
2012
2013 let mut response = http_client.send(request).await?;
2014
2015 if let Some(minimum_required_version) = response
2016 .headers()
2017 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2018 .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2019 {
2020 anyhow::ensure!(
2021 app_version >= minimum_required_version,
2022 ZedUpdateRequiredError {
2023 minimum_version: minimum_required_version
2024 }
2025 );
2026 }
2027
2028 if response.status().is_success() {
2029 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2030
2031 let mut body = Vec::new();
2032 response.body_mut().read_to_end(&mut body).await?;
2033 return Ok((serde_json::from_slice(&body)?, usage));
2034 } else if !did_retry
2035 && response
2036 .headers()
2037 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
2038 .is_some()
2039 {
2040 did_retry = true;
2041 token = llm_token.refresh(&client).await?;
2042 } else {
2043 let mut body = String::new();
2044 response.body_mut().read_to_string(&mut body).await?;
2045 anyhow::bail!(
2046 "Request failed with status: {:?}\nBody: {}",
2047 response.status(),
2048 body
2049 );
2050 }
2051 }
2052 }
2053
2054 pub const CONTEXT_RETRIEVAL_IDLE_DURATION: Duration = Duration::from_secs(10);
2055 pub const CONTEXT_RETRIEVAL_DEBOUNCE_DURATION: Duration = Duration::from_secs(3);
2056
2057 // Refresh the related excerpts when the user just beguns editing after
2058 // an idle period, and after they pause editing.
2059 fn refresh_context_if_needed(
2060 &mut self,
2061 project: &Entity<Project>,
2062 buffer: &Entity<language::Buffer>,
2063 cursor_position: language::Anchor,
2064 cx: &mut Context<Self>,
2065 ) {
2066 if !matches!(&self.options().context, ContextMode::Agentic { .. }) {
2067 return;
2068 }
2069
2070 let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
2071 return;
2072 };
2073
2074 let now = Instant::now();
2075 let was_idle = zeta_project
2076 .refresh_context_timestamp
2077 .map_or(true, |timestamp| {
2078 now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION
2079 });
2080 zeta_project.refresh_context_timestamp = Some(now);
2081 zeta_project.refresh_context_debounce_task = Some(cx.spawn({
2082 let buffer = buffer.clone();
2083 let project = project.clone();
2084 async move |this, cx| {
2085 if was_idle {
2086 log::debug!("refetching edit prediction context after idle");
2087 } else {
2088 cx.background_executor()
2089 .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION)
2090 .await;
2091 log::debug!("refetching edit prediction context after pause");
2092 }
2093 this.update(cx, |this, cx| {
2094 let task = this.refresh_context(project.clone(), buffer, cursor_position, cx);
2095
2096 if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
2097 zeta_project.refresh_context_task = Some(task.log_err());
2098 };
2099 })
2100 .ok()
2101 }
2102 }));
2103 }
2104
2105 // Refresh the related excerpts asynchronously. Ensure the task runs to completion,
2106 // and avoid spawning more than one concurrent task.
2107 pub fn refresh_context(
2108 &mut self,
2109 project: Entity<Project>,
2110 buffer: Entity<language::Buffer>,
2111 cursor_position: language::Anchor,
2112 cx: &mut Context<Self>,
2113 ) -> Task<Result<()>> {
2114 let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
2115 return Task::ready(anyhow::Ok(()));
2116 };
2117
2118 let ContextMode::Agentic(options) = &self.options().context else {
2119 return Task::ready(anyhow::Ok(()));
2120 };
2121
2122 let snapshot = buffer.read(cx).snapshot();
2123 let cursor_point = cursor_position.to_point(&snapshot);
2124 let Some(cursor_excerpt) = EditPredictionExcerpt::select_from_buffer(
2125 cursor_point,
2126 &snapshot,
2127 &options.excerpt,
2128 None,
2129 ) else {
2130 return Task::ready(Ok(()));
2131 };
2132
2133 let app_version = AppVersion::global(cx);
2134 let client = self.client.clone();
2135 let llm_token = self.llm_token.clone();
2136 let debug_tx = self.debug_tx.clone();
2137 let current_file_path: Arc<Path> = snapshot
2138 .file()
2139 .map(|f| f.full_path(cx).into())
2140 .unwrap_or_else(|| Path::new("untitled").into());
2141
2142 let prompt = match cloud_zeta2_prompt::retrieval_prompt::build_prompt(
2143 predict_edits_v3::PlanContextRetrievalRequest {
2144 excerpt: cursor_excerpt.text(&snapshot).body,
2145 excerpt_path: current_file_path,
2146 excerpt_line_range: cursor_excerpt.line_range,
2147 cursor_file_max_row: Line(snapshot.max_point().row),
2148 events: zeta_project.events(cx),
2149 },
2150 ) {
2151 Ok(prompt) => prompt,
2152 Err(err) => {
2153 return Task::ready(Err(err));
2154 }
2155 };
2156
2157 if let Some(debug_tx) = &debug_tx {
2158 debug_tx
2159 .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted(
2160 ZetaContextRetrievalStartedDebugInfo {
2161 project: project.clone(),
2162 timestamp: Instant::now(),
2163 search_prompt: prompt.clone(),
2164 },
2165 ))
2166 .ok();
2167 }
2168
2169 pub static TOOL_SCHEMA: LazyLock<(serde_json::Value, String)> = LazyLock::new(|| {
2170 let schema = language_model::tool_schema::root_schema_for::<SearchToolInput>(
2171 language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset,
2172 );
2173
2174 let description = schema
2175 .get("description")
2176 .and_then(|description| description.as_str())
2177 .unwrap()
2178 .to_string();
2179
2180 (schema.into(), description)
2181 });
2182
2183 let (tool_schema, tool_description) = TOOL_SCHEMA.clone();
2184
2185 let request = open_ai::Request {
2186 model: CONTEXT_RETRIEVAL_MODEL_ID.clone(),
2187 messages: vec![open_ai::RequestMessage::User {
2188 content: open_ai::MessageContent::Plain(prompt),
2189 }],
2190 stream: false,
2191 max_completion_tokens: None,
2192 stop: Default::default(),
2193 temperature: 0.7,
2194 tool_choice: None,
2195 parallel_tool_calls: None,
2196 tools: vec![open_ai::ToolDefinition::Function {
2197 function: FunctionDefinition {
2198 name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME.to_string(),
2199 description: Some(tool_description),
2200 parameters: Some(tool_schema),
2201 },
2202 }],
2203 prompt_cache_key: None,
2204 reasoning_effort: None,
2205 };
2206
2207 #[cfg(feature = "eval-support")]
2208 let eval_cache = self.eval_cache.clone();
2209
2210 cx.spawn(async move |this, cx| {
2211 log::trace!("Sending search planning request");
2212 let response = Self::send_raw_llm_request(
2213 request,
2214 client,
2215 llm_token,
2216 app_version,
2217 #[cfg(feature = "eval-support")]
2218 eval_cache.clone(),
2219 #[cfg(feature = "eval-support")]
2220 EvalCacheEntryKind::Context,
2221 )
2222 .await;
2223 let mut response = Self::handle_api_response(&this, response, cx)?;
2224 log::trace!("Got search planning response");
2225
2226 let choice = response
2227 .choices
2228 .pop()
2229 .context("No choices in retrieval response")?;
2230 let open_ai::RequestMessage::Assistant {
2231 content: _,
2232 tool_calls,
2233 } = choice.message
2234 else {
2235 anyhow::bail!("Retrieval response didn't include an assistant message");
2236 };
2237
2238 let mut queries: Vec<SearchToolQuery> = Vec::new();
2239 for tool_call in tool_calls {
2240 let open_ai::ToolCallContent::Function { function } = tool_call.content;
2241 if function.name != cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME {
2242 log::warn!(
2243 "Context retrieval response tried to call an unknown tool: {}",
2244 function.name
2245 );
2246
2247 continue;
2248 }
2249
2250 let input: SearchToolInput = serde_json::from_str(&function.arguments)
2251 .with_context(|| format!("invalid search json {}", &function.arguments))?;
2252 queries.extend(input.queries);
2253 }
2254
2255 if let Some(debug_tx) = &debug_tx {
2256 debug_tx
2257 .unbounded_send(ZetaDebugInfo::SearchQueriesGenerated(
2258 ZetaSearchQueryDebugInfo {
2259 project: project.clone(),
2260 timestamp: Instant::now(),
2261 search_queries: queries.clone(),
2262 },
2263 ))
2264 .ok();
2265 }
2266
2267 log::trace!("Running retrieval search: {queries:#?}");
2268
2269 let related_excerpts_result = retrieval_search::run_retrieval_searches(
2270 queries,
2271 project.clone(),
2272 #[cfg(feature = "eval-support")]
2273 eval_cache,
2274 cx,
2275 )
2276 .await;
2277
2278 log::trace!("Search queries executed");
2279
2280 if let Some(debug_tx) = &debug_tx {
2281 debug_tx
2282 .unbounded_send(ZetaDebugInfo::SearchQueriesExecuted(
2283 ZetaContextRetrievalDebugInfo {
2284 project: project.clone(),
2285 timestamp: Instant::now(),
2286 },
2287 ))
2288 .ok();
2289 }
2290
2291 this.update(cx, |this, _cx| {
2292 let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
2293 return Ok(());
2294 };
2295 zeta_project.refresh_context_task.take();
2296 if let Some(debug_tx) = &this.debug_tx {
2297 debug_tx
2298 .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished(
2299 ZetaContextRetrievalDebugInfo {
2300 project,
2301 timestamp: Instant::now(),
2302 },
2303 ))
2304 .ok();
2305 }
2306 match related_excerpts_result {
2307 Ok(excerpts) => {
2308 zeta_project.context = Some(excerpts);
2309 Ok(())
2310 }
2311 Err(error) => Err(error),
2312 }
2313 })?
2314 })
2315 }
2316
2317 pub fn set_context(
2318 &mut self,
2319 project: Entity<Project>,
2320 context: HashMap<Entity<Buffer>, Vec<Range<Anchor>>>,
2321 ) {
2322 if let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) {
2323 zeta_project.context = Some(context);
2324 }
2325 }
2326
2327 fn gather_nearby_diagnostics(
2328 cursor_offset: usize,
2329 diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
2330 snapshot: &BufferSnapshot,
2331 max_diagnostics_bytes: usize,
2332 ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
2333 // TODO: Could make this more efficient
2334 let mut diagnostic_groups = Vec::new();
2335 for (language_server_id, diagnostics) in diagnostic_sets {
2336 let mut groups = Vec::new();
2337 diagnostics.groups(*language_server_id, &mut groups, &snapshot);
2338 diagnostic_groups.extend(
2339 groups
2340 .into_iter()
2341 .map(|(_, group)| group.resolve::<usize>(&snapshot)),
2342 );
2343 }
2344
2345 // sort by proximity to cursor
2346 diagnostic_groups.sort_by_key(|group| {
2347 let range = &group.entries[group.primary_ix].range;
2348 if range.start >= cursor_offset {
2349 range.start - cursor_offset
2350 } else if cursor_offset >= range.end {
2351 cursor_offset - range.end
2352 } else {
2353 (cursor_offset - range.start).min(range.end - cursor_offset)
2354 }
2355 });
2356
2357 let mut results = Vec::new();
2358 let mut diagnostic_groups_truncated = false;
2359 let mut diagnostics_byte_count = 0;
2360 for group in diagnostic_groups {
2361 let raw_value = serde_json::value::to_raw_value(&group).unwrap();
2362 diagnostics_byte_count += raw_value.get().len();
2363 if diagnostics_byte_count > max_diagnostics_bytes {
2364 diagnostic_groups_truncated = true;
2365 break;
2366 }
2367 results.push(predict_edits_v3::DiagnosticGroup(raw_value));
2368 }
2369
2370 (results, diagnostic_groups_truncated)
2371 }
2372
2373 // TODO: Dedupe with similar code in request_prediction?
2374 pub fn cloud_request_for_zeta_cli(
2375 &mut self,
2376 project: &Entity<Project>,
2377 buffer: &Entity<Buffer>,
2378 position: language::Anchor,
2379 cx: &mut Context<Self>,
2380 ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
2381 let project_state = self.projects.get(&project.entity_id());
2382
2383 let index_state = project_state.and_then(|state| {
2384 state
2385 .syntax_index
2386 .as_ref()
2387 .map(|index| index.read_with(cx, |index, _cx| index.state().clone()))
2388 });
2389 let options = self.options.clone();
2390 let snapshot = buffer.read(cx).snapshot();
2391 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
2392 return Task::ready(Err(anyhow!("No file path for excerpt")));
2393 };
2394 let worktree_snapshots = project
2395 .read(cx)
2396 .worktrees(cx)
2397 .map(|worktree| worktree.read(cx).snapshot())
2398 .collect::<Vec<_>>();
2399
2400 let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
2401 let mut path = f.worktree.read(cx).absolutize(&f.path);
2402 if path.pop() { Some(path) } else { None }
2403 });
2404
2405 cx.background_spawn(async move {
2406 let index_state = if let Some(index_state) = index_state {
2407 Some(index_state.lock_owned().await)
2408 } else {
2409 None
2410 };
2411
2412 let cursor_point = position.to_point(&snapshot);
2413
2414 let debug_info = true;
2415 EditPredictionContext::gather_context(
2416 cursor_point,
2417 &snapshot,
2418 parent_abs_path.as_deref(),
2419 match &options.context {
2420 ContextMode::Agentic(_) => {
2421 // TODO
2422 panic!("Llm mode not supported in zeta cli yet");
2423 }
2424 ContextMode::Syntax(edit_prediction_context_options) => {
2425 edit_prediction_context_options
2426 }
2427 },
2428 index_state.as_deref(),
2429 )
2430 .context("Failed to select excerpt")
2431 .map(|context| {
2432 make_syntax_context_cloud_request(
2433 excerpt_path.into(),
2434 context,
2435 // TODO pass everything
2436 Vec::new(),
2437 false,
2438 Vec::new(),
2439 false,
2440 None,
2441 debug_info,
2442 &worktree_snapshots,
2443 index_state.as_deref(),
2444 Some(options.max_prompt_bytes),
2445 options.prompt_format,
2446 PredictEditsRequestTrigger::Other,
2447 )
2448 })
2449 })
2450 }
2451
2452 pub fn wait_for_initial_indexing(
2453 &mut self,
2454 project: &Entity<Project>,
2455 cx: &mut Context<Self>,
2456 ) -> Task<Result<()>> {
2457 let zeta_project = self.get_or_init_zeta_project(project, cx);
2458 if let Some(syntax_index) = &zeta_project.syntax_index {
2459 syntax_index.read(cx).wait_for_initial_file_indexing(cx)
2460 } else {
2461 Task::ready(Ok(()))
2462 }
2463 }
2464
2465 fn is_file_open_source(
2466 &self,
2467 project: &Entity<Project>,
2468 file: &Arc<dyn File>,
2469 cx: &App,
2470 ) -> bool {
2471 if !file.is_local() || file.is_private() {
2472 return false;
2473 }
2474 let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
2475 return false;
2476 };
2477 zeta_project
2478 .license_detection_watchers
2479 .get(&file.worktree_id(cx))
2480 .as_ref()
2481 .is_some_and(|watcher| watcher.is_project_open_source())
2482 }
2483
2484 fn can_collect_file(&self, project: &Entity<Project>, file: &Arc<dyn File>, cx: &App) -> bool {
2485 self.data_collection_choice.is_enabled() && self.is_file_open_source(project, file, cx)
2486 }
2487
2488 fn can_collect_events(&self, events: &[Arc<Event>]) -> bool {
2489 if !self.data_collection_choice.is_enabled() {
2490 return false;
2491 }
2492 events.iter().all(|event| {
2493 matches!(
2494 event.as_ref(),
2495 Event::BufferChange {
2496 in_open_source_repo: true,
2497 ..
2498 }
2499 )
2500 })
2501 }
2502
2503 fn load_data_collection_choice() -> DataCollectionChoice {
2504 let choice = KEY_VALUE_STORE
2505 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2506 .log_err()
2507 .flatten();
2508
2509 match choice.as_deref() {
2510 Some("true") => DataCollectionChoice::Enabled,
2511 Some("false") => DataCollectionChoice::Disabled,
2512 Some(_) => {
2513 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2514 DataCollectionChoice::NotAnswered
2515 }
2516 None => DataCollectionChoice::NotAnswered,
2517 }
2518 }
2519
2520 pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2521 self.shown_predictions.iter()
2522 }
2523
2524 pub fn shown_completions_len(&self) -> usize {
2525 self.shown_predictions.len()
2526 }
2527
2528 pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2529 self.rated_predictions.contains(id)
2530 }
2531
2532 pub fn rate_prediction(
2533 &mut self,
2534 prediction: &EditPrediction,
2535 rating: EditPredictionRating,
2536 feedback: String,
2537 cx: &mut Context<Self>,
2538 ) {
2539 self.rated_predictions.insert(prediction.id.clone());
2540 telemetry::event!(
2541 "Edit Prediction Rated",
2542 rating,
2543 inputs = prediction.inputs,
2544 output = prediction.edit_preview.as_unified_diff(&prediction.edits),
2545 feedback
2546 );
2547 self.client.telemetry().flush_events().detach();
2548 cx.notify();
2549 }
2550}
2551
2552pub fn text_from_response(mut res: open_ai::Response) -> Option<String> {
2553 let choice = res.choices.pop()?;
2554 let output_text = match choice.message {
2555 open_ai::RequestMessage::Assistant {
2556 content: Some(open_ai::MessageContent::Plain(content)),
2557 ..
2558 } => content,
2559 open_ai::RequestMessage::Assistant {
2560 content: Some(open_ai::MessageContent::Multipart(mut content)),
2561 ..
2562 } => {
2563 if content.is_empty() {
2564 log::error!("No output from Baseten completion response");
2565 return None;
2566 }
2567
2568 match content.remove(0) {
2569 open_ai::MessagePart::Text { text } => text,
2570 open_ai::MessagePart::Image { .. } => {
2571 log::error!("Expected text, got an image");
2572 return None;
2573 }
2574 }
2575 }
2576 _ => {
2577 log::error!("Invalid response message: {:?}", choice.message);
2578 return None;
2579 }
2580 };
2581 Some(output_text)
2582}
2583
2584#[derive(Error, Debug)]
2585#[error(
2586 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2587)]
2588pub struct ZedUpdateRequiredError {
2589 minimum_version: Version,
2590}
2591
2592fn make_syntax_context_cloud_request(
2593 excerpt_path: Arc<Path>,
2594 context: EditPredictionContext,
2595 events: Vec<Arc<predict_edits_v3::Event>>,
2596 can_collect_data: bool,
2597 diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
2598 diagnostic_groups_truncated: bool,
2599 git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
2600 debug_info: bool,
2601 worktrees: &Vec<worktree::Snapshot>,
2602 index_state: Option<&SyntaxIndexState>,
2603 prompt_max_bytes: Option<usize>,
2604 prompt_format: PromptFormat,
2605 trigger: PredictEditsRequestTrigger,
2606) -> predict_edits_v3::PredictEditsRequest {
2607 let mut signatures = Vec::new();
2608 let mut declaration_to_signature_index = HashMap::default();
2609 let mut referenced_declarations = Vec::new();
2610
2611 for snippet in context.declarations {
2612 let project_entry_id = snippet.declaration.project_entry_id();
2613 let Some(path) = worktrees.iter().find_map(|worktree| {
2614 worktree.entry_for_id(project_entry_id).map(|entry| {
2615 let mut full_path = RelPathBuf::new();
2616 full_path.push(worktree.root_name());
2617 full_path.push(&entry.path);
2618 full_path
2619 })
2620 }) else {
2621 continue;
2622 };
2623
2624 let parent_index = index_state.and_then(|index_state| {
2625 snippet.declaration.parent().and_then(|parent| {
2626 add_signature(
2627 parent,
2628 &mut declaration_to_signature_index,
2629 &mut signatures,
2630 index_state,
2631 )
2632 })
2633 });
2634
2635 let (text, text_is_truncated) = snippet.declaration.item_text();
2636 referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
2637 path: path.as_std_path().into(),
2638 text: text.into(),
2639 range: snippet.declaration.item_line_range(),
2640 text_is_truncated,
2641 signature_range: snippet.declaration.signature_range_in_item_text(),
2642 parent_index,
2643 signature_score: snippet.score(DeclarationStyle::Signature),
2644 declaration_score: snippet.score(DeclarationStyle::Declaration),
2645 score_components: snippet.components,
2646 });
2647 }
2648
2649 let excerpt_parent = index_state.and_then(|index_state| {
2650 context
2651 .excerpt
2652 .parent_declarations
2653 .last()
2654 .and_then(|(parent, _)| {
2655 add_signature(
2656 *parent,
2657 &mut declaration_to_signature_index,
2658 &mut signatures,
2659 index_state,
2660 )
2661 })
2662 });
2663
2664 predict_edits_v3::PredictEditsRequest {
2665 excerpt_path,
2666 excerpt: context.excerpt_text.body,
2667 excerpt_line_range: context.excerpt.line_range,
2668 excerpt_range: context.excerpt.range,
2669 cursor_point: predict_edits_v3::Point {
2670 line: predict_edits_v3::Line(context.cursor_point.row),
2671 column: context.cursor_point.column,
2672 },
2673 referenced_declarations,
2674 included_files: vec![],
2675 signatures,
2676 excerpt_parent,
2677 events,
2678 can_collect_data,
2679 diagnostic_groups,
2680 diagnostic_groups_truncated,
2681 git_info,
2682 debug_info,
2683 prompt_max_bytes,
2684 prompt_format,
2685 trigger,
2686 }
2687}
2688
2689fn add_signature(
2690 declaration_id: DeclarationId,
2691 declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
2692 signatures: &mut Vec<Signature>,
2693 index: &SyntaxIndexState,
2694) -> Option<usize> {
2695 if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
2696 return Some(*signature_index);
2697 }
2698 let Some(parent_declaration) = index.declaration(declaration_id) else {
2699 log::error!("bug: missing parent declaration");
2700 return None;
2701 };
2702 let parent_index = parent_declaration.parent().and_then(|parent| {
2703 add_signature(parent, declaration_to_signature_index, signatures, index)
2704 });
2705 let (text, text_is_truncated) = parent_declaration.signature_text();
2706 let signature_index = signatures.len();
2707 signatures.push(Signature {
2708 text: text.into(),
2709 text_is_truncated,
2710 parent_index,
2711 range: parent_declaration.signature_line_range(),
2712 });
2713 declaration_to_signature_index.insert(declaration_id, signature_index);
2714 Some(signature_index)
2715}
2716
2717#[cfg(feature = "eval-support")]
2718pub type EvalCacheKey = (EvalCacheEntryKind, u64);
2719
2720#[cfg(feature = "eval-support")]
2721#[derive(Debug, Clone, Copy, PartialEq)]
2722pub enum EvalCacheEntryKind {
2723 Context,
2724 Search,
2725 Prediction,
2726}
2727
2728#[cfg(feature = "eval-support")]
2729impl std::fmt::Display for EvalCacheEntryKind {
2730 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2731 match self {
2732 EvalCacheEntryKind::Search => write!(f, "search"),
2733 EvalCacheEntryKind::Context => write!(f, "context"),
2734 EvalCacheEntryKind::Prediction => write!(f, "prediction"),
2735 }
2736 }
2737}
2738
2739#[cfg(feature = "eval-support")]
2740pub trait EvalCache: Send + Sync {
2741 fn read(&self, key: EvalCacheKey) -> Option<String>;
2742 fn write(&self, key: EvalCacheKey, input: &str, value: &str);
2743}
2744
2745#[derive(Debug, Clone, Copy)]
2746pub enum DataCollectionChoice {
2747 NotAnswered,
2748 Enabled,
2749 Disabled,
2750}
2751
2752impl DataCollectionChoice {
2753 pub fn is_enabled(self) -> bool {
2754 match self {
2755 Self::Enabled => true,
2756 Self::NotAnswered | Self::Disabled => false,
2757 }
2758 }
2759
2760 pub fn is_answered(self) -> bool {
2761 match self {
2762 Self::Enabled | Self::Disabled => true,
2763 Self::NotAnswered => false,
2764 }
2765 }
2766
2767 #[must_use]
2768 pub fn toggle(&self) -> DataCollectionChoice {
2769 match self {
2770 Self::Enabled => Self::Disabled,
2771 Self::Disabled => Self::Enabled,
2772 Self::NotAnswered => Self::Enabled,
2773 }
2774 }
2775}
2776
2777impl From<bool> for DataCollectionChoice {
2778 fn from(value: bool) -> Self {
2779 match value {
2780 true => DataCollectionChoice::Enabled,
2781 false => DataCollectionChoice::Disabled,
2782 }
2783 }
2784}
2785
2786struct ZedPredictUpsell;
2787
2788impl Dismissable for ZedPredictUpsell {
2789 const KEY: &'static str = "dismissed-edit-predict-upsell";
2790
2791 fn dismissed() -> bool {
2792 // To make this backwards compatible with older versions of Zed, we
2793 // check if the user has seen the previous Edit Prediction Onboarding
2794 // before, by checking the data collection choice which was written to
2795 // the database once the user clicked on "Accept and Enable"
2796 if KEY_VALUE_STORE
2797 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2798 .log_err()
2799 .is_some_and(|s| s.is_some())
2800 {
2801 return true;
2802 }
2803
2804 KEY_VALUE_STORE
2805 .read_kvp(Self::KEY)
2806 .log_err()
2807 .is_some_and(|s| s.is_some())
2808 }
2809}
2810
2811pub fn should_show_upsell_modal() -> bool {
2812 !ZedPredictUpsell::dismissed()
2813}
2814
2815pub fn init(cx: &mut App) {
2816 feature_gate_predict_edits_actions(cx);
2817
2818 cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2819 workspace.register_action(|workspace, _: &RateCompletions, window, cx| {
2820 if cx.has_flag::<PredictEditsRateCompletionsFeatureFlag>() {
2821 RatePredictionsModal::toggle(workspace, window, cx);
2822 }
2823 });
2824
2825 workspace.register_action(
2826 move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2827 ZedPredictModal::toggle(
2828 workspace,
2829 workspace.user_store().clone(),
2830 workspace.client().clone(),
2831 window,
2832 cx,
2833 )
2834 },
2835 );
2836
2837 workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2838 update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2839 settings
2840 .project
2841 .all_languages
2842 .features
2843 .get_or_insert_default()
2844 .edit_prediction_provider = Some(EditPredictionProvider::None)
2845 });
2846 });
2847 })
2848 .detach();
2849}
2850
2851fn feature_gate_predict_edits_actions(cx: &mut App) {
2852 let rate_completion_action_types = [TypeId::of::<RateCompletions>()];
2853 let reset_onboarding_action_types = [TypeId::of::<ResetOnboarding>()];
2854 let zeta_all_action_types = [
2855 TypeId::of::<RateCompletions>(),
2856 TypeId::of::<ResetOnboarding>(),
2857 zed_actions::OpenZedPredictOnboarding.type_id(),
2858 TypeId::of::<ClearHistory>(),
2859 TypeId::of::<ThumbsUpActivePrediction>(),
2860 TypeId::of::<ThumbsDownActivePrediction>(),
2861 TypeId::of::<NextEdit>(),
2862 TypeId::of::<PreviousEdit>(),
2863 ];
2864
2865 CommandPaletteFilter::update_global(cx, |filter, _cx| {
2866 filter.hide_action_types(&rate_completion_action_types);
2867 filter.hide_action_types(&reset_onboarding_action_types);
2868 filter.hide_action_types(&[zed_actions::OpenZedPredictOnboarding.type_id()]);
2869 });
2870
2871 cx.observe_global::<SettingsStore>(move |cx| {
2872 let is_ai_disabled = DisableAiSettings::get_global(cx).disable_ai;
2873 let has_feature_flag = cx.has_flag::<PredictEditsRateCompletionsFeatureFlag>();
2874
2875 CommandPaletteFilter::update_global(cx, |filter, _cx| {
2876 if is_ai_disabled {
2877 filter.hide_action_types(&zeta_all_action_types);
2878 } else if has_feature_flag {
2879 filter.show_action_types(&rate_completion_action_types);
2880 } else {
2881 filter.hide_action_types(&rate_completion_action_types);
2882 }
2883 });
2884 })
2885 .detach();
2886
2887 cx.observe_flag::<PredictEditsRateCompletionsFeatureFlag, _>(move |is_enabled, cx| {
2888 if !DisableAiSettings::get_global(cx).disable_ai {
2889 if is_enabled {
2890 CommandPaletteFilter::update_global(cx, |filter, _cx| {
2891 filter.show_action_types(&rate_completion_action_types);
2892 });
2893 } else {
2894 CommandPaletteFilter::update_global(cx, |filter, _cx| {
2895 filter.hide_action_types(&rate_completion_action_types);
2896 });
2897 }
2898 }
2899 })
2900 .detach();
2901}
2902
2903#[cfg(test)]
2904mod tests {
2905 use std::{path::Path, sync::Arc};
2906
2907 use client::UserStore;
2908 use clock::FakeSystemClock;
2909 use cloud_llm_client::{
2910 EditPredictionRejectReason, EditPredictionRejection, RejectEditPredictionsBody,
2911 };
2912 use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
2913 use futures::{
2914 AsyncReadExt, StreamExt,
2915 channel::{mpsc, oneshot},
2916 };
2917 use gpui::{
2918 Entity, TestAppContext,
2919 http_client::{FakeHttpClient, Response},
2920 prelude::*,
2921 };
2922 use indoc::indoc;
2923 use language::OffsetRangeExt as _;
2924 use open_ai::Usage;
2925 use pretty_assertions::{assert_eq, assert_matches};
2926 use project::{FakeFs, Project};
2927 use serde_json::json;
2928 use settings::SettingsStore;
2929 use util::path;
2930 use uuid::Uuid;
2931
2932 use crate::{BufferEditPrediction, Zeta};
2933
2934 #[gpui::test]
2935 async fn test_current_state(cx: &mut TestAppContext) {
2936 let (zeta, mut requests) = init_test(cx);
2937 let fs = FakeFs::new(cx.executor());
2938 fs.insert_tree(
2939 "/root",
2940 json!({
2941 "1.txt": "Hello!\nHow\nBye\n",
2942 "2.txt": "Hola!\nComo\nAdios\n"
2943 }),
2944 )
2945 .await;
2946 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2947
2948 zeta.update(cx, |zeta, cx| {
2949 zeta.register_project(&project, cx);
2950 });
2951
2952 let buffer1 = project
2953 .update(cx, |project, cx| {
2954 let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
2955 project.open_buffer(path, cx)
2956 })
2957 .await
2958 .unwrap();
2959 let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
2960 let position = snapshot1.anchor_before(language::Point::new(1, 3));
2961
2962 // Prediction for current file
2963
2964 zeta.update(cx, |zeta, cx| {
2965 zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
2966 });
2967 let (_request, respond_tx) = requests.predict.next().await.unwrap();
2968
2969 respond_tx
2970 .send(model_response(indoc! {r"
2971 --- a/root/1.txt
2972 +++ b/root/1.txt
2973 @@ ... @@
2974 Hello!
2975 -How
2976 +How are you?
2977 Bye
2978 "}))
2979 .unwrap();
2980
2981 cx.run_until_parked();
2982
2983 zeta.read_with(cx, |zeta, cx| {
2984 let prediction = zeta
2985 .current_prediction_for_buffer(&buffer1, &project, cx)
2986 .unwrap();
2987 assert_matches!(prediction, BufferEditPrediction::Local { .. });
2988 });
2989
2990 // Context refresh
2991 let refresh_task = zeta.update(cx, |zeta, cx| {
2992 zeta.refresh_context(project.clone(), buffer1.clone(), position, cx)
2993 });
2994 let (_request, respond_tx) = requests.predict.next().await.unwrap();
2995 respond_tx
2996 .send(open_ai::Response {
2997 id: Uuid::new_v4().to_string(),
2998 object: "response".into(),
2999 created: 0,
3000 model: "model".into(),
3001 choices: vec![open_ai::Choice {
3002 index: 0,
3003 message: open_ai::RequestMessage::Assistant {
3004 content: None,
3005 tool_calls: vec![open_ai::ToolCall {
3006 id: "search".into(),
3007 content: open_ai::ToolCallContent::Function {
3008 function: open_ai::FunctionContent {
3009 name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME
3010 .to_string(),
3011 arguments: serde_json::to_string(&SearchToolInput {
3012 queries: Box::new([SearchToolQuery {
3013 glob: "root/2.txt".to_string(),
3014 syntax_node: vec![],
3015 content: Some(".".into()),
3016 }]),
3017 })
3018 .unwrap(),
3019 },
3020 },
3021 }],
3022 },
3023 finish_reason: None,
3024 }],
3025 usage: Usage {
3026 prompt_tokens: 0,
3027 completion_tokens: 0,
3028 total_tokens: 0,
3029 },
3030 })
3031 .unwrap();
3032 refresh_task.await.unwrap();
3033
3034 zeta.update(cx, |zeta, cx| {
3035 zeta.reject_current_prediction(EditPredictionRejectReason::Discarded, &project, cx);
3036 });
3037
3038 // Prediction for another file
3039 zeta.update(cx, |zeta, cx| {
3040 zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
3041 });
3042 let (_request, respond_tx) = requests.predict.next().await.unwrap();
3043 respond_tx
3044 .send(model_response(indoc! {r#"
3045 --- a/root/2.txt
3046 +++ b/root/2.txt
3047 Hola!
3048 -Como
3049 +Como estas?
3050 Adios
3051 "#}))
3052 .unwrap();
3053 cx.run_until_parked();
3054
3055 zeta.read_with(cx, |zeta, cx| {
3056 let prediction = zeta
3057 .current_prediction_for_buffer(&buffer1, &project, cx)
3058 .unwrap();
3059 assert_matches!(
3060 prediction,
3061 BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
3062 );
3063 });
3064
3065 let buffer2 = project
3066 .update(cx, |project, cx| {
3067 let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
3068 project.open_buffer(path, cx)
3069 })
3070 .await
3071 .unwrap();
3072
3073 zeta.read_with(cx, |zeta, cx| {
3074 let prediction = zeta
3075 .current_prediction_for_buffer(&buffer2, &project, cx)
3076 .unwrap();
3077 assert_matches!(prediction, BufferEditPrediction::Local { .. });
3078 });
3079 }
3080
3081 #[gpui::test]
3082 async fn test_simple_request(cx: &mut TestAppContext) {
3083 let (zeta, mut requests) = init_test(cx);
3084 let fs = FakeFs::new(cx.executor());
3085 fs.insert_tree(
3086 "/root",
3087 json!({
3088 "foo.md": "Hello!\nHow\nBye\n"
3089 }),
3090 )
3091 .await;
3092 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3093
3094 let buffer = project
3095 .update(cx, |project, cx| {
3096 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3097 project.open_buffer(path, cx)
3098 })
3099 .await
3100 .unwrap();
3101 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3102 let position = snapshot.anchor_before(language::Point::new(1, 3));
3103
3104 let prediction_task = zeta.update(cx, |zeta, cx| {
3105 zeta.request_prediction(&project, &buffer, position, Default::default(), cx)
3106 });
3107
3108 let (_, respond_tx) = requests.predict.next().await.unwrap();
3109
3110 // TODO Put back when we have a structured request again
3111 // assert_eq!(
3112 // request.excerpt_path.as_ref(),
3113 // Path::new(path!("root/foo.md"))
3114 // );
3115 // assert_eq!(
3116 // request.cursor_point,
3117 // Point {
3118 // line: Line(1),
3119 // column: 3
3120 // }
3121 // );
3122
3123 respond_tx
3124 .send(model_response(indoc! { r"
3125 --- a/root/foo.md
3126 +++ b/root/foo.md
3127 @@ ... @@
3128 Hello!
3129 -How
3130 +How are you?
3131 Bye
3132 "}))
3133 .unwrap();
3134
3135 let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
3136
3137 assert_eq!(prediction.edits.len(), 1);
3138 assert_eq!(
3139 prediction.edits[0].0.to_point(&snapshot).start,
3140 language::Point::new(1, 3)
3141 );
3142 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
3143 }
3144
3145 #[gpui::test]
3146 async fn test_request_events(cx: &mut TestAppContext) {
3147 let (zeta, mut requests) = init_test(cx);
3148 let fs = FakeFs::new(cx.executor());
3149 fs.insert_tree(
3150 "/root",
3151 json!({
3152 "foo.md": "Hello!\n\nBye\n"
3153 }),
3154 )
3155 .await;
3156 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3157
3158 let buffer = project
3159 .update(cx, |project, cx| {
3160 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3161 project.open_buffer(path, cx)
3162 })
3163 .await
3164 .unwrap();
3165
3166 zeta.update(cx, |zeta, cx| {
3167 zeta.register_buffer(&buffer, &project, cx);
3168 });
3169
3170 buffer.update(cx, |buffer, cx| {
3171 buffer.edit(vec![(7..7, "How")], None, cx);
3172 });
3173
3174 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3175 let position = snapshot.anchor_before(language::Point::new(1, 3));
3176
3177 let prediction_task = zeta.update(cx, |zeta, cx| {
3178 zeta.request_prediction(&project, &buffer, position, Default::default(), cx)
3179 });
3180
3181 let (request, respond_tx) = requests.predict.next().await.unwrap();
3182
3183 let prompt = prompt_from_request(&request);
3184 assert!(
3185 prompt.contains(indoc! {"
3186 --- a/root/foo.md
3187 +++ b/root/foo.md
3188 @@ -1,3 +1,3 @@
3189 Hello!
3190 -
3191 +How
3192 Bye
3193 "}),
3194 "{prompt}"
3195 );
3196
3197 respond_tx
3198 .send(model_response(indoc! {r#"
3199 --- a/root/foo.md
3200 +++ b/root/foo.md
3201 @@ ... @@
3202 Hello!
3203 -How
3204 +How are you?
3205 Bye
3206 "#}))
3207 .unwrap();
3208
3209 let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
3210
3211 assert_eq!(prediction.edits.len(), 1);
3212 assert_eq!(
3213 prediction.edits[0].0.to_point(&snapshot).start,
3214 language::Point::new(1, 3)
3215 );
3216 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
3217 }
3218
3219 #[gpui::test]
3220 async fn test_empty_prediction(cx: &mut TestAppContext) {
3221 let (zeta, mut requests) = init_test(cx);
3222 let fs = FakeFs::new(cx.executor());
3223 fs.insert_tree(
3224 "/root",
3225 json!({
3226 "foo.md": "Hello!\nHow\nBye\n"
3227 }),
3228 )
3229 .await;
3230 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3231
3232 let buffer = project
3233 .update(cx, |project, cx| {
3234 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3235 project.open_buffer(path, cx)
3236 })
3237 .await
3238 .unwrap();
3239 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3240 let position = snapshot.anchor_before(language::Point::new(1, 3));
3241
3242 zeta.update(cx, |zeta, cx| {
3243 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3244 });
3245
3246 const NO_OP_DIFF: &str = indoc! { r"
3247 --- a/root/foo.md
3248 +++ b/root/foo.md
3249 @@ ... @@
3250 Hello!
3251 -How
3252 +How
3253 Bye
3254 "};
3255
3256 let (_, respond_tx) = requests.predict.next().await.unwrap();
3257 let response = model_response(NO_OP_DIFF);
3258 let id = response.id.clone();
3259 respond_tx.send(response).unwrap();
3260
3261 cx.run_until_parked();
3262
3263 zeta.read_with(cx, |zeta, cx| {
3264 assert!(
3265 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3266 .is_none()
3267 );
3268 });
3269
3270 // prediction is reported as rejected
3271 let (reject_request, _) = requests.reject.next().await.unwrap();
3272
3273 assert_eq!(
3274 &reject_request.rejections,
3275 &[EditPredictionRejection {
3276 request_id: id,
3277 reason: EditPredictionRejectReason::Empty,
3278 was_shown: false
3279 }]
3280 );
3281 }
3282
3283 #[gpui::test]
3284 async fn test_interpolated_empty(cx: &mut TestAppContext) {
3285 let (zeta, mut requests) = init_test(cx);
3286 let fs = FakeFs::new(cx.executor());
3287 fs.insert_tree(
3288 "/root",
3289 json!({
3290 "foo.md": "Hello!\nHow\nBye\n"
3291 }),
3292 )
3293 .await;
3294 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3295
3296 let buffer = project
3297 .update(cx, |project, cx| {
3298 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3299 project.open_buffer(path, cx)
3300 })
3301 .await
3302 .unwrap();
3303 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3304 let position = snapshot.anchor_before(language::Point::new(1, 3));
3305
3306 zeta.update(cx, |zeta, cx| {
3307 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3308 });
3309
3310 let (_, respond_tx) = requests.predict.next().await.unwrap();
3311
3312 buffer.update(cx, |buffer, cx| {
3313 buffer.set_text("Hello!\nHow are you?\nBye", cx);
3314 });
3315
3316 let response = model_response(SIMPLE_DIFF);
3317 let id = response.id.clone();
3318 respond_tx.send(response).unwrap();
3319
3320 cx.run_until_parked();
3321
3322 zeta.read_with(cx, |zeta, cx| {
3323 assert!(
3324 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3325 .is_none()
3326 );
3327 });
3328
3329 // prediction is reported as rejected
3330 let (reject_request, _) = requests.reject.next().await.unwrap();
3331
3332 assert_eq!(
3333 &reject_request.rejections,
3334 &[EditPredictionRejection {
3335 request_id: id,
3336 reason: EditPredictionRejectReason::InterpolatedEmpty,
3337 was_shown: false
3338 }]
3339 );
3340 }
3341
3342 const SIMPLE_DIFF: &str = indoc! { r"
3343 --- a/root/foo.md
3344 +++ b/root/foo.md
3345 @@ ... @@
3346 Hello!
3347 -How
3348 +How are you?
3349 Bye
3350 "};
3351
3352 #[gpui::test]
3353 async fn test_replace_current(cx: &mut TestAppContext) {
3354 let (zeta, mut requests) = init_test(cx);
3355 let fs = FakeFs::new(cx.executor());
3356 fs.insert_tree(
3357 "/root",
3358 json!({
3359 "foo.md": "Hello!\nHow\nBye\n"
3360 }),
3361 )
3362 .await;
3363 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3364
3365 let buffer = project
3366 .update(cx, |project, cx| {
3367 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3368 project.open_buffer(path, cx)
3369 })
3370 .await
3371 .unwrap();
3372 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3373 let position = snapshot.anchor_before(language::Point::new(1, 3));
3374
3375 zeta.update(cx, |zeta, cx| {
3376 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3377 });
3378
3379 let (_, respond_tx) = requests.predict.next().await.unwrap();
3380 let first_response = model_response(SIMPLE_DIFF);
3381 let first_id = first_response.id.clone();
3382 respond_tx.send(first_response).unwrap();
3383
3384 cx.run_until_parked();
3385
3386 zeta.read_with(cx, |zeta, cx| {
3387 assert_eq!(
3388 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3389 .unwrap()
3390 .id
3391 .0,
3392 first_id
3393 );
3394 });
3395
3396 // a second request is triggered
3397 zeta.update(cx, |zeta, cx| {
3398 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3399 });
3400
3401 let (_, respond_tx) = requests.predict.next().await.unwrap();
3402 let second_response = model_response(SIMPLE_DIFF);
3403 let second_id = second_response.id.clone();
3404 respond_tx.send(second_response).unwrap();
3405
3406 cx.run_until_parked();
3407
3408 zeta.read_with(cx, |zeta, cx| {
3409 // second replaces first
3410 assert_eq!(
3411 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3412 .unwrap()
3413 .id
3414 .0,
3415 second_id
3416 );
3417 });
3418
3419 // first is reported as replaced
3420 let (reject_request, _) = requests.reject.next().await.unwrap();
3421
3422 assert_eq!(
3423 &reject_request.rejections,
3424 &[EditPredictionRejection {
3425 request_id: first_id,
3426 reason: EditPredictionRejectReason::Replaced,
3427 was_shown: false
3428 }]
3429 );
3430 }
3431
3432 #[gpui::test]
3433 async fn test_current_preferred(cx: &mut TestAppContext) {
3434 let (zeta, mut requests) = init_test(cx);
3435 let fs = FakeFs::new(cx.executor());
3436 fs.insert_tree(
3437 "/root",
3438 json!({
3439 "foo.md": "Hello!\nHow\nBye\n"
3440 }),
3441 )
3442 .await;
3443 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3444
3445 let buffer = project
3446 .update(cx, |project, cx| {
3447 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3448 project.open_buffer(path, cx)
3449 })
3450 .await
3451 .unwrap();
3452 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3453 let position = snapshot.anchor_before(language::Point::new(1, 3));
3454
3455 zeta.update(cx, |zeta, cx| {
3456 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3457 });
3458
3459 let (_, respond_tx) = requests.predict.next().await.unwrap();
3460 let first_response = model_response(SIMPLE_DIFF);
3461 let first_id = first_response.id.clone();
3462 respond_tx.send(first_response).unwrap();
3463
3464 cx.run_until_parked();
3465
3466 zeta.read_with(cx, |zeta, cx| {
3467 assert_eq!(
3468 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3469 .unwrap()
3470 .id
3471 .0,
3472 first_id
3473 );
3474 });
3475
3476 // a second request is triggered
3477 zeta.update(cx, |zeta, cx| {
3478 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3479 });
3480
3481 let (_, respond_tx) = requests.predict.next().await.unwrap();
3482 // worse than current prediction
3483 let second_response = model_response(indoc! { r"
3484 --- a/root/foo.md
3485 +++ b/root/foo.md
3486 @@ ... @@
3487 Hello!
3488 -How
3489 +How are
3490 Bye
3491 "});
3492 let second_id = second_response.id.clone();
3493 respond_tx.send(second_response).unwrap();
3494
3495 cx.run_until_parked();
3496
3497 zeta.read_with(cx, |zeta, cx| {
3498 // first is preferred over second
3499 assert_eq!(
3500 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3501 .unwrap()
3502 .id
3503 .0,
3504 first_id
3505 );
3506 });
3507
3508 // second is reported as rejected
3509 let (reject_request, _) = requests.reject.next().await.unwrap();
3510
3511 assert_eq!(
3512 &reject_request.rejections,
3513 &[EditPredictionRejection {
3514 request_id: second_id,
3515 reason: EditPredictionRejectReason::CurrentPreferred,
3516 was_shown: false
3517 }]
3518 );
3519 }
3520
3521 #[gpui::test]
3522 async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
3523 let (zeta, mut requests) = init_test(cx);
3524 let fs = FakeFs::new(cx.executor());
3525 fs.insert_tree(
3526 "/root",
3527 json!({
3528 "foo.md": "Hello!\nHow\nBye\n"
3529 }),
3530 )
3531 .await;
3532 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3533
3534 let buffer = project
3535 .update(cx, |project, cx| {
3536 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3537 project.open_buffer(path, cx)
3538 })
3539 .await
3540 .unwrap();
3541 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3542 let position = snapshot.anchor_before(language::Point::new(1, 3));
3543
3544 zeta.update(cx, |zeta, cx| {
3545 // start two refresh tasks
3546 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3547
3548 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3549 });
3550
3551 let (_, respond_first) = requests.predict.next().await.unwrap();
3552 let (_, respond_second) = requests.predict.next().await.unwrap();
3553
3554 // wait for throttle
3555 cx.run_until_parked();
3556
3557 // second responds first
3558 let second_response = model_response(SIMPLE_DIFF);
3559 let second_id = second_response.id.clone();
3560 respond_second.send(second_response).unwrap();
3561
3562 cx.run_until_parked();
3563
3564 zeta.read_with(cx, |zeta, cx| {
3565 // current prediction is second
3566 assert_eq!(
3567 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3568 .unwrap()
3569 .id
3570 .0,
3571 second_id
3572 );
3573 });
3574
3575 let first_response = model_response(SIMPLE_DIFF);
3576 let first_id = first_response.id.clone();
3577 respond_first.send(first_response).unwrap();
3578
3579 cx.run_until_parked();
3580
3581 zeta.read_with(cx, |zeta, cx| {
3582 // current prediction is still second, since first was cancelled
3583 assert_eq!(
3584 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3585 .unwrap()
3586 .id
3587 .0,
3588 second_id
3589 );
3590 });
3591
3592 // first is reported as rejected
3593 let (reject_request, _) = requests.reject.next().await.unwrap();
3594
3595 cx.run_until_parked();
3596
3597 assert_eq!(
3598 &reject_request.rejections,
3599 &[EditPredictionRejection {
3600 request_id: first_id,
3601 reason: EditPredictionRejectReason::Canceled,
3602 was_shown: false
3603 }]
3604 );
3605 }
3606
3607 #[gpui::test]
3608 async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
3609 let (zeta, mut requests) = init_test(cx);
3610 let fs = FakeFs::new(cx.executor());
3611 fs.insert_tree(
3612 "/root",
3613 json!({
3614 "foo.md": "Hello!\nHow\nBye\n"
3615 }),
3616 )
3617 .await;
3618 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3619
3620 let buffer = project
3621 .update(cx, |project, cx| {
3622 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3623 project.open_buffer(path, cx)
3624 })
3625 .await
3626 .unwrap();
3627 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3628 let position = snapshot.anchor_before(language::Point::new(1, 3));
3629
3630 zeta.update(cx, |zeta, cx| {
3631 // start two refresh tasks
3632 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3633 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3634 });
3635
3636 // wait for throttle, so requests are sent
3637 cx.run_until_parked();
3638
3639 let (_, respond_first) = requests.predict.next().await.unwrap();
3640 let (_, respond_second) = requests.predict.next().await.unwrap();
3641
3642 zeta.update(cx, |zeta, cx| {
3643 // start a third request
3644 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3645
3646 // 2 are pending, so 2nd is cancelled
3647 assert_eq!(
3648 zeta.get_or_init_zeta_project(&project, cx)
3649 .cancelled_predictions
3650 .iter()
3651 .copied()
3652 .collect::<Vec<_>>(),
3653 [1]
3654 );
3655 });
3656
3657 // wait for throttle
3658 cx.run_until_parked();
3659
3660 let (_, respond_third) = requests.predict.next().await.unwrap();
3661
3662 let first_response = model_response(SIMPLE_DIFF);
3663 let first_id = first_response.id.clone();
3664 respond_first.send(first_response).unwrap();
3665
3666 cx.run_until_parked();
3667
3668 zeta.read_with(cx, |zeta, cx| {
3669 // current prediction is first
3670 assert_eq!(
3671 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3672 .unwrap()
3673 .id
3674 .0,
3675 first_id
3676 );
3677 });
3678
3679 let cancelled_response = model_response(SIMPLE_DIFF);
3680 let cancelled_id = cancelled_response.id.clone();
3681 respond_second.send(cancelled_response).unwrap();
3682
3683 cx.run_until_parked();
3684
3685 zeta.read_with(cx, |zeta, cx| {
3686 // current prediction is still first, since second was cancelled
3687 assert_eq!(
3688 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3689 .unwrap()
3690 .id
3691 .0,
3692 first_id
3693 );
3694 });
3695
3696 let third_response = model_response(SIMPLE_DIFF);
3697 let third_response_id = third_response.id.clone();
3698 respond_third.send(third_response).unwrap();
3699
3700 cx.run_until_parked();
3701
3702 zeta.read_with(cx, |zeta, cx| {
3703 // third completes and replaces first
3704 assert_eq!(
3705 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3706 .unwrap()
3707 .id
3708 .0,
3709 third_response_id
3710 );
3711 });
3712
3713 // second is reported as rejected
3714 let (reject_request, _) = requests.reject.next().await.unwrap();
3715
3716 cx.run_until_parked();
3717
3718 assert_eq!(
3719 &reject_request.rejections,
3720 &[
3721 EditPredictionRejection {
3722 request_id: cancelled_id,
3723 reason: EditPredictionRejectReason::Canceled,
3724 was_shown: false
3725 },
3726 EditPredictionRejection {
3727 request_id: first_id,
3728 reason: EditPredictionRejectReason::Replaced,
3729 was_shown: false
3730 }
3731 ]
3732 );
3733 }
3734
3735 // Skipped until we start including diagnostics in prompt
3736 // #[gpui::test]
3737 // async fn test_request_diagnostics(cx: &mut TestAppContext) {
3738 // let (zeta, mut req_rx) = init_test(cx);
3739 // let fs = FakeFs::new(cx.executor());
3740 // fs.insert_tree(
3741 // "/root",
3742 // json!({
3743 // "foo.md": "Hello!\nBye"
3744 // }),
3745 // )
3746 // .await;
3747 // let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3748
3749 // let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
3750 // let diagnostic = lsp::Diagnostic {
3751 // range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
3752 // severity: Some(lsp::DiagnosticSeverity::ERROR),
3753 // message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
3754 // ..Default::default()
3755 // };
3756
3757 // project.update(cx, |project, cx| {
3758 // project.lsp_store().update(cx, |lsp_store, cx| {
3759 // // Create some diagnostics
3760 // lsp_store
3761 // .update_diagnostics(
3762 // LanguageServerId(0),
3763 // lsp::PublishDiagnosticsParams {
3764 // uri: path_to_buffer_uri.clone(),
3765 // diagnostics: vec![diagnostic],
3766 // version: None,
3767 // },
3768 // None,
3769 // language::DiagnosticSourceKind::Pushed,
3770 // &[],
3771 // cx,
3772 // )
3773 // .unwrap();
3774 // });
3775 // });
3776
3777 // let buffer = project
3778 // .update(cx, |project, cx| {
3779 // let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3780 // project.open_buffer(path, cx)
3781 // })
3782 // .await
3783 // .unwrap();
3784
3785 // let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3786 // let position = snapshot.anchor_before(language::Point::new(0, 0));
3787
3788 // let _prediction_task = zeta.update(cx, |zeta, cx| {
3789 // zeta.request_prediction(&project, &buffer, position, cx)
3790 // });
3791
3792 // let (request, _respond_tx) = req_rx.next().await.unwrap();
3793
3794 // assert_eq!(request.diagnostic_groups.len(), 1);
3795 // let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
3796 // .unwrap();
3797 // // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
3798 // assert_eq!(
3799 // value,
3800 // json!({
3801 // "entries": [{
3802 // "range": {
3803 // "start": 8,
3804 // "end": 10
3805 // },
3806 // "diagnostic": {
3807 // "source": null,
3808 // "code": null,
3809 // "code_description": null,
3810 // "severity": 1,
3811 // "message": "\"Hello\" deprecated. Use \"Hi\" instead",
3812 // "markdown": null,
3813 // "group_id": 0,
3814 // "is_primary": true,
3815 // "is_disk_based": false,
3816 // "is_unnecessary": false,
3817 // "source_kind": "Pushed",
3818 // "data": null,
3819 // "underline": true
3820 // }
3821 // }],
3822 // "primary_ix": 0
3823 // })
3824 // );
3825 // }
3826
3827 fn model_response(text: &str) -> open_ai::Response {
3828 open_ai::Response {
3829 id: Uuid::new_v4().to_string(),
3830 object: "response".into(),
3831 created: 0,
3832 model: "model".into(),
3833 choices: vec![open_ai::Choice {
3834 index: 0,
3835 message: open_ai::RequestMessage::Assistant {
3836 content: Some(open_ai::MessageContent::Plain(text.to_string())),
3837 tool_calls: vec![],
3838 },
3839 finish_reason: None,
3840 }],
3841 usage: Usage {
3842 prompt_tokens: 0,
3843 completion_tokens: 0,
3844 total_tokens: 0,
3845 },
3846 }
3847 }
3848
3849 fn prompt_from_request(request: &open_ai::Request) -> &str {
3850 assert_eq!(request.messages.len(), 1);
3851 let open_ai::RequestMessage::User {
3852 content: open_ai::MessageContent::Plain(content),
3853 ..
3854 } = &request.messages[0]
3855 else {
3856 panic!(
3857 "Request does not have single user message of type Plain. {:#?}",
3858 request
3859 );
3860 };
3861 content
3862 }
3863
3864 struct RequestChannels {
3865 predict: mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender<open_ai::Response>)>,
3866 reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>,
3867 }
3868
3869 fn init_test(cx: &mut TestAppContext) -> (Entity<Zeta>, RequestChannels) {
3870 cx.update(move |cx| {
3871 let settings_store = SettingsStore::test(cx);
3872 cx.set_global(settings_store);
3873 zlog::init_test();
3874
3875 let (predict_req_tx, predict_req_rx) = mpsc::unbounded();
3876 let (reject_req_tx, reject_req_rx) = mpsc::unbounded();
3877
3878 let http_client = FakeHttpClient::create({
3879 move |req| {
3880 let uri = req.uri().path().to_string();
3881 let mut body = req.into_body();
3882 let predict_req_tx = predict_req_tx.clone();
3883 let reject_req_tx = reject_req_tx.clone();
3884 async move {
3885 let resp = match uri.as_str() {
3886 "/client/llm_tokens" => serde_json::to_string(&json!({
3887 "token": "test"
3888 }))
3889 .unwrap(),
3890 "/predict_edits/raw" => {
3891 let mut buf = Vec::new();
3892 body.read_to_end(&mut buf).await.ok();
3893 let req = serde_json::from_slice(&buf).unwrap();
3894
3895 let (res_tx, res_rx) = oneshot::channel();
3896 predict_req_tx.unbounded_send((req, res_tx)).unwrap();
3897 serde_json::to_string(&res_rx.await?).unwrap()
3898 }
3899 "/predict_edits/reject" => {
3900 let mut buf = Vec::new();
3901 body.read_to_end(&mut buf).await.ok();
3902 let req = serde_json::from_slice(&buf).unwrap();
3903
3904 let (res_tx, res_rx) = oneshot::channel();
3905 reject_req_tx.unbounded_send((req, res_tx)).unwrap();
3906 serde_json::to_string(&res_rx.await?).unwrap()
3907 }
3908 _ => {
3909 panic!("Unexpected path: {}", uri)
3910 }
3911 };
3912
3913 Ok(Response::builder().body(resp.into()).unwrap())
3914 }
3915 }
3916 });
3917
3918 let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
3919 client.cloud_client().set_credentials(1, "test".into());
3920
3921 language_model::init(client.clone(), cx);
3922
3923 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
3924 let zeta = Zeta::global(&client, &user_store, cx);
3925
3926 (
3927 zeta,
3928 RequestChannels {
3929 predict: predict_req_rx,
3930 reject: reject_req_rx,
3931 },
3932 )
3933 })
3934 }
3935}