1use anyhow::{Context as _, Result, anyhow, bail};
2use arrayvec::ArrayVec;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat, Signature};
5use cloud_llm_client::{
6 AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejection,
7 MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
8 RejectEditPredictionsBody, ZED_VERSION_HEADER_NAME,
9};
10use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
11use cloud_zeta2_prompt::{CURSOR_MARKER, DEFAULT_MAX_PROMPT_BYTES};
12use collections::{HashMap, HashSet};
13use command_palette_hooks::CommandPaletteFilter;
14use db::kvp::{Dismissable, KEY_VALUE_STORE};
15use edit_prediction_context::{
16 DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
17 EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionScoreOptions, Line,
18 SyntaxIndex, SyntaxIndexState,
19};
20use feature_flags::{FeatureFlag, FeatureFlagAppExt as _, PredictEditsRateCompletionsFeatureFlag};
21use futures::channel::{mpsc, oneshot};
22use futures::{AsyncReadExt as _, StreamExt as _};
23use gpui::{
24 App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
25 http_client::{self, AsyncBody, Method},
26 prelude::*,
27};
28use language::{
29 Anchor, Buffer, DiagnosticSet, File, LanguageServerId, Point, ToOffset as _, ToPoint,
30};
31use language::{BufferSnapshot, OffsetRangeExt};
32use language_model::{LlmApiToken, RefreshLlmTokenListener};
33use open_ai::FunctionDefinition;
34use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
35use release_channel::AppVersion;
36use semver::Version;
37use serde::de::DeserializeOwned;
38use settings::{EditPredictionProvider, Settings as _, SettingsStore, update_settings_file};
39use std::any::{Any as _, TypeId};
40use std::collections::{VecDeque, hash_map};
41use telemetry_events::EditPredictionRating;
42use workspace::Workspace;
43
44use std::ops::Range;
45use std::path::Path;
46use std::rc::Rc;
47use std::str::FromStr as _;
48use std::sync::{Arc, LazyLock};
49use std::time::{Duration, Instant};
50use std::{env, mem};
51use thiserror::Error;
52use util::rel_path::RelPathBuf;
53use util::{LogErrorFuture, RangeExt as _, ResultExt as _, TryFutureExt};
54use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
55
56pub mod assemble_excerpts;
57mod license_detection;
58mod onboarding_modal;
59mod prediction;
60mod provider;
61mod rate_prediction_modal;
62pub mod retrieval_search;
63mod sweep_ai;
64pub mod udiff;
65mod xml_edits;
66pub mod zeta1;
67
68#[cfg(test)]
69mod zeta_tests;
70
71use crate::assemble_excerpts::assemble_excerpts;
72use crate::license_detection::LicenseDetectionWatcher;
73use crate::onboarding_modal::ZedPredictModal;
74pub use crate::prediction::EditPrediction;
75pub use crate::prediction::EditPredictionId;
76pub use crate::prediction::EditPredictionInputs;
77use crate::rate_prediction_modal::{
78 NextEdit, PreviousEdit, RatePredictionsModal, ThumbsDownActivePrediction,
79 ThumbsUpActivePrediction,
80};
81use crate::sweep_ai::SweepAi;
82use crate::zeta1::request_prediction_with_zeta1;
83pub use provider::ZetaEditPredictionProvider;
84
85actions!(
86 edit_prediction,
87 [
88 /// Resets the edit prediction onboarding state.
89 ResetOnboarding,
90 /// Opens the rate completions modal.
91 RateCompletions,
92 /// Clears the edit prediction history.
93 ClearHistory,
94 ]
95);
96
97/// Maximum number of events to track.
98const EVENT_COUNT_MAX: usize = 6;
99const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
100const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
101
102pub struct SweepFeatureFlag;
103
104impl FeatureFlag for SweepFeatureFlag {
105 const NAME: &str = "sweep-ai";
106}
107pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
108 max_bytes: 512,
109 min_bytes: 128,
110 target_before_cursor_over_total_bytes: 0.5,
111};
112
113pub const DEFAULT_CONTEXT_OPTIONS: ContextMode =
114 ContextMode::Agentic(DEFAULT_AGENTIC_CONTEXT_OPTIONS);
115
116pub const DEFAULT_AGENTIC_CONTEXT_OPTIONS: AgenticContextOptions = AgenticContextOptions {
117 excerpt: DEFAULT_EXCERPT_OPTIONS,
118};
119
120pub const DEFAULT_SYNTAX_CONTEXT_OPTIONS: EditPredictionContextOptions =
121 EditPredictionContextOptions {
122 use_imports: true,
123 max_retrieved_declarations: 0,
124 excerpt: DEFAULT_EXCERPT_OPTIONS,
125 score: EditPredictionScoreOptions {
126 omit_excerpt_overlaps: true,
127 },
128 };
129
130pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
131 context: DEFAULT_CONTEXT_OPTIONS,
132 max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
133 max_diagnostic_bytes: 2048,
134 prompt_format: PromptFormat::DEFAULT,
135 file_indexing_parallelism: 1,
136 buffer_change_grouping_interval: Duration::from_secs(1),
137};
138
139static USE_OLLAMA: LazyLock<bool> =
140 LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
141static CONTEXT_RETRIEVAL_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
142 env::var("ZED_ZETA2_CONTEXT_MODEL").unwrap_or(if *USE_OLLAMA {
143 "qwen3-coder:30b".to_string()
144 } else {
145 "yqvev8r3".to_string()
146 })
147});
148static EDIT_PREDICTIONS_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
149 match env::var("ZED_ZETA2_MODEL").as_deref() {
150 Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten
151 Ok(model) => model,
152 Err(_) if *USE_OLLAMA => "qwen3-coder:30b",
153 Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten
154 }
155 .to_string()
156});
157static PREDICT_EDITS_URL: LazyLock<Option<String>> = LazyLock::new(|| {
158 env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| {
159 if *USE_OLLAMA {
160 Some("http://localhost:11434/v1/chat/completions".into())
161 } else {
162 None
163 }
164 })
165});
166
167pub struct Zeta2FeatureFlag;
168
169impl FeatureFlag for Zeta2FeatureFlag {
170 const NAME: &'static str = "zeta2";
171
172 fn enabled_for_staff() -> bool {
173 true
174 }
175}
176
177#[derive(Clone)]
178struct ZetaGlobal(Entity<Zeta>);
179
180impl Global for ZetaGlobal {}
181
182pub struct Zeta {
183 client: Arc<Client>,
184 user_store: Entity<UserStore>,
185 llm_token: LlmApiToken,
186 _llm_token_subscription: Subscription,
187 projects: HashMap<EntityId, ZetaProject>,
188 options: ZetaOptions,
189 update_required: bool,
190 debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
191 #[cfg(feature = "eval-support")]
192 eval_cache: Option<Arc<dyn EvalCache>>,
193 edit_prediction_model: ZetaEditPredictionModel,
194 sweep_ai: SweepAi,
195 data_collection_choice: DataCollectionChoice,
196 rejected_predictions: Vec<EditPredictionRejection>,
197 reject_predictions_tx: mpsc::UnboundedSender<()>,
198 reject_predictions_debounce_task: Option<Task<()>>,
199 shown_predictions: VecDeque<EditPrediction>,
200 rated_predictions: HashSet<EditPredictionId>,
201}
202
203#[derive(Copy, Clone, Default, PartialEq, Eq)]
204pub enum ZetaEditPredictionModel {
205 #[default]
206 Zeta1,
207 Zeta2,
208 Sweep,
209}
210
211#[derive(Debug, Clone, PartialEq)]
212pub struct ZetaOptions {
213 pub context: ContextMode,
214 pub max_prompt_bytes: usize,
215 pub max_diagnostic_bytes: usize,
216 pub prompt_format: predict_edits_v3::PromptFormat,
217 pub file_indexing_parallelism: usize,
218 pub buffer_change_grouping_interval: Duration,
219}
220
221#[derive(Debug, Clone, PartialEq)]
222pub enum ContextMode {
223 Agentic(AgenticContextOptions),
224 Syntax(EditPredictionContextOptions),
225}
226
227#[derive(Debug, Clone, PartialEq)]
228pub struct AgenticContextOptions {
229 pub excerpt: EditPredictionExcerptOptions,
230}
231
232impl ContextMode {
233 pub fn excerpt(&self) -> &EditPredictionExcerptOptions {
234 match self {
235 ContextMode::Agentic(options) => &options.excerpt,
236 ContextMode::Syntax(options) => &options.excerpt,
237 }
238 }
239}
240
241#[derive(Debug)]
242pub enum ZetaDebugInfo {
243 ContextRetrievalStarted(ZetaContextRetrievalStartedDebugInfo),
244 SearchQueriesGenerated(ZetaSearchQueryDebugInfo),
245 SearchQueriesExecuted(ZetaContextRetrievalDebugInfo),
246 ContextRetrievalFinished(ZetaContextRetrievalDebugInfo),
247 EditPredictionRequested(ZetaEditPredictionDebugInfo),
248}
249
250#[derive(Debug)]
251pub struct ZetaContextRetrievalStartedDebugInfo {
252 pub project: Entity<Project>,
253 pub timestamp: Instant,
254 pub search_prompt: String,
255}
256
257#[derive(Debug)]
258pub struct ZetaContextRetrievalDebugInfo {
259 pub project: Entity<Project>,
260 pub timestamp: Instant,
261}
262
263#[derive(Debug)]
264pub struct ZetaEditPredictionDebugInfo {
265 pub inputs: EditPredictionInputs,
266 pub retrieval_time: Duration,
267 pub buffer: WeakEntity<Buffer>,
268 pub position: language::Anchor,
269 pub local_prompt: Result<String, String>,
270 pub response_rx: oneshot::Receiver<(Result<open_ai::Response, String>, Duration)>,
271}
272
273#[derive(Debug)]
274pub struct ZetaSearchQueryDebugInfo {
275 pub project: Entity<Project>,
276 pub timestamp: Instant,
277 pub search_queries: Vec<SearchToolQuery>,
278}
279
280pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
281
282struct ZetaProject {
283 syntax_index: Option<Entity<SyntaxIndex>>,
284 events: VecDeque<Arc<cloud_llm_client::predict_edits_v3::Event>>,
285 last_event: Option<LastEvent>,
286 recent_paths: VecDeque<ProjectPath>,
287 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
288 current_prediction: Option<CurrentEditPrediction>,
289 next_pending_prediction_id: usize,
290 pending_predictions: ArrayVec<PendingPrediction, 2>,
291 last_prediction_refresh: Option<(EntityId, Instant)>,
292 context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
293 refresh_context_task: Option<LogErrorFuture<Task<Result<()>>>>,
294 refresh_context_debounce_task: Option<Task<Option<()>>>,
295 refresh_context_timestamp: Option<Instant>,
296 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
297 _subscription: gpui::Subscription,
298}
299
300impl ZetaProject {
301 pub fn events(&self, cx: &App) -> Vec<Arc<cloud_llm_client::predict_edits_v3::Event>> {
302 self.events
303 .iter()
304 .cloned()
305 .chain(
306 self.last_event
307 .as_ref()
308 .and_then(|event| event.finalize(&self.license_detection_watchers, cx)),
309 )
310 .collect()
311 }
312}
313
314#[derive(Debug, Clone)]
315struct CurrentEditPrediction {
316 pub requested_by: PredictionRequestedBy,
317 pub prediction: EditPrediction,
318 pub was_shown: bool,
319}
320
321impl CurrentEditPrediction {
322 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
323 let Some(new_edits) = self
324 .prediction
325 .interpolate(&self.prediction.buffer.read(cx))
326 else {
327 return false;
328 };
329
330 if self.prediction.buffer != old_prediction.prediction.buffer {
331 return true;
332 }
333
334 let Some(old_edits) = old_prediction
335 .prediction
336 .interpolate(&old_prediction.prediction.buffer.read(cx))
337 else {
338 return true;
339 };
340
341 let requested_by_buffer_id = self.requested_by.buffer_id();
342
343 // This reduces the occurrence of UI thrash from replacing edits
344 //
345 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
346 if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
347 && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
348 && old_edits.len() == 1
349 && new_edits.len() == 1
350 {
351 let (old_range, old_text) = &old_edits[0];
352 let (new_range, new_text) = &new_edits[0];
353 new_range == old_range && new_text.starts_with(old_text.as_ref())
354 } else {
355 true
356 }
357 }
358}
359
360#[derive(Debug, Clone)]
361enum PredictionRequestedBy {
362 DiagnosticsUpdate,
363 Buffer(EntityId),
364}
365
366impl PredictionRequestedBy {
367 pub fn buffer_id(&self) -> Option<EntityId> {
368 match self {
369 PredictionRequestedBy::DiagnosticsUpdate => None,
370 PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
371 }
372 }
373}
374
375struct PendingPrediction {
376 id: usize,
377 task: Task<Option<EditPredictionId>>,
378}
379
380/// A prediction from the perspective of a buffer.
381#[derive(Debug)]
382enum BufferEditPrediction<'a> {
383 Local { prediction: &'a EditPrediction },
384 Jump { prediction: &'a EditPrediction },
385}
386
387struct RegisteredBuffer {
388 snapshot: BufferSnapshot,
389 _subscriptions: [gpui::Subscription; 2],
390}
391
392struct LastEvent {
393 old_snapshot: BufferSnapshot,
394 new_snapshot: BufferSnapshot,
395 end_edit_anchor: Option<Anchor>,
396}
397
398impl LastEvent {
399 pub fn finalize(
400 &self,
401 license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
402 cx: &App,
403 ) -> Option<Arc<predict_edits_v3::Event>> {
404 let path = buffer_path_with_id_fallback(&self.new_snapshot, cx);
405 let old_path = buffer_path_with_id_fallback(&self.old_snapshot, cx);
406
407 let file = self.new_snapshot.file();
408 let old_file = self.old_snapshot.file();
409
410 let in_open_source_repo = [file, old_file].iter().all(|file| {
411 file.is_some_and(|file| {
412 license_detection_watchers
413 .get(&file.worktree_id(cx))
414 .is_some_and(|watcher| watcher.is_project_open_source())
415 })
416 });
417
418 let diff = language::unified_diff(&self.old_snapshot.text(), &self.new_snapshot.text());
419
420 if path == old_path && diff.is_empty() {
421 None
422 } else {
423 Some(Arc::new(predict_edits_v3::Event::BufferChange {
424 old_path,
425 path,
426 diff,
427 in_open_source_repo,
428 // TODO: Actually detect if this edit was predicted or not
429 predicted: false,
430 }))
431 }
432 }
433}
434
435fn buffer_path_with_id_fallback(snapshot: &BufferSnapshot, cx: &App) -> Arc<Path> {
436 if let Some(file) = snapshot.file() {
437 file.full_path(cx).into()
438 } else {
439 Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
440 }
441}
442
443impl Zeta {
444 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
445 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
446 }
447
448 pub fn global(
449 client: &Arc<Client>,
450 user_store: &Entity<UserStore>,
451 cx: &mut App,
452 ) -> Entity<Self> {
453 cx.try_global::<ZetaGlobal>()
454 .map(|global| global.0.clone())
455 .unwrap_or_else(|| {
456 let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
457 cx.set_global(ZetaGlobal(zeta.clone()));
458 zeta
459 })
460 }
461
462 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
463 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
464 let data_collection_choice = Self::load_data_collection_choice();
465
466 let (reject_tx, mut reject_rx) = mpsc::unbounded();
467 cx.spawn(async move |this, cx| {
468 while let Some(()) = reject_rx.next().await {
469 this.update(cx, |this, cx| this.reject_edit_predictions(cx))?
470 .await
471 .log_err();
472 }
473 anyhow::Ok(())
474 })
475 .detach();
476
477 Self {
478 projects: HashMap::default(),
479 client,
480 user_store,
481 options: DEFAULT_OPTIONS,
482 llm_token: LlmApiToken::default(),
483 _llm_token_subscription: cx.subscribe(
484 &refresh_llm_token_listener,
485 |this, _listener, _event, cx| {
486 let client = this.client.clone();
487 let llm_token = this.llm_token.clone();
488 cx.spawn(async move |_this, _cx| {
489 llm_token.refresh(&client).await?;
490 anyhow::Ok(())
491 })
492 .detach_and_log_err(cx);
493 },
494 ),
495 update_required: false,
496 debug_tx: None,
497 #[cfg(feature = "eval-support")]
498 eval_cache: None,
499 edit_prediction_model: ZetaEditPredictionModel::Zeta2,
500 sweep_ai: SweepAi::new(cx),
501 data_collection_choice,
502 rejected_predictions: Vec::new(),
503 reject_predictions_debounce_task: None,
504 reject_predictions_tx: reject_tx,
505 rated_predictions: Default::default(),
506 shown_predictions: Default::default(),
507 }
508 }
509
510 pub fn set_edit_prediction_model(&mut self, model: ZetaEditPredictionModel) {
511 self.edit_prediction_model = model;
512 }
513
514 pub fn has_sweep_api_token(&self) -> bool {
515 self.sweep_ai.api_token.is_some()
516 }
517
518 #[cfg(feature = "eval-support")]
519 pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
520 self.eval_cache = Some(cache);
521 }
522
523 pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<ZetaDebugInfo> {
524 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
525 self.debug_tx = Some(debug_watch_tx);
526 debug_watch_rx
527 }
528
529 pub fn options(&self) -> &ZetaOptions {
530 &self.options
531 }
532
533 pub fn set_options(&mut self, options: ZetaOptions) {
534 self.options = options;
535 }
536
537 pub fn clear_history(&mut self) {
538 for zeta_project in self.projects.values_mut() {
539 zeta_project.events.clear();
540 }
541 }
542
543 pub fn context_for_project(
544 &self,
545 project: &Entity<Project>,
546 ) -> impl Iterator<Item = (Entity<Buffer>, &[Range<Anchor>])> {
547 self.projects
548 .get(&project.entity_id())
549 .and_then(|project| {
550 Some(
551 project
552 .context
553 .as_ref()?
554 .iter()
555 .map(|(buffer, ranges)| (buffer.clone(), ranges.as_slice())),
556 )
557 })
558 .into_iter()
559 .flatten()
560 }
561
562 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
563 if self.edit_prediction_model == ZetaEditPredictionModel::Zeta2 {
564 self.user_store.read(cx).edit_prediction_usage()
565 } else {
566 None
567 }
568 }
569
570 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
571 self.get_or_init_zeta_project(project, cx);
572 }
573
574 pub fn register_buffer(
575 &mut self,
576 buffer: &Entity<Buffer>,
577 project: &Entity<Project>,
578 cx: &mut Context<Self>,
579 ) {
580 let zeta_project = self.get_or_init_zeta_project(project, cx);
581 Self::register_buffer_impl(zeta_project, buffer, project, cx);
582 }
583
584 fn get_or_init_zeta_project(
585 &mut self,
586 project: &Entity<Project>,
587 cx: &mut Context<Self>,
588 ) -> &mut ZetaProject {
589 self.projects
590 .entry(project.entity_id())
591 .or_insert_with(|| ZetaProject {
592 syntax_index: if let ContextMode::Syntax(_) = &self.options.context {
593 Some(cx.new(|cx| {
594 SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
595 }))
596 } else {
597 None
598 },
599 events: VecDeque::new(),
600 last_event: None,
601 recent_paths: VecDeque::new(),
602 registered_buffers: HashMap::default(),
603 current_prediction: None,
604 pending_predictions: ArrayVec::new(),
605 next_pending_prediction_id: 0,
606 last_prediction_refresh: None,
607 context: None,
608 refresh_context_task: None,
609 refresh_context_debounce_task: None,
610 refresh_context_timestamp: None,
611 license_detection_watchers: HashMap::default(),
612 _subscription: cx.subscribe(&project, Self::handle_project_event),
613 })
614 }
615
616 fn handle_project_event(
617 &mut self,
618 project: Entity<Project>,
619 event: &project::Event,
620 cx: &mut Context<Self>,
621 ) {
622 // TODO [zeta2] init with recent paths
623 match event {
624 project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
625 let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
626 return;
627 };
628 let path = project.read(cx).path_for_entry(*active_entry_id, cx);
629 if let Some(path) = path {
630 if let Some(ix) = zeta_project
631 .recent_paths
632 .iter()
633 .position(|probe| probe == &path)
634 {
635 zeta_project.recent_paths.remove(ix);
636 }
637 zeta_project.recent_paths.push_front(path);
638 }
639 }
640 project::Event::DiagnosticsUpdated { .. } => {
641 if cx.has_flag::<Zeta2FeatureFlag>() {
642 self.refresh_prediction_from_diagnostics(project, cx);
643 }
644 }
645 _ => (),
646 }
647 }
648
649 fn register_buffer_impl<'a>(
650 zeta_project: &'a mut ZetaProject,
651 buffer: &Entity<Buffer>,
652 project: &Entity<Project>,
653 cx: &mut Context<Self>,
654 ) -> &'a mut RegisteredBuffer {
655 let buffer_id = buffer.entity_id();
656
657 if let Some(file) = buffer.read(cx).file() {
658 let worktree_id = file.worktree_id(cx);
659 if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
660 zeta_project
661 .license_detection_watchers
662 .entry(worktree_id)
663 .or_insert_with(|| {
664 let project_entity_id = project.entity_id();
665 cx.observe_release(&worktree, move |this, _worktree, _cx| {
666 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
667 else {
668 return;
669 };
670 zeta_project.license_detection_watchers.remove(&worktree_id);
671 })
672 .detach();
673 Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
674 });
675 }
676 }
677
678 match zeta_project.registered_buffers.entry(buffer_id) {
679 hash_map::Entry::Occupied(entry) => entry.into_mut(),
680 hash_map::Entry::Vacant(entry) => {
681 let snapshot = buffer.read(cx).snapshot();
682 let project_entity_id = project.entity_id();
683 entry.insert(RegisteredBuffer {
684 snapshot,
685 _subscriptions: [
686 cx.subscribe(buffer, {
687 let project = project.downgrade();
688 move |this, buffer, event, cx| {
689 if let language::BufferEvent::Edited = event
690 && let Some(project) = project.upgrade()
691 {
692 this.report_changes_for_buffer(&buffer, &project, cx);
693 }
694 }
695 }),
696 cx.observe_release(buffer, move |this, _buffer, _cx| {
697 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
698 else {
699 return;
700 };
701 zeta_project.registered_buffers.remove(&buffer_id);
702 }),
703 ],
704 })
705 }
706 }
707 }
708
709 fn report_changes_for_buffer(
710 &mut self,
711 buffer: &Entity<Buffer>,
712 project: &Entity<Project>,
713 cx: &mut Context<Self>,
714 ) {
715 let project_state = self.get_or_init_zeta_project(project, cx);
716 let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
717
718 let new_snapshot = buffer.read(cx).snapshot();
719 if new_snapshot.version == registered_buffer.snapshot.version {
720 return;
721 }
722
723 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
724 let end_edit_anchor = new_snapshot
725 .anchored_edits_since::<Point>(&old_snapshot.version)
726 .last()
727 .map(|(_, range)| range.end);
728 let events = &mut project_state.events;
729
730 if let Some(LastEvent {
731 new_snapshot: last_new_snapshot,
732 end_edit_anchor: last_end_edit_anchor,
733 ..
734 }) = project_state.last_event.as_mut()
735 {
736 let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
737 == last_new_snapshot.remote_id()
738 && old_snapshot.version == last_new_snapshot.version;
739
740 let should_coalesce = is_next_snapshot_of_same_buffer
741 && end_edit_anchor
742 .as_ref()
743 .zip(last_end_edit_anchor.as_ref())
744 .is_some_and(|(a, b)| {
745 let a = a.to_point(&new_snapshot);
746 let b = b.to_point(&new_snapshot);
747 a.row.abs_diff(b.row) <= CHANGE_GROUPING_LINE_SPAN
748 });
749
750 if should_coalesce {
751 *last_end_edit_anchor = end_edit_anchor;
752 *last_new_snapshot = new_snapshot;
753 return;
754 }
755 }
756
757 if events.len() + 1 >= EVENT_COUNT_MAX {
758 events.pop_front();
759 }
760
761 if let Some(event) = project_state.last_event.take() {
762 events.extend(event.finalize(&project_state.license_detection_watchers, cx));
763 }
764
765 project_state.last_event = Some(LastEvent {
766 old_snapshot,
767 new_snapshot,
768 end_edit_anchor,
769 });
770 }
771
772 fn current_prediction_for_buffer(
773 &self,
774 buffer: &Entity<Buffer>,
775 project: &Entity<Project>,
776 cx: &App,
777 ) -> Option<BufferEditPrediction<'_>> {
778 let project_state = self.projects.get(&project.entity_id())?;
779
780 let CurrentEditPrediction {
781 requested_by,
782 prediction,
783 ..
784 } = project_state.current_prediction.as_ref()?;
785
786 if prediction.targets_buffer(buffer.read(cx)) {
787 Some(BufferEditPrediction::Local { prediction })
788 } else {
789 let show_jump = match requested_by {
790 PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
791 requested_by_buffer_id == &buffer.entity_id()
792 }
793 PredictionRequestedBy::DiagnosticsUpdate => true,
794 };
795
796 if show_jump {
797 Some(BufferEditPrediction::Jump { prediction })
798 } else {
799 None
800 }
801 }
802 }
803
804 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
805 match self.edit_prediction_model {
806 ZetaEditPredictionModel::Zeta1 | ZetaEditPredictionModel::Zeta2 => {}
807 ZetaEditPredictionModel::Sweep => return,
808 }
809
810 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
811 return;
812 };
813
814 let Some(prediction) = project_state.current_prediction.take() else {
815 return;
816 };
817 let request_id = prediction.prediction.id.to_string();
818 for pending_prediction in mem::take(&mut project_state.pending_predictions) {
819 self.cancel_pending_prediction(pending_prediction, cx);
820 }
821
822 let client = self.client.clone();
823 let llm_token = self.llm_token.clone();
824 let app_version = AppVersion::global(cx);
825 cx.spawn(async move |this, cx| {
826 let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") {
827 http_client::Url::parse(&predict_edits_url)?
828 } else {
829 client
830 .http_client()
831 .build_zed_llm_url("/predict_edits/accept", &[])?
832 };
833
834 let response = cx
835 .background_spawn(Self::send_api_request::<()>(
836 move |builder| {
837 let req = builder.uri(url.as_ref()).body(
838 serde_json::to_string(&AcceptEditPredictionBody {
839 request_id: request_id.clone(),
840 })?
841 .into(),
842 );
843 Ok(req?)
844 },
845 client,
846 llm_token,
847 app_version,
848 ))
849 .await;
850
851 Self::handle_api_response(&this, response, cx)?;
852 anyhow::Ok(())
853 })
854 .detach_and_log_err(cx);
855 }
856
857 fn reject_edit_predictions(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
858 match self.edit_prediction_model {
859 ZetaEditPredictionModel::Zeta1 | ZetaEditPredictionModel::Zeta2 => {}
860 ZetaEditPredictionModel::Sweep => return Task::ready(anyhow::Ok(())),
861 }
862
863 let client = self.client.clone();
864 let llm_token = self.llm_token.clone();
865 let app_version = AppVersion::global(cx);
866 let last_rejection = self.rejected_predictions.last().cloned();
867 let Some(last_rejection) = last_rejection else {
868 return Task::ready(anyhow::Ok(()));
869 };
870
871 let body = serde_json::to_string(&RejectEditPredictionsBody {
872 rejections: self.rejected_predictions.clone(),
873 })
874 .ok();
875
876 cx.spawn(async move |this, cx| {
877 let url = client
878 .http_client()
879 .build_zed_llm_url("/predict_edits/reject", &[])?;
880
881 cx.background_spawn(Self::send_api_request::<()>(
882 move |builder| {
883 let req = builder.uri(url.as_ref()).body(body.clone().into());
884 Ok(req?)
885 },
886 client,
887 llm_token,
888 app_version,
889 ))
890 .await
891 .context("Failed to reject edit predictions")?;
892
893 this.update(cx, |this, _| {
894 if let Some(ix) = this
895 .rejected_predictions
896 .iter()
897 .position(|rejection| rejection.request_id == last_rejection.request_id)
898 {
899 this.rejected_predictions.drain(..ix + 1);
900 }
901 })
902 })
903 }
904
905 fn discard_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
906 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
907 project_state.pending_predictions.clear();
908 if let Some(prediction) = project_state.current_prediction.take() {
909 self.discard_prediction(prediction.prediction.id, prediction.was_shown, cx);
910 }
911 };
912 }
913
914 fn did_show_current_prediction(&mut self, project: &Entity<Project>, _cx: &mut Context<Self>) {
915 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
916 if let Some(current_prediction) = project_state.current_prediction.as_mut() {
917 if !current_prediction.was_shown {
918 current_prediction.was_shown = true;
919 self.shown_predictions
920 .push_front(current_prediction.prediction.clone());
921 if self.shown_predictions.len() > 50 {
922 let completion = self.shown_predictions.pop_back().unwrap();
923 self.rated_predictions.remove(&completion.id);
924 }
925 }
926 }
927 }
928 }
929
930 fn discard_prediction(
931 &mut self,
932 prediction_id: EditPredictionId,
933 was_shown: bool,
934 cx: &mut Context<Self>,
935 ) {
936 self.rejected_predictions.push(EditPredictionRejection {
937 request_id: prediction_id.to_string(),
938 was_shown,
939 });
940
941 let reached_request_limit =
942 self.rejected_predictions.len() >= MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST;
943 let reject_tx = self.reject_predictions_tx.clone();
944 self.reject_predictions_debounce_task = Some(cx.spawn(async move |_this, cx| {
945 const DISCARD_COMPLETIONS_DEBOUNCE: Duration = Duration::from_secs(15);
946 if !reached_request_limit {
947 cx.background_executor()
948 .timer(DISCARD_COMPLETIONS_DEBOUNCE)
949 .await;
950 }
951 reject_tx.unbounded_send(()).log_err();
952 }));
953 }
954
955 fn cancel_pending_prediction(
956 &self,
957 pending_prediction: PendingPrediction,
958 cx: &mut Context<Self>,
959 ) {
960 cx.spawn(async move |this, cx| {
961 let Some(prediction_id) = pending_prediction.task.await else {
962 return;
963 };
964
965 this.update(cx, |this, cx| {
966 this.discard_prediction(prediction_id, false, cx);
967 })
968 .ok();
969 })
970 .detach()
971 }
972
973 fn is_refreshing(&self, project: &Entity<Project>) -> bool {
974 self.projects
975 .get(&project.entity_id())
976 .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
977 }
978
979 pub fn refresh_prediction_from_buffer(
980 &mut self,
981 project: Entity<Project>,
982 buffer: Entity<Buffer>,
983 position: language::Anchor,
984 cx: &mut Context<Self>,
985 ) {
986 self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
987 let Some(request_task) = this
988 .update(cx, |this, cx| {
989 this.request_prediction(&project, &buffer, position, cx)
990 })
991 .log_err()
992 else {
993 return Task::ready(anyhow::Ok(None));
994 };
995
996 let project = project.clone();
997 cx.spawn(async move |cx| {
998 if let Some(prediction) = request_task.await? {
999 let id = prediction.id.clone();
1000 this.update(cx, |this, cx| {
1001 let project_state = this
1002 .projects
1003 .get_mut(&project.entity_id())
1004 .context("Project not found")?;
1005
1006 let new_prediction = CurrentEditPrediction {
1007 requested_by: PredictionRequestedBy::Buffer(buffer.entity_id()),
1008 prediction: prediction,
1009 was_shown: false,
1010 };
1011
1012 if project_state
1013 .current_prediction
1014 .as_ref()
1015 .is_none_or(|old_prediction| {
1016 new_prediction.should_replace_prediction(&old_prediction, cx)
1017 })
1018 {
1019 project_state.current_prediction = Some(new_prediction);
1020 cx.notify();
1021 }
1022 anyhow::Ok(())
1023 })??;
1024 Ok(Some(id))
1025 } else {
1026 Ok(None)
1027 }
1028 })
1029 })
1030 }
1031
1032 pub fn refresh_prediction_from_diagnostics(
1033 &mut self,
1034 project: Entity<Project>,
1035 cx: &mut Context<Self>,
1036 ) {
1037 let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1038 return;
1039 };
1040
1041 // Prefer predictions from buffer
1042 if zeta_project.current_prediction.is_some() {
1043 return;
1044 };
1045
1046 self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
1047 let Some(open_buffer_task) = project
1048 .update(cx, |project, cx| {
1049 project
1050 .active_entry()
1051 .and_then(|entry| project.path_for_entry(entry, cx))
1052 .map(|path| project.open_buffer(path, cx))
1053 })
1054 .log_err()
1055 .flatten()
1056 else {
1057 return Task::ready(anyhow::Ok(None));
1058 };
1059
1060 cx.spawn(async move |cx| {
1061 let active_buffer = open_buffer_task.await?;
1062 let snapshot = active_buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
1063
1064 let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1065 active_buffer,
1066 &snapshot,
1067 Default::default(),
1068 Default::default(),
1069 &project,
1070 cx,
1071 )
1072 .await?
1073 else {
1074 return anyhow::Ok(None);
1075 };
1076
1077 let Some(prediction) = this
1078 .update(cx, |this, cx| {
1079 this.request_prediction(&project, &jump_buffer, jump_position, cx)
1080 })?
1081 .await?
1082 else {
1083 return anyhow::Ok(None);
1084 };
1085
1086 let id = prediction.id.clone();
1087 this.update(cx, |this, cx| {
1088 if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
1089 zeta_project.current_prediction.get_or_insert_with(|| {
1090 cx.notify();
1091 CurrentEditPrediction {
1092 requested_by: PredictionRequestedBy::DiagnosticsUpdate,
1093 prediction,
1094 was_shown: false,
1095 }
1096 });
1097 }
1098 })?;
1099
1100 anyhow::Ok(Some(id))
1101 })
1102 });
1103 }
1104
1105 #[cfg(not(test))]
1106 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1107 #[cfg(test)]
1108 pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
1109
1110 fn queue_prediction_refresh(
1111 &mut self,
1112 project: Entity<Project>,
1113 throttle_entity: EntityId,
1114 cx: &mut Context<Self>,
1115 do_refresh: impl FnOnce(
1116 WeakEntity<Self>,
1117 &mut AsyncApp,
1118 ) -> Task<Result<Option<EditPredictionId>>>
1119 + 'static,
1120 ) {
1121 let zeta_project = self.get_or_init_zeta_project(&project, cx);
1122 let pending_prediction_id = zeta_project.next_pending_prediction_id;
1123 zeta_project.next_pending_prediction_id += 1;
1124 let last_request = zeta_project.last_prediction_refresh;
1125
1126 let task = cx.spawn(async move |this, cx| {
1127 if let Some((last_entity, last_timestamp)) = last_request
1128 && throttle_entity == last_entity
1129 && let Some(timeout) =
1130 (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
1131 {
1132 cx.background_executor().timer(timeout).await;
1133 }
1134
1135 this.update(cx, |this, cx| {
1136 this.get_or_init_zeta_project(&project, cx)
1137 .last_prediction_refresh = Some((throttle_entity, Instant::now()));
1138 })
1139 .ok();
1140
1141 let edit_prediction_id = do_refresh(this.clone(), cx).await.log_err().flatten();
1142
1143 // When a prediction completes, remove it from the pending list, and cancel
1144 // any pending predictions that were enqueued before it.
1145 this.update(cx, |this, cx| {
1146 let zeta_project = this.get_or_init_zeta_project(&project, cx);
1147 let mut pending_predictions = mem::take(&mut zeta_project.pending_predictions);
1148 for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1149 if pending_prediction.id == pending_prediction_id {
1150 pending_predictions.remove(ix);
1151 for pending_prediction in pending_predictions.drain(0..ix) {
1152 this.cancel_pending_prediction(pending_prediction, cx)
1153 }
1154 break;
1155 }
1156 }
1157 this.get_or_init_zeta_project(&project, cx)
1158 .pending_predictions = pending_predictions;
1159 cx.notify();
1160 })
1161 .ok();
1162
1163 edit_prediction_id
1164 });
1165
1166 if zeta_project.pending_predictions.len() <= 1 {
1167 zeta_project.pending_predictions.push(PendingPrediction {
1168 id: pending_prediction_id,
1169 task,
1170 });
1171 } else if zeta_project.pending_predictions.len() == 2 {
1172 let pending_prediction = zeta_project.pending_predictions.pop().unwrap();
1173 zeta_project.pending_predictions.push(PendingPrediction {
1174 id: pending_prediction_id,
1175 task,
1176 });
1177 self.cancel_pending_prediction(pending_prediction, cx);
1178 }
1179 }
1180
1181 pub fn request_prediction(
1182 &mut self,
1183 project: &Entity<Project>,
1184 active_buffer: &Entity<Buffer>,
1185 position: language::Anchor,
1186 cx: &mut Context<Self>,
1187 ) -> Task<Result<Option<EditPrediction>>> {
1188 self.request_prediction_internal(
1189 project.clone(),
1190 active_buffer.clone(),
1191 position,
1192 cx.has_flag::<Zeta2FeatureFlag>(),
1193 cx,
1194 )
1195 }
1196
1197 fn request_prediction_internal(
1198 &mut self,
1199 project: Entity<Project>,
1200 active_buffer: Entity<Buffer>,
1201 position: language::Anchor,
1202 allow_jump: bool,
1203 cx: &mut Context<Self>,
1204 ) -> Task<Result<Option<EditPrediction>>> {
1205 const DIAGNOSTIC_LINES_RANGE: u32 = 20;
1206
1207 self.get_or_init_zeta_project(&project, cx);
1208 let zeta_project = self.projects.get(&project.entity_id()).unwrap();
1209 let events = zeta_project.events(cx);
1210 let has_events = !events.is_empty();
1211
1212 let snapshot = active_buffer.read(cx).snapshot();
1213 let cursor_point = position.to_point(&snapshot);
1214 let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1215 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1216 let diagnostic_search_range =
1217 Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1218
1219 let task = match self.edit_prediction_model {
1220 ZetaEditPredictionModel::Zeta1 => request_prediction_with_zeta1(
1221 self,
1222 &project,
1223 &active_buffer,
1224 snapshot.clone(),
1225 position,
1226 events,
1227 cx,
1228 ),
1229 ZetaEditPredictionModel::Zeta2 => self.request_prediction_with_zeta2(
1230 &project,
1231 &active_buffer,
1232 snapshot.clone(),
1233 position,
1234 events,
1235 cx,
1236 ),
1237 ZetaEditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(
1238 &project,
1239 &active_buffer,
1240 snapshot.clone(),
1241 position,
1242 events,
1243 &zeta_project.recent_paths,
1244 diagnostic_search_range.clone(),
1245 cx,
1246 ),
1247 };
1248
1249 cx.spawn(async move |this, cx| {
1250 let prediction = task
1251 .await?
1252 .filter(|prediction| !prediction.edits.is_empty());
1253
1254 if prediction.is_none() && allow_jump {
1255 let cursor_point = position.to_point(&snapshot);
1256 if has_events
1257 && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1258 active_buffer.clone(),
1259 &snapshot,
1260 diagnostic_search_range,
1261 cursor_point,
1262 &project,
1263 cx,
1264 )
1265 .await?
1266 {
1267 return this
1268 .update(cx, |this, cx| {
1269 this.request_prediction_internal(
1270 project,
1271 jump_buffer,
1272 jump_position,
1273 false,
1274 cx,
1275 )
1276 })?
1277 .await;
1278 }
1279
1280 return anyhow::Ok(None);
1281 }
1282
1283 Ok(prediction)
1284 })
1285 }
1286
1287 async fn next_diagnostic_location(
1288 active_buffer: Entity<Buffer>,
1289 active_buffer_snapshot: &BufferSnapshot,
1290 active_buffer_diagnostic_search_range: Range<Point>,
1291 active_buffer_cursor_point: Point,
1292 project: &Entity<Project>,
1293 cx: &mut AsyncApp,
1294 ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1295 // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
1296 let mut jump_location = active_buffer_snapshot
1297 .diagnostic_groups(None)
1298 .into_iter()
1299 .filter_map(|(_, group)| {
1300 let range = &group.entries[group.primary_ix]
1301 .range
1302 .to_point(&active_buffer_snapshot);
1303 if range.overlaps(&active_buffer_diagnostic_search_range) {
1304 None
1305 } else {
1306 Some(range.start)
1307 }
1308 })
1309 .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
1310 .map(|position| {
1311 (
1312 active_buffer.clone(),
1313 active_buffer_snapshot.anchor_before(position),
1314 )
1315 });
1316
1317 if jump_location.is_none() {
1318 let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
1319 let file = buffer.file()?;
1320
1321 Some(ProjectPath {
1322 worktree_id: file.worktree_id(cx),
1323 path: file.path().clone(),
1324 })
1325 })?;
1326
1327 let buffer_task = project.update(cx, |project, cx| {
1328 let (path, _, _) = project
1329 .diagnostic_summaries(false, cx)
1330 .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
1331 .max_by_key(|(path, _, _)| {
1332 // find the buffer with errors that shares most parent directories
1333 path.path
1334 .components()
1335 .zip(
1336 active_buffer_path
1337 .as_ref()
1338 .map(|p| p.path.components())
1339 .unwrap_or_default(),
1340 )
1341 .take_while(|(a, b)| a == b)
1342 .count()
1343 })?;
1344
1345 Some(project.open_buffer(path, cx))
1346 })?;
1347
1348 if let Some(buffer_task) = buffer_task {
1349 let closest_buffer = buffer_task.await?;
1350
1351 jump_location = closest_buffer
1352 .read_with(cx, |buffer, _cx| {
1353 buffer
1354 .buffer_diagnostics(None)
1355 .into_iter()
1356 .min_by_key(|entry| entry.diagnostic.severity)
1357 .map(|entry| entry.range.start)
1358 })?
1359 .map(|position| (closest_buffer, position));
1360 }
1361 }
1362
1363 anyhow::Ok(jump_location)
1364 }
1365
1366 fn request_prediction_with_zeta2(
1367 &mut self,
1368 project: &Entity<Project>,
1369 active_buffer: &Entity<Buffer>,
1370 active_snapshot: BufferSnapshot,
1371 position: language::Anchor,
1372 events: Vec<Arc<Event>>,
1373 cx: &mut Context<Self>,
1374 ) -> Task<Result<Option<EditPrediction>>> {
1375 let project_state = self.projects.get(&project.entity_id());
1376
1377 let index_state = project_state.and_then(|state| {
1378 state
1379 .syntax_index
1380 .as_ref()
1381 .map(|syntax_index| syntax_index.read_with(cx, |index, _cx| index.state().clone()))
1382 });
1383 let options = self.options.clone();
1384 let buffer_snapshotted_at = Instant::now();
1385 let Some(excerpt_path) = active_snapshot
1386 .file()
1387 .map(|path| -> Arc<Path> { path.full_path(cx).into() })
1388 else {
1389 return Task::ready(Err(anyhow!("No file path for excerpt")));
1390 };
1391 let client = self.client.clone();
1392 let llm_token = self.llm_token.clone();
1393 let app_version = AppVersion::global(cx);
1394 let worktree_snapshots = project
1395 .read(cx)
1396 .worktrees(cx)
1397 .map(|worktree| worktree.read(cx).snapshot())
1398 .collect::<Vec<_>>();
1399 let debug_tx = self.debug_tx.clone();
1400
1401 let diagnostics = active_snapshot.diagnostic_sets().clone();
1402
1403 let file = active_buffer.read(cx).file();
1404 let parent_abs_path = project::File::from_dyn(file).and_then(|f| {
1405 let mut path = f.worktree.read(cx).absolutize(&f.path);
1406 if path.pop() { Some(path) } else { None }
1407 });
1408
1409 // TODO data collection
1410 let can_collect_data = file
1411 .as_ref()
1412 .map_or(false, |file| self.can_collect_file(project, file, cx));
1413
1414 let empty_context_files = HashMap::default();
1415 let context_files = project_state
1416 .and_then(|project_state| project_state.context.as_ref())
1417 .unwrap_or(&empty_context_files);
1418
1419 #[cfg(feature = "eval-support")]
1420 let parsed_fut = futures::future::join_all(
1421 context_files
1422 .keys()
1423 .map(|buffer| buffer.read(cx).parsing_idle()),
1424 );
1425
1426 let mut included_files = context_files
1427 .iter()
1428 .filter_map(|(buffer_entity, ranges)| {
1429 let buffer = buffer_entity.read(cx);
1430 Some((
1431 buffer_entity.clone(),
1432 buffer.snapshot(),
1433 buffer.file()?.full_path(cx).into(),
1434 ranges.clone(),
1435 ))
1436 })
1437 .collect::<Vec<_>>();
1438
1439 included_files.sort_by(|(_, _, path_a, ranges_a), (_, _, path_b, ranges_b)| {
1440 (path_a, ranges_a.len()).cmp(&(path_b, ranges_b.len()))
1441 });
1442
1443 #[cfg(feature = "eval-support")]
1444 let eval_cache = self.eval_cache.clone();
1445
1446 let request_task = cx.background_spawn({
1447 let active_buffer = active_buffer.clone();
1448 async move {
1449 #[cfg(feature = "eval-support")]
1450 parsed_fut.await;
1451
1452 let index_state = if let Some(index_state) = index_state {
1453 Some(index_state.lock_owned().await)
1454 } else {
1455 None
1456 };
1457
1458 let cursor_offset = position.to_offset(&active_snapshot);
1459 let cursor_point = cursor_offset.to_point(&active_snapshot);
1460
1461 let before_retrieval = Instant::now();
1462
1463 let (diagnostic_groups, diagnostic_groups_truncated) =
1464 Self::gather_nearby_diagnostics(
1465 cursor_offset,
1466 &diagnostics,
1467 &active_snapshot,
1468 options.max_diagnostic_bytes,
1469 );
1470
1471 let cloud_request = match options.context {
1472 ContextMode::Agentic(context_options) => {
1473 let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
1474 cursor_point,
1475 &active_snapshot,
1476 &context_options.excerpt,
1477 index_state.as_deref(),
1478 ) else {
1479 return Ok((None, None));
1480 };
1481
1482 let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start)
1483 ..active_snapshot.anchor_before(excerpt.range.end);
1484
1485 if let Some(buffer_ix) =
1486 included_files.iter().position(|(_, snapshot, _, _)| {
1487 snapshot.remote_id() == active_snapshot.remote_id()
1488 })
1489 {
1490 let (_, buffer, _, ranges) = &mut included_files[buffer_ix];
1491 ranges.push(excerpt_anchor_range);
1492 retrieval_search::merge_anchor_ranges(ranges, buffer);
1493 let last_ix = included_files.len() - 1;
1494 included_files.swap(buffer_ix, last_ix);
1495 } else {
1496 included_files.push((
1497 active_buffer.clone(),
1498 active_snapshot.clone(),
1499 excerpt_path.clone(),
1500 vec![excerpt_anchor_range],
1501 ));
1502 }
1503
1504 let included_files = included_files
1505 .iter()
1506 .map(|(_, snapshot, path, ranges)| {
1507 let ranges = ranges
1508 .iter()
1509 .map(|range| {
1510 let point_range = range.to_point(&snapshot);
1511 Line(point_range.start.row)..Line(point_range.end.row)
1512 })
1513 .collect::<Vec<_>>();
1514 let excerpts = assemble_excerpts(&snapshot, ranges);
1515 predict_edits_v3::IncludedFile {
1516 path: path.clone(),
1517 max_row: Line(snapshot.max_point().row),
1518 excerpts,
1519 }
1520 })
1521 .collect::<Vec<_>>();
1522
1523 predict_edits_v3::PredictEditsRequest {
1524 excerpt_path,
1525 excerpt: String::new(),
1526 excerpt_line_range: Line(0)..Line(0),
1527 excerpt_range: 0..0,
1528 cursor_point: predict_edits_v3::Point {
1529 line: predict_edits_v3::Line(cursor_point.row),
1530 column: cursor_point.column,
1531 },
1532 included_files,
1533 referenced_declarations: vec![],
1534 events,
1535 can_collect_data,
1536 diagnostic_groups,
1537 diagnostic_groups_truncated,
1538 debug_info: debug_tx.is_some(),
1539 prompt_max_bytes: Some(options.max_prompt_bytes),
1540 prompt_format: options.prompt_format,
1541 // TODO [zeta2]
1542 signatures: vec![],
1543 excerpt_parent: None,
1544 git_info: None,
1545 }
1546 }
1547 ContextMode::Syntax(context_options) => {
1548 let Some(context) = EditPredictionContext::gather_context(
1549 cursor_point,
1550 &active_snapshot,
1551 parent_abs_path.as_deref(),
1552 &context_options,
1553 index_state.as_deref(),
1554 ) else {
1555 return Ok((None, None));
1556 };
1557
1558 make_syntax_context_cloud_request(
1559 excerpt_path,
1560 context,
1561 events,
1562 can_collect_data,
1563 diagnostic_groups,
1564 diagnostic_groups_truncated,
1565 None,
1566 debug_tx.is_some(),
1567 &worktree_snapshots,
1568 index_state.as_deref(),
1569 Some(options.max_prompt_bytes),
1570 options.prompt_format,
1571 )
1572 }
1573 };
1574
1575 let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request);
1576
1577 let inputs = EditPredictionInputs {
1578 included_files: cloud_request.included_files,
1579 events: cloud_request.events,
1580 cursor_point: cloud_request.cursor_point,
1581 cursor_path: cloud_request.excerpt_path,
1582 };
1583
1584 let retrieval_time = Instant::now() - before_retrieval;
1585
1586 let debug_response_tx = if let Some(debug_tx) = &debug_tx {
1587 let (response_tx, response_rx) = oneshot::channel();
1588
1589 debug_tx
1590 .unbounded_send(ZetaDebugInfo::EditPredictionRequested(
1591 ZetaEditPredictionDebugInfo {
1592 inputs: inputs.clone(),
1593 retrieval_time,
1594 buffer: active_buffer.downgrade(),
1595 local_prompt: match prompt_result.as_ref() {
1596 Ok((prompt, _)) => Ok(prompt.clone()),
1597 Err(err) => Err(err.to_string()),
1598 },
1599 position,
1600 response_rx,
1601 },
1602 ))
1603 .ok();
1604 Some(response_tx)
1605 } else {
1606 None
1607 };
1608
1609 if cfg!(debug_assertions) && env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
1610 if let Some(debug_response_tx) = debug_response_tx {
1611 debug_response_tx
1612 .send((Err("Request skipped".to_string()), Duration::ZERO))
1613 .ok();
1614 }
1615 anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
1616 }
1617
1618 let (prompt, _) = prompt_result?;
1619 let generation_params =
1620 cloud_zeta2_prompt::generation_params(cloud_request.prompt_format);
1621 let request = open_ai::Request {
1622 model: EDIT_PREDICTIONS_MODEL_ID.clone(),
1623 messages: vec![open_ai::RequestMessage::User {
1624 content: open_ai::MessageContent::Plain(prompt),
1625 }],
1626 stream: false,
1627 max_completion_tokens: None,
1628 stop: generation_params.stop.unwrap_or_default(),
1629 temperature: generation_params.temperature.unwrap_or(0.7),
1630 tool_choice: None,
1631 parallel_tool_calls: None,
1632 tools: vec![],
1633 prompt_cache_key: None,
1634 reasoning_effort: None,
1635 };
1636
1637 log::trace!("Sending edit prediction request");
1638
1639 let before_request = Instant::now();
1640 let response = Self::send_raw_llm_request(
1641 request,
1642 client,
1643 llm_token,
1644 app_version,
1645 #[cfg(feature = "eval-support")]
1646 eval_cache,
1647 #[cfg(feature = "eval-support")]
1648 EvalCacheEntryKind::Prediction,
1649 )
1650 .await;
1651 let received_response_at = Instant::now();
1652 let request_time = received_response_at - before_request;
1653
1654 log::trace!("Got edit prediction response");
1655
1656 if let Some(debug_response_tx) = debug_response_tx {
1657 debug_response_tx
1658 .send((
1659 response
1660 .as_ref()
1661 .map_err(|err| err.to_string())
1662 .map(|response| response.0.clone()),
1663 request_time,
1664 ))
1665 .ok();
1666 }
1667
1668 let (res, usage) = response?;
1669 let request_id = EditPredictionId(res.id.clone().into());
1670 let Some(mut output_text) = text_from_response(res) else {
1671 return Ok((None, usage));
1672 };
1673
1674 if output_text.contains(CURSOR_MARKER) {
1675 log::trace!("Stripping out {CURSOR_MARKER} from response");
1676 output_text = output_text.replace(CURSOR_MARKER, "");
1677 }
1678
1679 let get_buffer_from_context = |path: &Path| {
1680 included_files
1681 .iter()
1682 .find_map(|(_, buffer, probe_path, ranges)| {
1683 if probe_path.as_ref() == path {
1684 Some((buffer, ranges.as_slice()))
1685 } else {
1686 None
1687 }
1688 })
1689 };
1690
1691 let (edited_buffer_snapshot, edits) = match options.prompt_format {
1692 PromptFormat::NumLinesUniDiff => {
1693 // TODO: Implement parsing of multi-file diffs
1694 crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
1695 }
1696 PromptFormat::Minimal
1697 | PromptFormat::MinimalQwen
1698 | PromptFormat::SeedCoder1120 => {
1699 if output_text.contains("--- a/\n+++ b/\nNo edits") {
1700 let edits = vec![];
1701 (&active_snapshot, edits)
1702 } else {
1703 crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
1704 }
1705 }
1706 PromptFormat::OldTextNewText => {
1707 crate::xml_edits::parse_xml_edits(&output_text, get_buffer_from_context)
1708 .await?
1709 }
1710 _ => {
1711 bail!("unsupported prompt format {}", options.prompt_format)
1712 }
1713 };
1714
1715 let edited_buffer = included_files
1716 .iter()
1717 .find_map(|(buffer, snapshot, _, _)| {
1718 if snapshot.remote_id() == edited_buffer_snapshot.remote_id() {
1719 Some(buffer.clone())
1720 } else {
1721 None
1722 }
1723 })
1724 .context("Failed to find buffer in included_buffers")?;
1725
1726 anyhow::Ok((
1727 Some((
1728 request_id,
1729 inputs,
1730 edited_buffer,
1731 edited_buffer_snapshot.clone(),
1732 edits,
1733 received_response_at,
1734 )),
1735 usage,
1736 ))
1737 }
1738 });
1739
1740 cx.spawn({
1741 async move |this, cx| {
1742 let Some((
1743 id,
1744 inputs,
1745 edited_buffer,
1746 edited_buffer_snapshot,
1747 edits,
1748 received_response_at,
1749 )) = Self::handle_api_response(&this, request_task.await, cx)?
1750 else {
1751 return Ok(None);
1752 };
1753
1754 // TODO telemetry: duration, etc
1755 Ok(EditPrediction::new(
1756 id,
1757 &edited_buffer,
1758 &edited_buffer_snapshot,
1759 edits.into(),
1760 buffer_snapshotted_at,
1761 received_response_at,
1762 inputs,
1763 cx,
1764 )
1765 .await)
1766 }
1767 })
1768 }
1769
1770 async fn send_raw_llm_request(
1771 request: open_ai::Request,
1772 client: Arc<Client>,
1773 llm_token: LlmApiToken,
1774 app_version: Version,
1775 #[cfg(feature = "eval-support")] eval_cache: Option<Arc<dyn EvalCache>>,
1776 #[cfg(feature = "eval-support")] eval_cache_kind: EvalCacheEntryKind,
1777 ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
1778 let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
1779 http_client::Url::parse(&predict_edits_url)?
1780 } else {
1781 client
1782 .http_client()
1783 .build_zed_llm_url("/predict_edits/raw", &[])?
1784 };
1785
1786 #[cfg(feature = "eval-support")]
1787 let cache_key = if let Some(cache) = eval_cache {
1788 use collections::FxHasher;
1789 use std::hash::{Hash, Hasher};
1790
1791 let mut hasher = FxHasher::default();
1792 url.hash(&mut hasher);
1793 let request_str = serde_json::to_string_pretty(&request)?;
1794 request_str.hash(&mut hasher);
1795 let hash = hasher.finish();
1796
1797 let key = (eval_cache_kind, hash);
1798 if let Some(response_str) = cache.read(key) {
1799 return Ok((serde_json::from_str(&response_str)?, None));
1800 }
1801
1802 Some((cache, request_str, key))
1803 } else {
1804 None
1805 };
1806
1807 let (response, usage) = Self::send_api_request(
1808 |builder| {
1809 let req = builder
1810 .uri(url.as_ref())
1811 .body(serde_json::to_string(&request)?.into());
1812 Ok(req?)
1813 },
1814 client,
1815 llm_token,
1816 app_version,
1817 )
1818 .await?;
1819
1820 #[cfg(feature = "eval-support")]
1821 if let Some((cache, request, key)) = cache_key {
1822 cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
1823 }
1824
1825 Ok((response, usage))
1826 }
1827
1828 fn handle_api_response<T>(
1829 this: &WeakEntity<Self>,
1830 response: Result<(T, Option<EditPredictionUsage>)>,
1831 cx: &mut gpui::AsyncApp,
1832 ) -> Result<T> {
1833 match response {
1834 Ok((data, usage)) => {
1835 if let Some(usage) = usage {
1836 this.update(cx, |this, cx| {
1837 this.user_store.update(cx, |user_store, cx| {
1838 user_store.update_edit_prediction_usage(usage, cx);
1839 });
1840 })
1841 .ok();
1842 }
1843 Ok(data)
1844 }
1845 Err(err) => {
1846 if err.is::<ZedUpdateRequiredError>() {
1847 cx.update(|cx| {
1848 this.update(cx, |this, _cx| {
1849 this.update_required = true;
1850 })
1851 .ok();
1852
1853 let error_message: SharedString = err.to_string().into();
1854 show_app_notification(
1855 NotificationId::unique::<ZedUpdateRequiredError>(),
1856 cx,
1857 move |cx| {
1858 cx.new(|cx| {
1859 ErrorMessagePrompt::new(error_message.clone(), cx)
1860 .with_link_button("Update Zed", "https://zed.dev/releases")
1861 })
1862 },
1863 );
1864 })
1865 .ok();
1866 }
1867 Err(err)
1868 }
1869 }
1870 }
1871
1872 async fn send_api_request<Res>(
1873 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1874 client: Arc<Client>,
1875 llm_token: LlmApiToken,
1876 app_version: Version,
1877 ) -> Result<(Res, Option<EditPredictionUsage>)>
1878 where
1879 Res: DeserializeOwned,
1880 {
1881 let http_client = client.http_client();
1882 let mut token = llm_token.acquire(&client).await?;
1883 let mut did_retry = false;
1884
1885 loop {
1886 let request_builder = http_client::Request::builder().method(Method::POST);
1887
1888 let request = build(
1889 request_builder
1890 .header("Content-Type", "application/json")
1891 .header("Authorization", format!("Bearer {}", token))
1892 .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
1893 )?;
1894
1895 let mut response = http_client.send(request).await?;
1896
1897 if let Some(minimum_required_version) = response
1898 .headers()
1899 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
1900 .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
1901 {
1902 anyhow::ensure!(
1903 app_version >= minimum_required_version,
1904 ZedUpdateRequiredError {
1905 minimum_version: minimum_required_version
1906 }
1907 );
1908 }
1909
1910 if response.status().is_success() {
1911 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
1912
1913 let mut body = Vec::new();
1914 response.body_mut().read_to_end(&mut body).await?;
1915 return Ok((serde_json::from_slice(&body)?, usage));
1916 } else if !did_retry
1917 && response
1918 .headers()
1919 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1920 .is_some()
1921 {
1922 did_retry = true;
1923 token = llm_token.refresh(&client).await?;
1924 } else {
1925 let mut body = String::new();
1926 response.body_mut().read_to_string(&mut body).await?;
1927 anyhow::bail!(
1928 "Request failed with status: {:?}\nBody: {}",
1929 response.status(),
1930 body
1931 );
1932 }
1933 }
1934 }
1935
1936 pub const CONTEXT_RETRIEVAL_IDLE_DURATION: Duration = Duration::from_secs(10);
1937 pub const CONTEXT_RETRIEVAL_DEBOUNCE_DURATION: Duration = Duration::from_secs(3);
1938
1939 // Refresh the related excerpts when the user just beguns editing after
1940 // an idle period, and after they pause editing.
1941 fn refresh_context_if_needed(
1942 &mut self,
1943 project: &Entity<Project>,
1944 buffer: &Entity<language::Buffer>,
1945 cursor_position: language::Anchor,
1946 cx: &mut Context<Self>,
1947 ) {
1948 if !matches!(&self.options().context, ContextMode::Agentic { .. }) {
1949 return;
1950 }
1951
1952 let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1953 return;
1954 };
1955
1956 let now = Instant::now();
1957 let was_idle = zeta_project
1958 .refresh_context_timestamp
1959 .map_or(true, |timestamp| {
1960 now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION
1961 });
1962 zeta_project.refresh_context_timestamp = Some(now);
1963 zeta_project.refresh_context_debounce_task = Some(cx.spawn({
1964 let buffer = buffer.clone();
1965 let project = project.clone();
1966 async move |this, cx| {
1967 if was_idle {
1968 log::debug!("refetching edit prediction context after idle");
1969 } else {
1970 cx.background_executor()
1971 .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION)
1972 .await;
1973 log::debug!("refetching edit prediction context after pause");
1974 }
1975 this.update(cx, |this, cx| {
1976 let task = this.refresh_context(project.clone(), buffer, cursor_position, cx);
1977
1978 if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
1979 zeta_project.refresh_context_task = Some(task.log_err());
1980 };
1981 })
1982 .ok()
1983 }
1984 }));
1985 }
1986
1987 // Refresh the related excerpts asynchronously. Ensure the task runs to completion,
1988 // and avoid spawning more than one concurrent task.
1989 pub fn refresh_context(
1990 &mut self,
1991 project: Entity<Project>,
1992 buffer: Entity<language::Buffer>,
1993 cursor_position: language::Anchor,
1994 cx: &mut Context<Self>,
1995 ) -> Task<Result<()>> {
1996 let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
1997 return Task::ready(anyhow::Ok(()));
1998 };
1999
2000 let ContextMode::Agentic(options) = &self.options().context else {
2001 return Task::ready(anyhow::Ok(()));
2002 };
2003
2004 let snapshot = buffer.read(cx).snapshot();
2005 let cursor_point = cursor_position.to_point(&snapshot);
2006 let Some(cursor_excerpt) = EditPredictionExcerpt::select_from_buffer(
2007 cursor_point,
2008 &snapshot,
2009 &options.excerpt,
2010 None,
2011 ) else {
2012 return Task::ready(Ok(()));
2013 };
2014
2015 let app_version = AppVersion::global(cx);
2016 let client = self.client.clone();
2017 let llm_token = self.llm_token.clone();
2018 let debug_tx = self.debug_tx.clone();
2019 let current_file_path: Arc<Path> = snapshot
2020 .file()
2021 .map(|f| f.full_path(cx).into())
2022 .unwrap_or_else(|| Path::new("untitled").into());
2023
2024 let prompt = match cloud_zeta2_prompt::retrieval_prompt::build_prompt(
2025 predict_edits_v3::PlanContextRetrievalRequest {
2026 excerpt: cursor_excerpt.text(&snapshot).body,
2027 excerpt_path: current_file_path,
2028 excerpt_line_range: cursor_excerpt.line_range,
2029 cursor_file_max_row: Line(snapshot.max_point().row),
2030 events: zeta_project.events(cx),
2031 },
2032 ) {
2033 Ok(prompt) => prompt,
2034 Err(err) => {
2035 return Task::ready(Err(err));
2036 }
2037 };
2038
2039 if let Some(debug_tx) = &debug_tx {
2040 debug_tx
2041 .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted(
2042 ZetaContextRetrievalStartedDebugInfo {
2043 project: project.clone(),
2044 timestamp: Instant::now(),
2045 search_prompt: prompt.clone(),
2046 },
2047 ))
2048 .ok();
2049 }
2050
2051 pub static TOOL_SCHEMA: LazyLock<(serde_json::Value, String)> = LazyLock::new(|| {
2052 let schema = language_model::tool_schema::root_schema_for::<SearchToolInput>(
2053 language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset,
2054 );
2055
2056 let description = schema
2057 .get("description")
2058 .and_then(|description| description.as_str())
2059 .unwrap()
2060 .to_string();
2061
2062 (schema.into(), description)
2063 });
2064
2065 let (tool_schema, tool_description) = TOOL_SCHEMA.clone();
2066
2067 let request = open_ai::Request {
2068 model: CONTEXT_RETRIEVAL_MODEL_ID.clone(),
2069 messages: vec![open_ai::RequestMessage::User {
2070 content: open_ai::MessageContent::Plain(prompt),
2071 }],
2072 stream: false,
2073 max_completion_tokens: None,
2074 stop: Default::default(),
2075 temperature: 0.7,
2076 tool_choice: None,
2077 parallel_tool_calls: None,
2078 tools: vec![open_ai::ToolDefinition::Function {
2079 function: FunctionDefinition {
2080 name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME.to_string(),
2081 description: Some(tool_description),
2082 parameters: Some(tool_schema),
2083 },
2084 }],
2085 prompt_cache_key: None,
2086 reasoning_effort: None,
2087 };
2088
2089 #[cfg(feature = "eval-support")]
2090 let eval_cache = self.eval_cache.clone();
2091
2092 cx.spawn(async move |this, cx| {
2093 log::trace!("Sending search planning request");
2094 let response = Self::send_raw_llm_request(
2095 request,
2096 client,
2097 llm_token,
2098 app_version,
2099 #[cfg(feature = "eval-support")]
2100 eval_cache.clone(),
2101 #[cfg(feature = "eval-support")]
2102 EvalCacheEntryKind::Context,
2103 )
2104 .await;
2105 let mut response = Self::handle_api_response(&this, response, cx)?;
2106 log::trace!("Got search planning response");
2107
2108 let choice = response
2109 .choices
2110 .pop()
2111 .context("No choices in retrieval response")?;
2112 let open_ai::RequestMessage::Assistant {
2113 content: _,
2114 tool_calls,
2115 } = choice.message
2116 else {
2117 anyhow::bail!("Retrieval response didn't include an assistant message");
2118 };
2119
2120 let mut queries: Vec<SearchToolQuery> = Vec::new();
2121 for tool_call in tool_calls {
2122 let open_ai::ToolCallContent::Function { function } = tool_call.content;
2123 if function.name != cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME {
2124 log::warn!(
2125 "Context retrieval response tried to call an unknown tool: {}",
2126 function.name
2127 );
2128
2129 continue;
2130 }
2131
2132 let input: SearchToolInput = serde_json::from_str(&function.arguments)
2133 .with_context(|| format!("invalid search json {}", &function.arguments))?;
2134 queries.extend(input.queries);
2135 }
2136
2137 if let Some(debug_tx) = &debug_tx {
2138 debug_tx
2139 .unbounded_send(ZetaDebugInfo::SearchQueriesGenerated(
2140 ZetaSearchQueryDebugInfo {
2141 project: project.clone(),
2142 timestamp: Instant::now(),
2143 search_queries: queries.clone(),
2144 },
2145 ))
2146 .ok();
2147 }
2148
2149 log::trace!("Running retrieval search: {queries:#?}");
2150
2151 let related_excerpts_result = retrieval_search::run_retrieval_searches(
2152 queries,
2153 project.clone(),
2154 #[cfg(feature = "eval-support")]
2155 eval_cache,
2156 cx,
2157 )
2158 .await;
2159
2160 log::trace!("Search queries executed");
2161
2162 if let Some(debug_tx) = &debug_tx {
2163 debug_tx
2164 .unbounded_send(ZetaDebugInfo::SearchQueriesExecuted(
2165 ZetaContextRetrievalDebugInfo {
2166 project: project.clone(),
2167 timestamp: Instant::now(),
2168 },
2169 ))
2170 .ok();
2171 }
2172
2173 this.update(cx, |this, _cx| {
2174 let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
2175 return Ok(());
2176 };
2177 zeta_project.refresh_context_task.take();
2178 if let Some(debug_tx) = &this.debug_tx {
2179 debug_tx
2180 .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished(
2181 ZetaContextRetrievalDebugInfo {
2182 project,
2183 timestamp: Instant::now(),
2184 },
2185 ))
2186 .ok();
2187 }
2188 match related_excerpts_result {
2189 Ok(excerpts) => {
2190 zeta_project.context = Some(excerpts);
2191 Ok(())
2192 }
2193 Err(error) => Err(error),
2194 }
2195 })?
2196 })
2197 }
2198
2199 pub fn set_context(
2200 &mut self,
2201 project: Entity<Project>,
2202 context: HashMap<Entity<Buffer>, Vec<Range<Anchor>>>,
2203 ) {
2204 if let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) {
2205 zeta_project.context = Some(context);
2206 }
2207 }
2208
2209 fn gather_nearby_diagnostics(
2210 cursor_offset: usize,
2211 diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
2212 snapshot: &BufferSnapshot,
2213 max_diagnostics_bytes: usize,
2214 ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
2215 // TODO: Could make this more efficient
2216 let mut diagnostic_groups = Vec::new();
2217 for (language_server_id, diagnostics) in diagnostic_sets {
2218 let mut groups = Vec::new();
2219 diagnostics.groups(*language_server_id, &mut groups, &snapshot);
2220 diagnostic_groups.extend(
2221 groups
2222 .into_iter()
2223 .map(|(_, group)| group.resolve::<usize>(&snapshot)),
2224 );
2225 }
2226
2227 // sort by proximity to cursor
2228 diagnostic_groups.sort_by_key(|group| {
2229 let range = &group.entries[group.primary_ix].range;
2230 if range.start >= cursor_offset {
2231 range.start - cursor_offset
2232 } else if cursor_offset >= range.end {
2233 cursor_offset - range.end
2234 } else {
2235 (cursor_offset - range.start).min(range.end - cursor_offset)
2236 }
2237 });
2238
2239 let mut results = Vec::new();
2240 let mut diagnostic_groups_truncated = false;
2241 let mut diagnostics_byte_count = 0;
2242 for group in diagnostic_groups {
2243 let raw_value = serde_json::value::to_raw_value(&group).unwrap();
2244 diagnostics_byte_count += raw_value.get().len();
2245 if diagnostics_byte_count > max_diagnostics_bytes {
2246 diagnostic_groups_truncated = true;
2247 break;
2248 }
2249 results.push(predict_edits_v3::DiagnosticGroup(raw_value));
2250 }
2251
2252 (results, diagnostic_groups_truncated)
2253 }
2254
2255 // TODO: Dedupe with similar code in request_prediction?
2256 pub fn cloud_request_for_zeta_cli(
2257 &mut self,
2258 project: &Entity<Project>,
2259 buffer: &Entity<Buffer>,
2260 position: language::Anchor,
2261 cx: &mut Context<Self>,
2262 ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
2263 let project_state = self.projects.get(&project.entity_id());
2264
2265 let index_state = project_state.and_then(|state| {
2266 state
2267 .syntax_index
2268 .as_ref()
2269 .map(|index| index.read_with(cx, |index, _cx| index.state().clone()))
2270 });
2271 let options = self.options.clone();
2272 let snapshot = buffer.read(cx).snapshot();
2273 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
2274 return Task::ready(Err(anyhow!("No file path for excerpt")));
2275 };
2276 let worktree_snapshots = project
2277 .read(cx)
2278 .worktrees(cx)
2279 .map(|worktree| worktree.read(cx).snapshot())
2280 .collect::<Vec<_>>();
2281
2282 let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
2283 let mut path = f.worktree.read(cx).absolutize(&f.path);
2284 if path.pop() { Some(path) } else { None }
2285 });
2286
2287 cx.background_spawn(async move {
2288 let index_state = if let Some(index_state) = index_state {
2289 Some(index_state.lock_owned().await)
2290 } else {
2291 None
2292 };
2293
2294 let cursor_point = position.to_point(&snapshot);
2295
2296 let debug_info = true;
2297 EditPredictionContext::gather_context(
2298 cursor_point,
2299 &snapshot,
2300 parent_abs_path.as_deref(),
2301 match &options.context {
2302 ContextMode::Agentic(_) => {
2303 // TODO
2304 panic!("Llm mode not supported in zeta cli yet");
2305 }
2306 ContextMode::Syntax(edit_prediction_context_options) => {
2307 edit_prediction_context_options
2308 }
2309 },
2310 index_state.as_deref(),
2311 )
2312 .context("Failed to select excerpt")
2313 .map(|context| {
2314 make_syntax_context_cloud_request(
2315 excerpt_path.into(),
2316 context,
2317 // TODO pass everything
2318 Vec::new(),
2319 false,
2320 Vec::new(),
2321 false,
2322 None,
2323 debug_info,
2324 &worktree_snapshots,
2325 index_state.as_deref(),
2326 Some(options.max_prompt_bytes),
2327 options.prompt_format,
2328 )
2329 })
2330 })
2331 }
2332
2333 pub fn wait_for_initial_indexing(
2334 &mut self,
2335 project: &Entity<Project>,
2336 cx: &mut Context<Self>,
2337 ) -> Task<Result<()>> {
2338 let zeta_project = self.get_or_init_zeta_project(project, cx);
2339 if let Some(syntax_index) = &zeta_project.syntax_index {
2340 syntax_index.read(cx).wait_for_initial_file_indexing(cx)
2341 } else {
2342 Task::ready(Ok(()))
2343 }
2344 }
2345
2346 fn is_file_open_source(
2347 &self,
2348 project: &Entity<Project>,
2349 file: &Arc<dyn File>,
2350 cx: &App,
2351 ) -> bool {
2352 if !file.is_local() || file.is_private() {
2353 return false;
2354 }
2355 let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
2356 return false;
2357 };
2358 zeta_project
2359 .license_detection_watchers
2360 .get(&file.worktree_id(cx))
2361 .as_ref()
2362 .is_some_and(|watcher| watcher.is_project_open_source())
2363 }
2364
2365 fn can_collect_file(&self, project: &Entity<Project>, file: &Arc<dyn File>, cx: &App) -> bool {
2366 self.data_collection_choice.is_enabled() && self.is_file_open_source(project, file, cx)
2367 }
2368
2369 fn can_collect_events(&self, events: &[Arc<Event>]) -> bool {
2370 if !self.data_collection_choice.is_enabled() {
2371 return false;
2372 }
2373 events.iter().all(|event| {
2374 matches!(
2375 event.as_ref(),
2376 Event::BufferChange {
2377 in_open_source_repo: true,
2378 ..
2379 }
2380 )
2381 })
2382 }
2383
2384 fn load_data_collection_choice() -> DataCollectionChoice {
2385 let choice = KEY_VALUE_STORE
2386 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2387 .log_err()
2388 .flatten();
2389
2390 match choice.as_deref() {
2391 Some("true") => DataCollectionChoice::Enabled,
2392 Some("false") => DataCollectionChoice::Disabled,
2393 Some(_) => {
2394 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2395 DataCollectionChoice::NotAnswered
2396 }
2397 None => DataCollectionChoice::NotAnswered,
2398 }
2399 }
2400
2401 pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2402 self.shown_predictions.iter()
2403 }
2404
2405 pub fn shown_completions_len(&self) -> usize {
2406 self.shown_predictions.len()
2407 }
2408
2409 pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2410 self.rated_predictions.contains(id)
2411 }
2412
2413 pub fn rate_prediction(
2414 &mut self,
2415 prediction: &EditPrediction,
2416 rating: EditPredictionRating,
2417 feedback: String,
2418 cx: &mut Context<Self>,
2419 ) {
2420 self.rated_predictions.insert(prediction.id.clone());
2421 telemetry::event!(
2422 "Edit Prediction Rated",
2423 rating,
2424 inputs = prediction.inputs,
2425 output = prediction.edit_preview.as_unified_diff(&prediction.edits),
2426 feedback
2427 );
2428 self.client.telemetry().flush_events().detach();
2429 cx.notify();
2430 }
2431}
2432
2433pub fn text_from_response(mut res: open_ai::Response) -> Option<String> {
2434 let choice = res.choices.pop()?;
2435 let output_text = match choice.message {
2436 open_ai::RequestMessage::Assistant {
2437 content: Some(open_ai::MessageContent::Plain(content)),
2438 ..
2439 } => content,
2440 open_ai::RequestMessage::Assistant {
2441 content: Some(open_ai::MessageContent::Multipart(mut content)),
2442 ..
2443 } => {
2444 if content.is_empty() {
2445 log::error!("No output from Baseten completion response");
2446 return None;
2447 }
2448
2449 match content.remove(0) {
2450 open_ai::MessagePart::Text { text } => text,
2451 open_ai::MessagePart::Image { .. } => {
2452 log::error!("Expected text, got an image");
2453 return None;
2454 }
2455 }
2456 }
2457 _ => {
2458 log::error!("Invalid response message: {:?}", choice.message);
2459 return None;
2460 }
2461 };
2462 Some(output_text)
2463}
2464
2465#[derive(Error, Debug)]
2466#[error(
2467 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2468)]
2469pub struct ZedUpdateRequiredError {
2470 minimum_version: Version,
2471}
2472
2473fn make_syntax_context_cloud_request(
2474 excerpt_path: Arc<Path>,
2475 context: EditPredictionContext,
2476 events: Vec<Arc<predict_edits_v3::Event>>,
2477 can_collect_data: bool,
2478 diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
2479 diagnostic_groups_truncated: bool,
2480 git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
2481 debug_info: bool,
2482 worktrees: &Vec<worktree::Snapshot>,
2483 index_state: Option<&SyntaxIndexState>,
2484 prompt_max_bytes: Option<usize>,
2485 prompt_format: PromptFormat,
2486) -> predict_edits_v3::PredictEditsRequest {
2487 let mut signatures = Vec::new();
2488 let mut declaration_to_signature_index = HashMap::default();
2489 let mut referenced_declarations = Vec::new();
2490
2491 for snippet in context.declarations {
2492 let project_entry_id = snippet.declaration.project_entry_id();
2493 let Some(path) = worktrees.iter().find_map(|worktree| {
2494 worktree.entry_for_id(project_entry_id).map(|entry| {
2495 let mut full_path = RelPathBuf::new();
2496 full_path.push(worktree.root_name());
2497 full_path.push(&entry.path);
2498 full_path
2499 })
2500 }) else {
2501 continue;
2502 };
2503
2504 let parent_index = index_state.and_then(|index_state| {
2505 snippet.declaration.parent().and_then(|parent| {
2506 add_signature(
2507 parent,
2508 &mut declaration_to_signature_index,
2509 &mut signatures,
2510 index_state,
2511 )
2512 })
2513 });
2514
2515 let (text, text_is_truncated) = snippet.declaration.item_text();
2516 referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
2517 path: path.as_std_path().into(),
2518 text: text.into(),
2519 range: snippet.declaration.item_line_range(),
2520 text_is_truncated,
2521 signature_range: snippet.declaration.signature_range_in_item_text(),
2522 parent_index,
2523 signature_score: snippet.score(DeclarationStyle::Signature),
2524 declaration_score: snippet.score(DeclarationStyle::Declaration),
2525 score_components: snippet.components,
2526 });
2527 }
2528
2529 let excerpt_parent = index_state.and_then(|index_state| {
2530 context
2531 .excerpt
2532 .parent_declarations
2533 .last()
2534 .and_then(|(parent, _)| {
2535 add_signature(
2536 *parent,
2537 &mut declaration_to_signature_index,
2538 &mut signatures,
2539 index_state,
2540 )
2541 })
2542 });
2543
2544 predict_edits_v3::PredictEditsRequest {
2545 excerpt_path,
2546 excerpt: context.excerpt_text.body,
2547 excerpt_line_range: context.excerpt.line_range,
2548 excerpt_range: context.excerpt.range,
2549 cursor_point: predict_edits_v3::Point {
2550 line: predict_edits_v3::Line(context.cursor_point.row),
2551 column: context.cursor_point.column,
2552 },
2553 referenced_declarations,
2554 included_files: vec![],
2555 signatures,
2556 excerpt_parent,
2557 events,
2558 can_collect_data,
2559 diagnostic_groups,
2560 diagnostic_groups_truncated,
2561 git_info,
2562 debug_info,
2563 prompt_max_bytes,
2564 prompt_format,
2565 }
2566}
2567
2568fn add_signature(
2569 declaration_id: DeclarationId,
2570 declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
2571 signatures: &mut Vec<Signature>,
2572 index: &SyntaxIndexState,
2573) -> Option<usize> {
2574 if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
2575 return Some(*signature_index);
2576 }
2577 let Some(parent_declaration) = index.declaration(declaration_id) else {
2578 log::error!("bug: missing parent declaration");
2579 return None;
2580 };
2581 let parent_index = parent_declaration.parent().and_then(|parent| {
2582 add_signature(parent, declaration_to_signature_index, signatures, index)
2583 });
2584 let (text, text_is_truncated) = parent_declaration.signature_text();
2585 let signature_index = signatures.len();
2586 signatures.push(Signature {
2587 text: text.into(),
2588 text_is_truncated,
2589 parent_index,
2590 range: parent_declaration.signature_line_range(),
2591 });
2592 declaration_to_signature_index.insert(declaration_id, signature_index);
2593 Some(signature_index)
2594}
2595
2596#[cfg(feature = "eval-support")]
2597pub type EvalCacheKey = (EvalCacheEntryKind, u64);
2598
2599#[cfg(feature = "eval-support")]
2600#[derive(Debug, Clone, Copy, PartialEq)]
2601pub enum EvalCacheEntryKind {
2602 Context,
2603 Search,
2604 Prediction,
2605}
2606
2607#[cfg(feature = "eval-support")]
2608impl std::fmt::Display for EvalCacheEntryKind {
2609 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2610 match self {
2611 EvalCacheEntryKind::Search => write!(f, "search"),
2612 EvalCacheEntryKind::Context => write!(f, "context"),
2613 EvalCacheEntryKind::Prediction => write!(f, "prediction"),
2614 }
2615 }
2616}
2617
2618#[cfg(feature = "eval-support")]
2619pub trait EvalCache: Send + Sync {
2620 fn read(&self, key: EvalCacheKey) -> Option<String>;
2621 fn write(&self, key: EvalCacheKey, input: &str, value: &str);
2622}
2623
2624#[derive(Debug, Clone, Copy)]
2625pub enum DataCollectionChoice {
2626 NotAnswered,
2627 Enabled,
2628 Disabled,
2629}
2630
2631impl DataCollectionChoice {
2632 pub fn is_enabled(self) -> bool {
2633 match self {
2634 Self::Enabled => true,
2635 Self::NotAnswered | Self::Disabled => false,
2636 }
2637 }
2638
2639 pub fn is_answered(self) -> bool {
2640 match self {
2641 Self::Enabled | Self::Disabled => true,
2642 Self::NotAnswered => false,
2643 }
2644 }
2645
2646 #[must_use]
2647 pub fn toggle(&self) -> DataCollectionChoice {
2648 match self {
2649 Self::Enabled => Self::Disabled,
2650 Self::Disabled => Self::Enabled,
2651 Self::NotAnswered => Self::Enabled,
2652 }
2653 }
2654}
2655
2656impl From<bool> for DataCollectionChoice {
2657 fn from(value: bool) -> Self {
2658 match value {
2659 true => DataCollectionChoice::Enabled,
2660 false => DataCollectionChoice::Disabled,
2661 }
2662 }
2663}
2664
2665struct ZedPredictUpsell;
2666
2667impl Dismissable for ZedPredictUpsell {
2668 const KEY: &'static str = "dismissed-edit-predict-upsell";
2669
2670 fn dismissed() -> bool {
2671 // To make this backwards compatible with older versions of Zed, we
2672 // check if the user has seen the previous Edit Prediction Onboarding
2673 // before, by checking the data collection choice which was written to
2674 // the database once the user clicked on "Accept and Enable"
2675 if KEY_VALUE_STORE
2676 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2677 .log_err()
2678 .is_some_and(|s| s.is_some())
2679 {
2680 return true;
2681 }
2682
2683 KEY_VALUE_STORE
2684 .read_kvp(Self::KEY)
2685 .log_err()
2686 .is_some_and(|s| s.is_some())
2687 }
2688}
2689
2690pub fn should_show_upsell_modal() -> bool {
2691 !ZedPredictUpsell::dismissed()
2692}
2693
2694pub fn init(cx: &mut App) {
2695 feature_gate_predict_edits_actions(cx);
2696
2697 cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2698 workspace.register_action(|workspace, _: &RateCompletions, window, cx| {
2699 if cx.has_flag::<PredictEditsRateCompletionsFeatureFlag>() {
2700 RatePredictionsModal::toggle(workspace, window, cx);
2701 }
2702 });
2703
2704 workspace.register_action(
2705 move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2706 ZedPredictModal::toggle(
2707 workspace,
2708 workspace.user_store().clone(),
2709 workspace.client().clone(),
2710 window,
2711 cx,
2712 )
2713 },
2714 );
2715
2716 workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2717 update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2718 settings
2719 .project
2720 .all_languages
2721 .features
2722 .get_or_insert_default()
2723 .edit_prediction_provider = Some(EditPredictionProvider::None)
2724 });
2725 });
2726 })
2727 .detach();
2728}
2729
2730fn feature_gate_predict_edits_actions(cx: &mut App) {
2731 let rate_completion_action_types = [TypeId::of::<RateCompletions>()];
2732 let reset_onboarding_action_types = [TypeId::of::<ResetOnboarding>()];
2733 let zeta_all_action_types = [
2734 TypeId::of::<RateCompletions>(),
2735 TypeId::of::<ResetOnboarding>(),
2736 zed_actions::OpenZedPredictOnboarding.type_id(),
2737 TypeId::of::<ClearHistory>(),
2738 TypeId::of::<ThumbsUpActivePrediction>(),
2739 TypeId::of::<ThumbsDownActivePrediction>(),
2740 TypeId::of::<NextEdit>(),
2741 TypeId::of::<PreviousEdit>(),
2742 ];
2743
2744 CommandPaletteFilter::update_global(cx, |filter, _cx| {
2745 filter.hide_action_types(&rate_completion_action_types);
2746 filter.hide_action_types(&reset_onboarding_action_types);
2747 filter.hide_action_types(&[zed_actions::OpenZedPredictOnboarding.type_id()]);
2748 });
2749
2750 cx.observe_global::<SettingsStore>(move |cx| {
2751 let is_ai_disabled = DisableAiSettings::get_global(cx).disable_ai;
2752 let has_feature_flag = cx.has_flag::<PredictEditsRateCompletionsFeatureFlag>();
2753
2754 CommandPaletteFilter::update_global(cx, |filter, _cx| {
2755 if is_ai_disabled {
2756 filter.hide_action_types(&zeta_all_action_types);
2757 } else if has_feature_flag {
2758 filter.show_action_types(&rate_completion_action_types);
2759 } else {
2760 filter.hide_action_types(&rate_completion_action_types);
2761 }
2762 });
2763 })
2764 .detach();
2765
2766 cx.observe_flag::<PredictEditsRateCompletionsFeatureFlag, _>(move |is_enabled, cx| {
2767 if !DisableAiSettings::get_global(cx).disable_ai {
2768 if is_enabled {
2769 CommandPaletteFilter::update_global(cx, |filter, _cx| {
2770 filter.show_action_types(&rate_completion_action_types);
2771 });
2772 } else {
2773 CommandPaletteFilter::update_global(cx, |filter, _cx| {
2774 filter.hide_action_types(&rate_completion_action_types);
2775 });
2776 }
2777 }
2778 })
2779 .detach();
2780}
2781
2782#[cfg(test)]
2783mod tests {
2784 use std::{path::Path, sync::Arc};
2785
2786 use client::UserStore;
2787 use clock::FakeSystemClock;
2788 use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
2789 use futures::{
2790 AsyncReadExt, StreamExt,
2791 channel::{mpsc, oneshot},
2792 };
2793 use gpui::{
2794 Entity, TestAppContext,
2795 http_client::{FakeHttpClient, Response},
2796 prelude::*,
2797 };
2798 use indoc::indoc;
2799 use language::OffsetRangeExt as _;
2800 use open_ai::Usage;
2801 use pretty_assertions::{assert_eq, assert_matches};
2802 use project::{FakeFs, Project};
2803 use serde_json::json;
2804 use settings::SettingsStore;
2805 use util::path;
2806 use uuid::Uuid;
2807
2808 use crate::{BufferEditPrediction, Zeta};
2809
2810 #[gpui::test]
2811 async fn test_current_state(cx: &mut TestAppContext) {
2812 let (zeta, mut req_rx) = init_test(cx);
2813 let fs = FakeFs::new(cx.executor());
2814 fs.insert_tree(
2815 "/root",
2816 json!({
2817 "1.txt": "Hello!\nHow\nBye\n",
2818 "2.txt": "Hola!\nComo\nAdios\n"
2819 }),
2820 )
2821 .await;
2822 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2823
2824 zeta.update(cx, |zeta, cx| {
2825 zeta.register_project(&project, cx);
2826 });
2827
2828 let buffer1 = project
2829 .update(cx, |project, cx| {
2830 let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
2831 project.open_buffer(path, cx)
2832 })
2833 .await
2834 .unwrap();
2835 let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
2836 let position = snapshot1.anchor_before(language::Point::new(1, 3));
2837
2838 // Prediction for current file
2839
2840 zeta.update(cx, |zeta, cx| {
2841 zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
2842 });
2843 let (_request, respond_tx) = req_rx.next().await.unwrap();
2844
2845 respond_tx
2846 .send(model_response(indoc! {r"
2847 --- a/root/1.txt
2848 +++ b/root/1.txt
2849 @@ ... @@
2850 Hello!
2851 -How
2852 +How are you?
2853 Bye
2854 "}))
2855 .unwrap();
2856
2857 cx.run_until_parked();
2858
2859 zeta.read_with(cx, |zeta, cx| {
2860 let prediction = zeta
2861 .current_prediction_for_buffer(&buffer1, &project, cx)
2862 .unwrap();
2863 assert_matches!(prediction, BufferEditPrediction::Local { .. });
2864 });
2865
2866 // Context refresh
2867 let refresh_task = zeta.update(cx, |zeta, cx| {
2868 zeta.refresh_context(project.clone(), buffer1.clone(), position, cx)
2869 });
2870 let (_request, respond_tx) = req_rx.next().await.unwrap();
2871 respond_tx
2872 .send(open_ai::Response {
2873 id: Uuid::new_v4().to_string(),
2874 object: "response".into(),
2875 created: 0,
2876 model: "model".into(),
2877 choices: vec![open_ai::Choice {
2878 index: 0,
2879 message: open_ai::RequestMessage::Assistant {
2880 content: None,
2881 tool_calls: vec![open_ai::ToolCall {
2882 id: "search".into(),
2883 content: open_ai::ToolCallContent::Function {
2884 function: open_ai::FunctionContent {
2885 name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME
2886 .to_string(),
2887 arguments: serde_json::to_string(&SearchToolInput {
2888 queries: Box::new([SearchToolQuery {
2889 glob: "root/2.txt".to_string(),
2890 syntax_node: vec![],
2891 content: Some(".".into()),
2892 }]),
2893 })
2894 .unwrap(),
2895 },
2896 },
2897 }],
2898 },
2899 finish_reason: None,
2900 }],
2901 usage: Usage {
2902 prompt_tokens: 0,
2903 completion_tokens: 0,
2904 total_tokens: 0,
2905 },
2906 })
2907 .unwrap();
2908 refresh_task.await.unwrap();
2909
2910 zeta.update(cx, |zeta, cx| {
2911 zeta.discard_current_prediction(&project, cx);
2912 });
2913
2914 // Prediction for another file
2915 zeta.update(cx, |zeta, cx| {
2916 zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
2917 });
2918 let (_request, respond_tx) = req_rx.next().await.unwrap();
2919 respond_tx
2920 .send(model_response(indoc! {r#"
2921 --- a/root/2.txt
2922 +++ b/root/2.txt
2923 Hola!
2924 -Como
2925 +Como estas?
2926 Adios
2927 "#}))
2928 .unwrap();
2929 cx.run_until_parked();
2930
2931 zeta.read_with(cx, |zeta, cx| {
2932 let prediction = zeta
2933 .current_prediction_for_buffer(&buffer1, &project, cx)
2934 .unwrap();
2935 assert_matches!(
2936 prediction,
2937 BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
2938 );
2939 });
2940
2941 let buffer2 = project
2942 .update(cx, |project, cx| {
2943 let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
2944 project.open_buffer(path, cx)
2945 })
2946 .await
2947 .unwrap();
2948
2949 zeta.read_with(cx, |zeta, cx| {
2950 let prediction = zeta
2951 .current_prediction_for_buffer(&buffer2, &project, cx)
2952 .unwrap();
2953 assert_matches!(prediction, BufferEditPrediction::Local { .. });
2954 });
2955 }
2956
2957 #[gpui::test]
2958 async fn test_simple_request(cx: &mut TestAppContext) {
2959 let (zeta, mut req_rx) = init_test(cx);
2960 let fs = FakeFs::new(cx.executor());
2961 fs.insert_tree(
2962 "/root",
2963 json!({
2964 "foo.md": "Hello!\nHow\nBye\n"
2965 }),
2966 )
2967 .await;
2968 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2969
2970 let buffer = project
2971 .update(cx, |project, cx| {
2972 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2973 project.open_buffer(path, cx)
2974 })
2975 .await
2976 .unwrap();
2977 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2978 let position = snapshot.anchor_before(language::Point::new(1, 3));
2979
2980 let prediction_task = zeta.update(cx, |zeta, cx| {
2981 zeta.request_prediction(&project, &buffer, position, cx)
2982 });
2983
2984 let (_, respond_tx) = req_rx.next().await.unwrap();
2985
2986 // TODO Put back when we have a structured request again
2987 // assert_eq!(
2988 // request.excerpt_path.as_ref(),
2989 // Path::new(path!("root/foo.md"))
2990 // );
2991 // assert_eq!(
2992 // request.cursor_point,
2993 // Point {
2994 // line: Line(1),
2995 // column: 3
2996 // }
2997 // );
2998
2999 respond_tx
3000 .send(model_response(indoc! { r"
3001 --- a/root/foo.md
3002 +++ b/root/foo.md
3003 @@ ... @@
3004 Hello!
3005 -How
3006 +How are you?
3007 Bye
3008 "}))
3009 .unwrap();
3010
3011 let prediction = prediction_task.await.unwrap().unwrap();
3012
3013 assert_eq!(prediction.edits.len(), 1);
3014 assert_eq!(
3015 prediction.edits[0].0.to_point(&snapshot).start,
3016 language::Point::new(1, 3)
3017 );
3018 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
3019 }
3020
3021 #[gpui::test]
3022 async fn test_request_events(cx: &mut TestAppContext) {
3023 let (zeta, mut req_rx) = init_test(cx);
3024 let fs = FakeFs::new(cx.executor());
3025 fs.insert_tree(
3026 "/root",
3027 json!({
3028 "foo.md": "Hello!\n\nBye\n"
3029 }),
3030 )
3031 .await;
3032 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3033
3034 let buffer = project
3035 .update(cx, |project, cx| {
3036 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3037 project.open_buffer(path, cx)
3038 })
3039 .await
3040 .unwrap();
3041
3042 zeta.update(cx, |zeta, cx| {
3043 zeta.register_buffer(&buffer, &project, cx);
3044 });
3045
3046 buffer.update(cx, |buffer, cx| {
3047 buffer.edit(vec![(7..7, "How")], None, cx);
3048 });
3049
3050 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3051 let position = snapshot.anchor_before(language::Point::new(1, 3));
3052
3053 let prediction_task = zeta.update(cx, |zeta, cx| {
3054 zeta.request_prediction(&project, &buffer, position, cx)
3055 });
3056
3057 let (request, respond_tx) = req_rx.next().await.unwrap();
3058
3059 let prompt = prompt_from_request(&request);
3060 assert!(
3061 prompt.contains(indoc! {"
3062 --- a/root/foo.md
3063 +++ b/root/foo.md
3064 @@ -1,3 +1,3 @@
3065 Hello!
3066 -
3067 +How
3068 Bye
3069 "}),
3070 "{prompt}"
3071 );
3072
3073 respond_tx
3074 .send(model_response(indoc! {r#"
3075 --- a/root/foo.md
3076 +++ b/root/foo.md
3077 @@ ... @@
3078 Hello!
3079 -How
3080 +How are you?
3081 Bye
3082 "#}))
3083 .unwrap();
3084
3085 let prediction = prediction_task.await.unwrap().unwrap();
3086
3087 assert_eq!(prediction.edits.len(), 1);
3088 assert_eq!(
3089 prediction.edits[0].0.to_point(&snapshot).start,
3090 language::Point::new(1, 3)
3091 );
3092 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
3093 }
3094
3095 // Skipped until we start including diagnostics in prompt
3096 // #[gpui::test]
3097 // async fn test_request_diagnostics(cx: &mut TestAppContext) {
3098 // let (zeta, mut req_rx) = init_test(cx);
3099 // let fs = FakeFs::new(cx.executor());
3100 // fs.insert_tree(
3101 // "/root",
3102 // json!({
3103 // "foo.md": "Hello!\nBye"
3104 // }),
3105 // )
3106 // .await;
3107 // let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3108
3109 // let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
3110 // let diagnostic = lsp::Diagnostic {
3111 // range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
3112 // severity: Some(lsp::DiagnosticSeverity::ERROR),
3113 // message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
3114 // ..Default::default()
3115 // };
3116
3117 // project.update(cx, |project, cx| {
3118 // project.lsp_store().update(cx, |lsp_store, cx| {
3119 // // Create some diagnostics
3120 // lsp_store
3121 // .update_diagnostics(
3122 // LanguageServerId(0),
3123 // lsp::PublishDiagnosticsParams {
3124 // uri: path_to_buffer_uri.clone(),
3125 // diagnostics: vec![diagnostic],
3126 // version: None,
3127 // },
3128 // None,
3129 // language::DiagnosticSourceKind::Pushed,
3130 // &[],
3131 // cx,
3132 // )
3133 // .unwrap();
3134 // });
3135 // });
3136
3137 // let buffer = project
3138 // .update(cx, |project, cx| {
3139 // let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3140 // project.open_buffer(path, cx)
3141 // })
3142 // .await
3143 // .unwrap();
3144
3145 // let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3146 // let position = snapshot.anchor_before(language::Point::new(0, 0));
3147
3148 // let _prediction_task = zeta.update(cx, |zeta, cx| {
3149 // zeta.request_prediction(&project, &buffer, position, cx)
3150 // });
3151
3152 // let (request, _respond_tx) = req_rx.next().await.unwrap();
3153
3154 // assert_eq!(request.diagnostic_groups.len(), 1);
3155 // let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
3156 // .unwrap();
3157 // // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
3158 // assert_eq!(
3159 // value,
3160 // json!({
3161 // "entries": [{
3162 // "range": {
3163 // "start": 8,
3164 // "end": 10
3165 // },
3166 // "diagnostic": {
3167 // "source": null,
3168 // "code": null,
3169 // "code_description": null,
3170 // "severity": 1,
3171 // "message": "\"Hello\" deprecated. Use \"Hi\" instead",
3172 // "markdown": null,
3173 // "group_id": 0,
3174 // "is_primary": true,
3175 // "is_disk_based": false,
3176 // "is_unnecessary": false,
3177 // "source_kind": "Pushed",
3178 // "data": null,
3179 // "underline": true
3180 // }
3181 // }],
3182 // "primary_ix": 0
3183 // })
3184 // );
3185 // }
3186
3187 fn model_response(text: &str) -> open_ai::Response {
3188 open_ai::Response {
3189 id: Uuid::new_v4().to_string(),
3190 object: "response".into(),
3191 created: 0,
3192 model: "model".into(),
3193 choices: vec![open_ai::Choice {
3194 index: 0,
3195 message: open_ai::RequestMessage::Assistant {
3196 content: Some(open_ai::MessageContent::Plain(text.to_string())),
3197 tool_calls: vec![],
3198 },
3199 finish_reason: None,
3200 }],
3201 usage: Usage {
3202 prompt_tokens: 0,
3203 completion_tokens: 0,
3204 total_tokens: 0,
3205 },
3206 }
3207 }
3208
3209 fn prompt_from_request(request: &open_ai::Request) -> &str {
3210 assert_eq!(request.messages.len(), 1);
3211 let open_ai::RequestMessage::User {
3212 content: open_ai::MessageContent::Plain(content),
3213 ..
3214 } = &request.messages[0]
3215 else {
3216 panic!(
3217 "Request does not have single user message of type Plain. {:#?}",
3218 request
3219 );
3220 };
3221 content
3222 }
3223
3224 fn init_test(
3225 cx: &mut TestAppContext,
3226 ) -> (
3227 Entity<Zeta>,
3228 mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender<open_ai::Response>)>,
3229 ) {
3230 cx.update(move |cx| {
3231 let settings_store = SettingsStore::test(cx);
3232 cx.set_global(settings_store);
3233 zlog::init_test();
3234
3235 let (req_tx, req_rx) = mpsc::unbounded();
3236
3237 let http_client = FakeHttpClient::create({
3238 move |req| {
3239 let uri = req.uri().path().to_string();
3240 let mut body = req.into_body();
3241 let req_tx = req_tx.clone();
3242 async move {
3243 let resp = match uri.as_str() {
3244 "/client/llm_tokens" => serde_json::to_string(&json!({
3245 "token": "test"
3246 }))
3247 .unwrap(),
3248 "/predict_edits/raw" => {
3249 let mut buf = Vec::new();
3250 body.read_to_end(&mut buf).await.ok();
3251 let req = serde_json::from_slice(&buf).unwrap();
3252
3253 let (res_tx, res_rx) = oneshot::channel();
3254 req_tx.unbounded_send((req, res_tx)).unwrap();
3255 serde_json::to_string(&res_rx.await?).unwrap()
3256 }
3257 _ => {
3258 panic!("Unexpected path: {}", uri)
3259 }
3260 };
3261
3262 Ok(Response::builder().body(resp.into()).unwrap())
3263 }
3264 }
3265 });
3266
3267 let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
3268 client.cloud_client().set_credentials(1, "test".into());
3269
3270 language_model::init(client.clone(), cx);
3271
3272 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
3273 let zeta = Zeta::global(&client, &user_store, cx);
3274
3275 (zeta, req_rx)
3276 })
3277 }
3278}