1use anyhow::{Context as _, Result, anyhow, bail};
2use arrayvec::ArrayVec;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat, Signature};
5use cloud_llm_client::{
6 AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason,
7 EditPredictionRejection, MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST,
8 MINIMUM_REQUIRED_VERSION_HEADER_NAME, PredictEditsRequestTrigger, RejectEditPredictionsBody,
9 ZED_VERSION_HEADER_NAME,
10};
11use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
12use cloud_zeta2_prompt::{CURSOR_MARKER, DEFAULT_MAX_PROMPT_BYTES};
13use collections::{HashMap, HashSet};
14use command_palette_hooks::CommandPaletteFilter;
15use db::kvp::{Dismissable, KEY_VALUE_STORE};
16use edit_prediction_context::{
17 DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
18 EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionScoreOptions, Line,
19 SyntaxIndex, SyntaxIndexState,
20};
21use feature_flags::{FeatureFlag, FeatureFlagAppExt as _, PredictEditsRateCompletionsFeatureFlag};
22use futures::channel::{mpsc, oneshot};
23use futures::{AsyncReadExt as _, StreamExt as _};
24use gpui::{
25 App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
26 http_client::{self, AsyncBody, Method},
27 prelude::*,
28};
29use language::{
30 Anchor, Buffer, DiagnosticSet, File, LanguageServerId, Point, ToOffset as _, ToPoint,
31};
32use language::{BufferSnapshot, OffsetRangeExt};
33use language_model::{LlmApiToken, RefreshLlmTokenListener};
34use open_ai::FunctionDefinition;
35use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
36use release_channel::AppVersion;
37use semver::Version;
38use serde::de::DeserializeOwned;
39use settings::{EditPredictionProvider, Settings as _, SettingsStore, update_settings_file};
40use std::any::{Any as _, TypeId};
41use std::collections::{VecDeque, hash_map};
42use telemetry_events::EditPredictionRating;
43use workspace::Workspace;
44
45use std::ops::Range;
46use std::path::Path;
47use std::rc::Rc;
48use std::str::FromStr as _;
49use std::sync::{Arc, LazyLock};
50use std::time::{Duration, Instant};
51use std::{env, mem};
52use thiserror::Error;
53use util::rel_path::RelPathBuf;
54use util::{LogErrorFuture, RangeExt as _, ResultExt as _, TryFutureExt};
55use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
56
57pub mod assemble_excerpts;
58mod license_detection;
59mod onboarding_modal;
60mod prediction;
61mod provider;
62mod rate_prediction_modal;
63pub mod retrieval_search;
64mod sweep_ai;
65pub mod udiff;
66mod xml_edits;
67pub mod zeta1;
68
69#[cfg(test)]
70mod zeta_tests;
71
72use crate::assemble_excerpts::assemble_excerpts;
73use crate::license_detection::LicenseDetectionWatcher;
74use crate::onboarding_modal::ZedPredictModal;
75pub use crate::prediction::EditPrediction;
76pub use crate::prediction::EditPredictionId;
77pub use crate::prediction::EditPredictionInputs;
78use crate::prediction::EditPredictionResult;
79use crate::rate_prediction_modal::{
80 NextEdit, PreviousEdit, RatePredictionsModal, ThumbsDownActivePrediction,
81 ThumbsUpActivePrediction,
82};
83use crate::sweep_ai::SweepAi;
84use crate::zeta1::request_prediction_with_zeta1;
85pub use provider::ZetaEditPredictionProvider;
86
87actions!(
88 edit_prediction,
89 [
90 /// Resets the edit prediction onboarding state.
91 ResetOnboarding,
92 /// Opens the rate completions modal.
93 RateCompletions,
94 /// Clears the edit prediction history.
95 ClearHistory,
96 ]
97);
98
99/// Maximum number of events to track.
100const EVENT_COUNT_MAX: usize = 6;
101const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
102const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
103
104pub struct SweepFeatureFlag;
105
106impl FeatureFlag for SweepFeatureFlag {
107 const NAME: &str = "sweep-ai";
108}
109pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
110 max_bytes: 512,
111 min_bytes: 128,
112 target_before_cursor_over_total_bytes: 0.5,
113};
114
115pub const DEFAULT_CONTEXT_OPTIONS: ContextMode =
116 ContextMode::Agentic(DEFAULT_AGENTIC_CONTEXT_OPTIONS);
117
118pub const DEFAULT_AGENTIC_CONTEXT_OPTIONS: AgenticContextOptions = AgenticContextOptions {
119 excerpt: DEFAULT_EXCERPT_OPTIONS,
120};
121
122pub const DEFAULT_SYNTAX_CONTEXT_OPTIONS: EditPredictionContextOptions =
123 EditPredictionContextOptions {
124 use_imports: true,
125 max_retrieved_declarations: 0,
126 excerpt: DEFAULT_EXCERPT_OPTIONS,
127 score: EditPredictionScoreOptions {
128 omit_excerpt_overlaps: true,
129 },
130 };
131
132pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
133 context: DEFAULT_CONTEXT_OPTIONS,
134 max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
135 max_diagnostic_bytes: 2048,
136 prompt_format: PromptFormat::DEFAULT,
137 file_indexing_parallelism: 1,
138 buffer_change_grouping_interval: Duration::from_secs(1),
139};
140
141static USE_OLLAMA: LazyLock<bool> =
142 LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
143static CONTEXT_RETRIEVAL_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
144 env::var("ZED_ZETA2_CONTEXT_MODEL").unwrap_or(if *USE_OLLAMA {
145 "qwen3-coder:30b".to_string()
146 } else {
147 "yqvev8r3".to_string()
148 })
149});
150static EDIT_PREDICTIONS_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
151 match env::var("ZED_ZETA2_MODEL").as_deref() {
152 Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten
153 Ok(model) => model,
154 Err(_) if *USE_OLLAMA => "qwen3-coder:30b",
155 Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten
156 }
157 .to_string()
158});
159static PREDICT_EDITS_URL: LazyLock<Option<String>> = LazyLock::new(|| {
160 env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| {
161 if *USE_OLLAMA {
162 Some("http://localhost:11434/v1/chat/completions".into())
163 } else {
164 None
165 }
166 })
167});
168
169pub struct Zeta2FeatureFlag;
170
171impl FeatureFlag for Zeta2FeatureFlag {
172 const NAME: &'static str = "zeta2";
173
174 fn enabled_for_staff() -> bool {
175 true
176 }
177}
178
179#[derive(Clone)]
180struct ZetaGlobal(Entity<Zeta>);
181
182impl Global for ZetaGlobal {}
183
184pub struct Zeta {
185 client: Arc<Client>,
186 user_store: Entity<UserStore>,
187 llm_token: LlmApiToken,
188 _llm_token_subscription: Subscription,
189 projects: HashMap<EntityId, ZetaProject>,
190 options: ZetaOptions,
191 update_required: bool,
192 debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
193 #[cfg(feature = "eval-support")]
194 eval_cache: Option<Arc<dyn EvalCache>>,
195 edit_prediction_model: ZetaEditPredictionModel,
196 sweep_ai: SweepAi,
197 data_collection_choice: DataCollectionChoice,
198 rejected_predictions: Vec<EditPredictionRejection>,
199 reject_predictions_tx: mpsc::UnboundedSender<()>,
200 reject_predictions_debounce_task: Option<Task<()>>,
201 shown_predictions: VecDeque<EditPrediction>,
202 rated_predictions: HashSet<EditPredictionId>,
203}
204
205#[derive(Copy, Clone, Default, PartialEq, Eq)]
206pub enum ZetaEditPredictionModel {
207 #[default]
208 Zeta1,
209 Zeta2,
210 Sweep,
211}
212
213#[derive(Debug, Clone, PartialEq)]
214pub struct ZetaOptions {
215 pub context: ContextMode,
216 pub max_prompt_bytes: usize,
217 pub max_diagnostic_bytes: usize,
218 pub prompt_format: predict_edits_v3::PromptFormat,
219 pub file_indexing_parallelism: usize,
220 pub buffer_change_grouping_interval: Duration,
221}
222
223#[derive(Debug, Clone, PartialEq)]
224pub enum ContextMode {
225 Agentic(AgenticContextOptions),
226 Syntax(EditPredictionContextOptions),
227}
228
229#[derive(Debug, Clone, PartialEq)]
230pub struct AgenticContextOptions {
231 pub excerpt: EditPredictionExcerptOptions,
232}
233
234impl ContextMode {
235 pub fn excerpt(&self) -> &EditPredictionExcerptOptions {
236 match self {
237 ContextMode::Agentic(options) => &options.excerpt,
238 ContextMode::Syntax(options) => &options.excerpt,
239 }
240 }
241}
242
243#[derive(Debug)]
244pub enum ZetaDebugInfo {
245 ContextRetrievalStarted(ZetaContextRetrievalStartedDebugInfo),
246 SearchQueriesGenerated(ZetaSearchQueryDebugInfo),
247 SearchQueriesExecuted(ZetaContextRetrievalDebugInfo),
248 ContextRetrievalFinished(ZetaContextRetrievalDebugInfo),
249 EditPredictionRequested(ZetaEditPredictionDebugInfo),
250}
251
252#[derive(Debug)]
253pub struct ZetaContextRetrievalStartedDebugInfo {
254 pub project: Entity<Project>,
255 pub timestamp: Instant,
256 pub search_prompt: String,
257}
258
259#[derive(Debug)]
260pub struct ZetaContextRetrievalDebugInfo {
261 pub project: Entity<Project>,
262 pub timestamp: Instant,
263}
264
265#[derive(Debug)]
266pub struct ZetaEditPredictionDebugInfo {
267 pub inputs: EditPredictionInputs,
268 pub retrieval_time: Duration,
269 pub buffer: WeakEntity<Buffer>,
270 pub position: language::Anchor,
271 pub local_prompt: Result<String, String>,
272 pub response_rx: oneshot::Receiver<(Result<open_ai::Response, String>, Duration)>,
273}
274
275#[derive(Debug)]
276pub struct ZetaSearchQueryDebugInfo {
277 pub project: Entity<Project>,
278 pub timestamp: Instant,
279 pub search_queries: Vec<SearchToolQuery>,
280}
281
282pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
283
284struct ZetaProject {
285 syntax_index: Option<Entity<SyntaxIndex>>,
286 events: VecDeque<Arc<cloud_llm_client::predict_edits_v3::Event>>,
287 last_event: Option<LastEvent>,
288 recent_paths: VecDeque<ProjectPath>,
289 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
290 current_prediction: Option<CurrentEditPrediction>,
291 next_pending_prediction_id: usize,
292 pending_predictions: ArrayVec<PendingPrediction, 2>,
293 last_prediction_refresh: Option<(EntityId, Instant)>,
294 cancelled_predictions: HashSet<usize>,
295 context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
296 refresh_context_task: Option<LogErrorFuture<Task<Result<()>>>>,
297 refresh_context_debounce_task: Option<Task<Option<()>>>,
298 refresh_context_timestamp: Option<Instant>,
299 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
300 _subscription: gpui::Subscription,
301}
302
303impl ZetaProject {
304 pub fn events(&self, cx: &App) -> Vec<Arc<cloud_llm_client::predict_edits_v3::Event>> {
305 self.events
306 .iter()
307 .cloned()
308 .chain(
309 self.last_event
310 .as_ref()
311 .and_then(|event| event.finalize(&self.license_detection_watchers, cx)),
312 )
313 .collect()
314 }
315
316 fn cancel_pending_prediction(
317 &mut self,
318 pending_prediction: PendingPrediction,
319 cx: &mut Context<Zeta>,
320 ) {
321 self.cancelled_predictions.insert(pending_prediction.id);
322
323 cx.spawn(async move |this, cx| {
324 let Some(prediction_id) = pending_prediction.task.await else {
325 return;
326 };
327
328 this.update(cx, |this, cx| {
329 this.reject_prediction(
330 prediction_id,
331 EditPredictionRejectReason::Canceled,
332 false,
333 cx,
334 );
335 })
336 .ok();
337 })
338 .detach()
339 }
340}
341
342#[derive(Debug, Clone)]
343struct CurrentEditPrediction {
344 pub requested_by: PredictionRequestedBy,
345 pub prediction: EditPrediction,
346 pub was_shown: bool,
347}
348
349impl CurrentEditPrediction {
350 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
351 let Some(new_edits) = self
352 .prediction
353 .interpolate(&self.prediction.buffer.read(cx))
354 else {
355 return false;
356 };
357
358 if self.prediction.buffer != old_prediction.prediction.buffer {
359 return true;
360 }
361
362 let Some(old_edits) = old_prediction
363 .prediction
364 .interpolate(&old_prediction.prediction.buffer.read(cx))
365 else {
366 return true;
367 };
368
369 let requested_by_buffer_id = self.requested_by.buffer_id();
370
371 // This reduces the occurrence of UI thrash from replacing edits
372 //
373 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
374 if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
375 && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
376 && old_edits.len() == 1
377 && new_edits.len() == 1
378 {
379 let (old_range, old_text) = &old_edits[0];
380 let (new_range, new_text) = &new_edits[0];
381 new_range == old_range && new_text.starts_with(old_text.as_ref())
382 } else {
383 true
384 }
385 }
386}
387
388#[derive(Debug, Clone)]
389enum PredictionRequestedBy {
390 DiagnosticsUpdate,
391 Buffer(EntityId),
392}
393
394impl PredictionRequestedBy {
395 pub fn buffer_id(&self) -> Option<EntityId> {
396 match self {
397 PredictionRequestedBy::DiagnosticsUpdate => None,
398 PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
399 }
400 }
401}
402
403#[derive(Debug)]
404struct PendingPrediction {
405 id: usize,
406 task: Task<Option<EditPredictionId>>,
407}
408
409/// A prediction from the perspective of a buffer.
410#[derive(Debug)]
411enum BufferEditPrediction<'a> {
412 Local { prediction: &'a EditPrediction },
413 Jump { prediction: &'a EditPrediction },
414}
415
416#[cfg(test)]
417impl std::ops::Deref for BufferEditPrediction<'_> {
418 type Target = EditPrediction;
419
420 fn deref(&self) -> &Self::Target {
421 match self {
422 BufferEditPrediction::Local { prediction } => prediction,
423 BufferEditPrediction::Jump { prediction } => prediction,
424 }
425 }
426}
427
428struct RegisteredBuffer {
429 snapshot: BufferSnapshot,
430 _subscriptions: [gpui::Subscription; 2],
431}
432
433struct LastEvent {
434 old_snapshot: BufferSnapshot,
435 new_snapshot: BufferSnapshot,
436 end_edit_anchor: Option<Anchor>,
437}
438
439impl LastEvent {
440 pub fn finalize(
441 &self,
442 license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
443 cx: &App,
444 ) -> Option<Arc<predict_edits_v3::Event>> {
445 let path = buffer_path_with_id_fallback(&self.new_snapshot, cx);
446 let old_path = buffer_path_with_id_fallback(&self.old_snapshot, cx);
447
448 let file = self.new_snapshot.file();
449 let old_file = self.old_snapshot.file();
450
451 let in_open_source_repo = [file, old_file].iter().all(|file| {
452 file.is_some_and(|file| {
453 license_detection_watchers
454 .get(&file.worktree_id(cx))
455 .is_some_and(|watcher| watcher.is_project_open_source())
456 })
457 });
458
459 let diff = language::unified_diff(&self.old_snapshot.text(), &self.new_snapshot.text());
460
461 if path == old_path && diff.is_empty() {
462 None
463 } else {
464 Some(Arc::new(predict_edits_v3::Event::BufferChange {
465 old_path,
466 path,
467 diff,
468 in_open_source_repo,
469 // TODO: Actually detect if this edit was predicted or not
470 predicted: false,
471 }))
472 }
473 }
474}
475
476fn buffer_path_with_id_fallback(snapshot: &BufferSnapshot, cx: &App) -> Arc<Path> {
477 if let Some(file) = snapshot.file() {
478 file.full_path(cx).into()
479 } else {
480 Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
481 }
482}
483
484impl Zeta {
485 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
486 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
487 }
488
489 pub fn global(
490 client: &Arc<Client>,
491 user_store: &Entity<UserStore>,
492 cx: &mut App,
493 ) -> Entity<Self> {
494 cx.try_global::<ZetaGlobal>()
495 .map(|global| global.0.clone())
496 .unwrap_or_else(|| {
497 let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
498 cx.set_global(ZetaGlobal(zeta.clone()));
499 zeta
500 })
501 }
502
503 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
504 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
505 let data_collection_choice = Self::load_data_collection_choice();
506
507 let (reject_tx, mut reject_rx) = mpsc::unbounded();
508 cx.spawn(async move |this, cx| {
509 while let Some(()) = reject_rx.next().await {
510 this.update(cx, |this, cx| this.flush_rejected_predictions(cx))?
511 .await
512 .log_err();
513 }
514 anyhow::Ok(())
515 })
516 .detach();
517
518 Self {
519 projects: HashMap::default(),
520 client,
521 user_store,
522 options: DEFAULT_OPTIONS,
523 llm_token: LlmApiToken::default(),
524 _llm_token_subscription: cx.subscribe(
525 &refresh_llm_token_listener,
526 |this, _listener, _event, cx| {
527 let client = this.client.clone();
528 let llm_token = this.llm_token.clone();
529 cx.spawn(async move |_this, _cx| {
530 llm_token.refresh(&client).await?;
531 anyhow::Ok(())
532 })
533 .detach_and_log_err(cx);
534 },
535 ),
536 update_required: false,
537 debug_tx: None,
538 #[cfg(feature = "eval-support")]
539 eval_cache: None,
540 edit_prediction_model: ZetaEditPredictionModel::Zeta2,
541 sweep_ai: SweepAi::new(cx),
542 data_collection_choice,
543 rejected_predictions: Vec::new(),
544 reject_predictions_debounce_task: None,
545 reject_predictions_tx: reject_tx,
546 rated_predictions: Default::default(),
547 shown_predictions: Default::default(),
548 }
549 }
550
551 pub fn set_edit_prediction_model(&mut self, model: ZetaEditPredictionModel) {
552 self.edit_prediction_model = model;
553 }
554
555 pub fn has_sweep_api_token(&self) -> bool {
556 self.sweep_ai.api_token.is_some()
557 }
558
559 #[cfg(feature = "eval-support")]
560 pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
561 self.eval_cache = Some(cache);
562 }
563
564 pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<ZetaDebugInfo> {
565 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
566 self.debug_tx = Some(debug_watch_tx);
567 debug_watch_rx
568 }
569
570 pub fn options(&self) -> &ZetaOptions {
571 &self.options
572 }
573
574 pub fn set_options(&mut self, options: ZetaOptions) {
575 self.options = options;
576 }
577
578 pub fn clear_history(&mut self) {
579 for zeta_project in self.projects.values_mut() {
580 zeta_project.events.clear();
581 }
582 }
583
584 pub fn context_for_project(
585 &self,
586 project: &Entity<Project>,
587 ) -> impl Iterator<Item = (Entity<Buffer>, &[Range<Anchor>])> {
588 self.projects
589 .get(&project.entity_id())
590 .and_then(|project| {
591 Some(
592 project
593 .context
594 .as_ref()?
595 .iter()
596 .map(|(buffer, ranges)| (buffer.clone(), ranges.as_slice())),
597 )
598 })
599 .into_iter()
600 .flatten()
601 }
602
603 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
604 if self.edit_prediction_model == ZetaEditPredictionModel::Zeta2 {
605 self.user_store.read(cx).edit_prediction_usage()
606 } else {
607 None
608 }
609 }
610
611 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
612 self.get_or_init_zeta_project(project, cx);
613 }
614
615 pub fn register_buffer(
616 &mut self,
617 buffer: &Entity<Buffer>,
618 project: &Entity<Project>,
619 cx: &mut Context<Self>,
620 ) {
621 let zeta_project = self.get_or_init_zeta_project(project, cx);
622 Self::register_buffer_impl(zeta_project, buffer, project, cx);
623 }
624
625 fn get_or_init_zeta_project(
626 &mut self,
627 project: &Entity<Project>,
628 cx: &mut Context<Self>,
629 ) -> &mut ZetaProject {
630 self.projects
631 .entry(project.entity_id())
632 .or_insert_with(|| ZetaProject {
633 syntax_index: if let ContextMode::Syntax(_) = &self.options.context {
634 Some(cx.new(|cx| {
635 SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
636 }))
637 } else {
638 None
639 },
640 events: VecDeque::new(),
641 last_event: None,
642 recent_paths: VecDeque::new(),
643 registered_buffers: HashMap::default(),
644 current_prediction: None,
645 cancelled_predictions: HashSet::default(),
646 pending_predictions: ArrayVec::new(),
647 next_pending_prediction_id: 0,
648 last_prediction_refresh: None,
649 context: None,
650 refresh_context_task: None,
651 refresh_context_debounce_task: None,
652 refresh_context_timestamp: None,
653 license_detection_watchers: HashMap::default(),
654 _subscription: cx.subscribe(&project, Self::handle_project_event),
655 })
656 }
657
658 fn handle_project_event(
659 &mut self,
660 project: Entity<Project>,
661 event: &project::Event,
662 cx: &mut Context<Self>,
663 ) {
664 // TODO [zeta2] init with recent paths
665 match event {
666 project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
667 let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
668 return;
669 };
670 let path = project.read(cx).path_for_entry(*active_entry_id, cx);
671 if let Some(path) = path {
672 if let Some(ix) = zeta_project
673 .recent_paths
674 .iter()
675 .position(|probe| probe == &path)
676 {
677 zeta_project.recent_paths.remove(ix);
678 }
679 zeta_project.recent_paths.push_front(path);
680 }
681 }
682 project::Event::DiagnosticsUpdated { .. } => {
683 if cx.has_flag::<Zeta2FeatureFlag>() {
684 self.refresh_prediction_from_diagnostics(project, cx);
685 }
686 }
687 _ => (),
688 }
689 }
690
691 fn register_buffer_impl<'a>(
692 zeta_project: &'a mut ZetaProject,
693 buffer: &Entity<Buffer>,
694 project: &Entity<Project>,
695 cx: &mut Context<Self>,
696 ) -> &'a mut RegisteredBuffer {
697 let buffer_id = buffer.entity_id();
698
699 if let Some(file) = buffer.read(cx).file() {
700 let worktree_id = file.worktree_id(cx);
701 if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
702 zeta_project
703 .license_detection_watchers
704 .entry(worktree_id)
705 .or_insert_with(|| {
706 let project_entity_id = project.entity_id();
707 cx.observe_release(&worktree, move |this, _worktree, _cx| {
708 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
709 else {
710 return;
711 };
712 zeta_project.license_detection_watchers.remove(&worktree_id);
713 })
714 .detach();
715 Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
716 });
717 }
718 }
719
720 match zeta_project.registered_buffers.entry(buffer_id) {
721 hash_map::Entry::Occupied(entry) => entry.into_mut(),
722 hash_map::Entry::Vacant(entry) => {
723 let snapshot = buffer.read(cx).snapshot();
724 let project_entity_id = project.entity_id();
725 entry.insert(RegisteredBuffer {
726 snapshot,
727 _subscriptions: [
728 cx.subscribe(buffer, {
729 let project = project.downgrade();
730 move |this, buffer, event, cx| {
731 if let language::BufferEvent::Edited = event
732 && let Some(project) = project.upgrade()
733 {
734 this.report_changes_for_buffer(&buffer, &project, cx);
735 }
736 }
737 }),
738 cx.observe_release(buffer, move |this, _buffer, _cx| {
739 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
740 else {
741 return;
742 };
743 zeta_project.registered_buffers.remove(&buffer_id);
744 }),
745 ],
746 })
747 }
748 }
749 }
750
751 fn report_changes_for_buffer(
752 &mut self,
753 buffer: &Entity<Buffer>,
754 project: &Entity<Project>,
755 cx: &mut Context<Self>,
756 ) {
757 let project_state = self.get_or_init_zeta_project(project, cx);
758 let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
759
760 let new_snapshot = buffer.read(cx).snapshot();
761 if new_snapshot.version == registered_buffer.snapshot.version {
762 return;
763 }
764
765 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
766 let end_edit_anchor = new_snapshot
767 .anchored_edits_since::<Point>(&old_snapshot.version)
768 .last()
769 .map(|(_, range)| range.end);
770 let events = &mut project_state.events;
771
772 if let Some(LastEvent {
773 new_snapshot: last_new_snapshot,
774 end_edit_anchor: last_end_edit_anchor,
775 ..
776 }) = project_state.last_event.as_mut()
777 {
778 let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
779 == last_new_snapshot.remote_id()
780 && old_snapshot.version == last_new_snapshot.version;
781
782 let should_coalesce = is_next_snapshot_of_same_buffer
783 && end_edit_anchor
784 .as_ref()
785 .zip(last_end_edit_anchor.as_ref())
786 .is_some_and(|(a, b)| {
787 let a = a.to_point(&new_snapshot);
788 let b = b.to_point(&new_snapshot);
789 a.row.abs_diff(b.row) <= CHANGE_GROUPING_LINE_SPAN
790 });
791
792 if should_coalesce {
793 *last_end_edit_anchor = end_edit_anchor;
794 *last_new_snapshot = new_snapshot;
795 return;
796 }
797 }
798
799 if events.len() + 1 >= EVENT_COUNT_MAX {
800 events.pop_front();
801 }
802
803 if let Some(event) = project_state.last_event.take() {
804 events.extend(event.finalize(&project_state.license_detection_watchers, cx));
805 }
806
807 project_state.last_event = Some(LastEvent {
808 old_snapshot,
809 new_snapshot,
810 end_edit_anchor,
811 });
812 }
813
814 fn current_prediction_for_buffer(
815 &self,
816 buffer: &Entity<Buffer>,
817 project: &Entity<Project>,
818 cx: &App,
819 ) -> Option<BufferEditPrediction<'_>> {
820 let project_state = self.projects.get(&project.entity_id())?;
821
822 let CurrentEditPrediction {
823 requested_by,
824 prediction,
825 ..
826 } = project_state.current_prediction.as_ref()?;
827
828 if prediction.targets_buffer(buffer.read(cx)) {
829 Some(BufferEditPrediction::Local { prediction })
830 } else {
831 let show_jump = match requested_by {
832 PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
833 requested_by_buffer_id == &buffer.entity_id()
834 }
835 PredictionRequestedBy::DiagnosticsUpdate => true,
836 };
837
838 if show_jump {
839 Some(BufferEditPrediction::Jump { prediction })
840 } else {
841 None
842 }
843 }
844 }
845
846 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
847 match self.edit_prediction_model {
848 ZetaEditPredictionModel::Zeta1 | ZetaEditPredictionModel::Zeta2 => {}
849 ZetaEditPredictionModel::Sweep => return,
850 }
851
852 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
853 return;
854 };
855
856 let Some(prediction) = project_state.current_prediction.take() else {
857 return;
858 };
859 let request_id = prediction.prediction.id.to_string();
860 for pending_prediction in mem::take(&mut project_state.pending_predictions) {
861 project_state.cancel_pending_prediction(pending_prediction, cx);
862 }
863
864 let client = self.client.clone();
865 let llm_token = self.llm_token.clone();
866 let app_version = AppVersion::global(cx);
867 cx.spawn(async move |this, cx| {
868 let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") {
869 http_client::Url::parse(&predict_edits_url)?
870 } else {
871 client
872 .http_client()
873 .build_zed_llm_url("/predict_edits/accept", &[])?
874 };
875
876 let response = cx
877 .background_spawn(Self::send_api_request::<()>(
878 move |builder| {
879 let req = builder.uri(url.as_ref()).body(
880 serde_json::to_string(&AcceptEditPredictionBody {
881 request_id: request_id.clone(),
882 })?
883 .into(),
884 );
885 Ok(req?)
886 },
887 client,
888 llm_token,
889 app_version,
890 ))
891 .await;
892
893 Self::handle_api_response(&this, response, cx)?;
894 anyhow::Ok(())
895 })
896 .detach_and_log_err(cx);
897 }
898
899 fn flush_rejected_predictions(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
900 match self.edit_prediction_model {
901 ZetaEditPredictionModel::Zeta1 | ZetaEditPredictionModel::Zeta2 => {}
902 ZetaEditPredictionModel::Sweep => return Task::ready(anyhow::Ok(())),
903 }
904
905 let client = self.client.clone();
906 let llm_token = self.llm_token.clone();
907 let app_version = AppVersion::global(cx);
908 let last_rejection = self.rejected_predictions.last().cloned();
909 let Some(last_rejection) = last_rejection else {
910 return Task::ready(anyhow::Ok(()));
911 };
912
913 let body = serde_json::to_string(&RejectEditPredictionsBody {
914 rejections: self.rejected_predictions.clone(),
915 })
916 .ok();
917
918 cx.spawn(async move |this, cx| {
919 let url = client
920 .http_client()
921 .build_zed_llm_url("/predict_edits/reject", &[])?;
922
923 cx.background_spawn(Self::send_api_request::<()>(
924 move |builder| {
925 let req = builder.uri(url.as_ref()).body(body.clone().into());
926 Ok(req?)
927 },
928 client,
929 llm_token,
930 app_version,
931 ))
932 .await
933 .context("Failed to reject edit predictions")?;
934
935 this.update(cx, |this, _| {
936 if let Some(ix) = this
937 .rejected_predictions
938 .iter()
939 .position(|rejection| rejection.request_id == last_rejection.request_id)
940 {
941 this.rejected_predictions.drain(..ix + 1);
942 }
943 })
944 })
945 }
946
947 fn reject_current_prediction(
948 &mut self,
949 reason: EditPredictionRejectReason,
950 project: &Entity<Project>,
951 cx: &mut Context<Self>,
952 ) {
953 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
954 project_state.pending_predictions.clear();
955 if let Some(prediction) = project_state.current_prediction.take() {
956 self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown, cx);
957 }
958 };
959 }
960
961 fn did_show_current_prediction(&mut self, project: &Entity<Project>, _cx: &mut Context<Self>) {
962 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
963 if let Some(current_prediction) = project_state.current_prediction.as_mut() {
964 if !current_prediction.was_shown {
965 current_prediction.was_shown = true;
966 self.shown_predictions
967 .push_front(current_prediction.prediction.clone());
968 if self.shown_predictions.len() > 50 {
969 let completion = self.shown_predictions.pop_back().unwrap();
970 self.rated_predictions.remove(&completion.id);
971 }
972 }
973 }
974 }
975 }
976
977 fn reject_prediction(
978 &mut self,
979 prediction_id: EditPredictionId,
980 reason: EditPredictionRejectReason,
981 was_shown: bool,
982 cx: &mut Context<Self>,
983 ) {
984 self.rejected_predictions.push(EditPredictionRejection {
985 request_id: prediction_id.to_string(),
986 reason,
987 was_shown,
988 });
989
990 let reached_request_limit =
991 self.rejected_predictions.len() >= MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST;
992 let reject_tx = self.reject_predictions_tx.clone();
993 self.reject_predictions_debounce_task = Some(cx.spawn(async move |_this, cx| {
994 const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
995 if !reached_request_limit {
996 cx.background_executor()
997 .timer(REJECT_REQUEST_DEBOUNCE)
998 .await;
999 }
1000 reject_tx.unbounded_send(()).log_err();
1001 }));
1002 }
1003
1004 fn is_refreshing(&self, project: &Entity<Project>) -> bool {
1005 self.projects
1006 .get(&project.entity_id())
1007 .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
1008 }
1009
1010 pub fn refresh_prediction_from_buffer(
1011 &mut self,
1012 project: Entity<Project>,
1013 buffer: Entity<Buffer>,
1014 position: language::Anchor,
1015 cx: &mut Context<Self>,
1016 ) {
1017 self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
1018 let Some(request_task) = this
1019 .update(cx, |this, cx| {
1020 this.request_prediction(
1021 &project,
1022 &buffer,
1023 position,
1024 PredictEditsRequestTrigger::Other,
1025 cx,
1026 )
1027 })
1028 .log_err()
1029 else {
1030 return Task::ready(anyhow::Ok(None));
1031 };
1032
1033 cx.spawn(async move |_cx| {
1034 request_task.await.map(|prediction_result| {
1035 prediction_result.map(|prediction_result| {
1036 (
1037 prediction_result,
1038 PredictionRequestedBy::Buffer(buffer.entity_id()),
1039 )
1040 })
1041 })
1042 })
1043 })
1044 }
1045
1046 pub fn refresh_prediction_from_diagnostics(
1047 &mut self,
1048 project: Entity<Project>,
1049 cx: &mut Context<Self>,
1050 ) {
1051 let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1052 return;
1053 };
1054
1055 // Prefer predictions from buffer
1056 if zeta_project.current_prediction.is_some() {
1057 return;
1058 };
1059
1060 self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
1061 let Some(open_buffer_task) = project
1062 .update(cx, |project, cx| {
1063 project
1064 .active_entry()
1065 .and_then(|entry| project.path_for_entry(entry, cx))
1066 .map(|path| project.open_buffer(path, cx))
1067 })
1068 .log_err()
1069 .flatten()
1070 else {
1071 return Task::ready(anyhow::Ok(None));
1072 };
1073
1074 cx.spawn(async move |cx| {
1075 let active_buffer = open_buffer_task.await?;
1076 let snapshot = active_buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
1077
1078 let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1079 active_buffer,
1080 &snapshot,
1081 Default::default(),
1082 Default::default(),
1083 &project,
1084 cx,
1085 )
1086 .await?
1087 else {
1088 return anyhow::Ok(None);
1089 };
1090
1091 let Some(prediction_result) = this
1092 .update(cx, |this, cx| {
1093 this.request_prediction(
1094 &project,
1095 &jump_buffer,
1096 jump_position,
1097 PredictEditsRequestTrigger::Diagnostics,
1098 cx,
1099 )
1100 })?
1101 .await?
1102 else {
1103 return anyhow::Ok(None);
1104 };
1105
1106 this.update(cx, |this, cx| {
1107 Some((
1108 if this
1109 .get_or_init_zeta_project(&project, cx)
1110 .current_prediction
1111 .is_none()
1112 {
1113 prediction_result
1114 } else {
1115 EditPredictionResult {
1116 id: prediction_result.id,
1117 prediction: Err(EditPredictionRejectReason::CurrentPreferred),
1118 }
1119 },
1120 PredictionRequestedBy::DiagnosticsUpdate,
1121 ))
1122 })
1123 })
1124 });
1125 }
1126
1127 #[cfg(not(test))]
1128 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1129 #[cfg(test)]
1130 pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
1131
1132 fn queue_prediction_refresh(
1133 &mut self,
1134 project: Entity<Project>,
1135 throttle_entity: EntityId,
1136 cx: &mut Context<Self>,
1137 do_refresh: impl FnOnce(
1138 WeakEntity<Self>,
1139 &mut AsyncApp,
1140 )
1141 -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
1142 + 'static,
1143 ) {
1144 let zeta_project = self.get_or_init_zeta_project(&project, cx);
1145 let pending_prediction_id = zeta_project.next_pending_prediction_id;
1146 zeta_project.next_pending_prediction_id += 1;
1147 let last_request = zeta_project.last_prediction_refresh;
1148
1149 let task = cx.spawn(async move |this, cx| {
1150 if let Some((last_entity, last_timestamp)) = last_request
1151 && throttle_entity == last_entity
1152 && let Some(timeout) =
1153 (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
1154 {
1155 cx.background_executor().timer(timeout).await;
1156 }
1157
1158 // If this task was cancelled before the throttle timeout expired,
1159 // do not perform a request.
1160 let mut is_cancelled = true;
1161 this.update(cx, |this, cx| {
1162 let project_state = this.get_or_init_zeta_project(&project, cx);
1163 if !project_state
1164 .cancelled_predictions
1165 .remove(&pending_prediction_id)
1166 {
1167 project_state.last_prediction_refresh = Some((throttle_entity, Instant::now()));
1168 is_cancelled = false;
1169 }
1170 })
1171 .ok();
1172 if is_cancelled {
1173 return None;
1174 }
1175
1176 let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
1177 let new_prediction_id = new_prediction_result
1178 .as_ref()
1179 .map(|(prediction, _)| prediction.id.clone());
1180
1181 // When a prediction completes, remove it from the pending list, and cancel
1182 // any pending predictions that were enqueued before it.
1183 this.update(cx, |this, cx| {
1184 let zeta_project = this.get_or_init_zeta_project(&project, cx);
1185
1186 let is_cancelled = zeta_project
1187 .cancelled_predictions
1188 .remove(&pending_prediction_id);
1189
1190 let new_current_prediction = if !is_cancelled
1191 && let Some((prediction_result, requested_by)) = new_prediction_result
1192 {
1193 match prediction_result.prediction {
1194 Ok(prediction) => {
1195 let new_prediction = CurrentEditPrediction {
1196 requested_by,
1197 prediction,
1198 was_shown: false,
1199 };
1200
1201 if let Some(current_prediction) =
1202 zeta_project.current_prediction.as_ref()
1203 {
1204 if new_prediction.should_replace_prediction(¤t_prediction, cx)
1205 {
1206 this.reject_current_prediction(
1207 EditPredictionRejectReason::Replaced,
1208 &project,
1209 cx,
1210 );
1211
1212 Some(new_prediction)
1213 } else {
1214 this.reject_prediction(
1215 new_prediction.prediction.id,
1216 EditPredictionRejectReason::CurrentPreferred,
1217 false,
1218 cx,
1219 );
1220 None
1221 }
1222 } else {
1223 Some(new_prediction)
1224 }
1225 }
1226 Err(reject_reason) => {
1227 this.reject_prediction(prediction_result.id, reject_reason, false, cx);
1228 None
1229 }
1230 }
1231 } else {
1232 None
1233 };
1234
1235 let zeta_project = this.get_or_init_zeta_project(&project, cx);
1236
1237 if let Some(new_prediction) = new_current_prediction {
1238 zeta_project.current_prediction = Some(new_prediction);
1239 }
1240
1241 let mut pending_predictions = mem::take(&mut zeta_project.pending_predictions);
1242 for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1243 if pending_prediction.id == pending_prediction_id {
1244 pending_predictions.remove(ix);
1245 for pending_prediction in pending_predictions.drain(0..ix) {
1246 zeta_project.cancel_pending_prediction(pending_prediction, cx)
1247 }
1248 break;
1249 }
1250 }
1251 this.get_or_init_zeta_project(&project, cx)
1252 .pending_predictions = pending_predictions;
1253 cx.notify();
1254 })
1255 .ok();
1256
1257 new_prediction_id
1258 });
1259
1260 if zeta_project.pending_predictions.len() <= 1 {
1261 zeta_project.pending_predictions.push(PendingPrediction {
1262 id: pending_prediction_id,
1263 task,
1264 });
1265 } else if zeta_project.pending_predictions.len() == 2 {
1266 let pending_prediction = zeta_project.pending_predictions.pop().unwrap();
1267 zeta_project.pending_predictions.push(PendingPrediction {
1268 id: pending_prediction_id,
1269 task,
1270 });
1271 zeta_project.cancel_pending_prediction(pending_prediction, cx);
1272 }
1273 }
1274
1275 pub fn request_prediction(
1276 &mut self,
1277 project: &Entity<Project>,
1278 active_buffer: &Entity<Buffer>,
1279 position: language::Anchor,
1280 trigger: PredictEditsRequestTrigger,
1281 cx: &mut Context<Self>,
1282 ) -> Task<Result<Option<EditPredictionResult>>> {
1283 self.request_prediction_internal(
1284 project.clone(),
1285 active_buffer.clone(),
1286 position,
1287 trigger,
1288 cx.has_flag::<Zeta2FeatureFlag>(),
1289 cx,
1290 )
1291 }
1292
1293 fn request_prediction_internal(
1294 &mut self,
1295 project: Entity<Project>,
1296 active_buffer: Entity<Buffer>,
1297 position: language::Anchor,
1298 trigger: PredictEditsRequestTrigger,
1299 allow_jump: bool,
1300 cx: &mut Context<Self>,
1301 ) -> Task<Result<Option<EditPredictionResult>>> {
1302 const DIAGNOSTIC_LINES_RANGE: u32 = 20;
1303
1304 self.get_or_init_zeta_project(&project, cx);
1305 let zeta_project = self.projects.get(&project.entity_id()).unwrap();
1306 let events = zeta_project.events(cx);
1307 let has_events = !events.is_empty();
1308
1309 let snapshot = active_buffer.read(cx).snapshot();
1310 let cursor_point = position.to_point(&snapshot);
1311 let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1312 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1313 let diagnostic_search_range =
1314 Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1315
1316 let task = match self.edit_prediction_model {
1317 ZetaEditPredictionModel::Zeta1 => request_prediction_with_zeta1(
1318 self,
1319 &project,
1320 &active_buffer,
1321 snapshot.clone(),
1322 position,
1323 events,
1324 trigger,
1325 cx,
1326 ),
1327 ZetaEditPredictionModel::Zeta2 => self.request_prediction_with_zeta2(
1328 &project,
1329 &active_buffer,
1330 snapshot.clone(),
1331 position,
1332 events,
1333 trigger,
1334 cx,
1335 ),
1336 ZetaEditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(
1337 &project,
1338 &active_buffer,
1339 snapshot.clone(),
1340 position,
1341 events,
1342 &zeta_project.recent_paths,
1343 diagnostic_search_range.clone(),
1344 cx,
1345 ),
1346 };
1347
1348 cx.spawn(async move |this, cx| {
1349 let prediction = task.await?;
1350
1351 if prediction.is_none() && allow_jump {
1352 let cursor_point = position.to_point(&snapshot);
1353 if has_events
1354 && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1355 active_buffer.clone(),
1356 &snapshot,
1357 diagnostic_search_range,
1358 cursor_point,
1359 &project,
1360 cx,
1361 )
1362 .await?
1363 {
1364 return this
1365 .update(cx, |this, cx| {
1366 this.request_prediction_internal(
1367 project,
1368 jump_buffer,
1369 jump_position,
1370 trigger,
1371 false,
1372 cx,
1373 )
1374 })?
1375 .await;
1376 }
1377
1378 return anyhow::Ok(None);
1379 }
1380
1381 Ok(prediction)
1382 })
1383 }
1384
1385 async fn next_diagnostic_location(
1386 active_buffer: Entity<Buffer>,
1387 active_buffer_snapshot: &BufferSnapshot,
1388 active_buffer_diagnostic_search_range: Range<Point>,
1389 active_buffer_cursor_point: Point,
1390 project: &Entity<Project>,
1391 cx: &mut AsyncApp,
1392 ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1393 // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
1394 let mut jump_location = active_buffer_snapshot
1395 .diagnostic_groups(None)
1396 .into_iter()
1397 .filter_map(|(_, group)| {
1398 let range = &group.entries[group.primary_ix]
1399 .range
1400 .to_point(&active_buffer_snapshot);
1401 if range.overlaps(&active_buffer_diagnostic_search_range) {
1402 None
1403 } else {
1404 Some(range.start)
1405 }
1406 })
1407 .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
1408 .map(|position| {
1409 (
1410 active_buffer.clone(),
1411 active_buffer_snapshot.anchor_before(position),
1412 )
1413 });
1414
1415 if jump_location.is_none() {
1416 let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
1417 let file = buffer.file()?;
1418
1419 Some(ProjectPath {
1420 worktree_id: file.worktree_id(cx),
1421 path: file.path().clone(),
1422 })
1423 })?;
1424
1425 let buffer_task = project.update(cx, |project, cx| {
1426 let (path, _, _) = project
1427 .diagnostic_summaries(false, cx)
1428 .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
1429 .max_by_key(|(path, _, _)| {
1430 // find the buffer with errors that shares most parent directories
1431 path.path
1432 .components()
1433 .zip(
1434 active_buffer_path
1435 .as_ref()
1436 .map(|p| p.path.components())
1437 .unwrap_or_default(),
1438 )
1439 .take_while(|(a, b)| a == b)
1440 .count()
1441 })?;
1442
1443 Some(project.open_buffer(path, cx))
1444 })?;
1445
1446 if let Some(buffer_task) = buffer_task {
1447 let closest_buffer = buffer_task.await?;
1448
1449 jump_location = closest_buffer
1450 .read_with(cx, |buffer, _cx| {
1451 buffer
1452 .buffer_diagnostics(None)
1453 .into_iter()
1454 .min_by_key(|entry| entry.diagnostic.severity)
1455 .map(|entry| entry.range.start)
1456 })?
1457 .map(|position| (closest_buffer, position));
1458 }
1459 }
1460
1461 anyhow::Ok(jump_location)
1462 }
1463
1464 fn request_prediction_with_zeta2(
1465 &mut self,
1466 project: &Entity<Project>,
1467 active_buffer: &Entity<Buffer>,
1468 active_snapshot: BufferSnapshot,
1469 position: language::Anchor,
1470 events: Vec<Arc<Event>>,
1471 trigger: PredictEditsRequestTrigger,
1472 cx: &mut Context<Self>,
1473 ) -> Task<Result<Option<EditPredictionResult>>> {
1474 let project_state = self.projects.get(&project.entity_id());
1475
1476 let index_state = project_state.and_then(|state| {
1477 state
1478 .syntax_index
1479 .as_ref()
1480 .map(|syntax_index| syntax_index.read_with(cx, |index, _cx| index.state().clone()))
1481 });
1482 let options = self.options.clone();
1483 let buffer_snapshotted_at = Instant::now();
1484 let Some(excerpt_path) = active_snapshot
1485 .file()
1486 .map(|path| -> Arc<Path> { path.full_path(cx).into() })
1487 else {
1488 return Task::ready(Err(anyhow!("No file path for excerpt")));
1489 };
1490 let client = self.client.clone();
1491 let llm_token = self.llm_token.clone();
1492 let app_version = AppVersion::global(cx);
1493 let worktree_snapshots = project
1494 .read(cx)
1495 .worktrees(cx)
1496 .map(|worktree| worktree.read(cx).snapshot())
1497 .collect::<Vec<_>>();
1498 let debug_tx = self.debug_tx.clone();
1499
1500 let diagnostics = active_snapshot.diagnostic_sets().clone();
1501
1502 let file = active_buffer.read(cx).file();
1503 let parent_abs_path = project::File::from_dyn(file).and_then(|f| {
1504 let mut path = f.worktree.read(cx).absolutize(&f.path);
1505 if path.pop() { Some(path) } else { None }
1506 });
1507
1508 // TODO data collection
1509 let can_collect_data = file
1510 .as_ref()
1511 .map_or(false, |file| self.can_collect_file(project, file, cx));
1512
1513 let empty_context_files = HashMap::default();
1514 let context_files = project_state
1515 .and_then(|project_state| project_state.context.as_ref())
1516 .unwrap_or(&empty_context_files);
1517
1518 #[cfg(feature = "eval-support")]
1519 let parsed_fut = futures::future::join_all(
1520 context_files
1521 .keys()
1522 .map(|buffer| buffer.read(cx).parsing_idle()),
1523 );
1524
1525 let mut included_files = context_files
1526 .iter()
1527 .filter_map(|(buffer_entity, ranges)| {
1528 let buffer = buffer_entity.read(cx);
1529 Some((
1530 buffer_entity.clone(),
1531 buffer.snapshot(),
1532 buffer.file()?.full_path(cx).into(),
1533 ranges.clone(),
1534 ))
1535 })
1536 .collect::<Vec<_>>();
1537
1538 included_files.sort_by(|(_, _, path_a, ranges_a), (_, _, path_b, ranges_b)| {
1539 (path_a, ranges_a.len()).cmp(&(path_b, ranges_b.len()))
1540 });
1541
1542 #[cfg(feature = "eval-support")]
1543 let eval_cache = self.eval_cache.clone();
1544
1545 let request_task = cx.background_spawn({
1546 let active_buffer = active_buffer.clone();
1547 async move {
1548 #[cfg(feature = "eval-support")]
1549 parsed_fut.await;
1550
1551 let index_state = if let Some(index_state) = index_state {
1552 Some(index_state.lock_owned().await)
1553 } else {
1554 None
1555 };
1556
1557 let cursor_offset = position.to_offset(&active_snapshot);
1558 let cursor_point = cursor_offset.to_point(&active_snapshot);
1559
1560 let before_retrieval = Instant::now();
1561
1562 let (diagnostic_groups, diagnostic_groups_truncated) =
1563 Self::gather_nearby_diagnostics(
1564 cursor_offset,
1565 &diagnostics,
1566 &active_snapshot,
1567 options.max_diagnostic_bytes,
1568 );
1569
1570 let cloud_request = match options.context {
1571 ContextMode::Agentic(context_options) => {
1572 let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
1573 cursor_point,
1574 &active_snapshot,
1575 &context_options.excerpt,
1576 index_state.as_deref(),
1577 ) else {
1578 return Ok((None, None));
1579 };
1580
1581 let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start)
1582 ..active_snapshot.anchor_before(excerpt.range.end);
1583
1584 if let Some(buffer_ix) =
1585 included_files.iter().position(|(_, snapshot, _, _)| {
1586 snapshot.remote_id() == active_snapshot.remote_id()
1587 })
1588 {
1589 let (_, buffer, _, ranges) = &mut included_files[buffer_ix];
1590 ranges.push(excerpt_anchor_range);
1591 retrieval_search::merge_anchor_ranges(ranges, buffer);
1592 let last_ix = included_files.len() - 1;
1593 included_files.swap(buffer_ix, last_ix);
1594 } else {
1595 included_files.push((
1596 active_buffer.clone(),
1597 active_snapshot.clone(),
1598 excerpt_path.clone(),
1599 vec![excerpt_anchor_range],
1600 ));
1601 }
1602
1603 let included_files = included_files
1604 .iter()
1605 .map(|(_, snapshot, path, ranges)| {
1606 let ranges = ranges
1607 .iter()
1608 .map(|range| {
1609 let point_range = range.to_point(&snapshot);
1610 Line(point_range.start.row)..Line(point_range.end.row)
1611 })
1612 .collect::<Vec<_>>();
1613 let excerpts = assemble_excerpts(&snapshot, ranges);
1614 predict_edits_v3::IncludedFile {
1615 path: path.clone(),
1616 max_row: Line(snapshot.max_point().row),
1617 excerpts,
1618 }
1619 })
1620 .collect::<Vec<_>>();
1621
1622 predict_edits_v3::PredictEditsRequest {
1623 excerpt_path,
1624 excerpt: String::new(),
1625 excerpt_line_range: Line(0)..Line(0),
1626 excerpt_range: 0..0,
1627 cursor_point: predict_edits_v3::Point {
1628 line: predict_edits_v3::Line(cursor_point.row),
1629 column: cursor_point.column,
1630 },
1631 included_files,
1632 referenced_declarations: vec![],
1633 events,
1634 can_collect_data,
1635 diagnostic_groups,
1636 diagnostic_groups_truncated,
1637 debug_info: debug_tx.is_some(),
1638 prompt_max_bytes: Some(options.max_prompt_bytes),
1639 prompt_format: options.prompt_format,
1640 // TODO [zeta2]
1641 signatures: vec![],
1642 excerpt_parent: None,
1643 git_info: None,
1644 trigger,
1645 }
1646 }
1647 ContextMode::Syntax(context_options) => {
1648 let Some(context) = EditPredictionContext::gather_context(
1649 cursor_point,
1650 &active_snapshot,
1651 parent_abs_path.as_deref(),
1652 &context_options,
1653 index_state.as_deref(),
1654 ) else {
1655 return Ok((None, None));
1656 };
1657
1658 make_syntax_context_cloud_request(
1659 excerpt_path,
1660 context,
1661 events,
1662 can_collect_data,
1663 diagnostic_groups,
1664 diagnostic_groups_truncated,
1665 None,
1666 debug_tx.is_some(),
1667 &worktree_snapshots,
1668 index_state.as_deref(),
1669 Some(options.max_prompt_bytes),
1670 options.prompt_format,
1671 trigger,
1672 )
1673 }
1674 };
1675
1676 let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request);
1677
1678 let inputs = EditPredictionInputs {
1679 included_files: cloud_request.included_files,
1680 events: cloud_request.events,
1681 cursor_point: cloud_request.cursor_point,
1682 cursor_path: cloud_request.excerpt_path,
1683 };
1684
1685 let retrieval_time = Instant::now() - before_retrieval;
1686
1687 let debug_response_tx = if let Some(debug_tx) = &debug_tx {
1688 let (response_tx, response_rx) = oneshot::channel();
1689
1690 debug_tx
1691 .unbounded_send(ZetaDebugInfo::EditPredictionRequested(
1692 ZetaEditPredictionDebugInfo {
1693 inputs: inputs.clone(),
1694 retrieval_time,
1695 buffer: active_buffer.downgrade(),
1696 local_prompt: match prompt_result.as_ref() {
1697 Ok((prompt, _)) => Ok(prompt.clone()),
1698 Err(err) => Err(err.to_string()),
1699 },
1700 position,
1701 response_rx,
1702 },
1703 ))
1704 .ok();
1705 Some(response_tx)
1706 } else {
1707 None
1708 };
1709
1710 if cfg!(debug_assertions) && env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
1711 if let Some(debug_response_tx) = debug_response_tx {
1712 debug_response_tx
1713 .send((Err("Request skipped".to_string()), Duration::ZERO))
1714 .ok();
1715 }
1716 anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
1717 }
1718
1719 let (prompt, _) = prompt_result?;
1720 let generation_params =
1721 cloud_zeta2_prompt::generation_params(cloud_request.prompt_format);
1722 let request = open_ai::Request {
1723 model: EDIT_PREDICTIONS_MODEL_ID.clone(),
1724 messages: vec![open_ai::RequestMessage::User {
1725 content: open_ai::MessageContent::Plain(prompt),
1726 }],
1727 stream: false,
1728 max_completion_tokens: None,
1729 stop: generation_params.stop.unwrap_or_default(),
1730 temperature: generation_params.temperature.unwrap_or(0.7),
1731 tool_choice: None,
1732 parallel_tool_calls: None,
1733 tools: vec![],
1734 prompt_cache_key: None,
1735 reasoning_effort: None,
1736 };
1737
1738 log::trace!("Sending edit prediction request");
1739
1740 let before_request = Instant::now();
1741 let response = Self::send_raw_llm_request(
1742 request,
1743 client,
1744 llm_token,
1745 app_version,
1746 #[cfg(feature = "eval-support")]
1747 eval_cache,
1748 #[cfg(feature = "eval-support")]
1749 EvalCacheEntryKind::Prediction,
1750 )
1751 .await;
1752 let received_response_at = Instant::now();
1753 let request_time = received_response_at - before_request;
1754
1755 log::trace!("Got edit prediction response");
1756
1757 if let Some(debug_response_tx) = debug_response_tx {
1758 debug_response_tx
1759 .send((
1760 response
1761 .as_ref()
1762 .map_err(|err| err.to_string())
1763 .map(|response| response.0.clone()),
1764 request_time,
1765 ))
1766 .ok();
1767 }
1768
1769 let (res, usage) = response?;
1770 let request_id = EditPredictionId(res.id.clone().into());
1771 let Some(mut output_text) = text_from_response(res) else {
1772 return Ok((Some((request_id, None)), usage));
1773 };
1774
1775 if output_text.contains(CURSOR_MARKER) {
1776 log::trace!("Stripping out {CURSOR_MARKER} from response");
1777 output_text = output_text.replace(CURSOR_MARKER, "");
1778 }
1779
1780 let get_buffer_from_context = |path: &Path| {
1781 included_files
1782 .iter()
1783 .find_map(|(_, buffer, probe_path, ranges)| {
1784 if probe_path.as_ref() == path {
1785 Some((buffer, ranges.as_slice()))
1786 } else {
1787 None
1788 }
1789 })
1790 };
1791
1792 let (edited_buffer_snapshot, edits) = match options.prompt_format {
1793 PromptFormat::NumLinesUniDiff => {
1794 // TODO: Implement parsing of multi-file diffs
1795 crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
1796 }
1797 PromptFormat::Minimal
1798 | PromptFormat::MinimalQwen
1799 | PromptFormat::SeedCoder1120 => {
1800 if output_text.contains("--- a/\n+++ b/\nNo edits") {
1801 let edits = vec![];
1802 (&active_snapshot, edits)
1803 } else {
1804 crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
1805 }
1806 }
1807 PromptFormat::OldTextNewText => {
1808 crate::xml_edits::parse_xml_edits(&output_text, get_buffer_from_context)
1809 .await?
1810 }
1811 _ => {
1812 bail!("unsupported prompt format {}", options.prompt_format)
1813 }
1814 };
1815
1816 let edited_buffer = included_files
1817 .iter()
1818 .find_map(|(buffer, snapshot, _, _)| {
1819 if snapshot.remote_id() == edited_buffer_snapshot.remote_id() {
1820 Some(buffer.clone())
1821 } else {
1822 None
1823 }
1824 })
1825 .context("Failed to find buffer in included_buffers")?;
1826
1827 anyhow::Ok((
1828 Some((
1829 request_id,
1830 Some((
1831 inputs,
1832 edited_buffer,
1833 edited_buffer_snapshot.clone(),
1834 edits,
1835 received_response_at,
1836 )),
1837 )),
1838 usage,
1839 ))
1840 }
1841 });
1842
1843 cx.spawn({
1844 async move |this, cx| {
1845 let Some((id, prediction)) =
1846 Self::handle_api_response(&this, request_task.await, cx)?
1847 else {
1848 return Ok(None);
1849 };
1850
1851 let Some((
1852 inputs,
1853 edited_buffer,
1854 edited_buffer_snapshot,
1855 edits,
1856 received_response_at,
1857 )) = prediction
1858 else {
1859 return Ok(Some(EditPredictionResult {
1860 id,
1861 prediction: Err(EditPredictionRejectReason::Empty),
1862 }));
1863 };
1864
1865 // TODO telemetry: duration, etc
1866 Ok(Some(
1867 EditPredictionResult::new(
1868 id,
1869 &edited_buffer,
1870 &edited_buffer_snapshot,
1871 edits.into(),
1872 buffer_snapshotted_at,
1873 received_response_at,
1874 inputs,
1875 cx,
1876 )
1877 .await,
1878 ))
1879 }
1880 })
1881 }
1882
1883 async fn send_raw_llm_request(
1884 request: open_ai::Request,
1885 client: Arc<Client>,
1886 llm_token: LlmApiToken,
1887 app_version: Version,
1888 #[cfg(feature = "eval-support")] eval_cache: Option<Arc<dyn EvalCache>>,
1889 #[cfg(feature = "eval-support")] eval_cache_kind: EvalCacheEntryKind,
1890 ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
1891 let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
1892 http_client::Url::parse(&predict_edits_url)?
1893 } else {
1894 client
1895 .http_client()
1896 .build_zed_llm_url("/predict_edits/raw", &[])?
1897 };
1898
1899 #[cfg(feature = "eval-support")]
1900 let cache_key = if let Some(cache) = eval_cache {
1901 use collections::FxHasher;
1902 use std::hash::{Hash, Hasher};
1903
1904 let mut hasher = FxHasher::default();
1905 url.hash(&mut hasher);
1906 let request_str = serde_json::to_string_pretty(&request)?;
1907 request_str.hash(&mut hasher);
1908 let hash = hasher.finish();
1909
1910 let key = (eval_cache_kind, hash);
1911 if let Some(response_str) = cache.read(key) {
1912 return Ok((serde_json::from_str(&response_str)?, None));
1913 }
1914
1915 Some((cache, request_str, key))
1916 } else {
1917 None
1918 };
1919
1920 let (response, usage) = Self::send_api_request(
1921 |builder| {
1922 let req = builder
1923 .uri(url.as_ref())
1924 .body(serde_json::to_string(&request)?.into());
1925 Ok(req?)
1926 },
1927 client,
1928 llm_token,
1929 app_version,
1930 )
1931 .await?;
1932
1933 #[cfg(feature = "eval-support")]
1934 if let Some((cache, request, key)) = cache_key {
1935 cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
1936 }
1937
1938 Ok((response, usage))
1939 }
1940
1941 fn handle_api_response<T>(
1942 this: &WeakEntity<Self>,
1943 response: Result<(T, Option<EditPredictionUsage>)>,
1944 cx: &mut gpui::AsyncApp,
1945 ) -> Result<T> {
1946 match response {
1947 Ok((data, usage)) => {
1948 if let Some(usage) = usage {
1949 this.update(cx, |this, cx| {
1950 this.user_store.update(cx, |user_store, cx| {
1951 user_store.update_edit_prediction_usage(usage, cx);
1952 });
1953 })
1954 .ok();
1955 }
1956 Ok(data)
1957 }
1958 Err(err) => {
1959 if err.is::<ZedUpdateRequiredError>() {
1960 cx.update(|cx| {
1961 this.update(cx, |this, _cx| {
1962 this.update_required = true;
1963 })
1964 .ok();
1965
1966 let error_message: SharedString = err.to_string().into();
1967 show_app_notification(
1968 NotificationId::unique::<ZedUpdateRequiredError>(),
1969 cx,
1970 move |cx| {
1971 cx.new(|cx| {
1972 ErrorMessagePrompt::new(error_message.clone(), cx)
1973 .with_link_button("Update Zed", "https://zed.dev/releases")
1974 })
1975 },
1976 );
1977 })
1978 .ok();
1979 }
1980 Err(err)
1981 }
1982 }
1983 }
1984
1985 async fn send_api_request<Res>(
1986 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1987 client: Arc<Client>,
1988 llm_token: LlmApiToken,
1989 app_version: Version,
1990 ) -> Result<(Res, Option<EditPredictionUsage>)>
1991 where
1992 Res: DeserializeOwned,
1993 {
1994 let http_client = client.http_client();
1995 let mut token = llm_token.acquire(&client).await?;
1996 let mut did_retry = false;
1997
1998 loop {
1999 let request_builder = http_client::Request::builder().method(Method::POST);
2000
2001 let request = build(
2002 request_builder
2003 .header("Content-Type", "application/json")
2004 .header("Authorization", format!("Bearer {}", token))
2005 .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
2006 )?;
2007
2008 let mut response = http_client.send(request).await?;
2009
2010 if let Some(minimum_required_version) = response
2011 .headers()
2012 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2013 .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2014 {
2015 anyhow::ensure!(
2016 app_version >= minimum_required_version,
2017 ZedUpdateRequiredError {
2018 minimum_version: minimum_required_version
2019 }
2020 );
2021 }
2022
2023 if response.status().is_success() {
2024 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2025
2026 let mut body = Vec::new();
2027 response.body_mut().read_to_end(&mut body).await?;
2028 return Ok((serde_json::from_slice(&body)?, usage));
2029 } else if !did_retry
2030 && response
2031 .headers()
2032 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
2033 .is_some()
2034 {
2035 did_retry = true;
2036 token = llm_token.refresh(&client).await?;
2037 } else {
2038 let mut body = String::new();
2039 response.body_mut().read_to_string(&mut body).await?;
2040 anyhow::bail!(
2041 "Request failed with status: {:?}\nBody: {}",
2042 response.status(),
2043 body
2044 );
2045 }
2046 }
2047 }
2048
2049 pub const CONTEXT_RETRIEVAL_IDLE_DURATION: Duration = Duration::from_secs(10);
2050 pub const CONTEXT_RETRIEVAL_DEBOUNCE_DURATION: Duration = Duration::from_secs(3);
2051
2052 // Refresh the related excerpts when the user just beguns editing after
2053 // an idle period, and after they pause editing.
2054 fn refresh_context_if_needed(
2055 &mut self,
2056 project: &Entity<Project>,
2057 buffer: &Entity<language::Buffer>,
2058 cursor_position: language::Anchor,
2059 cx: &mut Context<Self>,
2060 ) {
2061 if !matches!(&self.options().context, ContextMode::Agentic { .. }) {
2062 return;
2063 }
2064
2065 let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
2066 return;
2067 };
2068
2069 let now = Instant::now();
2070 let was_idle = zeta_project
2071 .refresh_context_timestamp
2072 .map_or(true, |timestamp| {
2073 now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION
2074 });
2075 zeta_project.refresh_context_timestamp = Some(now);
2076 zeta_project.refresh_context_debounce_task = Some(cx.spawn({
2077 let buffer = buffer.clone();
2078 let project = project.clone();
2079 async move |this, cx| {
2080 if was_idle {
2081 log::debug!("refetching edit prediction context after idle");
2082 } else {
2083 cx.background_executor()
2084 .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION)
2085 .await;
2086 log::debug!("refetching edit prediction context after pause");
2087 }
2088 this.update(cx, |this, cx| {
2089 let task = this.refresh_context(project.clone(), buffer, cursor_position, cx);
2090
2091 if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
2092 zeta_project.refresh_context_task = Some(task.log_err());
2093 };
2094 })
2095 .ok()
2096 }
2097 }));
2098 }
2099
2100 // Refresh the related excerpts asynchronously. Ensure the task runs to completion,
2101 // and avoid spawning more than one concurrent task.
2102 pub fn refresh_context(
2103 &mut self,
2104 project: Entity<Project>,
2105 buffer: Entity<language::Buffer>,
2106 cursor_position: language::Anchor,
2107 cx: &mut Context<Self>,
2108 ) -> Task<Result<()>> {
2109 let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
2110 return Task::ready(anyhow::Ok(()));
2111 };
2112
2113 let ContextMode::Agentic(options) = &self.options().context else {
2114 return Task::ready(anyhow::Ok(()));
2115 };
2116
2117 let snapshot = buffer.read(cx).snapshot();
2118 let cursor_point = cursor_position.to_point(&snapshot);
2119 let Some(cursor_excerpt) = EditPredictionExcerpt::select_from_buffer(
2120 cursor_point,
2121 &snapshot,
2122 &options.excerpt,
2123 None,
2124 ) else {
2125 return Task::ready(Ok(()));
2126 };
2127
2128 let app_version = AppVersion::global(cx);
2129 let client = self.client.clone();
2130 let llm_token = self.llm_token.clone();
2131 let debug_tx = self.debug_tx.clone();
2132 let current_file_path: Arc<Path> = snapshot
2133 .file()
2134 .map(|f| f.full_path(cx).into())
2135 .unwrap_or_else(|| Path::new("untitled").into());
2136
2137 let prompt = match cloud_zeta2_prompt::retrieval_prompt::build_prompt(
2138 predict_edits_v3::PlanContextRetrievalRequest {
2139 excerpt: cursor_excerpt.text(&snapshot).body,
2140 excerpt_path: current_file_path,
2141 excerpt_line_range: cursor_excerpt.line_range,
2142 cursor_file_max_row: Line(snapshot.max_point().row),
2143 events: zeta_project.events(cx),
2144 },
2145 ) {
2146 Ok(prompt) => prompt,
2147 Err(err) => {
2148 return Task::ready(Err(err));
2149 }
2150 };
2151
2152 if let Some(debug_tx) = &debug_tx {
2153 debug_tx
2154 .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted(
2155 ZetaContextRetrievalStartedDebugInfo {
2156 project: project.clone(),
2157 timestamp: Instant::now(),
2158 search_prompt: prompt.clone(),
2159 },
2160 ))
2161 .ok();
2162 }
2163
2164 pub static TOOL_SCHEMA: LazyLock<(serde_json::Value, String)> = LazyLock::new(|| {
2165 let schema = language_model::tool_schema::root_schema_for::<SearchToolInput>(
2166 language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset,
2167 );
2168
2169 let description = schema
2170 .get("description")
2171 .and_then(|description| description.as_str())
2172 .unwrap()
2173 .to_string();
2174
2175 (schema.into(), description)
2176 });
2177
2178 let (tool_schema, tool_description) = TOOL_SCHEMA.clone();
2179
2180 let request = open_ai::Request {
2181 model: CONTEXT_RETRIEVAL_MODEL_ID.clone(),
2182 messages: vec![open_ai::RequestMessage::User {
2183 content: open_ai::MessageContent::Plain(prompt),
2184 }],
2185 stream: false,
2186 max_completion_tokens: None,
2187 stop: Default::default(),
2188 temperature: 0.7,
2189 tool_choice: None,
2190 parallel_tool_calls: None,
2191 tools: vec![open_ai::ToolDefinition::Function {
2192 function: FunctionDefinition {
2193 name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME.to_string(),
2194 description: Some(tool_description),
2195 parameters: Some(tool_schema),
2196 },
2197 }],
2198 prompt_cache_key: None,
2199 reasoning_effort: None,
2200 };
2201
2202 #[cfg(feature = "eval-support")]
2203 let eval_cache = self.eval_cache.clone();
2204
2205 cx.spawn(async move |this, cx| {
2206 log::trace!("Sending search planning request");
2207 let response = Self::send_raw_llm_request(
2208 request,
2209 client,
2210 llm_token,
2211 app_version,
2212 #[cfg(feature = "eval-support")]
2213 eval_cache.clone(),
2214 #[cfg(feature = "eval-support")]
2215 EvalCacheEntryKind::Context,
2216 )
2217 .await;
2218 let mut response = Self::handle_api_response(&this, response, cx)?;
2219 log::trace!("Got search planning response");
2220
2221 let choice = response
2222 .choices
2223 .pop()
2224 .context("No choices in retrieval response")?;
2225 let open_ai::RequestMessage::Assistant {
2226 content: _,
2227 tool_calls,
2228 } = choice.message
2229 else {
2230 anyhow::bail!("Retrieval response didn't include an assistant message");
2231 };
2232
2233 let mut queries: Vec<SearchToolQuery> = Vec::new();
2234 for tool_call in tool_calls {
2235 let open_ai::ToolCallContent::Function { function } = tool_call.content;
2236 if function.name != cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME {
2237 log::warn!(
2238 "Context retrieval response tried to call an unknown tool: {}",
2239 function.name
2240 );
2241
2242 continue;
2243 }
2244
2245 let input: SearchToolInput = serde_json::from_str(&function.arguments)
2246 .with_context(|| format!("invalid search json {}", &function.arguments))?;
2247 queries.extend(input.queries);
2248 }
2249
2250 if let Some(debug_tx) = &debug_tx {
2251 debug_tx
2252 .unbounded_send(ZetaDebugInfo::SearchQueriesGenerated(
2253 ZetaSearchQueryDebugInfo {
2254 project: project.clone(),
2255 timestamp: Instant::now(),
2256 search_queries: queries.clone(),
2257 },
2258 ))
2259 .ok();
2260 }
2261
2262 log::trace!("Running retrieval search: {queries:#?}");
2263
2264 let related_excerpts_result = retrieval_search::run_retrieval_searches(
2265 queries,
2266 project.clone(),
2267 #[cfg(feature = "eval-support")]
2268 eval_cache,
2269 cx,
2270 )
2271 .await;
2272
2273 log::trace!("Search queries executed");
2274
2275 if let Some(debug_tx) = &debug_tx {
2276 debug_tx
2277 .unbounded_send(ZetaDebugInfo::SearchQueriesExecuted(
2278 ZetaContextRetrievalDebugInfo {
2279 project: project.clone(),
2280 timestamp: Instant::now(),
2281 },
2282 ))
2283 .ok();
2284 }
2285
2286 this.update(cx, |this, _cx| {
2287 let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
2288 return Ok(());
2289 };
2290 zeta_project.refresh_context_task.take();
2291 if let Some(debug_tx) = &this.debug_tx {
2292 debug_tx
2293 .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished(
2294 ZetaContextRetrievalDebugInfo {
2295 project,
2296 timestamp: Instant::now(),
2297 },
2298 ))
2299 .ok();
2300 }
2301 match related_excerpts_result {
2302 Ok(excerpts) => {
2303 zeta_project.context = Some(excerpts);
2304 Ok(())
2305 }
2306 Err(error) => Err(error),
2307 }
2308 })?
2309 })
2310 }
2311
2312 pub fn set_context(
2313 &mut self,
2314 project: Entity<Project>,
2315 context: HashMap<Entity<Buffer>, Vec<Range<Anchor>>>,
2316 ) {
2317 if let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) {
2318 zeta_project.context = Some(context);
2319 }
2320 }
2321
2322 fn gather_nearby_diagnostics(
2323 cursor_offset: usize,
2324 diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
2325 snapshot: &BufferSnapshot,
2326 max_diagnostics_bytes: usize,
2327 ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
2328 // TODO: Could make this more efficient
2329 let mut diagnostic_groups = Vec::new();
2330 for (language_server_id, diagnostics) in diagnostic_sets {
2331 let mut groups = Vec::new();
2332 diagnostics.groups(*language_server_id, &mut groups, &snapshot);
2333 diagnostic_groups.extend(
2334 groups
2335 .into_iter()
2336 .map(|(_, group)| group.resolve::<usize>(&snapshot)),
2337 );
2338 }
2339
2340 // sort by proximity to cursor
2341 diagnostic_groups.sort_by_key(|group| {
2342 let range = &group.entries[group.primary_ix].range;
2343 if range.start >= cursor_offset {
2344 range.start - cursor_offset
2345 } else if cursor_offset >= range.end {
2346 cursor_offset - range.end
2347 } else {
2348 (cursor_offset - range.start).min(range.end - cursor_offset)
2349 }
2350 });
2351
2352 let mut results = Vec::new();
2353 let mut diagnostic_groups_truncated = false;
2354 let mut diagnostics_byte_count = 0;
2355 for group in diagnostic_groups {
2356 let raw_value = serde_json::value::to_raw_value(&group).unwrap();
2357 diagnostics_byte_count += raw_value.get().len();
2358 if diagnostics_byte_count > max_diagnostics_bytes {
2359 diagnostic_groups_truncated = true;
2360 break;
2361 }
2362 results.push(predict_edits_v3::DiagnosticGroup(raw_value));
2363 }
2364
2365 (results, diagnostic_groups_truncated)
2366 }
2367
2368 // TODO: Dedupe with similar code in request_prediction?
2369 pub fn cloud_request_for_zeta_cli(
2370 &mut self,
2371 project: &Entity<Project>,
2372 buffer: &Entity<Buffer>,
2373 position: language::Anchor,
2374 cx: &mut Context<Self>,
2375 ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
2376 let project_state = self.projects.get(&project.entity_id());
2377
2378 let index_state = project_state.and_then(|state| {
2379 state
2380 .syntax_index
2381 .as_ref()
2382 .map(|index| index.read_with(cx, |index, _cx| index.state().clone()))
2383 });
2384 let options = self.options.clone();
2385 let snapshot = buffer.read(cx).snapshot();
2386 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
2387 return Task::ready(Err(anyhow!("No file path for excerpt")));
2388 };
2389 let worktree_snapshots = project
2390 .read(cx)
2391 .worktrees(cx)
2392 .map(|worktree| worktree.read(cx).snapshot())
2393 .collect::<Vec<_>>();
2394
2395 let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
2396 let mut path = f.worktree.read(cx).absolutize(&f.path);
2397 if path.pop() { Some(path) } else { None }
2398 });
2399
2400 cx.background_spawn(async move {
2401 let index_state = if let Some(index_state) = index_state {
2402 Some(index_state.lock_owned().await)
2403 } else {
2404 None
2405 };
2406
2407 let cursor_point = position.to_point(&snapshot);
2408
2409 let debug_info = true;
2410 EditPredictionContext::gather_context(
2411 cursor_point,
2412 &snapshot,
2413 parent_abs_path.as_deref(),
2414 match &options.context {
2415 ContextMode::Agentic(_) => {
2416 // TODO
2417 panic!("Llm mode not supported in zeta cli yet");
2418 }
2419 ContextMode::Syntax(edit_prediction_context_options) => {
2420 edit_prediction_context_options
2421 }
2422 },
2423 index_state.as_deref(),
2424 )
2425 .context("Failed to select excerpt")
2426 .map(|context| {
2427 make_syntax_context_cloud_request(
2428 excerpt_path.into(),
2429 context,
2430 // TODO pass everything
2431 Vec::new(),
2432 false,
2433 Vec::new(),
2434 false,
2435 None,
2436 debug_info,
2437 &worktree_snapshots,
2438 index_state.as_deref(),
2439 Some(options.max_prompt_bytes),
2440 options.prompt_format,
2441 PredictEditsRequestTrigger::Other,
2442 )
2443 })
2444 })
2445 }
2446
2447 pub fn wait_for_initial_indexing(
2448 &mut self,
2449 project: &Entity<Project>,
2450 cx: &mut Context<Self>,
2451 ) -> Task<Result<()>> {
2452 let zeta_project = self.get_or_init_zeta_project(project, cx);
2453 if let Some(syntax_index) = &zeta_project.syntax_index {
2454 syntax_index.read(cx).wait_for_initial_file_indexing(cx)
2455 } else {
2456 Task::ready(Ok(()))
2457 }
2458 }
2459
2460 fn is_file_open_source(
2461 &self,
2462 project: &Entity<Project>,
2463 file: &Arc<dyn File>,
2464 cx: &App,
2465 ) -> bool {
2466 if !file.is_local() || file.is_private() {
2467 return false;
2468 }
2469 let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
2470 return false;
2471 };
2472 zeta_project
2473 .license_detection_watchers
2474 .get(&file.worktree_id(cx))
2475 .as_ref()
2476 .is_some_and(|watcher| watcher.is_project_open_source())
2477 }
2478
2479 fn can_collect_file(&self, project: &Entity<Project>, file: &Arc<dyn File>, cx: &App) -> bool {
2480 self.data_collection_choice.is_enabled() && self.is_file_open_source(project, file, cx)
2481 }
2482
2483 fn can_collect_events(&self, events: &[Arc<Event>]) -> bool {
2484 if !self.data_collection_choice.is_enabled() {
2485 return false;
2486 }
2487 events.iter().all(|event| {
2488 matches!(
2489 event.as_ref(),
2490 Event::BufferChange {
2491 in_open_source_repo: true,
2492 ..
2493 }
2494 )
2495 })
2496 }
2497
2498 fn load_data_collection_choice() -> DataCollectionChoice {
2499 let choice = KEY_VALUE_STORE
2500 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2501 .log_err()
2502 .flatten();
2503
2504 match choice.as_deref() {
2505 Some("true") => DataCollectionChoice::Enabled,
2506 Some("false") => DataCollectionChoice::Disabled,
2507 Some(_) => {
2508 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2509 DataCollectionChoice::NotAnswered
2510 }
2511 None => DataCollectionChoice::NotAnswered,
2512 }
2513 }
2514
2515 pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2516 self.shown_predictions.iter()
2517 }
2518
2519 pub fn shown_completions_len(&self) -> usize {
2520 self.shown_predictions.len()
2521 }
2522
2523 pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2524 self.rated_predictions.contains(id)
2525 }
2526
2527 pub fn rate_prediction(
2528 &mut self,
2529 prediction: &EditPrediction,
2530 rating: EditPredictionRating,
2531 feedback: String,
2532 cx: &mut Context<Self>,
2533 ) {
2534 self.rated_predictions.insert(prediction.id.clone());
2535 telemetry::event!(
2536 "Edit Prediction Rated",
2537 rating,
2538 inputs = prediction.inputs,
2539 output = prediction.edit_preview.as_unified_diff(&prediction.edits),
2540 feedback
2541 );
2542 self.client.telemetry().flush_events().detach();
2543 cx.notify();
2544 }
2545}
2546
2547pub fn text_from_response(mut res: open_ai::Response) -> Option<String> {
2548 let choice = res.choices.pop()?;
2549 let output_text = match choice.message {
2550 open_ai::RequestMessage::Assistant {
2551 content: Some(open_ai::MessageContent::Plain(content)),
2552 ..
2553 } => content,
2554 open_ai::RequestMessage::Assistant {
2555 content: Some(open_ai::MessageContent::Multipart(mut content)),
2556 ..
2557 } => {
2558 if content.is_empty() {
2559 log::error!("No output from Baseten completion response");
2560 return None;
2561 }
2562
2563 match content.remove(0) {
2564 open_ai::MessagePart::Text { text } => text,
2565 open_ai::MessagePart::Image { .. } => {
2566 log::error!("Expected text, got an image");
2567 return None;
2568 }
2569 }
2570 }
2571 _ => {
2572 log::error!("Invalid response message: {:?}", choice.message);
2573 return None;
2574 }
2575 };
2576 Some(output_text)
2577}
2578
2579#[derive(Error, Debug)]
2580#[error(
2581 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2582)]
2583pub struct ZedUpdateRequiredError {
2584 minimum_version: Version,
2585}
2586
2587fn make_syntax_context_cloud_request(
2588 excerpt_path: Arc<Path>,
2589 context: EditPredictionContext,
2590 events: Vec<Arc<predict_edits_v3::Event>>,
2591 can_collect_data: bool,
2592 diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
2593 diagnostic_groups_truncated: bool,
2594 git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
2595 debug_info: bool,
2596 worktrees: &Vec<worktree::Snapshot>,
2597 index_state: Option<&SyntaxIndexState>,
2598 prompt_max_bytes: Option<usize>,
2599 prompt_format: PromptFormat,
2600 trigger: PredictEditsRequestTrigger,
2601) -> predict_edits_v3::PredictEditsRequest {
2602 let mut signatures = Vec::new();
2603 let mut declaration_to_signature_index = HashMap::default();
2604 let mut referenced_declarations = Vec::new();
2605
2606 for snippet in context.declarations {
2607 let project_entry_id = snippet.declaration.project_entry_id();
2608 let Some(path) = worktrees.iter().find_map(|worktree| {
2609 worktree.entry_for_id(project_entry_id).map(|entry| {
2610 let mut full_path = RelPathBuf::new();
2611 full_path.push(worktree.root_name());
2612 full_path.push(&entry.path);
2613 full_path
2614 })
2615 }) else {
2616 continue;
2617 };
2618
2619 let parent_index = index_state.and_then(|index_state| {
2620 snippet.declaration.parent().and_then(|parent| {
2621 add_signature(
2622 parent,
2623 &mut declaration_to_signature_index,
2624 &mut signatures,
2625 index_state,
2626 )
2627 })
2628 });
2629
2630 let (text, text_is_truncated) = snippet.declaration.item_text();
2631 referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
2632 path: path.as_std_path().into(),
2633 text: text.into(),
2634 range: snippet.declaration.item_line_range(),
2635 text_is_truncated,
2636 signature_range: snippet.declaration.signature_range_in_item_text(),
2637 parent_index,
2638 signature_score: snippet.score(DeclarationStyle::Signature),
2639 declaration_score: snippet.score(DeclarationStyle::Declaration),
2640 score_components: snippet.components,
2641 });
2642 }
2643
2644 let excerpt_parent = index_state.and_then(|index_state| {
2645 context
2646 .excerpt
2647 .parent_declarations
2648 .last()
2649 .and_then(|(parent, _)| {
2650 add_signature(
2651 *parent,
2652 &mut declaration_to_signature_index,
2653 &mut signatures,
2654 index_state,
2655 )
2656 })
2657 });
2658
2659 predict_edits_v3::PredictEditsRequest {
2660 excerpt_path,
2661 excerpt: context.excerpt_text.body,
2662 excerpt_line_range: context.excerpt.line_range,
2663 excerpt_range: context.excerpt.range,
2664 cursor_point: predict_edits_v3::Point {
2665 line: predict_edits_v3::Line(context.cursor_point.row),
2666 column: context.cursor_point.column,
2667 },
2668 referenced_declarations,
2669 included_files: vec![],
2670 signatures,
2671 excerpt_parent,
2672 events,
2673 can_collect_data,
2674 diagnostic_groups,
2675 diagnostic_groups_truncated,
2676 git_info,
2677 debug_info,
2678 prompt_max_bytes,
2679 prompt_format,
2680 trigger,
2681 }
2682}
2683
2684fn add_signature(
2685 declaration_id: DeclarationId,
2686 declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
2687 signatures: &mut Vec<Signature>,
2688 index: &SyntaxIndexState,
2689) -> Option<usize> {
2690 if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
2691 return Some(*signature_index);
2692 }
2693 let Some(parent_declaration) = index.declaration(declaration_id) else {
2694 log::error!("bug: missing parent declaration");
2695 return None;
2696 };
2697 let parent_index = parent_declaration.parent().and_then(|parent| {
2698 add_signature(parent, declaration_to_signature_index, signatures, index)
2699 });
2700 let (text, text_is_truncated) = parent_declaration.signature_text();
2701 let signature_index = signatures.len();
2702 signatures.push(Signature {
2703 text: text.into(),
2704 text_is_truncated,
2705 parent_index,
2706 range: parent_declaration.signature_line_range(),
2707 });
2708 declaration_to_signature_index.insert(declaration_id, signature_index);
2709 Some(signature_index)
2710}
2711
2712#[cfg(feature = "eval-support")]
2713pub type EvalCacheKey = (EvalCacheEntryKind, u64);
2714
2715#[cfg(feature = "eval-support")]
2716#[derive(Debug, Clone, Copy, PartialEq)]
2717pub enum EvalCacheEntryKind {
2718 Context,
2719 Search,
2720 Prediction,
2721}
2722
2723#[cfg(feature = "eval-support")]
2724impl std::fmt::Display for EvalCacheEntryKind {
2725 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2726 match self {
2727 EvalCacheEntryKind::Search => write!(f, "search"),
2728 EvalCacheEntryKind::Context => write!(f, "context"),
2729 EvalCacheEntryKind::Prediction => write!(f, "prediction"),
2730 }
2731 }
2732}
2733
2734#[cfg(feature = "eval-support")]
2735pub trait EvalCache: Send + Sync {
2736 fn read(&self, key: EvalCacheKey) -> Option<String>;
2737 fn write(&self, key: EvalCacheKey, input: &str, value: &str);
2738}
2739
2740#[derive(Debug, Clone, Copy)]
2741pub enum DataCollectionChoice {
2742 NotAnswered,
2743 Enabled,
2744 Disabled,
2745}
2746
2747impl DataCollectionChoice {
2748 pub fn is_enabled(self) -> bool {
2749 match self {
2750 Self::Enabled => true,
2751 Self::NotAnswered | Self::Disabled => false,
2752 }
2753 }
2754
2755 pub fn is_answered(self) -> bool {
2756 match self {
2757 Self::Enabled | Self::Disabled => true,
2758 Self::NotAnswered => false,
2759 }
2760 }
2761
2762 #[must_use]
2763 pub fn toggle(&self) -> DataCollectionChoice {
2764 match self {
2765 Self::Enabled => Self::Disabled,
2766 Self::Disabled => Self::Enabled,
2767 Self::NotAnswered => Self::Enabled,
2768 }
2769 }
2770}
2771
2772impl From<bool> for DataCollectionChoice {
2773 fn from(value: bool) -> Self {
2774 match value {
2775 true => DataCollectionChoice::Enabled,
2776 false => DataCollectionChoice::Disabled,
2777 }
2778 }
2779}
2780
2781struct ZedPredictUpsell;
2782
2783impl Dismissable for ZedPredictUpsell {
2784 const KEY: &'static str = "dismissed-edit-predict-upsell";
2785
2786 fn dismissed() -> bool {
2787 // To make this backwards compatible with older versions of Zed, we
2788 // check if the user has seen the previous Edit Prediction Onboarding
2789 // before, by checking the data collection choice which was written to
2790 // the database once the user clicked on "Accept and Enable"
2791 if KEY_VALUE_STORE
2792 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2793 .log_err()
2794 .is_some_and(|s| s.is_some())
2795 {
2796 return true;
2797 }
2798
2799 KEY_VALUE_STORE
2800 .read_kvp(Self::KEY)
2801 .log_err()
2802 .is_some_and(|s| s.is_some())
2803 }
2804}
2805
2806pub fn should_show_upsell_modal() -> bool {
2807 !ZedPredictUpsell::dismissed()
2808}
2809
2810pub fn init(cx: &mut App) {
2811 feature_gate_predict_edits_actions(cx);
2812
2813 cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2814 workspace.register_action(|workspace, _: &RateCompletions, window, cx| {
2815 if cx.has_flag::<PredictEditsRateCompletionsFeatureFlag>() {
2816 RatePredictionsModal::toggle(workspace, window, cx);
2817 }
2818 });
2819
2820 workspace.register_action(
2821 move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2822 ZedPredictModal::toggle(
2823 workspace,
2824 workspace.user_store().clone(),
2825 workspace.client().clone(),
2826 window,
2827 cx,
2828 )
2829 },
2830 );
2831
2832 workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2833 update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2834 settings
2835 .project
2836 .all_languages
2837 .features
2838 .get_or_insert_default()
2839 .edit_prediction_provider = Some(EditPredictionProvider::None)
2840 });
2841 });
2842 })
2843 .detach();
2844}
2845
2846fn feature_gate_predict_edits_actions(cx: &mut App) {
2847 let rate_completion_action_types = [TypeId::of::<RateCompletions>()];
2848 let reset_onboarding_action_types = [TypeId::of::<ResetOnboarding>()];
2849 let zeta_all_action_types = [
2850 TypeId::of::<RateCompletions>(),
2851 TypeId::of::<ResetOnboarding>(),
2852 zed_actions::OpenZedPredictOnboarding.type_id(),
2853 TypeId::of::<ClearHistory>(),
2854 TypeId::of::<ThumbsUpActivePrediction>(),
2855 TypeId::of::<ThumbsDownActivePrediction>(),
2856 TypeId::of::<NextEdit>(),
2857 TypeId::of::<PreviousEdit>(),
2858 ];
2859
2860 CommandPaletteFilter::update_global(cx, |filter, _cx| {
2861 filter.hide_action_types(&rate_completion_action_types);
2862 filter.hide_action_types(&reset_onboarding_action_types);
2863 filter.hide_action_types(&[zed_actions::OpenZedPredictOnboarding.type_id()]);
2864 });
2865
2866 cx.observe_global::<SettingsStore>(move |cx| {
2867 let is_ai_disabled = DisableAiSettings::get_global(cx).disable_ai;
2868 let has_feature_flag = cx.has_flag::<PredictEditsRateCompletionsFeatureFlag>();
2869
2870 CommandPaletteFilter::update_global(cx, |filter, _cx| {
2871 if is_ai_disabled {
2872 filter.hide_action_types(&zeta_all_action_types);
2873 } else if has_feature_flag {
2874 filter.show_action_types(&rate_completion_action_types);
2875 } else {
2876 filter.hide_action_types(&rate_completion_action_types);
2877 }
2878 });
2879 })
2880 .detach();
2881
2882 cx.observe_flag::<PredictEditsRateCompletionsFeatureFlag, _>(move |is_enabled, cx| {
2883 if !DisableAiSettings::get_global(cx).disable_ai {
2884 if is_enabled {
2885 CommandPaletteFilter::update_global(cx, |filter, _cx| {
2886 filter.show_action_types(&rate_completion_action_types);
2887 });
2888 } else {
2889 CommandPaletteFilter::update_global(cx, |filter, _cx| {
2890 filter.hide_action_types(&rate_completion_action_types);
2891 });
2892 }
2893 }
2894 })
2895 .detach();
2896}
2897
2898#[cfg(test)]
2899mod tests {
2900 use std::{path::Path, sync::Arc};
2901
2902 use client::UserStore;
2903 use clock::FakeSystemClock;
2904 use cloud_llm_client::{
2905 EditPredictionRejectReason, EditPredictionRejection, RejectEditPredictionsBody,
2906 };
2907 use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
2908 use futures::{
2909 AsyncReadExt, StreamExt,
2910 channel::{mpsc, oneshot},
2911 };
2912 use gpui::{
2913 Entity, TestAppContext,
2914 http_client::{FakeHttpClient, Response},
2915 prelude::*,
2916 };
2917 use indoc::indoc;
2918 use language::OffsetRangeExt as _;
2919 use open_ai::Usage;
2920 use pretty_assertions::{assert_eq, assert_matches};
2921 use project::{FakeFs, Project};
2922 use serde_json::json;
2923 use settings::SettingsStore;
2924 use util::path;
2925 use uuid::Uuid;
2926
2927 use crate::{BufferEditPrediction, Zeta};
2928
2929 #[gpui::test]
2930 async fn test_current_state(cx: &mut TestAppContext) {
2931 let (zeta, mut requests) = init_test(cx);
2932 let fs = FakeFs::new(cx.executor());
2933 fs.insert_tree(
2934 "/root",
2935 json!({
2936 "1.txt": "Hello!\nHow\nBye\n",
2937 "2.txt": "Hola!\nComo\nAdios\n"
2938 }),
2939 )
2940 .await;
2941 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2942
2943 zeta.update(cx, |zeta, cx| {
2944 zeta.register_project(&project, cx);
2945 });
2946
2947 let buffer1 = project
2948 .update(cx, |project, cx| {
2949 let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
2950 project.open_buffer(path, cx)
2951 })
2952 .await
2953 .unwrap();
2954 let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
2955 let position = snapshot1.anchor_before(language::Point::new(1, 3));
2956
2957 // Prediction for current file
2958
2959 zeta.update(cx, |zeta, cx| {
2960 zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
2961 });
2962 let (_request, respond_tx) = requests.predict.next().await.unwrap();
2963
2964 respond_tx
2965 .send(model_response(indoc! {r"
2966 --- a/root/1.txt
2967 +++ b/root/1.txt
2968 @@ ... @@
2969 Hello!
2970 -How
2971 +How are you?
2972 Bye
2973 "}))
2974 .unwrap();
2975
2976 cx.run_until_parked();
2977
2978 zeta.read_with(cx, |zeta, cx| {
2979 let prediction = zeta
2980 .current_prediction_for_buffer(&buffer1, &project, cx)
2981 .unwrap();
2982 assert_matches!(prediction, BufferEditPrediction::Local { .. });
2983 });
2984
2985 // Context refresh
2986 let refresh_task = zeta.update(cx, |zeta, cx| {
2987 zeta.refresh_context(project.clone(), buffer1.clone(), position, cx)
2988 });
2989 let (_request, respond_tx) = requests.predict.next().await.unwrap();
2990 respond_tx
2991 .send(open_ai::Response {
2992 id: Uuid::new_v4().to_string(),
2993 object: "response".into(),
2994 created: 0,
2995 model: "model".into(),
2996 choices: vec![open_ai::Choice {
2997 index: 0,
2998 message: open_ai::RequestMessage::Assistant {
2999 content: None,
3000 tool_calls: vec![open_ai::ToolCall {
3001 id: "search".into(),
3002 content: open_ai::ToolCallContent::Function {
3003 function: open_ai::FunctionContent {
3004 name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME
3005 .to_string(),
3006 arguments: serde_json::to_string(&SearchToolInput {
3007 queries: Box::new([SearchToolQuery {
3008 glob: "root/2.txt".to_string(),
3009 syntax_node: vec![],
3010 content: Some(".".into()),
3011 }]),
3012 })
3013 .unwrap(),
3014 },
3015 },
3016 }],
3017 },
3018 finish_reason: None,
3019 }],
3020 usage: Usage {
3021 prompt_tokens: 0,
3022 completion_tokens: 0,
3023 total_tokens: 0,
3024 },
3025 })
3026 .unwrap();
3027 refresh_task.await.unwrap();
3028
3029 zeta.update(cx, |zeta, cx| {
3030 zeta.reject_current_prediction(EditPredictionRejectReason::Discarded, &project, cx);
3031 });
3032
3033 // Prediction for another file
3034 zeta.update(cx, |zeta, cx| {
3035 zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
3036 });
3037 let (_request, respond_tx) = requests.predict.next().await.unwrap();
3038 respond_tx
3039 .send(model_response(indoc! {r#"
3040 --- a/root/2.txt
3041 +++ b/root/2.txt
3042 Hola!
3043 -Como
3044 +Como estas?
3045 Adios
3046 "#}))
3047 .unwrap();
3048 cx.run_until_parked();
3049
3050 zeta.read_with(cx, |zeta, cx| {
3051 let prediction = zeta
3052 .current_prediction_for_buffer(&buffer1, &project, cx)
3053 .unwrap();
3054 assert_matches!(
3055 prediction,
3056 BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
3057 );
3058 });
3059
3060 let buffer2 = project
3061 .update(cx, |project, cx| {
3062 let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
3063 project.open_buffer(path, cx)
3064 })
3065 .await
3066 .unwrap();
3067
3068 zeta.read_with(cx, |zeta, cx| {
3069 let prediction = zeta
3070 .current_prediction_for_buffer(&buffer2, &project, cx)
3071 .unwrap();
3072 assert_matches!(prediction, BufferEditPrediction::Local { .. });
3073 });
3074 }
3075
3076 #[gpui::test]
3077 async fn test_simple_request(cx: &mut TestAppContext) {
3078 let (zeta, mut requests) = init_test(cx);
3079 let fs = FakeFs::new(cx.executor());
3080 fs.insert_tree(
3081 "/root",
3082 json!({
3083 "foo.md": "Hello!\nHow\nBye\n"
3084 }),
3085 )
3086 .await;
3087 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3088
3089 let buffer = project
3090 .update(cx, |project, cx| {
3091 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3092 project.open_buffer(path, cx)
3093 })
3094 .await
3095 .unwrap();
3096 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3097 let position = snapshot.anchor_before(language::Point::new(1, 3));
3098
3099 let prediction_task = zeta.update(cx, |zeta, cx| {
3100 zeta.request_prediction(&project, &buffer, position, Default::default(), cx)
3101 });
3102
3103 let (_, respond_tx) = requests.predict.next().await.unwrap();
3104
3105 // TODO Put back when we have a structured request again
3106 // assert_eq!(
3107 // request.excerpt_path.as_ref(),
3108 // Path::new(path!("root/foo.md"))
3109 // );
3110 // assert_eq!(
3111 // request.cursor_point,
3112 // Point {
3113 // line: Line(1),
3114 // column: 3
3115 // }
3116 // );
3117
3118 respond_tx
3119 .send(model_response(indoc! { r"
3120 --- a/root/foo.md
3121 +++ b/root/foo.md
3122 @@ ... @@
3123 Hello!
3124 -How
3125 +How are you?
3126 Bye
3127 "}))
3128 .unwrap();
3129
3130 let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
3131
3132 assert_eq!(prediction.edits.len(), 1);
3133 assert_eq!(
3134 prediction.edits[0].0.to_point(&snapshot).start,
3135 language::Point::new(1, 3)
3136 );
3137 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
3138 }
3139
3140 #[gpui::test]
3141 async fn test_request_events(cx: &mut TestAppContext) {
3142 let (zeta, mut requests) = init_test(cx);
3143 let fs = FakeFs::new(cx.executor());
3144 fs.insert_tree(
3145 "/root",
3146 json!({
3147 "foo.md": "Hello!\n\nBye\n"
3148 }),
3149 )
3150 .await;
3151 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3152
3153 let buffer = project
3154 .update(cx, |project, cx| {
3155 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3156 project.open_buffer(path, cx)
3157 })
3158 .await
3159 .unwrap();
3160
3161 zeta.update(cx, |zeta, cx| {
3162 zeta.register_buffer(&buffer, &project, cx);
3163 });
3164
3165 buffer.update(cx, |buffer, cx| {
3166 buffer.edit(vec![(7..7, "How")], None, cx);
3167 });
3168
3169 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3170 let position = snapshot.anchor_before(language::Point::new(1, 3));
3171
3172 let prediction_task = zeta.update(cx, |zeta, cx| {
3173 zeta.request_prediction(&project, &buffer, position, Default::default(), cx)
3174 });
3175
3176 let (request, respond_tx) = requests.predict.next().await.unwrap();
3177
3178 let prompt = prompt_from_request(&request);
3179 assert!(
3180 prompt.contains(indoc! {"
3181 --- a/root/foo.md
3182 +++ b/root/foo.md
3183 @@ -1,3 +1,3 @@
3184 Hello!
3185 -
3186 +How
3187 Bye
3188 "}),
3189 "{prompt}"
3190 );
3191
3192 respond_tx
3193 .send(model_response(indoc! {r#"
3194 --- a/root/foo.md
3195 +++ b/root/foo.md
3196 @@ ... @@
3197 Hello!
3198 -How
3199 +How are you?
3200 Bye
3201 "#}))
3202 .unwrap();
3203
3204 let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
3205
3206 assert_eq!(prediction.edits.len(), 1);
3207 assert_eq!(
3208 prediction.edits[0].0.to_point(&snapshot).start,
3209 language::Point::new(1, 3)
3210 );
3211 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
3212 }
3213
3214 #[gpui::test]
3215 async fn test_empty_prediction(cx: &mut TestAppContext) {
3216 let (zeta, mut requests) = init_test(cx);
3217 let fs = FakeFs::new(cx.executor());
3218 fs.insert_tree(
3219 "/root",
3220 json!({
3221 "foo.md": "Hello!\nHow\nBye\n"
3222 }),
3223 )
3224 .await;
3225 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3226
3227 let buffer = project
3228 .update(cx, |project, cx| {
3229 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3230 project.open_buffer(path, cx)
3231 })
3232 .await
3233 .unwrap();
3234 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3235 let position = snapshot.anchor_before(language::Point::new(1, 3));
3236
3237 zeta.update(cx, |zeta, cx| {
3238 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3239 });
3240
3241 const NO_OP_DIFF: &str = indoc! { r"
3242 --- a/root/foo.md
3243 +++ b/root/foo.md
3244 @@ ... @@
3245 Hello!
3246 -How
3247 +How
3248 Bye
3249 "};
3250
3251 let (_, respond_tx) = requests.predict.next().await.unwrap();
3252 let response = model_response(NO_OP_DIFF);
3253 let id = response.id.clone();
3254 respond_tx.send(response).unwrap();
3255
3256 cx.run_until_parked();
3257
3258 zeta.read_with(cx, |zeta, cx| {
3259 assert!(
3260 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3261 .is_none()
3262 );
3263 });
3264
3265 // prediction is reported as rejected
3266 let (reject_request, _) = requests.reject.next().await.unwrap();
3267
3268 assert_eq!(
3269 &reject_request.rejections,
3270 &[EditPredictionRejection {
3271 request_id: id,
3272 reason: EditPredictionRejectReason::Empty,
3273 was_shown: false
3274 }]
3275 );
3276 }
3277
3278 #[gpui::test]
3279 async fn test_interpolated_empty(cx: &mut TestAppContext) {
3280 let (zeta, mut requests) = init_test(cx);
3281 let fs = FakeFs::new(cx.executor());
3282 fs.insert_tree(
3283 "/root",
3284 json!({
3285 "foo.md": "Hello!\nHow\nBye\n"
3286 }),
3287 )
3288 .await;
3289 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3290
3291 let buffer = project
3292 .update(cx, |project, cx| {
3293 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3294 project.open_buffer(path, cx)
3295 })
3296 .await
3297 .unwrap();
3298 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3299 let position = snapshot.anchor_before(language::Point::new(1, 3));
3300
3301 zeta.update(cx, |zeta, cx| {
3302 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3303 });
3304
3305 let (_, respond_tx) = requests.predict.next().await.unwrap();
3306
3307 buffer.update(cx, |buffer, cx| {
3308 buffer.set_text("Hello!\nHow are you?\nBye", cx);
3309 });
3310
3311 let response = model_response(SIMPLE_DIFF);
3312 let id = response.id.clone();
3313 respond_tx.send(response).unwrap();
3314
3315 cx.run_until_parked();
3316
3317 zeta.read_with(cx, |zeta, cx| {
3318 assert!(
3319 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3320 .is_none()
3321 );
3322 });
3323
3324 // prediction is reported as rejected
3325 let (reject_request, _) = requests.reject.next().await.unwrap();
3326
3327 assert_eq!(
3328 &reject_request.rejections,
3329 &[EditPredictionRejection {
3330 request_id: id,
3331 reason: EditPredictionRejectReason::InterpolatedEmpty,
3332 was_shown: false
3333 }]
3334 );
3335 }
3336
3337 const SIMPLE_DIFF: &str = indoc! { r"
3338 --- a/root/foo.md
3339 +++ b/root/foo.md
3340 @@ ... @@
3341 Hello!
3342 -How
3343 +How are you?
3344 Bye
3345 "};
3346
3347 #[gpui::test]
3348 async fn test_replace_current(cx: &mut TestAppContext) {
3349 let (zeta, mut requests) = init_test(cx);
3350 let fs = FakeFs::new(cx.executor());
3351 fs.insert_tree(
3352 "/root",
3353 json!({
3354 "foo.md": "Hello!\nHow\nBye\n"
3355 }),
3356 )
3357 .await;
3358 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3359
3360 let buffer = project
3361 .update(cx, |project, cx| {
3362 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3363 project.open_buffer(path, cx)
3364 })
3365 .await
3366 .unwrap();
3367 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3368 let position = snapshot.anchor_before(language::Point::new(1, 3));
3369
3370 zeta.update(cx, |zeta, cx| {
3371 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3372 });
3373
3374 let (_, respond_tx) = requests.predict.next().await.unwrap();
3375 let first_response = model_response(SIMPLE_DIFF);
3376 let first_id = first_response.id.clone();
3377 respond_tx.send(first_response).unwrap();
3378
3379 cx.run_until_parked();
3380
3381 zeta.read_with(cx, |zeta, cx| {
3382 assert_eq!(
3383 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3384 .unwrap()
3385 .id
3386 .0,
3387 first_id
3388 );
3389 });
3390
3391 // a second request is triggered
3392 zeta.update(cx, |zeta, cx| {
3393 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3394 });
3395
3396 let (_, respond_tx) = requests.predict.next().await.unwrap();
3397 let second_response = model_response(SIMPLE_DIFF);
3398 let second_id = second_response.id.clone();
3399 respond_tx.send(second_response).unwrap();
3400
3401 cx.run_until_parked();
3402
3403 zeta.read_with(cx, |zeta, cx| {
3404 // second replaces first
3405 assert_eq!(
3406 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3407 .unwrap()
3408 .id
3409 .0,
3410 second_id
3411 );
3412 });
3413
3414 // first is reported as replaced
3415 let (reject_request, _) = requests.reject.next().await.unwrap();
3416
3417 assert_eq!(
3418 &reject_request.rejections,
3419 &[EditPredictionRejection {
3420 request_id: first_id,
3421 reason: EditPredictionRejectReason::Replaced,
3422 was_shown: false
3423 }]
3424 );
3425 }
3426
3427 #[gpui::test]
3428 async fn test_current_preferred(cx: &mut TestAppContext) {
3429 let (zeta, mut requests) = init_test(cx);
3430 let fs = FakeFs::new(cx.executor());
3431 fs.insert_tree(
3432 "/root",
3433 json!({
3434 "foo.md": "Hello!\nHow\nBye\n"
3435 }),
3436 )
3437 .await;
3438 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3439
3440 let buffer = project
3441 .update(cx, |project, cx| {
3442 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3443 project.open_buffer(path, cx)
3444 })
3445 .await
3446 .unwrap();
3447 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3448 let position = snapshot.anchor_before(language::Point::new(1, 3));
3449
3450 zeta.update(cx, |zeta, cx| {
3451 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3452 });
3453
3454 let (_, respond_tx) = requests.predict.next().await.unwrap();
3455 let first_response = model_response(SIMPLE_DIFF);
3456 let first_id = first_response.id.clone();
3457 respond_tx.send(first_response).unwrap();
3458
3459 cx.run_until_parked();
3460
3461 zeta.read_with(cx, |zeta, cx| {
3462 assert_eq!(
3463 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3464 .unwrap()
3465 .id
3466 .0,
3467 first_id
3468 );
3469 });
3470
3471 // a second request is triggered
3472 zeta.update(cx, |zeta, cx| {
3473 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3474 });
3475
3476 let (_, respond_tx) = requests.predict.next().await.unwrap();
3477 // worse than current prediction
3478 let second_response = model_response(indoc! { r"
3479 --- a/root/foo.md
3480 +++ b/root/foo.md
3481 @@ ... @@
3482 Hello!
3483 -How
3484 +How are
3485 Bye
3486 "});
3487 let second_id = second_response.id.clone();
3488 respond_tx.send(second_response).unwrap();
3489
3490 cx.run_until_parked();
3491
3492 zeta.read_with(cx, |zeta, cx| {
3493 // first is preferred over second
3494 assert_eq!(
3495 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3496 .unwrap()
3497 .id
3498 .0,
3499 first_id
3500 );
3501 });
3502
3503 // second is reported as rejected
3504 let (reject_request, _) = requests.reject.next().await.unwrap();
3505
3506 assert_eq!(
3507 &reject_request.rejections,
3508 &[EditPredictionRejection {
3509 request_id: second_id,
3510 reason: EditPredictionRejectReason::CurrentPreferred,
3511 was_shown: false
3512 }]
3513 );
3514 }
3515
3516 #[gpui::test]
3517 async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
3518 let (zeta, mut requests) = init_test(cx);
3519 let fs = FakeFs::new(cx.executor());
3520 fs.insert_tree(
3521 "/root",
3522 json!({
3523 "foo.md": "Hello!\nHow\nBye\n"
3524 }),
3525 )
3526 .await;
3527 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3528
3529 let buffer = project
3530 .update(cx, |project, cx| {
3531 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3532 project.open_buffer(path, cx)
3533 })
3534 .await
3535 .unwrap();
3536 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3537 let position = snapshot.anchor_before(language::Point::new(1, 3));
3538
3539 zeta.update(cx, |zeta, cx| {
3540 // start two refresh tasks
3541 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3542
3543 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3544 });
3545
3546 let (_, respond_first) = requests.predict.next().await.unwrap();
3547 let (_, respond_second) = requests.predict.next().await.unwrap();
3548
3549 // wait for throttle
3550 cx.run_until_parked();
3551
3552 // second responds first
3553 let second_response = model_response(SIMPLE_DIFF);
3554 let second_id = second_response.id.clone();
3555 respond_second.send(second_response).unwrap();
3556
3557 cx.run_until_parked();
3558
3559 zeta.read_with(cx, |zeta, cx| {
3560 // current prediction is second
3561 assert_eq!(
3562 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3563 .unwrap()
3564 .id
3565 .0,
3566 second_id
3567 );
3568 });
3569
3570 let first_response = model_response(SIMPLE_DIFF);
3571 let first_id = first_response.id.clone();
3572 respond_first.send(first_response).unwrap();
3573
3574 cx.run_until_parked();
3575
3576 zeta.read_with(cx, |zeta, cx| {
3577 // current prediction is still second, since first was cancelled
3578 assert_eq!(
3579 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3580 .unwrap()
3581 .id
3582 .0,
3583 second_id
3584 );
3585 });
3586
3587 // first is reported as rejected
3588 let (reject_request, _) = requests.reject.next().await.unwrap();
3589
3590 cx.run_until_parked();
3591
3592 assert_eq!(
3593 &reject_request.rejections,
3594 &[EditPredictionRejection {
3595 request_id: first_id,
3596 reason: EditPredictionRejectReason::Canceled,
3597 was_shown: false
3598 }]
3599 );
3600 }
3601
3602 #[gpui::test]
3603 async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
3604 let (zeta, mut requests) = init_test(cx);
3605 let fs = FakeFs::new(cx.executor());
3606 fs.insert_tree(
3607 "/root",
3608 json!({
3609 "foo.md": "Hello!\nHow\nBye\n"
3610 }),
3611 )
3612 .await;
3613 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3614
3615 let buffer = project
3616 .update(cx, |project, cx| {
3617 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3618 project.open_buffer(path, cx)
3619 })
3620 .await
3621 .unwrap();
3622 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3623 let position = snapshot.anchor_before(language::Point::new(1, 3));
3624
3625 zeta.update(cx, |zeta, cx| {
3626 // start two refresh tasks
3627 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3628 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3629 });
3630
3631 // wait for throttle, so requests are sent
3632 cx.run_until_parked();
3633
3634 let (_, respond_first) = requests.predict.next().await.unwrap();
3635 let (_, respond_second) = requests.predict.next().await.unwrap();
3636
3637 zeta.update(cx, |zeta, cx| {
3638 // start a third request
3639 zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
3640
3641 // 2 are pending, so 2nd is cancelled
3642 assert_eq!(
3643 zeta.get_or_init_zeta_project(&project, cx)
3644 .cancelled_predictions
3645 .iter()
3646 .copied()
3647 .collect::<Vec<_>>(),
3648 [1]
3649 );
3650 });
3651
3652 // wait for throttle
3653 cx.run_until_parked();
3654
3655 let (_, respond_third) = requests.predict.next().await.unwrap();
3656
3657 let first_response = model_response(SIMPLE_DIFF);
3658 let first_id = first_response.id.clone();
3659 respond_first.send(first_response).unwrap();
3660
3661 cx.run_until_parked();
3662
3663 zeta.read_with(cx, |zeta, cx| {
3664 // current prediction is first
3665 assert_eq!(
3666 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3667 .unwrap()
3668 .id
3669 .0,
3670 first_id
3671 );
3672 });
3673
3674 let cancelled_response = model_response(SIMPLE_DIFF);
3675 let cancelled_id = cancelled_response.id.clone();
3676 respond_second.send(cancelled_response).unwrap();
3677
3678 cx.run_until_parked();
3679
3680 zeta.read_with(cx, |zeta, cx| {
3681 // current prediction is still first, since second was cancelled
3682 assert_eq!(
3683 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3684 .unwrap()
3685 .id
3686 .0,
3687 first_id
3688 );
3689 });
3690
3691 let third_response = model_response(SIMPLE_DIFF);
3692 let third_response_id = third_response.id.clone();
3693 respond_third.send(third_response).unwrap();
3694
3695 cx.run_until_parked();
3696
3697 zeta.read_with(cx, |zeta, cx| {
3698 // third completes and replaces first
3699 assert_eq!(
3700 zeta.current_prediction_for_buffer(&buffer, &project, cx)
3701 .unwrap()
3702 .id
3703 .0,
3704 third_response_id
3705 );
3706 });
3707
3708 // second is reported as rejected
3709 let (reject_request, _) = requests.reject.next().await.unwrap();
3710
3711 cx.run_until_parked();
3712
3713 assert_eq!(
3714 &reject_request.rejections,
3715 &[
3716 EditPredictionRejection {
3717 request_id: cancelled_id,
3718 reason: EditPredictionRejectReason::Canceled,
3719 was_shown: false
3720 },
3721 EditPredictionRejection {
3722 request_id: first_id,
3723 reason: EditPredictionRejectReason::Replaced,
3724 was_shown: false
3725 }
3726 ]
3727 );
3728 }
3729
3730 // Skipped until we start including diagnostics in prompt
3731 // #[gpui::test]
3732 // async fn test_request_diagnostics(cx: &mut TestAppContext) {
3733 // let (zeta, mut req_rx) = init_test(cx);
3734 // let fs = FakeFs::new(cx.executor());
3735 // fs.insert_tree(
3736 // "/root",
3737 // json!({
3738 // "foo.md": "Hello!\nBye"
3739 // }),
3740 // )
3741 // .await;
3742 // let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3743
3744 // let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
3745 // let diagnostic = lsp::Diagnostic {
3746 // range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
3747 // severity: Some(lsp::DiagnosticSeverity::ERROR),
3748 // message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
3749 // ..Default::default()
3750 // };
3751
3752 // project.update(cx, |project, cx| {
3753 // project.lsp_store().update(cx, |lsp_store, cx| {
3754 // // Create some diagnostics
3755 // lsp_store
3756 // .update_diagnostics(
3757 // LanguageServerId(0),
3758 // lsp::PublishDiagnosticsParams {
3759 // uri: path_to_buffer_uri.clone(),
3760 // diagnostics: vec![diagnostic],
3761 // version: None,
3762 // },
3763 // None,
3764 // language::DiagnosticSourceKind::Pushed,
3765 // &[],
3766 // cx,
3767 // )
3768 // .unwrap();
3769 // });
3770 // });
3771
3772 // let buffer = project
3773 // .update(cx, |project, cx| {
3774 // let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3775 // project.open_buffer(path, cx)
3776 // })
3777 // .await
3778 // .unwrap();
3779
3780 // let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3781 // let position = snapshot.anchor_before(language::Point::new(0, 0));
3782
3783 // let _prediction_task = zeta.update(cx, |zeta, cx| {
3784 // zeta.request_prediction(&project, &buffer, position, cx)
3785 // });
3786
3787 // let (request, _respond_tx) = req_rx.next().await.unwrap();
3788
3789 // assert_eq!(request.diagnostic_groups.len(), 1);
3790 // let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
3791 // .unwrap();
3792 // // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
3793 // assert_eq!(
3794 // value,
3795 // json!({
3796 // "entries": [{
3797 // "range": {
3798 // "start": 8,
3799 // "end": 10
3800 // },
3801 // "diagnostic": {
3802 // "source": null,
3803 // "code": null,
3804 // "code_description": null,
3805 // "severity": 1,
3806 // "message": "\"Hello\" deprecated. Use \"Hi\" instead",
3807 // "markdown": null,
3808 // "group_id": 0,
3809 // "is_primary": true,
3810 // "is_disk_based": false,
3811 // "is_unnecessary": false,
3812 // "source_kind": "Pushed",
3813 // "data": null,
3814 // "underline": true
3815 // }
3816 // }],
3817 // "primary_ix": 0
3818 // })
3819 // );
3820 // }
3821
3822 fn model_response(text: &str) -> open_ai::Response {
3823 open_ai::Response {
3824 id: Uuid::new_v4().to_string(),
3825 object: "response".into(),
3826 created: 0,
3827 model: "model".into(),
3828 choices: vec![open_ai::Choice {
3829 index: 0,
3830 message: open_ai::RequestMessage::Assistant {
3831 content: Some(open_ai::MessageContent::Plain(text.to_string())),
3832 tool_calls: vec![],
3833 },
3834 finish_reason: None,
3835 }],
3836 usage: Usage {
3837 prompt_tokens: 0,
3838 completion_tokens: 0,
3839 total_tokens: 0,
3840 },
3841 }
3842 }
3843
3844 fn prompt_from_request(request: &open_ai::Request) -> &str {
3845 assert_eq!(request.messages.len(), 1);
3846 let open_ai::RequestMessage::User {
3847 content: open_ai::MessageContent::Plain(content),
3848 ..
3849 } = &request.messages[0]
3850 else {
3851 panic!(
3852 "Request does not have single user message of type Plain. {:#?}",
3853 request
3854 );
3855 };
3856 content
3857 }
3858
3859 struct RequestChannels {
3860 predict: mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender<open_ai::Response>)>,
3861 reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>,
3862 }
3863
3864 fn init_test(cx: &mut TestAppContext) -> (Entity<Zeta>, RequestChannels) {
3865 cx.update(move |cx| {
3866 let settings_store = SettingsStore::test(cx);
3867 cx.set_global(settings_store);
3868 zlog::init_test();
3869
3870 let (predict_req_tx, predict_req_rx) = mpsc::unbounded();
3871 let (reject_req_tx, reject_req_rx) = mpsc::unbounded();
3872
3873 let http_client = FakeHttpClient::create({
3874 move |req| {
3875 let uri = req.uri().path().to_string();
3876 let mut body = req.into_body();
3877 let predict_req_tx = predict_req_tx.clone();
3878 let reject_req_tx = reject_req_tx.clone();
3879 async move {
3880 let resp = match uri.as_str() {
3881 "/client/llm_tokens" => serde_json::to_string(&json!({
3882 "token": "test"
3883 }))
3884 .unwrap(),
3885 "/predict_edits/raw" => {
3886 let mut buf = Vec::new();
3887 body.read_to_end(&mut buf).await.ok();
3888 let req = serde_json::from_slice(&buf).unwrap();
3889
3890 let (res_tx, res_rx) = oneshot::channel();
3891 predict_req_tx.unbounded_send((req, res_tx)).unwrap();
3892 serde_json::to_string(&res_rx.await?).unwrap()
3893 }
3894 "/predict_edits/reject" => {
3895 let mut buf = Vec::new();
3896 body.read_to_end(&mut buf).await.ok();
3897 let req = serde_json::from_slice(&buf).unwrap();
3898
3899 let (res_tx, res_rx) = oneshot::channel();
3900 reject_req_tx.unbounded_send((req, res_tx)).unwrap();
3901 serde_json::to_string(&res_rx.await?).unwrap()
3902 }
3903 _ => {
3904 panic!("Unexpected path: {}", uri)
3905 }
3906 };
3907
3908 Ok(Response::builder().body(resp.into()).unwrap())
3909 }
3910 }
3911 });
3912
3913 let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
3914 client.cloud_client().set_credentials(1, "test".into());
3915
3916 language_model::init(client.clone(), cx);
3917
3918 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
3919 let zeta = Zeta::global(&client, &user_store, cx);
3920
3921 (
3922 zeta,
3923 RequestChannels {
3924 predict: predict_req_rx,
3925 reject: reject_req_rx,
3926 },
3927 )
3928 })
3929 }
3930}