1use anyhow::{Context as _, Result, anyhow, bail};
2use arrayvec::ArrayVec;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat, Signature};
5use cloud_llm_client::{
6 AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejection,
7 MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
8 RejectEditPredictionsBody, ZED_VERSION_HEADER_NAME,
9};
10use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
11use cloud_zeta2_prompt::{CURSOR_MARKER, DEFAULT_MAX_PROMPT_BYTES};
12use collections::{HashMap, HashSet};
13use command_palette_hooks::CommandPaletteFilter;
14use db::kvp::{Dismissable, KEY_VALUE_STORE};
15use edit_prediction_context::{
16 DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
17 EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionScoreOptions, Line,
18 SyntaxIndex, SyntaxIndexState,
19};
20use feature_flags::{FeatureFlag, FeatureFlagAppExt as _, PredictEditsRateCompletionsFeatureFlag};
21use futures::channel::{mpsc, oneshot};
22use futures::{AsyncReadExt as _, StreamExt as _};
23use gpui::{
24 App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
25 http_client::{self, AsyncBody, Method},
26 prelude::*,
27};
28use language::{
29 Anchor, Buffer, DiagnosticSet, File, LanguageServerId, Point, ToOffset as _, ToPoint,
30};
31use language::{BufferSnapshot, OffsetRangeExt};
32use language_model::{LlmApiToken, RefreshLlmTokenListener};
33use open_ai::FunctionDefinition;
34use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
35use release_channel::AppVersion;
36use semver::Version;
37use serde::de::DeserializeOwned;
38use settings::{EditPredictionProvider, Settings as _, SettingsStore, update_settings_file};
39use std::any::{Any as _, TypeId};
40use std::collections::{VecDeque, hash_map};
41use telemetry_events::EditPredictionRating;
42use workspace::Workspace;
43
44use std::ops::Range;
45use std::path::Path;
46use std::rc::Rc;
47use std::str::FromStr as _;
48use std::sync::{Arc, LazyLock};
49use std::time::{Duration, Instant};
50use std::{env, mem};
51use thiserror::Error;
52use util::rel_path::RelPathBuf;
53use util::{LogErrorFuture, RangeExt as _, ResultExt as _, TryFutureExt};
54use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
55
56pub mod assemble_excerpts;
57mod license_detection;
58mod onboarding_modal;
59mod prediction;
60mod provider;
61mod rate_prediction_modal;
62pub mod retrieval_search;
63mod sweep_ai;
64pub mod udiff;
65mod xml_edits;
66pub mod zeta1;
67
68#[cfg(test)]
69mod zeta_tests;
70
71use crate::assemble_excerpts::assemble_excerpts;
72use crate::license_detection::LicenseDetectionWatcher;
73use crate::onboarding_modal::ZedPredictModal;
74pub use crate::prediction::EditPrediction;
75pub use crate::prediction::EditPredictionId;
76pub use crate::prediction::EditPredictionInputs;
77use crate::rate_prediction_modal::{
78 NextEdit, PreviousEdit, RatePredictionsModal, ThumbsDownActivePrediction,
79 ThumbsUpActivePrediction,
80};
81use crate::sweep_ai::SweepAi;
82use crate::zeta1::request_prediction_with_zeta1;
83pub use provider::ZetaEditPredictionProvider;
84
85actions!(
86 edit_prediction,
87 [
88 /// Resets the edit prediction onboarding state.
89 ResetOnboarding,
90 /// Opens the rate completions modal.
91 RateCompletions,
92 /// Clears the edit prediction history.
93 ClearHistory,
94 ]
95);
96
97/// Maximum number of events to track.
98const EVENT_COUNT_MAX: usize = 6;
99const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
100const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
101
102pub struct SweepFeatureFlag;
103
104impl FeatureFlag for SweepFeatureFlag {
105 const NAME: &str = "sweep-ai";
106}
107pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
108 max_bytes: 512,
109 min_bytes: 128,
110 target_before_cursor_over_total_bytes: 0.5,
111};
112
113pub const DEFAULT_CONTEXT_OPTIONS: ContextMode =
114 ContextMode::Agentic(DEFAULT_AGENTIC_CONTEXT_OPTIONS);
115
116pub const DEFAULT_AGENTIC_CONTEXT_OPTIONS: AgenticContextOptions = AgenticContextOptions {
117 excerpt: DEFAULT_EXCERPT_OPTIONS,
118};
119
120pub const DEFAULT_SYNTAX_CONTEXT_OPTIONS: EditPredictionContextOptions =
121 EditPredictionContextOptions {
122 use_imports: true,
123 max_retrieved_declarations: 0,
124 excerpt: DEFAULT_EXCERPT_OPTIONS,
125 score: EditPredictionScoreOptions {
126 omit_excerpt_overlaps: true,
127 },
128 };
129
130pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
131 context: DEFAULT_CONTEXT_OPTIONS,
132 max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
133 max_diagnostic_bytes: 2048,
134 prompt_format: PromptFormat::DEFAULT,
135 file_indexing_parallelism: 1,
136 buffer_change_grouping_interval: Duration::from_secs(1),
137};
138
139static USE_OLLAMA: LazyLock<bool> =
140 LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
141static CONTEXT_RETRIEVAL_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
142 env::var("ZED_ZETA2_CONTEXT_MODEL").unwrap_or(if *USE_OLLAMA {
143 "qwen3-coder:30b".to_string()
144 } else {
145 "yqvev8r3".to_string()
146 })
147});
148static EDIT_PREDICTIONS_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
149 match env::var("ZED_ZETA2_MODEL").as_deref() {
150 Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten
151 Ok(model) => model,
152 Err(_) if *USE_OLLAMA => "qwen3-coder:30b",
153 Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten
154 }
155 .to_string()
156});
157static PREDICT_EDITS_URL: LazyLock<Option<String>> = LazyLock::new(|| {
158 env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| {
159 if *USE_OLLAMA {
160 Some("http://localhost:11434/v1/chat/completions".into())
161 } else {
162 None
163 }
164 })
165});
166
167pub struct Zeta2FeatureFlag;
168
169impl FeatureFlag for Zeta2FeatureFlag {
170 const NAME: &'static str = "zeta2";
171
172 fn enabled_for_staff() -> bool {
173 true
174 }
175}
176
177#[derive(Clone)]
178struct ZetaGlobal(Entity<Zeta>);
179
180impl Global for ZetaGlobal {}
181
182pub struct Zeta {
183 client: Arc<Client>,
184 user_store: Entity<UserStore>,
185 llm_token: LlmApiToken,
186 _llm_token_subscription: Subscription,
187 projects: HashMap<EntityId, ZetaProject>,
188 options: ZetaOptions,
189 update_required: bool,
190 debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
191 #[cfg(feature = "eval-support")]
192 eval_cache: Option<Arc<dyn EvalCache>>,
193 edit_prediction_model: ZetaEditPredictionModel,
194 sweep_ai: SweepAi,
195 data_collection_choice: DataCollectionChoice,
196 rejected_predictions: Vec<EditPredictionRejection>,
197 reject_predictions_tx: mpsc::UnboundedSender<()>,
198 reject_predictions_debounce_task: Option<Task<()>>,
199 shown_predictions: VecDeque<EditPrediction>,
200 rated_predictions: HashSet<EditPredictionId>,
201}
202
203#[derive(Copy, Clone, Default, PartialEq, Eq)]
204pub enum ZetaEditPredictionModel {
205 #[default]
206 Zeta1,
207 Zeta2,
208 Sweep,
209}
210
211#[derive(Debug, Clone, PartialEq)]
212pub struct ZetaOptions {
213 pub context: ContextMode,
214 pub max_prompt_bytes: usize,
215 pub max_diagnostic_bytes: usize,
216 pub prompt_format: predict_edits_v3::PromptFormat,
217 pub file_indexing_parallelism: usize,
218 pub buffer_change_grouping_interval: Duration,
219}
220
221#[derive(Debug, Clone, PartialEq)]
222pub enum ContextMode {
223 Agentic(AgenticContextOptions),
224 Syntax(EditPredictionContextOptions),
225}
226
227#[derive(Debug, Clone, PartialEq)]
228pub struct AgenticContextOptions {
229 pub excerpt: EditPredictionExcerptOptions,
230}
231
232impl ContextMode {
233 pub fn excerpt(&self) -> &EditPredictionExcerptOptions {
234 match self {
235 ContextMode::Agentic(options) => &options.excerpt,
236 ContextMode::Syntax(options) => &options.excerpt,
237 }
238 }
239}
240
241#[derive(Debug)]
242pub enum ZetaDebugInfo {
243 ContextRetrievalStarted(ZetaContextRetrievalStartedDebugInfo),
244 SearchQueriesGenerated(ZetaSearchQueryDebugInfo),
245 SearchQueriesExecuted(ZetaContextRetrievalDebugInfo),
246 ContextRetrievalFinished(ZetaContextRetrievalDebugInfo),
247 EditPredictionRequested(ZetaEditPredictionDebugInfo),
248}
249
250#[derive(Debug)]
251pub struct ZetaContextRetrievalStartedDebugInfo {
252 pub project: Entity<Project>,
253 pub timestamp: Instant,
254 pub search_prompt: String,
255}
256
257#[derive(Debug)]
258pub struct ZetaContextRetrievalDebugInfo {
259 pub project: Entity<Project>,
260 pub timestamp: Instant,
261}
262
263#[derive(Debug)]
264pub struct ZetaEditPredictionDebugInfo {
265 pub inputs: EditPredictionInputs,
266 pub retrieval_time: Duration,
267 pub buffer: WeakEntity<Buffer>,
268 pub position: language::Anchor,
269 pub local_prompt: Result<String, String>,
270 pub response_rx: oneshot::Receiver<(Result<open_ai::Response, String>, Duration)>,
271}
272
273#[derive(Debug)]
274pub struct ZetaSearchQueryDebugInfo {
275 pub project: Entity<Project>,
276 pub timestamp: Instant,
277 pub search_queries: Vec<SearchToolQuery>,
278}
279
280pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
281
282struct ZetaProject {
283 syntax_index: Option<Entity<SyntaxIndex>>,
284 events: VecDeque<Arc<cloud_llm_client::predict_edits_v3::Event>>,
285 last_event: Option<LastEvent>,
286 recent_paths: VecDeque<ProjectPath>,
287 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
288 current_prediction: Option<CurrentEditPrediction>,
289 next_pending_prediction_id: usize,
290 pending_predictions: ArrayVec<PendingPrediction, 2>,
291 last_prediction_refresh: Option<(EntityId, Instant)>,
292 cancelled_predictions: HashSet<usize>,
293 context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
294 refresh_context_task: Option<LogErrorFuture<Task<Result<()>>>>,
295 refresh_context_debounce_task: Option<Task<Option<()>>>,
296 refresh_context_timestamp: Option<Instant>,
297 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
298 _subscription: gpui::Subscription,
299}
300
301impl ZetaProject {
302 pub fn events(&self, cx: &App) -> Vec<Arc<cloud_llm_client::predict_edits_v3::Event>> {
303 self.events
304 .iter()
305 .cloned()
306 .chain(
307 self.last_event
308 .as_ref()
309 .and_then(|event| event.finalize(&self.license_detection_watchers, cx)),
310 )
311 .collect()
312 }
313}
314
315#[derive(Debug, Clone)]
316struct CurrentEditPrediction {
317 pub requested_by: PredictionRequestedBy,
318 pub prediction: EditPrediction,
319 pub was_shown: bool,
320}
321
322impl CurrentEditPrediction {
323 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
324 let Some(new_edits) = self
325 .prediction
326 .interpolate(&self.prediction.buffer.read(cx))
327 else {
328 return false;
329 };
330
331 if self.prediction.buffer != old_prediction.prediction.buffer {
332 return true;
333 }
334
335 let Some(old_edits) = old_prediction
336 .prediction
337 .interpolate(&old_prediction.prediction.buffer.read(cx))
338 else {
339 return true;
340 };
341
342 let requested_by_buffer_id = self.requested_by.buffer_id();
343
344 // This reduces the occurrence of UI thrash from replacing edits
345 //
346 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
347 if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
348 && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
349 && old_edits.len() == 1
350 && new_edits.len() == 1
351 {
352 let (old_range, old_text) = &old_edits[0];
353 let (new_range, new_text) = &new_edits[0];
354 new_range == old_range && new_text.starts_with(old_text.as_ref())
355 } else {
356 true
357 }
358 }
359}
360
361#[derive(Debug, Clone)]
362enum PredictionRequestedBy {
363 DiagnosticsUpdate,
364 Buffer(EntityId),
365}
366
367impl PredictionRequestedBy {
368 pub fn buffer_id(&self) -> Option<EntityId> {
369 match self {
370 PredictionRequestedBy::DiagnosticsUpdate => None,
371 PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
372 }
373 }
374}
375
376struct PendingPrediction {
377 id: usize,
378 task: Task<Option<EditPredictionId>>,
379}
380
381/// A prediction from the perspective of a buffer.
382#[derive(Debug)]
383enum BufferEditPrediction<'a> {
384 Local { prediction: &'a EditPrediction },
385 Jump { prediction: &'a EditPrediction },
386}
387
388struct RegisteredBuffer {
389 snapshot: BufferSnapshot,
390 _subscriptions: [gpui::Subscription; 2],
391}
392
393struct LastEvent {
394 old_snapshot: BufferSnapshot,
395 new_snapshot: BufferSnapshot,
396 end_edit_anchor: Option<Anchor>,
397}
398
399impl LastEvent {
400 pub fn finalize(
401 &self,
402 license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
403 cx: &App,
404 ) -> Option<Arc<predict_edits_v3::Event>> {
405 let path = buffer_path_with_id_fallback(&self.new_snapshot, cx);
406 let old_path = buffer_path_with_id_fallback(&self.old_snapshot, cx);
407
408 let file = self.new_snapshot.file();
409 let old_file = self.old_snapshot.file();
410
411 let in_open_source_repo = [file, old_file].iter().all(|file| {
412 file.is_some_and(|file| {
413 license_detection_watchers
414 .get(&file.worktree_id(cx))
415 .is_some_and(|watcher| watcher.is_project_open_source())
416 })
417 });
418
419 let diff = language::unified_diff(&self.old_snapshot.text(), &self.new_snapshot.text());
420
421 if path == old_path && diff.is_empty() {
422 None
423 } else {
424 Some(Arc::new(predict_edits_v3::Event::BufferChange {
425 old_path,
426 path,
427 diff,
428 in_open_source_repo,
429 // TODO: Actually detect if this edit was predicted or not
430 predicted: false,
431 }))
432 }
433 }
434}
435
436fn buffer_path_with_id_fallback(snapshot: &BufferSnapshot, cx: &App) -> Arc<Path> {
437 if let Some(file) = snapshot.file() {
438 file.full_path(cx).into()
439 } else {
440 Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
441 }
442}
443
444impl Zeta {
445 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
446 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
447 }
448
449 pub fn global(
450 client: &Arc<Client>,
451 user_store: &Entity<UserStore>,
452 cx: &mut App,
453 ) -> Entity<Self> {
454 cx.try_global::<ZetaGlobal>()
455 .map(|global| global.0.clone())
456 .unwrap_or_else(|| {
457 let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
458 cx.set_global(ZetaGlobal(zeta.clone()));
459 zeta
460 })
461 }
462
463 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
464 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
465 let data_collection_choice = Self::load_data_collection_choice();
466
467 let (reject_tx, mut reject_rx) = mpsc::unbounded();
468 cx.spawn(async move |this, cx| {
469 while let Some(()) = reject_rx.next().await {
470 this.update(cx, |this, cx| this.reject_edit_predictions(cx))?
471 .await
472 .log_err();
473 }
474 anyhow::Ok(())
475 })
476 .detach();
477
478 Self {
479 projects: HashMap::default(),
480 client,
481 user_store,
482 options: DEFAULT_OPTIONS,
483 llm_token: LlmApiToken::default(),
484 _llm_token_subscription: cx.subscribe(
485 &refresh_llm_token_listener,
486 |this, _listener, _event, cx| {
487 let client = this.client.clone();
488 let llm_token = this.llm_token.clone();
489 cx.spawn(async move |_this, _cx| {
490 llm_token.refresh(&client).await?;
491 anyhow::Ok(())
492 })
493 .detach_and_log_err(cx);
494 },
495 ),
496 update_required: false,
497 debug_tx: None,
498 #[cfg(feature = "eval-support")]
499 eval_cache: None,
500 edit_prediction_model: ZetaEditPredictionModel::Zeta2,
501 sweep_ai: SweepAi::new(cx),
502 data_collection_choice,
503 rejected_predictions: Vec::new(),
504 reject_predictions_debounce_task: None,
505 reject_predictions_tx: reject_tx,
506 rated_predictions: Default::default(),
507 shown_predictions: Default::default(),
508 }
509 }
510
511 pub fn set_edit_prediction_model(&mut self, model: ZetaEditPredictionModel) {
512 self.edit_prediction_model = model;
513 }
514
515 pub fn has_sweep_api_token(&self) -> bool {
516 self.sweep_ai.api_token.is_some()
517 }
518
519 #[cfg(feature = "eval-support")]
520 pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
521 self.eval_cache = Some(cache);
522 }
523
524 pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<ZetaDebugInfo> {
525 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
526 self.debug_tx = Some(debug_watch_tx);
527 debug_watch_rx
528 }
529
530 pub fn options(&self) -> &ZetaOptions {
531 &self.options
532 }
533
534 pub fn set_options(&mut self, options: ZetaOptions) {
535 self.options = options;
536 }
537
538 pub fn clear_history(&mut self) {
539 for zeta_project in self.projects.values_mut() {
540 zeta_project.events.clear();
541 }
542 }
543
544 pub fn context_for_project(
545 &self,
546 project: &Entity<Project>,
547 ) -> impl Iterator<Item = (Entity<Buffer>, &[Range<Anchor>])> {
548 self.projects
549 .get(&project.entity_id())
550 .and_then(|project| {
551 Some(
552 project
553 .context
554 .as_ref()?
555 .iter()
556 .map(|(buffer, ranges)| (buffer.clone(), ranges.as_slice())),
557 )
558 })
559 .into_iter()
560 .flatten()
561 }
562
563 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
564 if self.edit_prediction_model == ZetaEditPredictionModel::Zeta2 {
565 self.user_store.read(cx).edit_prediction_usage()
566 } else {
567 None
568 }
569 }
570
571 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
572 self.get_or_init_zeta_project(project, cx);
573 }
574
575 pub fn register_buffer(
576 &mut self,
577 buffer: &Entity<Buffer>,
578 project: &Entity<Project>,
579 cx: &mut Context<Self>,
580 ) {
581 let zeta_project = self.get_or_init_zeta_project(project, cx);
582 Self::register_buffer_impl(zeta_project, buffer, project, cx);
583 }
584
585 fn get_or_init_zeta_project(
586 &mut self,
587 project: &Entity<Project>,
588 cx: &mut Context<Self>,
589 ) -> &mut ZetaProject {
590 self.projects
591 .entry(project.entity_id())
592 .or_insert_with(|| ZetaProject {
593 syntax_index: if let ContextMode::Syntax(_) = &self.options.context {
594 Some(cx.new(|cx| {
595 SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
596 }))
597 } else {
598 None
599 },
600 events: VecDeque::new(),
601 last_event: None,
602 recent_paths: VecDeque::new(),
603 registered_buffers: HashMap::default(),
604 current_prediction: None,
605 cancelled_predictions: HashSet::default(),
606 pending_predictions: ArrayVec::new(),
607 next_pending_prediction_id: 0,
608 last_prediction_refresh: None,
609 context: None,
610 refresh_context_task: None,
611 refresh_context_debounce_task: None,
612 refresh_context_timestamp: None,
613 license_detection_watchers: HashMap::default(),
614 _subscription: cx.subscribe(&project, Self::handle_project_event),
615 })
616 }
617
618 fn handle_project_event(
619 &mut self,
620 project: Entity<Project>,
621 event: &project::Event,
622 cx: &mut Context<Self>,
623 ) {
624 // TODO [zeta2] init with recent paths
625 match event {
626 project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
627 let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
628 return;
629 };
630 let path = project.read(cx).path_for_entry(*active_entry_id, cx);
631 if let Some(path) = path {
632 if let Some(ix) = zeta_project
633 .recent_paths
634 .iter()
635 .position(|probe| probe == &path)
636 {
637 zeta_project.recent_paths.remove(ix);
638 }
639 zeta_project.recent_paths.push_front(path);
640 }
641 }
642 project::Event::DiagnosticsUpdated { .. } => {
643 if cx.has_flag::<Zeta2FeatureFlag>() {
644 self.refresh_prediction_from_diagnostics(project, cx);
645 }
646 }
647 _ => (),
648 }
649 }
650
651 fn register_buffer_impl<'a>(
652 zeta_project: &'a mut ZetaProject,
653 buffer: &Entity<Buffer>,
654 project: &Entity<Project>,
655 cx: &mut Context<Self>,
656 ) -> &'a mut RegisteredBuffer {
657 let buffer_id = buffer.entity_id();
658
659 if let Some(file) = buffer.read(cx).file() {
660 let worktree_id = file.worktree_id(cx);
661 if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
662 zeta_project
663 .license_detection_watchers
664 .entry(worktree_id)
665 .or_insert_with(|| {
666 let project_entity_id = project.entity_id();
667 cx.observe_release(&worktree, move |this, _worktree, _cx| {
668 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
669 else {
670 return;
671 };
672 zeta_project.license_detection_watchers.remove(&worktree_id);
673 })
674 .detach();
675 Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
676 });
677 }
678 }
679
680 match zeta_project.registered_buffers.entry(buffer_id) {
681 hash_map::Entry::Occupied(entry) => entry.into_mut(),
682 hash_map::Entry::Vacant(entry) => {
683 let snapshot = buffer.read(cx).snapshot();
684 let project_entity_id = project.entity_id();
685 entry.insert(RegisteredBuffer {
686 snapshot,
687 _subscriptions: [
688 cx.subscribe(buffer, {
689 let project = project.downgrade();
690 move |this, buffer, event, cx| {
691 if let language::BufferEvent::Edited = event
692 && let Some(project) = project.upgrade()
693 {
694 this.report_changes_for_buffer(&buffer, &project, cx);
695 }
696 }
697 }),
698 cx.observe_release(buffer, move |this, _buffer, _cx| {
699 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
700 else {
701 return;
702 };
703 zeta_project.registered_buffers.remove(&buffer_id);
704 }),
705 ],
706 })
707 }
708 }
709 }
710
711 fn report_changes_for_buffer(
712 &mut self,
713 buffer: &Entity<Buffer>,
714 project: &Entity<Project>,
715 cx: &mut Context<Self>,
716 ) {
717 let project_state = self.get_or_init_zeta_project(project, cx);
718 let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
719
720 let new_snapshot = buffer.read(cx).snapshot();
721 if new_snapshot.version == registered_buffer.snapshot.version {
722 return;
723 }
724
725 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
726 let end_edit_anchor = new_snapshot
727 .anchored_edits_since::<Point>(&old_snapshot.version)
728 .last()
729 .map(|(_, range)| range.end);
730 let events = &mut project_state.events;
731
732 if let Some(LastEvent {
733 new_snapshot: last_new_snapshot,
734 end_edit_anchor: last_end_edit_anchor,
735 ..
736 }) = project_state.last_event.as_mut()
737 {
738 let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
739 == last_new_snapshot.remote_id()
740 && old_snapshot.version == last_new_snapshot.version;
741
742 let should_coalesce = is_next_snapshot_of_same_buffer
743 && end_edit_anchor
744 .as_ref()
745 .zip(last_end_edit_anchor.as_ref())
746 .is_some_and(|(a, b)| {
747 let a = a.to_point(&new_snapshot);
748 let b = b.to_point(&new_snapshot);
749 a.row.abs_diff(b.row) <= CHANGE_GROUPING_LINE_SPAN
750 });
751
752 if should_coalesce {
753 *last_end_edit_anchor = end_edit_anchor;
754 *last_new_snapshot = new_snapshot;
755 return;
756 }
757 }
758
759 if events.len() + 1 >= EVENT_COUNT_MAX {
760 events.pop_front();
761 }
762
763 if let Some(event) = project_state.last_event.take() {
764 events.extend(event.finalize(&project_state.license_detection_watchers, cx));
765 }
766
767 project_state.last_event = Some(LastEvent {
768 old_snapshot,
769 new_snapshot,
770 end_edit_anchor,
771 });
772 }
773
774 fn current_prediction_for_buffer(
775 &self,
776 buffer: &Entity<Buffer>,
777 project: &Entity<Project>,
778 cx: &App,
779 ) -> Option<BufferEditPrediction<'_>> {
780 let project_state = self.projects.get(&project.entity_id())?;
781
782 let CurrentEditPrediction {
783 requested_by,
784 prediction,
785 ..
786 } = project_state.current_prediction.as_ref()?;
787
788 if prediction.targets_buffer(buffer.read(cx)) {
789 Some(BufferEditPrediction::Local { prediction })
790 } else {
791 let show_jump = match requested_by {
792 PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
793 requested_by_buffer_id == &buffer.entity_id()
794 }
795 PredictionRequestedBy::DiagnosticsUpdate => true,
796 };
797
798 if show_jump {
799 Some(BufferEditPrediction::Jump { prediction })
800 } else {
801 None
802 }
803 }
804 }
805
806 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
807 match self.edit_prediction_model {
808 ZetaEditPredictionModel::Zeta1 | ZetaEditPredictionModel::Zeta2 => {}
809 ZetaEditPredictionModel::Sweep => return,
810 }
811
812 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
813 return;
814 };
815
816 let Some(prediction) = project_state.current_prediction.take() else {
817 return;
818 };
819 let request_id = prediction.prediction.id.to_string();
820 for pending_prediction in mem::take(&mut project_state.pending_predictions) {
821 self.cancel_pending_prediction(pending_prediction, cx);
822 }
823
824 let client = self.client.clone();
825 let llm_token = self.llm_token.clone();
826 let app_version = AppVersion::global(cx);
827 cx.spawn(async move |this, cx| {
828 let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") {
829 http_client::Url::parse(&predict_edits_url)?
830 } else {
831 client
832 .http_client()
833 .build_zed_llm_url("/predict_edits/accept", &[])?
834 };
835
836 let response = cx
837 .background_spawn(Self::send_api_request::<()>(
838 move |builder| {
839 let req = builder.uri(url.as_ref()).body(
840 serde_json::to_string(&AcceptEditPredictionBody {
841 request_id: request_id.clone(),
842 })?
843 .into(),
844 );
845 Ok(req?)
846 },
847 client,
848 llm_token,
849 app_version,
850 ))
851 .await;
852
853 Self::handle_api_response(&this, response, cx)?;
854 anyhow::Ok(())
855 })
856 .detach_and_log_err(cx);
857 }
858
859 fn reject_edit_predictions(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
860 match self.edit_prediction_model {
861 ZetaEditPredictionModel::Zeta1 | ZetaEditPredictionModel::Zeta2 => {}
862 ZetaEditPredictionModel::Sweep => return Task::ready(anyhow::Ok(())),
863 }
864
865 let client = self.client.clone();
866 let llm_token = self.llm_token.clone();
867 let app_version = AppVersion::global(cx);
868 let last_rejection = self.rejected_predictions.last().cloned();
869 let Some(last_rejection) = last_rejection else {
870 return Task::ready(anyhow::Ok(()));
871 };
872
873 let body = serde_json::to_string(&RejectEditPredictionsBody {
874 rejections: self.rejected_predictions.clone(),
875 })
876 .ok();
877
878 cx.spawn(async move |this, cx| {
879 let url = client
880 .http_client()
881 .build_zed_llm_url("/predict_edits/reject", &[])?;
882
883 cx.background_spawn(Self::send_api_request::<()>(
884 move |builder| {
885 let req = builder.uri(url.as_ref()).body(body.clone().into());
886 Ok(req?)
887 },
888 client,
889 llm_token,
890 app_version,
891 ))
892 .await
893 .context("Failed to reject edit predictions")?;
894
895 this.update(cx, |this, _| {
896 if let Some(ix) = this
897 .rejected_predictions
898 .iter()
899 .position(|rejection| rejection.request_id == last_rejection.request_id)
900 {
901 this.rejected_predictions.drain(..ix + 1);
902 }
903 })
904 })
905 }
906
907 fn discard_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
908 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
909 project_state.pending_predictions.clear();
910 if let Some(prediction) = project_state.current_prediction.take() {
911 self.discard_prediction(prediction.prediction.id, prediction.was_shown, cx);
912 }
913 };
914 }
915
916 fn did_show_current_prediction(&mut self, project: &Entity<Project>, _cx: &mut Context<Self>) {
917 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
918 if let Some(current_prediction) = project_state.current_prediction.as_mut() {
919 if !current_prediction.was_shown {
920 current_prediction.was_shown = true;
921 self.shown_predictions
922 .push_front(current_prediction.prediction.clone());
923 if self.shown_predictions.len() > 50 {
924 let completion = self.shown_predictions.pop_back().unwrap();
925 self.rated_predictions.remove(&completion.id);
926 }
927 }
928 }
929 }
930 }
931
932 fn discard_prediction(
933 &mut self,
934 prediction_id: EditPredictionId,
935 was_shown: bool,
936 cx: &mut Context<Self>,
937 ) {
938 self.rejected_predictions.push(EditPredictionRejection {
939 request_id: prediction_id.to_string(),
940 was_shown,
941 });
942
943 let reached_request_limit =
944 self.rejected_predictions.len() >= MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST;
945 let reject_tx = self.reject_predictions_tx.clone();
946 self.reject_predictions_debounce_task = Some(cx.spawn(async move |_this, cx| {
947 const DISCARD_COMPLETIONS_DEBOUNCE: Duration = Duration::from_secs(15);
948 if !reached_request_limit {
949 cx.background_executor()
950 .timer(DISCARD_COMPLETIONS_DEBOUNCE)
951 .await;
952 }
953 reject_tx.unbounded_send(()).log_err();
954 }));
955 }
956
957 fn cancel_pending_prediction(
958 &self,
959 pending_prediction: PendingPrediction,
960 cx: &mut Context<Self>,
961 ) {
962 cx.spawn(async move |this, cx| {
963 let Some(prediction_id) = pending_prediction.task.await else {
964 return;
965 };
966
967 this.update(cx, |this, cx| {
968 this.discard_prediction(prediction_id, false, cx);
969 })
970 .ok();
971 })
972 .detach()
973 }
974
975 fn is_refreshing(&self, project: &Entity<Project>) -> bool {
976 self.projects
977 .get(&project.entity_id())
978 .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
979 }
980
981 pub fn refresh_prediction_from_buffer(
982 &mut self,
983 project: Entity<Project>,
984 buffer: Entity<Buffer>,
985 position: language::Anchor,
986 cx: &mut Context<Self>,
987 ) {
988 self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
989 let Some(request_task) = this
990 .update(cx, |this, cx| {
991 this.request_prediction(&project, &buffer, position, cx)
992 })
993 .log_err()
994 else {
995 return Task::ready(anyhow::Ok(None));
996 };
997
998 let project = project.clone();
999 cx.spawn(async move |cx| {
1000 if let Some(prediction) = request_task.await? {
1001 let id = prediction.id.clone();
1002 this.update(cx, |this, cx| {
1003 let project_state = this
1004 .projects
1005 .get_mut(&project.entity_id())
1006 .context("Project not found")?;
1007
1008 let new_prediction = CurrentEditPrediction {
1009 requested_by: PredictionRequestedBy::Buffer(buffer.entity_id()),
1010 prediction: prediction,
1011 was_shown: false,
1012 };
1013
1014 if project_state
1015 .current_prediction
1016 .as_ref()
1017 .is_none_or(|old_prediction| {
1018 new_prediction.should_replace_prediction(&old_prediction, cx)
1019 })
1020 {
1021 project_state.current_prediction = Some(new_prediction);
1022 cx.notify();
1023 }
1024 anyhow::Ok(())
1025 })??;
1026 Ok(Some(id))
1027 } else {
1028 Ok(None)
1029 }
1030 })
1031 })
1032 }
1033
1034 pub fn refresh_prediction_from_diagnostics(
1035 &mut self,
1036 project: Entity<Project>,
1037 cx: &mut Context<Self>,
1038 ) {
1039 let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1040 return;
1041 };
1042
1043 // Prefer predictions from buffer
1044 if zeta_project.current_prediction.is_some() {
1045 return;
1046 };
1047
1048 self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
1049 let Some(open_buffer_task) = project
1050 .update(cx, |project, cx| {
1051 project
1052 .active_entry()
1053 .and_then(|entry| project.path_for_entry(entry, cx))
1054 .map(|path| project.open_buffer(path, cx))
1055 })
1056 .log_err()
1057 .flatten()
1058 else {
1059 return Task::ready(anyhow::Ok(None));
1060 };
1061
1062 cx.spawn(async move |cx| {
1063 let active_buffer = open_buffer_task.await?;
1064 let snapshot = active_buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
1065
1066 let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1067 active_buffer,
1068 &snapshot,
1069 Default::default(),
1070 Default::default(),
1071 &project,
1072 cx,
1073 )
1074 .await?
1075 else {
1076 return anyhow::Ok(None);
1077 };
1078
1079 let Some(prediction) = this
1080 .update(cx, |this, cx| {
1081 this.request_prediction(&project, &jump_buffer, jump_position, cx)
1082 })?
1083 .await?
1084 else {
1085 return anyhow::Ok(None);
1086 };
1087
1088 let id = prediction.id.clone();
1089 this.update(cx, |this, cx| {
1090 if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
1091 zeta_project.current_prediction.get_or_insert_with(|| {
1092 cx.notify();
1093 CurrentEditPrediction {
1094 requested_by: PredictionRequestedBy::DiagnosticsUpdate,
1095 prediction,
1096 was_shown: false,
1097 }
1098 });
1099 }
1100 })?;
1101
1102 anyhow::Ok(Some(id))
1103 })
1104 });
1105 }
1106
1107 #[cfg(not(test))]
1108 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1109 #[cfg(test)]
1110 pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
1111
1112 fn queue_prediction_refresh(
1113 &mut self,
1114 project: Entity<Project>,
1115 throttle_entity: EntityId,
1116 cx: &mut Context<Self>,
1117 do_refresh: impl FnOnce(
1118 WeakEntity<Self>,
1119 &mut AsyncApp,
1120 ) -> Task<Result<Option<EditPredictionId>>>
1121 + 'static,
1122 ) {
1123 let zeta_project = self.get_or_init_zeta_project(&project, cx);
1124 let pending_prediction_id = zeta_project.next_pending_prediction_id;
1125 zeta_project.next_pending_prediction_id += 1;
1126 let last_request = zeta_project.last_prediction_refresh;
1127
1128 let task = cx.spawn(async move |this, cx| {
1129 if let Some((last_entity, last_timestamp)) = last_request
1130 && throttle_entity == last_entity
1131 && let Some(timeout) =
1132 (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
1133 {
1134 cx.background_executor().timer(timeout).await;
1135 }
1136
1137 // If this task was cancelled before the throttle timeout expired,
1138 // do not perform a request.
1139 let mut is_cancelled = true;
1140 this.update(cx, |this, cx| {
1141 let project_state = this.get_or_init_zeta_project(&project, cx);
1142 if !project_state
1143 .cancelled_predictions
1144 .remove(&pending_prediction_id)
1145 {
1146 project_state.last_prediction_refresh = Some((throttle_entity, Instant::now()));
1147 is_cancelled = false;
1148 }
1149 })
1150 .ok();
1151 if is_cancelled {
1152 return None;
1153 }
1154
1155 let edit_prediction_id = do_refresh(this.clone(), cx).await.log_err().flatten();
1156
1157 // When a prediction completes, remove it from the pending list, and cancel
1158 // any pending predictions that were enqueued before it.
1159 this.update(cx, |this, cx| {
1160 let zeta_project = this.get_or_init_zeta_project(&project, cx);
1161 zeta_project
1162 .cancelled_predictions
1163 .remove(&pending_prediction_id);
1164
1165 let mut pending_predictions = mem::take(&mut zeta_project.pending_predictions);
1166 for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1167 if pending_prediction.id == pending_prediction_id {
1168 pending_predictions.remove(ix);
1169 for pending_prediction in pending_predictions.drain(0..ix) {
1170 this.cancel_pending_prediction(pending_prediction, cx)
1171 }
1172 break;
1173 }
1174 }
1175 this.get_or_init_zeta_project(&project, cx)
1176 .pending_predictions = pending_predictions;
1177 cx.notify();
1178 })
1179 .ok();
1180
1181 edit_prediction_id
1182 });
1183
1184 if zeta_project.pending_predictions.len() <= 1 {
1185 zeta_project.pending_predictions.push(PendingPrediction {
1186 id: pending_prediction_id,
1187 task,
1188 });
1189 } else if zeta_project.pending_predictions.len() == 2 {
1190 let pending_prediction = zeta_project.pending_predictions.pop().unwrap();
1191 zeta_project.pending_predictions.push(PendingPrediction {
1192 id: pending_prediction_id,
1193 task,
1194 });
1195 zeta_project
1196 .cancelled_predictions
1197 .insert(pending_prediction.id);
1198 self.cancel_pending_prediction(pending_prediction, cx);
1199 }
1200 }
1201
1202 pub fn request_prediction(
1203 &mut self,
1204 project: &Entity<Project>,
1205 active_buffer: &Entity<Buffer>,
1206 position: language::Anchor,
1207 cx: &mut Context<Self>,
1208 ) -> Task<Result<Option<EditPrediction>>> {
1209 self.request_prediction_internal(
1210 project.clone(),
1211 active_buffer.clone(),
1212 position,
1213 cx.has_flag::<Zeta2FeatureFlag>(),
1214 cx,
1215 )
1216 }
1217
1218 fn request_prediction_internal(
1219 &mut self,
1220 project: Entity<Project>,
1221 active_buffer: Entity<Buffer>,
1222 position: language::Anchor,
1223 allow_jump: bool,
1224 cx: &mut Context<Self>,
1225 ) -> Task<Result<Option<EditPrediction>>> {
1226 const DIAGNOSTIC_LINES_RANGE: u32 = 20;
1227
1228 self.get_or_init_zeta_project(&project, cx);
1229 let zeta_project = self.projects.get(&project.entity_id()).unwrap();
1230 let events = zeta_project.events(cx);
1231 let has_events = !events.is_empty();
1232
1233 let snapshot = active_buffer.read(cx).snapshot();
1234 let cursor_point = position.to_point(&snapshot);
1235 let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1236 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1237 let diagnostic_search_range =
1238 Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1239
1240 let task = match self.edit_prediction_model {
1241 ZetaEditPredictionModel::Zeta1 => request_prediction_with_zeta1(
1242 self,
1243 &project,
1244 &active_buffer,
1245 snapshot.clone(),
1246 position,
1247 events,
1248 cx,
1249 ),
1250 ZetaEditPredictionModel::Zeta2 => self.request_prediction_with_zeta2(
1251 &project,
1252 &active_buffer,
1253 snapshot.clone(),
1254 position,
1255 events,
1256 cx,
1257 ),
1258 ZetaEditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(
1259 &project,
1260 &active_buffer,
1261 snapshot.clone(),
1262 position,
1263 events,
1264 &zeta_project.recent_paths,
1265 diagnostic_search_range.clone(),
1266 cx,
1267 ),
1268 };
1269
1270 cx.spawn(async move |this, cx| {
1271 let prediction = task
1272 .await?
1273 .filter(|prediction| !prediction.edits.is_empty());
1274
1275 if prediction.is_none() && allow_jump {
1276 let cursor_point = position.to_point(&snapshot);
1277 if has_events
1278 && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1279 active_buffer.clone(),
1280 &snapshot,
1281 diagnostic_search_range,
1282 cursor_point,
1283 &project,
1284 cx,
1285 )
1286 .await?
1287 {
1288 return this
1289 .update(cx, |this, cx| {
1290 this.request_prediction_internal(
1291 project,
1292 jump_buffer,
1293 jump_position,
1294 false,
1295 cx,
1296 )
1297 })?
1298 .await;
1299 }
1300
1301 return anyhow::Ok(None);
1302 }
1303
1304 Ok(prediction)
1305 })
1306 }
1307
1308 async fn next_diagnostic_location(
1309 active_buffer: Entity<Buffer>,
1310 active_buffer_snapshot: &BufferSnapshot,
1311 active_buffer_diagnostic_search_range: Range<Point>,
1312 active_buffer_cursor_point: Point,
1313 project: &Entity<Project>,
1314 cx: &mut AsyncApp,
1315 ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1316 // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
1317 let mut jump_location = active_buffer_snapshot
1318 .diagnostic_groups(None)
1319 .into_iter()
1320 .filter_map(|(_, group)| {
1321 let range = &group.entries[group.primary_ix]
1322 .range
1323 .to_point(&active_buffer_snapshot);
1324 if range.overlaps(&active_buffer_diagnostic_search_range) {
1325 None
1326 } else {
1327 Some(range.start)
1328 }
1329 })
1330 .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
1331 .map(|position| {
1332 (
1333 active_buffer.clone(),
1334 active_buffer_snapshot.anchor_before(position),
1335 )
1336 });
1337
1338 if jump_location.is_none() {
1339 let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
1340 let file = buffer.file()?;
1341
1342 Some(ProjectPath {
1343 worktree_id: file.worktree_id(cx),
1344 path: file.path().clone(),
1345 })
1346 })?;
1347
1348 let buffer_task = project.update(cx, |project, cx| {
1349 let (path, _, _) = project
1350 .diagnostic_summaries(false, cx)
1351 .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
1352 .max_by_key(|(path, _, _)| {
1353 // find the buffer with errors that shares most parent directories
1354 path.path
1355 .components()
1356 .zip(
1357 active_buffer_path
1358 .as_ref()
1359 .map(|p| p.path.components())
1360 .unwrap_or_default(),
1361 )
1362 .take_while(|(a, b)| a == b)
1363 .count()
1364 })?;
1365
1366 Some(project.open_buffer(path, cx))
1367 })?;
1368
1369 if let Some(buffer_task) = buffer_task {
1370 let closest_buffer = buffer_task.await?;
1371
1372 jump_location = closest_buffer
1373 .read_with(cx, |buffer, _cx| {
1374 buffer
1375 .buffer_diagnostics(None)
1376 .into_iter()
1377 .min_by_key(|entry| entry.diagnostic.severity)
1378 .map(|entry| entry.range.start)
1379 })?
1380 .map(|position| (closest_buffer, position));
1381 }
1382 }
1383
1384 anyhow::Ok(jump_location)
1385 }
1386
1387 fn request_prediction_with_zeta2(
1388 &mut self,
1389 project: &Entity<Project>,
1390 active_buffer: &Entity<Buffer>,
1391 active_snapshot: BufferSnapshot,
1392 position: language::Anchor,
1393 events: Vec<Arc<Event>>,
1394 cx: &mut Context<Self>,
1395 ) -> Task<Result<Option<EditPrediction>>> {
1396 let project_state = self.projects.get(&project.entity_id());
1397
1398 let index_state = project_state.and_then(|state| {
1399 state
1400 .syntax_index
1401 .as_ref()
1402 .map(|syntax_index| syntax_index.read_with(cx, |index, _cx| index.state().clone()))
1403 });
1404 let options = self.options.clone();
1405 let buffer_snapshotted_at = Instant::now();
1406 let Some(excerpt_path) = active_snapshot
1407 .file()
1408 .map(|path| -> Arc<Path> { path.full_path(cx).into() })
1409 else {
1410 return Task::ready(Err(anyhow!("No file path for excerpt")));
1411 };
1412 let client = self.client.clone();
1413 let llm_token = self.llm_token.clone();
1414 let app_version = AppVersion::global(cx);
1415 let worktree_snapshots = project
1416 .read(cx)
1417 .worktrees(cx)
1418 .map(|worktree| worktree.read(cx).snapshot())
1419 .collect::<Vec<_>>();
1420 let debug_tx = self.debug_tx.clone();
1421
1422 let diagnostics = active_snapshot.diagnostic_sets().clone();
1423
1424 let file = active_buffer.read(cx).file();
1425 let parent_abs_path = project::File::from_dyn(file).and_then(|f| {
1426 let mut path = f.worktree.read(cx).absolutize(&f.path);
1427 if path.pop() { Some(path) } else { None }
1428 });
1429
1430 // TODO data collection
1431 let can_collect_data = file
1432 .as_ref()
1433 .map_or(false, |file| self.can_collect_file(project, file, cx));
1434
1435 let empty_context_files = HashMap::default();
1436 let context_files = project_state
1437 .and_then(|project_state| project_state.context.as_ref())
1438 .unwrap_or(&empty_context_files);
1439
1440 #[cfg(feature = "eval-support")]
1441 let parsed_fut = futures::future::join_all(
1442 context_files
1443 .keys()
1444 .map(|buffer| buffer.read(cx).parsing_idle()),
1445 );
1446
1447 let mut included_files = context_files
1448 .iter()
1449 .filter_map(|(buffer_entity, ranges)| {
1450 let buffer = buffer_entity.read(cx);
1451 Some((
1452 buffer_entity.clone(),
1453 buffer.snapshot(),
1454 buffer.file()?.full_path(cx).into(),
1455 ranges.clone(),
1456 ))
1457 })
1458 .collect::<Vec<_>>();
1459
1460 included_files.sort_by(|(_, _, path_a, ranges_a), (_, _, path_b, ranges_b)| {
1461 (path_a, ranges_a.len()).cmp(&(path_b, ranges_b.len()))
1462 });
1463
1464 #[cfg(feature = "eval-support")]
1465 let eval_cache = self.eval_cache.clone();
1466
1467 let request_task = cx.background_spawn({
1468 let active_buffer = active_buffer.clone();
1469 async move {
1470 #[cfg(feature = "eval-support")]
1471 parsed_fut.await;
1472
1473 let index_state = if let Some(index_state) = index_state {
1474 Some(index_state.lock_owned().await)
1475 } else {
1476 None
1477 };
1478
1479 let cursor_offset = position.to_offset(&active_snapshot);
1480 let cursor_point = cursor_offset.to_point(&active_snapshot);
1481
1482 let before_retrieval = Instant::now();
1483
1484 let (diagnostic_groups, diagnostic_groups_truncated) =
1485 Self::gather_nearby_diagnostics(
1486 cursor_offset,
1487 &diagnostics,
1488 &active_snapshot,
1489 options.max_diagnostic_bytes,
1490 );
1491
1492 let cloud_request = match options.context {
1493 ContextMode::Agentic(context_options) => {
1494 let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
1495 cursor_point,
1496 &active_snapshot,
1497 &context_options.excerpt,
1498 index_state.as_deref(),
1499 ) else {
1500 return Ok((None, None));
1501 };
1502
1503 let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start)
1504 ..active_snapshot.anchor_before(excerpt.range.end);
1505
1506 if let Some(buffer_ix) =
1507 included_files.iter().position(|(_, snapshot, _, _)| {
1508 snapshot.remote_id() == active_snapshot.remote_id()
1509 })
1510 {
1511 let (_, buffer, _, ranges) = &mut included_files[buffer_ix];
1512 ranges.push(excerpt_anchor_range);
1513 retrieval_search::merge_anchor_ranges(ranges, buffer);
1514 let last_ix = included_files.len() - 1;
1515 included_files.swap(buffer_ix, last_ix);
1516 } else {
1517 included_files.push((
1518 active_buffer.clone(),
1519 active_snapshot.clone(),
1520 excerpt_path.clone(),
1521 vec![excerpt_anchor_range],
1522 ));
1523 }
1524
1525 let included_files = included_files
1526 .iter()
1527 .map(|(_, snapshot, path, ranges)| {
1528 let ranges = ranges
1529 .iter()
1530 .map(|range| {
1531 let point_range = range.to_point(&snapshot);
1532 Line(point_range.start.row)..Line(point_range.end.row)
1533 })
1534 .collect::<Vec<_>>();
1535 let excerpts = assemble_excerpts(&snapshot, ranges);
1536 predict_edits_v3::IncludedFile {
1537 path: path.clone(),
1538 max_row: Line(snapshot.max_point().row),
1539 excerpts,
1540 }
1541 })
1542 .collect::<Vec<_>>();
1543
1544 predict_edits_v3::PredictEditsRequest {
1545 excerpt_path,
1546 excerpt: String::new(),
1547 excerpt_line_range: Line(0)..Line(0),
1548 excerpt_range: 0..0,
1549 cursor_point: predict_edits_v3::Point {
1550 line: predict_edits_v3::Line(cursor_point.row),
1551 column: cursor_point.column,
1552 },
1553 included_files,
1554 referenced_declarations: vec![],
1555 events,
1556 can_collect_data,
1557 diagnostic_groups,
1558 diagnostic_groups_truncated,
1559 debug_info: debug_tx.is_some(),
1560 prompt_max_bytes: Some(options.max_prompt_bytes),
1561 prompt_format: options.prompt_format,
1562 // TODO [zeta2]
1563 signatures: vec![],
1564 excerpt_parent: None,
1565 git_info: None,
1566 }
1567 }
1568 ContextMode::Syntax(context_options) => {
1569 let Some(context) = EditPredictionContext::gather_context(
1570 cursor_point,
1571 &active_snapshot,
1572 parent_abs_path.as_deref(),
1573 &context_options,
1574 index_state.as_deref(),
1575 ) else {
1576 return Ok((None, None));
1577 };
1578
1579 make_syntax_context_cloud_request(
1580 excerpt_path,
1581 context,
1582 events,
1583 can_collect_data,
1584 diagnostic_groups,
1585 diagnostic_groups_truncated,
1586 None,
1587 debug_tx.is_some(),
1588 &worktree_snapshots,
1589 index_state.as_deref(),
1590 Some(options.max_prompt_bytes),
1591 options.prompt_format,
1592 )
1593 }
1594 };
1595
1596 let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request);
1597
1598 let inputs = EditPredictionInputs {
1599 included_files: cloud_request.included_files,
1600 events: cloud_request.events,
1601 cursor_point: cloud_request.cursor_point,
1602 cursor_path: cloud_request.excerpt_path,
1603 };
1604
1605 let retrieval_time = Instant::now() - before_retrieval;
1606
1607 let debug_response_tx = if let Some(debug_tx) = &debug_tx {
1608 let (response_tx, response_rx) = oneshot::channel();
1609
1610 debug_tx
1611 .unbounded_send(ZetaDebugInfo::EditPredictionRequested(
1612 ZetaEditPredictionDebugInfo {
1613 inputs: inputs.clone(),
1614 retrieval_time,
1615 buffer: active_buffer.downgrade(),
1616 local_prompt: match prompt_result.as_ref() {
1617 Ok((prompt, _)) => Ok(prompt.clone()),
1618 Err(err) => Err(err.to_string()),
1619 },
1620 position,
1621 response_rx,
1622 },
1623 ))
1624 .ok();
1625 Some(response_tx)
1626 } else {
1627 None
1628 };
1629
1630 if cfg!(debug_assertions) && env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
1631 if let Some(debug_response_tx) = debug_response_tx {
1632 debug_response_tx
1633 .send((Err("Request skipped".to_string()), Duration::ZERO))
1634 .ok();
1635 }
1636 anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
1637 }
1638
1639 let (prompt, _) = prompt_result?;
1640 let generation_params =
1641 cloud_zeta2_prompt::generation_params(cloud_request.prompt_format);
1642 let request = open_ai::Request {
1643 model: EDIT_PREDICTIONS_MODEL_ID.clone(),
1644 messages: vec![open_ai::RequestMessage::User {
1645 content: open_ai::MessageContent::Plain(prompt),
1646 }],
1647 stream: false,
1648 max_completion_tokens: None,
1649 stop: generation_params.stop.unwrap_or_default(),
1650 temperature: generation_params.temperature.unwrap_or(0.7),
1651 tool_choice: None,
1652 parallel_tool_calls: None,
1653 tools: vec![],
1654 prompt_cache_key: None,
1655 reasoning_effort: None,
1656 };
1657
1658 log::trace!("Sending edit prediction request");
1659
1660 let before_request = Instant::now();
1661 let response = Self::send_raw_llm_request(
1662 request,
1663 client,
1664 llm_token,
1665 app_version,
1666 #[cfg(feature = "eval-support")]
1667 eval_cache,
1668 #[cfg(feature = "eval-support")]
1669 EvalCacheEntryKind::Prediction,
1670 )
1671 .await;
1672 let received_response_at = Instant::now();
1673 let request_time = received_response_at - before_request;
1674
1675 log::trace!("Got edit prediction response");
1676
1677 if let Some(debug_response_tx) = debug_response_tx {
1678 debug_response_tx
1679 .send((
1680 response
1681 .as_ref()
1682 .map_err(|err| err.to_string())
1683 .map(|response| response.0.clone()),
1684 request_time,
1685 ))
1686 .ok();
1687 }
1688
1689 let (res, usage) = response?;
1690 let request_id = EditPredictionId(res.id.clone().into());
1691 let Some(mut output_text) = text_from_response(res) else {
1692 return Ok((None, usage));
1693 };
1694
1695 if output_text.contains(CURSOR_MARKER) {
1696 log::trace!("Stripping out {CURSOR_MARKER} from response");
1697 output_text = output_text.replace(CURSOR_MARKER, "");
1698 }
1699
1700 let get_buffer_from_context = |path: &Path| {
1701 included_files
1702 .iter()
1703 .find_map(|(_, buffer, probe_path, ranges)| {
1704 if probe_path.as_ref() == path {
1705 Some((buffer, ranges.as_slice()))
1706 } else {
1707 None
1708 }
1709 })
1710 };
1711
1712 let (edited_buffer_snapshot, edits) = match options.prompt_format {
1713 PromptFormat::NumLinesUniDiff => {
1714 // TODO: Implement parsing of multi-file diffs
1715 crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
1716 }
1717 PromptFormat::Minimal
1718 | PromptFormat::MinimalQwen
1719 | PromptFormat::SeedCoder1120 => {
1720 if output_text.contains("--- a/\n+++ b/\nNo edits") {
1721 let edits = vec![];
1722 (&active_snapshot, edits)
1723 } else {
1724 crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
1725 }
1726 }
1727 PromptFormat::OldTextNewText => {
1728 crate::xml_edits::parse_xml_edits(&output_text, get_buffer_from_context)
1729 .await?
1730 }
1731 _ => {
1732 bail!("unsupported prompt format {}", options.prompt_format)
1733 }
1734 };
1735
1736 let edited_buffer = included_files
1737 .iter()
1738 .find_map(|(buffer, snapshot, _, _)| {
1739 if snapshot.remote_id() == edited_buffer_snapshot.remote_id() {
1740 Some(buffer.clone())
1741 } else {
1742 None
1743 }
1744 })
1745 .context("Failed to find buffer in included_buffers")?;
1746
1747 anyhow::Ok((
1748 Some((
1749 request_id,
1750 inputs,
1751 edited_buffer,
1752 edited_buffer_snapshot.clone(),
1753 edits,
1754 received_response_at,
1755 )),
1756 usage,
1757 ))
1758 }
1759 });
1760
1761 cx.spawn({
1762 async move |this, cx| {
1763 let Some((
1764 id,
1765 inputs,
1766 edited_buffer,
1767 edited_buffer_snapshot,
1768 edits,
1769 received_response_at,
1770 )) = Self::handle_api_response(&this, request_task.await, cx)?
1771 else {
1772 return Ok(None);
1773 };
1774
1775 // TODO telemetry: duration, etc
1776 Ok(EditPrediction::new(
1777 id,
1778 &edited_buffer,
1779 &edited_buffer_snapshot,
1780 edits.into(),
1781 buffer_snapshotted_at,
1782 received_response_at,
1783 inputs,
1784 cx,
1785 )
1786 .await)
1787 }
1788 })
1789 }
1790
1791 async fn send_raw_llm_request(
1792 request: open_ai::Request,
1793 client: Arc<Client>,
1794 llm_token: LlmApiToken,
1795 app_version: Version,
1796 #[cfg(feature = "eval-support")] eval_cache: Option<Arc<dyn EvalCache>>,
1797 #[cfg(feature = "eval-support")] eval_cache_kind: EvalCacheEntryKind,
1798 ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
1799 let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
1800 http_client::Url::parse(&predict_edits_url)?
1801 } else {
1802 client
1803 .http_client()
1804 .build_zed_llm_url("/predict_edits/raw", &[])?
1805 };
1806
1807 #[cfg(feature = "eval-support")]
1808 let cache_key = if let Some(cache) = eval_cache {
1809 use collections::FxHasher;
1810 use std::hash::{Hash, Hasher};
1811
1812 let mut hasher = FxHasher::default();
1813 url.hash(&mut hasher);
1814 let request_str = serde_json::to_string_pretty(&request)?;
1815 request_str.hash(&mut hasher);
1816 let hash = hasher.finish();
1817
1818 let key = (eval_cache_kind, hash);
1819 if let Some(response_str) = cache.read(key) {
1820 return Ok((serde_json::from_str(&response_str)?, None));
1821 }
1822
1823 Some((cache, request_str, key))
1824 } else {
1825 None
1826 };
1827
1828 let (response, usage) = Self::send_api_request(
1829 |builder| {
1830 let req = builder
1831 .uri(url.as_ref())
1832 .body(serde_json::to_string(&request)?.into());
1833 Ok(req?)
1834 },
1835 client,
1836 llm_token,
1837 app_version,
1838 )
1839 .await?;
1840
1841 #[cfg(feature = "eval-support")]
1842 if let Some((cache, request, key)) = cache_key {
1843 cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
1844 }
1845
1846 Ok((response, usage))
1847 }
1848
1849 fn handle_api_response<T>(
1850 this: &WeakEntity<Self>,
1851 response: Result<(T, Option<EditPredictionUsage>)>,
1852 cx: &mut gpui::AsyncApp,
1853 ) -> Result<T> {
1854 match response {
1855 Ok((data, usage)) => {
1856 if let Some(usage) = usage {
1857 this.update(cx, |this, cx| {
1858 this.user_store.update(cx, |user_store, cx| {
1859 user_store.update_edit_prediction_usage(usage, cx);
1860 });
1861 })
1862 .ok();
1863 }
1864 Ok(data)
1865 }
1866 Err(err) => {
1867 if err.is::<ZedUpdateRequiredError>() {
1868 cx.update(|cx| {
1869 this.update(cx, |this, _cx| {
1870 this.update_required = true;
1871 })
1872 .ok();
1873
1874 let error_message: SharedString = err.to_string().into();
1875 show_app_notification(
1876 NotificationId::unique::<ZedUpdateRequiredError>(),
1877 cx,
1878 move |cx| {
1879 cx.new(|cx| {
1880 ErrorMessagePrompt::new(error_message.clone(), cx)
1881 .with_link_button("Update Zed", "https://zed.dev/releases")
1882 })
1883 },
1884 );
1885 })
1886 .ok();
1887 }
1888 Err(err)
1889 }
1890 }
1891 }
1892
1893 async fn send_api_request<Res>(
1894 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1895 client: Arc<Client>,
1896 llm_token: LlmApiToken,
1897 app_version: Version,
1898 ) -> Result<(Res, Option<EditPredictionUsage>)>
1899 where
1900 Res: DeserializeOwned,
1901 {
1902 let http_client = client.http_client();
1903 let mut token = llm_token.acquire(&client).await?;
1904 let mut did_retry = false;
1905
1906 loop {
1907 let request_builder = http_client::Request::builder().method(Method::POST);
1908
1909 let request = build(
1910 request_builder
1911 .header("Content-Type", "application/json")
1912 .header("Authorization", format!("Bearer {}", token))
1913 .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
1914 )?;
1915
1916 let mut response = http_client.send(request).await?;
1917
1918 if let Some(minimum_required_version) = response
1919 .headers()
1920 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
1921 .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
1922 {
1923 anyhow::ensure!(
1924 app_version >= minimum_required_version,
1925 ZedUpdateRequiredError {
1926 minimum_version: minimum_required_version
1927 }
1928 );
1929 }
1930
1931 if response.status().is_success() {
1932 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
1933
1934 let mut body = Vec::new();
1935 response.body_mut().read_to_end(&mut body).await?;
1936 return Ok((serde_json::from_slice(&body)?, usage));
1937 } else if !did_retry
1938 && response
1939 .headers()
1940 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1941 .is_some()
1942 {
1943 did_retry = true;
1944 token = llm_token.refresh(&client).await?;
1945 } else {
1946 let mut body = String::new();
1947 response.body_mut().read_to_string(&mut body).await?;
1948 anyhow::bail!(
1949 "Request failed with status: {:?}\nBody: {}",
1950 response.status(),
1951 body
1952 );
1953 }
1954 }
1955 }
1956
1957 pub const CONTEXT_RETRIEVAL_IDLE_DURATION: Duration = Duration::from_secs(10);
1958 pub const CONTEXT_RETRIEVAL_DEBOUNCE_DURATION: Duration = Duration::from_secs(3);
1959
1960 // Refresh the related excerpts when the user just beguns editing after
1961 // an idle period, and after they pause editing.
1962 fn refresh_context_if_needed(
1963 &mut self,
1964 project: &Entity<Project>,
1965 buffer: &Entity<language::Buffer>,
1966 cursor_position: language::Anchor,
1967 cx: &mut Context<Self>,
1968 ) {
1969 if !matches!(&self.options().context, ContextMode::Agentic { .. }) {
1970 return;
1971 }
1972
1973 let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1974 return;
1975 };
1976
1977 let now = Instant::now();
1978 let was_idle = zeta_project
1979 .refresh_context_timestamp
1980 .map_or(true, |timestamp| {
1981 now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION
1982 });
1983 zeta_project.refresh_context_timestamp = Some(now);
1984 zeta_project.refresh_context_debounce_task = Some(cx.spawn({
1985 let buffer = buffer.clone();
1986 let project = project.clone();
1987 async move |this, cx| {
1988 if was_idle {
1989 log::debug!("refetching edit prediction context after idle");
1990 } else {
1991 cx.background_executor()
1992 .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION)
1993 .await;
1994 log::debug!("refetching edit prediction context after pause");
1995 }
1996 this.update(cx, |this, cx| {
1997 let task = this.refresh_context(project.clone(), buffer, cursor_position, cx);
1998
1999 if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
2000 zeta_project.refresh_context_task = Some(task.log_err());
2001 };
2002 })
2003 .ok()
2004 }
2005 }));
2006 }
2007
2008 // Refresh the related excerpts asynchronously. Ensure the task runs to completion,
2009 // and avoid spawning more than one concurrent task.
2010 pub fn refresh_context(
2011 &mut self,
2012 project: Entity<Project>,
2013 buffer: Entity<language::Buffer>,
2014 cursor_position: language::Anchor,
2015 cx: &mut Context<Self>,
2016 ) -> Task<Result<()>> {
2017 let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
2018 return Task::ready(anyhow::Ok(()));
2019 };
2020
2021 let ContextMode::Agentic(options) = &self.options().context else {
2022 return Task::ready(anyhow::Ok(()));
2023 };
2024
2025 let snapshot = buffer.read(cx).snapshot();
2026 let cursor_point = cursor_position.to_point(&snapshot);
2027 let Some(cursor_excerpt) = EditPredictionExcerpt::select_from_buffer(
2028 cursor_point,
2029 &snapshot,
2030 &options.excerpt,
2031 None,
2032 ) else {
2033 return Task::ready(Ok(()));
2034 };
2035
2036 let app_version = AppVersion::global(cx);
2037 let client = self.client.clone();
2038 let llm_token = self.llm_token.clone();
2039 let debug_tx = self.debug_tx.clone();
2040 let current_file_path: Arc<Path> = snapshot
2041 .file()
2042 .map(|f| f.full_path(cx).into())
2043 .unwrap_or_else(|| Path::new("untitled").into());
2044
2045 let prompt = match cloud_zeta2_prompt::retrieval_prompt::build_prompt(
2046 predict_edits_v3::PlanContextRetrievalRequest {
2047 excerpt: cursor_excerpt.text(&snapshot).body,
2048 excerpt_path: current_file_path,
2049 excerpt_line_range: cursor_excerpt.line_range,
2050 cursor_file_max_row: Line(snapshot.max_point().row),
2051 events: zeta_project.events(cx),
2052 },
2053 ) {
2054 Ok(prompt) => prompt,
2055 Err(err) => {
2056 return Task::ready(Err(err));
2057 }
2058 };
2059
2060 if let Some(debug_tx) = &debug_tx {
2061 debug_tx
2062 .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted(
2063 ZetaContextRetrievalStartedDebugInfo {
2064 project: project.clone(),
2065 timestamp: Instant::now(),
2066 search_prompt: prompt.clone(),
2067 },
2068 ))
2069 .ok();
2070 }
2071
2072 pub static TOOL_SCHEMA: LazyLock<(serde_json::Value, String)> = LazyLock::new(|| {
2073 let schema = language_model::tool_schema::root_schema_for::<SearchToolInput>(
2074 language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset,
2075 );
2076
2077 let description = schema
2078 .get("description")
2079 .and_then(|description| description.as_str())
2080 .unwrap()
2081 .to_string();
2082
2083 (schema.into(), description)
2084 });
2085
2086 let (tool_schema, tool_description) = TOOL_SCHEMA.clone();
2087
2088 let request = open_ai::Request {
2089 model: CONTEXT_RETRIEVAL_MODEL_ID.clone(),
2090 messages: vec![open_ai::RequestMessage::User {
2091 content: open_ai::MessageContent::Plain(prompt),
2092 }],
2093 stream: false,
2094 max_completion_tokens: None,
2095 stop: Default::default(),
2096 temperature: 0.7,
2097 tool_choice: None,
2098 parallel_tool_calls: None,
2099 tools: vec![open_ai::ToolDefinition::Function {
2100 function: FunctionDefinition {
2101 name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME.to_string(),
2102 description: Some(tool_description),
2103 parameters: Some(tool_schema),
2104 },
2105 }],
2106 prompt_cache_key: None,
2107 reasoning_effort: None,
2108 };
2109
2110 #[cfg(feature = "eval-support")]
2111 let eval_cache = self.eval_cache.clone();
2112
2113 cx.spawn(async move |this, cx| {
2114 log::trace!("Sending search planning request");
2115 let response = Self::send_raw_llm_request(
2116 request,
2117 client,
2118 llm_token,
2119 app_version,
2120 #[cfg(feature = "eval-support")]
2121 eval_cache.clone(),
2122 #[cfg(feature = "eval-support")]
2123 EvalCacheEntryKind::Context,
2124 )
2125 .await;
2126 let mut response = Self::handle_api_response(&this, response, cx)?;
2127 log::trace!("Got search planning response");
2128
2129 let choice = response
2130 .choices
2131 .pop()
2132 .context("No choices in retrieval response")?;
2133 let open_ai::RequestMessage::Assistant {
2134 content: _,
2135 tool_calls,
2136 } = choice.message
2137 else {
2138 anyhow::bail!("Retrieval response didn't include an assistant message");
2139 };
2140
2141 let mut queries: Vec<SearchToolQuery> = Vec::new();
2142 for tool_call in tool_calls {
2143 let open_ai::ToolCallContent::Function { function } = tool_call.content;
2144 if function.name != cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME {
2145 log::warn!(
2146 "Context retrieval response tried to call an unknown tool: {}",
2147 function.name
2148 );
2149
2150 continue;
2151 }
2152
2153 let input: SearchToolInput = serde_json::from_str(&function.arguments)
2154 .with_context(|| format!("invalid search json {}", &function.arguments))?;
2155 queries.extend(input.queries);
2156 }
2157
2158 if let Some(debug_tx) = &debug_tx {
2159 debug_tx
2160 .unbounded_send(ZetaDebugInfo::SearchQueriesGenerated(
2161 ZetaSearchQueryDebugInfo {
2162 project: project.clone(),
2163 timestamp: Instant::now(),
2164 search_queries: queries.clone(),
2165 },
2166 ))
2167 .ok();
2168 }
2169
2170 log::trace!("Running retrieval search: {queries:#?}");
2171
2172 let related_excerpts_result = retrieval_search::run_retrieval_searches(
2173 queries,
2174 project.clone(),
2175 #[cfg(feature = "eval-support")]
2176 eval_cache,
2177 cx,
2178 )
2179 .await;
2180
2181 log::trace!("Search queries executed");
2182
2183 if let Some(debug_tx) = &debug_tx {
2184 debug_tx
2185 .unbounded_send(ZetaDebugInfo::SearchQueriesExecuted(
2186 ZetaContextRetrievalDebugInfo {
2187 project: project.clone(),
2188 timestamp: Instant::now(),
2189 },
2190 ))
2191 .ok();
2192 }
2193
2194 this.update(cx, |this, _cx| {
2195 let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
2196 return Ok(());
2197 };
2198 zeta_project.refresh_context_task.take();
2199 if let Some(debug_tx) = &this.debug_tx {
2200 debug_tx
2201 .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished(
2202 ZetaContextRetrievalDebugInfo {
2203 project,
2204 timestamp: Instant::now(),
2205 },
2206 ))
2207 .ok();
2208 }
2209 match related_excerpts_result {
2210 Ok(excerpts) => {
2211 zeta_project.context = Some(excerpts);
2212 Ok(())
2213 }
2214 Err(error) => Err(error),
2215 }
2216 })?
2217 })
2218 }
2219
2220 pub fn set_context(
2221 &mut self,
2222 project: Entity<Project>,
2223 context: HashMap<Entity<Buffer>, Vec<Range<Anchor>>>,
2224 ) {
2225 if let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) {
2226 zeta_project.context = Some(context);
2227 }
2228 }
2229
2230 fn gather_nearby_diagnostics(
2231 cursor_offset: usize,
2232 diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
2233 snapshot: &BufferSnapshot,
2234 max_diagnostics_bytes: usize,
2235 ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
2236 // TODO: Could make this more efficient
2237 let mut diagnostic_groups = Vec::new();
2238 for (language_server_id, diagnostics) in diagnostic_sets {
2239 let mut groups = Vec::new();
2240 diagnostics.groups(*language_server_id, &mut groups, &snapshot);
2241 diagnostic_groups.extend(
2242 groups
2243 .into_iter()
2244 .map(|(_, group)| group.resolve::<usize>(&snapshot)),
2245 );
2246 }
2247
2248 // sort by proximity to cursor
2249 diagnostic_groups.sort_by_key(|group| {
2250 let range = &group.entries[group.primary_ix].range;
2251 if range.start >= cursor_offset {
2252 range.start - cursor_offset
2253 } else if cursor_offset >= range.end {
2254 cursor_offset - range.end
2255 } else {
2256 (cursor_offset - range.start).min(range.end - cursor_offset)
2257 }
2258 });
2259
2260 let mut results = Vec::new();
2261 let mut diagnostic_groups_truncated = false;
2262 let mut diagnostics_byte_count = 0;
2263 for group in diagnostic_groups {
2264 let raw_value = serde_json::value::to_raw_value(&group).unwrap();
2265 diagnostics_byte_count += raw_value.get().len();
2266 if diagnostics_byte_count > max_diagnostics_bytes {
2267 diagnostic_groups_truncated = true;
2268 break;
2269 }
2270 results.push(predict_edits_v3::DiagnosticGroup(raw_value));
2271 }
2272
2273 (results, diagnostic_groups_truncated)
2274 }
2275
2276 // TODO: Dedupe with similar code in request_prediction?
2277 pub fn cloud_request_for_zeta_cli(
2278 &mut self,
2279 project: &Entity<Project>,
2280 buffer: &Entity<Buffer>,
2281 position: language::Anchor,
2282 cx: &mut Context<Self>,
2283 ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
2284 let project_state = self.projects.get(&project.entity_id());
2285
2286 let index_state = project_state.and_then(|state| {
2287 state
2288 .syntax_index
2289 .as_ref()
2290 .map(|index| index.read_with(cx, |index, _cx| index.state().clone()))
2291 });
2292 let options = self.options.clone();
2293 let snapshot = buffer.read(cx).snapshot();
2294 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
2295 return Task::ready(Err(anyhow!("No file path for excerpt")));
2296 };
2297 let worktree_snapshots = project
2298 .read(cx)
2299 .worktrees(cx)
2300 .map(|worktree| worktree.read(cx).snapshot())
2301 .collect::<Vec<_>>();
2302
2303 let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
2304 let mut path = f.worktree.read(cx).absolutize(&f.path);
2305 if path.pop() { Some(path) } else { None }
2306 });
2307
2308 cx.background_spawn(async move {
2309 let index_state = if let Some(index_state) = index_state {
2310 Some(index_state.lock_owned().await)
2311 } else {
2312 None
2313 };
2314
2315 let cursor_point = position.to_point(&snapshot);
2316
2317 let debug_info = true;
2318 EditPredictionContext::gather_context(
2319 cursor_point,
2320 &snapshot,
2321 parent_abs_path.as_deref(),
2322 match &options.context {
2323 ContextMode::Agentic(_) => {
2324 // TODO
2325 panic!("Llm mode not supported in zeta cli yet");
2326 }
2327 ContextMode::Syntax(edit_prediction_context_options) => {
2328 edit_prediction_context_options
2329 }
2330 },
2331 index_state.as_deref(),
2332 )
2333 .context("Failed to select excerpt")
2334 .map(|context| {
2335 make_syntax_context_cloud_request(
2336 excerpt_path.into(),
2337 context,
2338 // TODO pass everything
2339 Vec::new(),
2340 false,
2341 Vec::new(),
2342 false,
2343 None,
2344 debug_info,
2345 &worktree_snapshots,
2346 index_state.as_deref(),
2347 Some(options.max_prompt_bytes),
2348 options.prompt_format,
2349 )
2350 })
2351 })
2352 }
2353
2354 pub fn wait_for_initial_indexing(
2355 &mut self,
2356 project: &Entity<Project>,
2357 cx: &mut Context<Self>,
2358 ) -> Task<Result<()>> {
2359 let zeta_project = self.get_or_init_zeta_project(project, cx);
2360 if let Some(syntax_index) = &zeta_project.syntax_index {
2361 syntax_index.read(cx).wait_for_initial_file_indexing(cx)
2362 } else {
2363 Task::ready(Ok(()))
2364 }
2365 }
2366
2367 fn is_file_open_source(
2368 &self,
2369 project: &Entity<Project>,
2370 file: &Arc<dyn File>,
2371 cx: &App,
2372 ) -> bool {
2373 if !file.is_local() || file.is_private() {
2374 return false;
2375 }
2376 let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
2377 return false;
2378 };
2379 zeta_project
2380 .license_detection_watchers
2381 .get(&file.worktree_id(cx))
2382 .as_ref()
2383 .is_some_and(|watcher| watcher.is_project_open_source())
2384 }
2385
2386 fn can_collect_file(&self, project: &Entity<Project>, file: &Arc<dyn File>, cx: &App) -> bool {
2387 self.data_collection_choice.is_enabled() && self.is_file_open_source(project, file, cx)
2388 }
2389
2390 fn can_collect_events(&self, events: &[Arc<Event>]) -> bool {
2391 if !self.data_collection_choice.is_enabled() {
2392 return false;
2393 }
2394 events.iter().all(|event| {
2395 matches!(
2396 event.as_ref(),
2397 Event::BufferChange {
2398 in_open_source_repo: true,
2399 ..
2400 }
2401 )
2402 })
2403 }
2404
2405 fn load_data_collection_choice() -> DataCollectionChoice {
2406 let choice = KEY_VALUE_STORE
2407 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2408 .log_err()
2409 .flatten();
2410
2411 match choice.as_deref() {
2412 Some("true") => DataCollectionChoice::Enabled,
2413 Some("false") => DataCollectionChoice::Disabled,
2414 Some(_) => {
2415 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2416 DataCollectionChoice::NotAnswered
2417 }
2418 None => DataCollectionChoice::NotAnswered,
2419 }
2420 }
2421
2422 pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2423 self.shown_predictions.iter()
2424 }
2425
2426 pub fn shown_completions_len(&self) -> usize {
2427 self.shown_predictions.len()
2428 }
2429
2430 pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2431 self.rated_predictions.contains(id)
2432 }
2433
2434 pub fn rate_prediction(
2435 &mut self,
2436 prediction: &EditPrediction,
2437 rating: EditPredictionRating,
2438 feedback: String,
2439 cx: &mut Context<Self>,
2440 ) {
2441 self.rated_predictions.insert(prediction.id.clone());
2442 telemetry::event!(
2443 "Edit Prediction Rated",
2444 rating,
2445 inputs = prediction.inputs,
2446 output = prediction.edit_preview.as_unified_diff(&prediction.edits),
2447 feedback
2448 );
2449 self.client.telemetry().flush_events().detach();
2450 cx.notify();
2451 }
2452}
2453
2454pub fn text_from_response(mut res: open_ai::Response) -> Option<String> {
2455 let choice = res.choices.pop()?;
2456 let output_text = match choice.message {
2457 open_ai::RequestMessage::Assistant {
2458 content: Some(open_ai::MessageContent::Plain(content)),
2459 ..
2460 } => content,
2461 open_ai::RequestMessage::Assistant {
2462 content: Some(open_ai::MessageContent::Multipart(mut content)),
2463 ..
2464 } => {
2465 if content.is_empty() {
2466 log::error!("No output from Baseten completion response");
2467 return None;
2468 }
2469
2470 match content.remove(0) {
2471 open_ai::MessagePart::Text { text } => text,
2472 open_ai::MessagePart::Image { .. } => {
2473 log::error!("Expected text, got an image");
2474 return None;
2475 }
2476 }
2477 }
2478 _ => {
2479 log::error!("Invalid response message: {:?}", choice.message);
2480 return None;
2481 }
2482 };
2483 Some(output_text)
2484}
2485
2486#[derive(Error, Debug)]
2487#[error(
2488 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2489)]
2490pub struct ZedUpdateRequiredError {
2491 minimum_version: Version,
2492}
2493
2494fn make_syntax_context_cloud_request(
2495 excerpt_path: Arc<Path>,
2496 context: EditPredictionContext,
2497 events: Vec<Arc<predict_edits_v3::Event>>,
2498 can_collect_data: bool,
2499 diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
2500 diagnostic_groups_truncated: bool,
2501 git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
2502 debug_info: bool,
2503 worktrees: &Vec<worktree::Snapshot>,
2504 index_state: Option<&SyntaxIndexState>,
2505 prompt_max_bytes: Option<usize>,
2506 prompt_format: PromptFormat,
2507) -> predict_edits_v3::PredictEditsRequest {
2508 let mut signatures = Vec::new();
2509 let mut declaration_to_signature_index = HashMap::default();
2510 let mut referenced_declarations = Vec::new();
2511
2512 for snippet in context.declarations {
2513 let project_entry_id = snippet.declaration.project_entry_id();
2514 let Some(path) = worktrees.iter().find_map(|worktree| {
2515 worktree.entry_for_id(project_entry_id).map(|entry| {
2516 let mut full_path = RelPathBuf::new();
2517 full_path.push(worktree.root_name());
2518 full_path.push(&entry.path);
2519 full_path
2520 })
2521 }) else {
2522 continue;
2523 };
2524
2525 let parent_index = index_state.and_then(|index_state| {
2526 snippet.declaration.parent().and_then(|parent| {
2527 add_signature(
2528 parent,
2529 &mut declaration_to_signature_index,
2530 &mut signatures,
2531 index_state,
2532 )
2533 })
2534 });
2535
2536 let (text, text_is_truncated) = snippet.declaration.item_text();
2537 referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
2538 path: path.as_std_path().into(),
2539 text: text.into(),
2540 range: snippet.declaration.item_line_range(),
2541 text_is_truncated,
2542 signature_range: snippet.declaration.signature_range_in_item_text(),
2543 parent_index,
2544 signature_score: snippet.score(DeclarationStyle::Signature),
2545 declaration_score: snippet.score(DeclarationStyle::Declaration),
2546 score_components: snippet.components,
2547 });
2548 }
2549
2550 let excerpt_parent = index_state.and_then(|index_state| {
2551 context
2552 .excerpt
2553 .parent_declarations
2554 .last()
2555 .and_then(|(parent, _)| {
2556 add_signature(
2557 *parent,
2558 &mut declaration_to_signature_index,
2559 &mut signatures,
2560 index_state,
2561 )
2562 })
2563 });
2564
2565 predict_edits_v3::PredictEditsRequest {
2566 excerpt_path,
2567 excerpt: context.excerpt_text.body,
2568 excerpt_line_range: context.excerpt.line_range,
2569 excerpt_range: context.excerpt.range,
2570 cursor_point: predict_edits_v3::Point {
2571 line: predict_edits_v3::Line(context.cursor_point.row),
2572 column: context.cursor_point.column,
2573 },
2574 referenced_declarations,
2575 included_files: vec![],
2576 signatures,
2577 excerpt_parent,
2578 events,
2579 can_collect_data,
2580 diagnostic_groups,
2581 diagnostic_groups_truncated,
2582 git_info,
2583 debug_info,
2584 prompt_max_bytes,
2585 prompt_format,
2586 }
2587}
2588
2589fn add_signature(
2590 declaration_id: DeclarationId,
2591 declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
2592 signatures: &mut Vec<Signature>,
2593 index: &SyntaxIndexState,
2594) -> Option<usize> {
2595 if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
2596 return Some(*signature_index);
2597 }
2598 let Some(parent_declaration) = index.declaration(declaration_id) else {
2599 log::error!("bug: missing parent declaration");
2600 return None;
2601 };
2602 let parent_index = parent_declaration.parent().and_then(|parent| {
2603 add_signature(parent, declaration_to_signature_index, signatures, index)
2604 });
2605 let (text, text_is_truncated) = parent_declaration.signature_text();
2606 let signature_index = signatures.len();
2607 signatures.push(Signature {
2608 text: text.into(),
2609 text_is_truncated,
2610 parent_index,
2611 range: parent_declaration.signature_line_range(),
2612 });
2613 declaration_to_signature_index.insert(declaration_id, signature_index);
2614 Some(signature_index)
2615}
2616
2617#[cfg(feature = "eval-support")]
2618pub type EvalCacheKey = (EvalCacheEntryKind, u64);
2619
2620#[cfg(feature = "eval-support")]
2621#[derive(Debug, Clone, Copy, PartialEq)]
2622pub enum EvalCacheEntryKind {
2623 Context,
2624 Search,
2625 Prediction,
2626}
2627
2628#[cfg(feature = "eval-support")]
2629impl std::fmt::Display for EvalCacheEntryKind {
2630 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2631 match self {
2632 EvalCacheEntryKind::Search => write!(f, "search"),
2633 EvalCacheEntryKind::Context => write!(f, "context"),
2634 EvalCacheEntryKind::Prediction => write!(f, "prediction"),
2635 }
2636 }
2637}
2638
2639#[cfg(feature = "eval-support")]
2640pub trait EvalCache: Send + Sync {
2641 fn read(&self, key: EvalCacheKey) -> Option<String>;
2642 fn write(&self, key: EvalCacheKey, input: &str, value: &str);
2643}
2644
2645#[derive(Debug, Clone, Copy)]
2646pub enum DataCollectionChoice {
2647 NotAnswered,
2648 Enabled,
2649 Disabled,
2650}
2651
2652impl DataCollectionChoice {
2653 pub fn is_enabled(self) -> bool {
2654 match self {
2655 Self::Enabled => true,
2656 Self::NotAnswered | Self::Disabled => false,
2657 }
2658 }
2659
2660 pub fn is_answered(self) -> bool {
2661 match self {
2662 Self::Enabled | Self::Disabled => true,
2663 Self::NotAnswered => false,
2664 }
2665 }
2666
2667 #[must_use]
2668 pub fn toggle(&self) -> DataCollectionChoice {
2669 match self {
2670 Self::Enabled => Self::Disabled,
2671 Self::Disabled => Self::Enabled,
2672 Self::NotAnswered => Self::Enabled,
2673 }
2674 }
2675}
2676
2677impl From<bool> for DataCollectionChoice {
2678 fn from(value: bool) -> Self {
2679 match value {
2680 true => DataCollectionChoice::Enabled,
2681 false => DataCollectionChoice::Disabled,
2682 }
2683 }
2684}
2685
2686struct ZedPredictUpsell;
2687
2688impl Dismissable for ZedPredictUpsell {
2689 const KEY: &'static str = "dismissed-edit-predict-upsell";
2690
2691 fn dismissed() -> bool {
2692 // To make this backwards compatible with older versions of Zed, we
2693 // check if the user has seen the previous Edit Prediction Onboarding
2694 // before, by checking the data collection choice which was written to
2695 // the database once the user clicked on "Accept and Enable"
2696 if KEY_VALUE_STORE
2697 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2698 .log_err()
2699 .is_some_and(|s| s.is_some())
2700 {
2701 return true;
2702 }
2703
2704 KEY_VALUE_STORE
2705 .read_kvp(Self::KEY)
2706 .log_err()
2707 .is_some_and(|s| s.is_some())
2708 }
2709}
2710
2711pub fn should_show_upsell_modal() -> bool {
2712 !ZedPredictUpsell::dismissed()
2713}
2714
2715pub fn init(cx: &mut App) {
2716 feature_gate_predict_edits_actions(cx);
2717
2718 cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2719 workspace.register_action(|workspace, _: &RateCompletions, window, cx| {
2720 if cx.has_flag::<PredictEditsRateCompletionsFeatureFlag>() {
2721 RatePredictionsModal::toggle(workspace, window, cx);
2722 }
2723 });
2724
2725 workspace.register_action(
2726 move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2727 ZedPredictModal::toggle(
2728 workspace,
2729 workspace.user_store().clone(),
2730 workspace.client().clone(),
2731 window,
2732 cx,
2733 )
2734 },
2735 );
2736
2737 workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2738 update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2739 settings
2740 .project
2741 .all_languages
2742 .features
2743 .get_or_insert_default()
2744 .edit_prediction_provider = Some(EditPredictionProvider::None)
2745 });
2746 });
2747 })
2748 .detach();
2749}
2750
2751fn feature_gate_predict_edits_actions(cx: &mut App) {
2752 let rate_completion_action_types = [TypeId::of::<RateCompletions>()];
2753 let reset_onboarding_action_types = [TypeId::of::<ResetOnboarding>()];
2754 let zeta_all_action_types = [
2755 TypeId::of::<RateCompletions>(),
2756 TypeId::of::<ResetOnboarding>(),
2757 zed_actions::OpenZedPredictOnboarding.type_id(),
2758 TypeId::of::<ClearHistory>(),
2759 TypeId::of::<ThumbsUpActivePrediction>(),
2760 TypeId::of::<ThumbsDownActivePrediction>(),
2761 TypeId::of::<NextEdit>(),
2762 TypeId::of::<PreviousEdit>(),
2763 ];
2764
2765 CommandPaletteFilter::update_global(cx, |filter, _cx| {
2766 filter.hide_action_types(&rate_completion_action_types);
2767 filter.hide_action_types(&reset_onboarding_action_types);
2768 filter.hide_action_types(&[zed_actions::OpenZedPredictOnboarding.type_id()]);
2769 });
2770
2771 cx.observe_global::<SettingsStore>(move |cx| {
2772 let is_ai_disabled = DisableAiSettings::get_global(cx).disable_ai;
2773 let has_feature_flag = cx.has_flag::<PredictEditsRateCompletionsFeatureFlag>();
2774
2775 CommandPaletteFilter::update_global(cx, |filter, _cx| {
2776 if is_ai_disabled {
2777 filter.hide_action_types(&zeta_all_action_types);
2778 } else if has_feature_flag {
2779 filter.show_action_types(&rate_completion_action_types);
2780 } else {
2781 filter.hide_action_types(&rate_completion_action_types);
2782 }
2783 });
2784 })
2785 .detach();
2786
2787 cx.observe_flag::<PredictEditsRateCompletionsFeatureFlag, _>(move |is_enabled, cx| {
2788 if !DisableAiSettings::get_global(cx).disable_ai {
2789 if is_enabled {
2790 CommandPaletteFilter::update_global(cx, |filter, _cx| {
2791 filter.show_action_types(&rate_completion_action_types);
2792 });
2793 } else {
2794 CommandPaletteFilter::update_global(cx, |filter, _cx| {
2795 filter.hide_action_types(&rate_completion_action_types);
2796 });
2797 }
2798 }
2799 })
2800 .detach();
2801}
2802
2803#[cfg(test)]
2804mod tests {
2805 use std::{path::Path, sync::Arc};
2806
2807 use client::UserStore;
2808 use clock::FakeSystemClock;
2809 use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
2810 use futures::{
2811 AsyncReadExt, StreamExt,
2812 channel::{mpsc, oneshot},
2813 };
2814 use gpui::{
2815 Entity, TestAppContext,
2816 http_client::{FakeHttpClient, Response},
2817 prelude::*,
2818 };
2819 use indoc::indoc;
2820 use language::OffsetRangeExt as _;
2821 use open_ai::Usage;
2822 use pretty_assertions::{assert_eq, assert_matches};
2823 use project::{FakeFs, Project};
2824 use serde_json::json;
2825 use settings::SettingsStore;
2826 use util::path;
2827 use uuid::Uuid;
2828
2829 use crate::{BufferEditPrediction, Zeta};
2830
2831 #[gpui::test]
2832 async fn test_current_state(cx: &mut TestAppContext) {
2833 let (zeta, mut req_rx) = init_test(cx);
2834 let fs = FakeFs::new(cx.executor());
2835 fs.insert_tree(
2836 "/root",
2837 json!({
2838 "1.txt": "Hello!\nHow\nBye\n",
2839 "2.txt": "Hola!\nComo\nAdios\n"
2840 }),
2841 )
2842 .await;
2843 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2844
2845 zeta.update(cx, |zeta, cx| {
2846 zeta.register_project(&project, cx);
2847 });
2848
2849 let buffer1 = project
2850 .update(cx, |project, cx| {
2851 let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
2852 project.open_buffer(path, cx)
2853 })
2854 .await
2855 .unwrap();
2856 let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
2857 let position = snapshot1.anchor_before(language::Point::new(1, 3));
2858
2859 // Prediction for current file
2860
2861 zeta.update(cx, |zeta, cx| {
2862 zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
2863 });
2864 let (_request, respond_tx) = req_rx.next().await.unwrap();
2865
2866 respond_tx
2867 .send(model_response(indoc! {r"
2868 --- a/root/1.txt
2869 +++ b/root/1.txt
2870 @@ ... @@
2871 Hello!
2872 -How
2873 +How are you?
2874 Bye
2875 "}))
2876 .unwrap();
2877
2878 cx.run_until_parked();
2879
2880 zeta.read_with(cx, |zeta, cx| {
2881 let prediction = zeta
2882 .current_prediction_for_buffer(&buffer1, &project, cx)
2883 .unwrap();
2884 assert_matches!(prediction, BufferEditPrediction::Local { .. });
2885 });
2886
2887 // Context refresh
2888 let refresh_task = zeta.update(cx, |zeta, cx| {
2889 zeta.refresh_context(project.clone(), buffer1.clone(), position, cx)
2890 });
2891 let (_request, respond_tx) = req_rx.next().await.unwrap();
2892 respond_tx
2893 .send(open_ai::Response {
2894 id: Uuid::new_v4().to_string(),
2895 object: "response".into(),
2896 created: 0,
2897 model: "model".into(),
2898 choices: vec![open_ai::Choice {
2899 index: 0,
2900 message: open_ai::RequestMessage::Assistant {
2901 content: None,
2902 tool_calls: vec![open_ai::ToolCall {
2903 id: "search".into(),
2904 content: open_ai::ToolCallContent::Function {
2905 function: open_ai::FunctionContent {
2906 name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME
2907 .to_string(),
2908 arguments: serde_json::to_string(&SearchToolInput {
2909 queries: Box::new([SearchToolQuery {
2910 glob: "root/2.txt".to_string(),
2911 syntax_node: vec![],
2912 content: Some(".".into()),
2913 }]),
2914 })
2915 .unwrap(),
2916 },
2917 },
2918 }],
2919 },
2920 finish_reason: None,
2921 }],
2922 usage: Usage {
2923 prompt_tokens: 0,
2924 completion_tokens: 0,
2925 total_tokens: 0,
2926 },
2927 })
2928 .unwrap();
2929 refresh_task.await.unwrap();
2930
2931 zeta.update(cx, |zeta, cx| {
2932 zeta.discard_current_prediction(&project, cx);
2933 });
2934
2935 // Prediction for another file
2936 zeta.update(cx, |zeta, cx| {
2937 zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
2938 });
2939 let (_request, respond_tx) = req_rx.next().await.unwrap();
2940 respond_tx
2941 .send(model_response(indoc! {r#"
2942 --- a/root/2.txt
2943 +++ b/root/2.txt
2944 Hola!
2945 -Como
2946 +Como estas?
2947 Adios
2948 "#}))
2949 .unwrap();
2950 cx.run_until_parked();
2951
2952 zeta.read_with(cx, |zeta, cx| {
2953 let prediction = zeta
2954 .current_prediction_for_buffer(&buffer1, &project, cx)
2955 .unwrap();
2956 assert_matches!(
2957 prediction,
2958 BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
2959 );
2960 });
2961
2962 let buffer2 = project
2963 .update(cx, |project, cx| {
2964 let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
2965 project.open_buffer(path, cx)
2966 })
2967 .await
2968 .unwrap();
2969
2970 zeta.read_with(cx, |zeta, cx| {
2971 let prediction = zeta
2972 .current_prediction_for_buffer(&buffer2, &project, cx)
2973 .unwrap();
2974 assert_matches!(prediction, BufferEditPrediction::Local { .. });
2975 });
2976 }
2977
2978 #[gpui::test]
2979 async fn test_simple_request(cx: &mut TestAppContext) {
2980 let (zeta, mut req_rx) = init_test(cx);
2981 let fs = FakeFs::new(cx.executor());
2982 fs.insert_tree(
2983 "/root",
2984 json!({
2985 "foo.md": "Hello!\nHow\nBye\n"
2986 }),
2987 )
2988 .await;
2989 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2990
2991 let buffer = project
2992 .update(cx, |project, cx| {
2993 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2994 project.open_buffer(path, cx)
2995 })
2996 .await
2997 .unwrap();
2998 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2999 let position = snapshot.anchor_before(language::Point::new(1, 3));
3000
3001 let prediction_task = zeta.update(cx, |zeta, cx| {
3002 zeta.request_prediction(&project, &buffer, position, cx)
3003 });
3004
3005 let (_, respond_tx) = req_rx.next().await.unwrap();
3006
3007 // TODO Put back when we have a structured request again
3008 // assert_eq!(
3009 // request.excerpt_path.as_ref(),
3010 // Path::new(path!("root/foo.md"))
3011 // );
3012 // assert_eq!(
3013 // request.cursor_point,
3014 // Point {
3015 // line: Line(1),
3016 // column: 3
3017 // }
3018 // );
3019
3020 respond_tx
3021 .send(model_response(indoc! { r"
3022 --- a/root/foo.md
3023 +++ b/root/foo.md
3024 @@ ... @@
3025 Hello!
3026 -How
3027 +How are you?
3028 Bye
3029 "}))
3030 .unwrap();
3031
3032 let prediction = prediction_task.await.unwrap().unwrap();
3033
3034 assert_eq!(prediction.edits.len(), 1);
3035 assert_eq!(
3036 prediction.edits[0].0.to_point(&snapshot).start,
3037 language::Point::new(1, 3)
3038 );
3039 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
3040 }
3041
3042 #[gpui::test]
3043 async fn test_request_events(cx: &mut TestAppContext) {
3044 let (zeta, mut req_rx) = init_test(cx);
3045 let fs = FakeFs::new(cx.executor());
3046 fs.insert_tree(
3047 "/root",
3048 json!({
3049 "foo.md": "Hello!\n\nBye\n"
3050 }),
3051 )
3052 .await;
3053 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3054
3055 let buffer = project
3056 .update(cx, |project, cx| {
3057 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3058 project.open_buffer(path, cx)
3059 })
3060 .await
3061 .unwrap();
3062
3063 zeta.update(cx, |zeta, cx| {
3064 zeta.register_buffer(&buffer, &project, cx);
3065 });
3066
3067 buffer.update(cx, |buffer, cx| {
3068 buffer.edit(vec![(7..7, "How")], None, cx);
3069 });
3070
3071 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3072 let position = snapshot.anchor_before(language::Point::new(1, 3));
3073
3074 let prediction_task = zeta.update(cx, |zeta, cx| {
3075 zeta.request_prediction(&project, &buffer, position, cx)
3076 });
3077
3078 let (request, respond_tx) = req_rx.next().await.unwrap();
3079
3080 let prompt = prompt_from_request(&request);
3081 assert!(
3082 prompt.contains(indoc! {"
3083 --- a/root/foo.md
3084 +++ b/root/foo.md
3085 @@ -1,3 +1,3 @@
3086 Hello!
3087 -
3088 +How
3089 Bye
3090 "}),
3091 "{prompt}"
3092 );
3093
3094 respond_tx
3095 .send(model_response(indoc! {r#"
3096 --- a/root/foo.md
3097 +++ b/root/foo.md
3098 @@ ... @@
3099 Hello!
3100 -How
3101 +How are you?
3102 Bye
3103 "#}))
3104 .unwrap();
3105
3106 let prediction = prediction_task.await.unwrap().unwrap();
3107
3108 assert_eq!(prediction.edits.len(), 1);
3109 assert_eq!(
3110 prediction.edits[0].0.to_point(&snapshot).start,
3111 language::Point::new(1, 3)
3112 );
3113 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
3114 }
3115
3116 // Skipped until we start including diagnostics in prompt
3117 // #[gpui::test]
3118 // async fn test_request_diagnostics(cx: &mut TestAppContext) {
3119 // let (zeta, mut req_rx) = init_test(cx);
3120 // let fs = FakeFs::new(cx.executor());
3121 // fs.insert_tree(
3122 // "/root",
3123 // json!({
3124 // "foo.md": "Hello!\nBye"
3125 // }),
3126 // )
3127 // .await;
3128 // let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3129
3130 // let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
3131 // let diagnostic = lsp::Diagnostic {
3132 // range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
3133 // severity: Some(lsp::DiagnosticSeverity::ERROR),
3134 // message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
3135 // ..Default::default()
3136 // };
3137
3138 // project.update(cx, |project, cx| {
3139 // project.lsp_store().update(cx, |lsp_store, cx| {
3140 // // Create some diagnostics
3141 // lsp_store
3142 // .update_diagnostics(
3143 // LanguageServerId(0),
3144 // lsp::PublishDiagnosticsParams {
3145 // uri: path_to_buffer_uri.clone(),
3146 // diagnostics: vec![diagnostic],
3147 // version: None,
3148 // },
3149 // None,
3150 // language::DiagnosticSourceKind::Pushed,
3151 // &[],
3152 // cx,
3153 // )
3154 // .unwrap();
3155 // });
3156 // });
3157
3158 // let buffer = project
3159 // .update(cx, |project, cx| {
3160 // let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3161 // project.open_buffer(path, cx)
3162 // })
3163 // .await
3164 // .unwrap();
3165
3166 // let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3167 // let position = snapshot.anchor_before(language::Point::new(0, 0));
3168
3169 // let _prediction_task = zeta.update(cx, |zeta, cx| {
3170 // zeta.request_prediction(&project, &buffer, position, cx)
3171 // });
3172
3173 // let (request, _respond_tx) = req_rx.next().await.unwrap();
3174
3175 // assert_eq!(request.diagnostic_groups.len(), 1);
3176 // let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
3177 // .unwrap();
3178 // // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
3179 // assert_eq!(
3180 // value,
3181 // json!({
3182 // "entries": [{
3183 // "range": {
3184 // "start": 8,
3185 // "end": 10
3186 // },
3187 // "diagnostic": {
3188 // "source": null,
3189 // "code": null,
3190 // "code_description": null,
3191 // "severity": 1,
3192 // "message": "\"Hello\" deprecated. Use \"Hi\" instead",
3193 // "markdown": null,
3194 // "group_id": 0,
3195 // "is_primary": true,
3196 // "is_disk_based": false,
3197 // "is_unnecessary": false,
3198 // "source_kind": "Pushed",
3199 // "data": null,
3200 // "underline": true
3201 // }
3202 // }],
3203 // "primary_ix": 0
3204 // })
3205 // );
3206 // }
3207
3208 fn model_response(text: &str) -> open_ai::Response {
3209 open_ai::Response {
3210 id: Uuid::new_v4().to_string(),
3211 object: "response".into(),
3212 created: 0,
3213 model: "model".into(),
3214 choices: vec![open_ai::Choice {
3215 index: 0,
3216 message: open_ai::RequestMessage::Assistant {
3217 content: Some(open_ai::MessageContent::Plain(text.to_string())),
3218 tool_calls: vec![],
3219 },
3220 finish_reason: None,
3221 }],
3222 usage: Usage {
3223 prompt_tokens: 0,
3224 completion_tokens: 0,
3225 total_tokens: 0,
3226 },
3227 }
3228 }
3229
3230 fn prompt_from_request(request: &open_ai::Request) -> &str {
3231 assert_eq!(request.messages.len(), 1);
3232 let open_ai::RequestMessage::User {
3233 content: open_ai::MessageContent::Plain(content),
3234 ..
3235 } = &request.messages[0]
3236 else {
3237 panic!(
3238 "Request does not have single user message of type Plain. {:#?}",
3239 request
3240 );
3241 };
3242 content
3243 }
3244
3245 fn init_test(
3246 cx: &mut TestAppContext,
3247 ) -> (
3248 Entity<Zeta>,
3249 mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender<open_ai::Response>)>,
3250 ) {
3251 cx.update(move |cx| {
3252 let settings_store = SettingsStore::test(cx);
3253 cx.set_global(settings_store);
3254 zlog::init_test();
3255
3256 let (req_tx, req_rx) = mpsc::unbounded();
3257
3258 let http_client = FakeHttpClient::create({
3259 move |req| {
3260 let uri = req.uri().path().to_string();
3261 let mut body = req.into_body();
3262 let req_tx = req_tx.clone();
3263 async move {
3264 let resp = match uri.as_str() {
3265 "/client/llm_tokens" => serde_json::to_string(&json!({
3266 "token": "test"
3267 }))
3268 .unwrap(),
3269 "/predict_edits/raw" => {
3270 let mut buf = Vec::new();
3271 body.read_to_end(&mut buf).await.ok();
3272 let req = serde_json::from_slice(&buf).unwrap();
3273
3274 let (res_tx, res_rx) = oneshot::channel();
3275 req_tx.unbounded_send((req, res_tx)).unwrap();
3276 serde_json::to_string(&res_rx.await?).unwrap()
3277 }
3278 _ => {
3279 panic!("Unexpected path: {}", uri)
3280 }
3281 };
3282
3283 Ok(Response::builder().body(resp.into()).unwrap())
3284 }
3285 }
3286 });
3287
3288 let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
3289 client.cloud_client().set_credentials(1, "test".into());
3290
3291 language_model::init(client.clone(), cx);
3292
3293 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
3294 let zeta = Zeta::global(&client, &user_store, cx);
3295
3296 (zeta, req_rx)
3297 })
3298 }
3299}