1use anyhow::{Context as _, Result, anyhow, bail};
2use arrayvec::ArrayVec;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat, Signature};
5use cloud_llm_client::{
6 AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejection,
7 MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
8 RejectEditPredictionsBody, ZED_VERSION_HEADER_NAME,
9};
10use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
11use cloud_zeta2_prompt::{CURSOR_MARKER, DEFAULT_MAX_PROMPT_BYTES};
12use collections::{HashMap, HashSet};
13use command_palette_hooks::CommandPaletteFilter;
14use db::kvp::{Dismissable, KEY_VALUE_STORE};
15use edit_prediction_context::{
16 DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
17 EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionScoreOptions, Line,
18 SyntaxIndex, SyntaxIndexState,
19};
20use feature_flags::{FeatureFlag, FeatureFlagAppExt as _, PredictEditsRateCompletionsFeatureFlag};
21use futures::channel::{mpsc, oneshot};
22use futures::{AsyncReadExt as _, StreamExt as _};
23use gpui::{
24 App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
25 http_client::{self, AsyncBody, Method},
26 prelude::*,
27};
28use language::{
29 Anchor, Buffer, DiagnosticSet, File, LanguageServerId, Point, ToOffset as _, ToPoint,
30};
31use language::{BufferSnapshot, OffsetRangeExt};
32use language_model::{LlmApiToken, RefreshLlmTokenListener};
33use lsp::DiagnosticSeverity;
34use open_ai::FunctionDefinition;
35use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
36use release_channel::AppVersion;
37use semver::Version;
38use serde::de::DeserializeOwned;
39use settings::{EditPredictionProvider, Settings as _, SettingsStore, update_settings_file};
40use std::any::{Any as _, TypeId};
41use std::collections::{VecDeque, hash_map};
42use telemetry_events::EditPredictionRating;
43use workspace::Workspace;
44
45use std::fmt::Write as _;
46use std::ops::Range;
47use std::path::Path;
48use std::rc::Rc;
49use std::str::FromStr as _;
50use std::sync::{Arc, LazyLock};
51use std::time::{Duration, Instant};
52use std::{env, mem};
53use thiserror::Error;
54use util::rel_path::RelPathBuf;
55use util::{LogErrorFuture, RangeExt as _, ResultExt as _, TryFutureExt};
56use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
57
58pub mod assemble_excerpts;
59mod license_detection;
60mod onboarding_modal;
61mod prediction;
62mod provider;
63mod rate_prediction_modal;
64pub mod retrieval_search;
65mod sweep_ai;
66pub mod udiff;
67mod xml_edits;
68pub mod zeta1;
69
70#[cfg(test)]
71mod zeta_tests;
72
73use crate::assemble_excerpts::assemble_excerpts;
74use crate::license_detection::LicenseDetectionWatcher;
75use crate::onboarding_modal::ZedPredictModal;
76pub use crate::prediction::EditPrediction;
77pub use crate::prediction::EditPredictionId;
78pub use crate::prediction::EditPredictionInputs;
79use crate::rate_prediction_modal::{
80 NextEdit, PreviousEdit, RatePredictionsModal, ThumbsDownActivePrediction,
81 ThumbsUpActivePrediction,
82};
83use crate::zeta1::request_prediction_with_zeta1;
84pub use provider::ZetaEditPredictionProvider;
85
86actions!(
87 edit_prediction,
88 [
89 /// Resets the edit prediction onboarding state.
90 ResetOnboarding,
91 /// Opens the rate completions modal.
92 RateCompletions,
93 /// Clears the edit prediction history.
94 ClearHistory,
95 ]
96);
97
98/// Maximum number of events to track.
99const EVENT_COUNT_MAX: usize = 6;
100const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
101const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
102
103pub struct SweepFeatureFlag;
104
105impl FeatureFlag for SweepFeatureFlag {
106 const NAME: &str = "sweep-ai";
107}
108pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
109 max_bytes: 512,
110 min_bytes: 128,
111 target_before_cursor_over_total_bytes: 0.5,
112};
113
114pub const DEFAULT_CONTEXT_OPTIONS: ContextMode =
115 ContextMode::Agentic(DEFAULT_AGENTIC_CONTEXT_OPTIONS);
116
117pub const DEFAULT_AGENTIC_CONTEXT_OPTIONS: AgenticContextOptions = AgenticContextOptions {
118 excerpt: DEFAULT_EXCERPT_OPTIONS,
119};
120
121pub const DEFAULT_SYNTAX_CONTEXT_OPTIONS: EditPredictionContextOptions =
122 EditPredictionContextOptions {
123 use_imports: true,
124 max_retrieved_declarations: 0,
125 excerpt: DEFAULT_EXCERPT_OPTIONS,
126 score: EditPredictionScoreOptions {
127 omit_excerpt_overlaps: true,
128 },
129 };
130
131pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
132 context: DEFAULT_CONTEXT_OPTIONS,
133 max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
134 max_diagnostic_bytes: 2048,
135 prompt_format: PromptFormat::DEFAULT,
136 file_indexing_parallelism: 1,
137 buffer_change_grouping_interval: Duration::from_secs(1),
138};
139
140static USE_OLLAMA: LazyLock<bool> =
141 LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
142static CONTEXT_RETRIEVAL_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
143 env::var("ZED_ZETA2_CONTEXT_MODEL").unwrap_or(if *USE_OLLAMA {
144 "qwen3-coder:30b".to_string()
145 } else {
146 "yqvev8r3".to_string()
147 })
148});
149static EDIT_PREDICTIONS_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
150 match env::var("ZED_ZETA2_MODEL").as_deref() {
151 Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten
152 Ok(model) => model,
153 Err(_) if *USE_OLLAMA => "qwen3-coder:30b",
154 Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten
155 }
156 .to_string()
157});
158static PREDICT_EDITS_URL: LazyLock<Option<String>> = LazyLock::new(|| {
159 env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| {
160 if *USE_OLLAMA {
161 Some("http://localhost:11434/v1/chat/completions".into())
162 } else {
163 None
164 }
165 })
166});
167
168pub struct Zeta2FeatureFlag;
169
170impl FeatureFlag for Zeta2FeatureFlag {
171 const NAME: &'static str = "zeta2";
172
173 fn enabled_for_staff() -> bool {
174 false
175 }
176}
177
178#[derive(Clone)]
179struct ZetaGlobal(Entity<Zeta>);
180
181impl Global for ZetaGlobal {}
182
183pub struct Zeta {
184 client: Arc<Client>,
185 user_store: Entity<UserStore>,
186 llm_token: LlmApiToken,
187 _llm_token_subscription: Subscription,
188 projects: HashMap<EntityId, ZetaProject>,
189 options: ZetaOptions,
190 update_required: bool,
191 debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
192 #[cfg(feature = "eval-support")]
193 eval_cache: Option<Arc<dyn EvalCache>>,
194 edit_prediction_model: ZetaEditPredictionModel,
195 sweep_api_token: Option<String>,
196 sweep_ai_debug_info: Arc<str>,
197 data_collection_choice: DataCollectionChoice,
198 rejected_predictions: Vec<EditPredictionRejection>,
199 reject_predictions_tx: mpsc::UnboundedSender<()>,
200 reject_predictions_debounce_task: Option<Task<()>>,
201 shown_predictions: VecDeque<EditPrediction>,
202 rated_predictions: HashSet<EditPredictionId>,
203}
204
205#[derive(Default, PartialEq, Eq)]
206pub enum ZetaEditPredictionModel {
207 #[default]
208 Zeta1,
209 Zeta2,
210 Sweep,
211}
212
213#[derive(Debug, Clone, PartialEq)]
214pub struct ZetaOptions {
215 pub context: ContextMode,
216 pub max_prompt_bytes: usize,
217 pub max_diagnostic_bytes: usize,
218 pub prompt_format: predict_edits_v3::PromptFormat,
219 pub file_indexing_parallelism: usize,
220 pub buffer_change_grouping_interval: Duration,
221}
222
223#[derive(Debug, Clone, PartialEq)]
224pub enum ContextMode {
225 Agentic(AgenticContextOptions),
226 Syntax(EditPredictionContextOptions),
227}
228
229#[derive(Debug, Clone, PartialEq)]
230pub struct AgenticContextOptions {
231 pub excerpt: EditPredictionExcerptOptions,
232}
233
234impl ContextMode {
235 pub fn excerpt(&self) -> &EditPredictionExcerptOptions {
236 match self {
237 ContextMode::Agentic(options) => &options.excerpt,
238 ContextMode::Syntax(options) => &options.excerpt,
239 }
240 }
241}
242
243#[derive(Debug)]
244pub enum ZetaDebugInfo {
245 ContextRetrievalStarted(ZetaContextRetrievalStartedDebugInfo),
246 SearchQueriesGenerated(ZetaSearchQueryDebugInfo),
247 SearchQueriesExecuted(ZetaContextRetrievalDebugInfo),
248 ContextRetrievalFinished(ZetaContextRetrievalDebugInfo),
249 EditPredictionRequested(ZetaEditPredictionDebugInfo),
250}
251
252#[derive(Debug)]
253pub struct ZetaContextRetrievalStartedDebugInfo {
254 pub project: Entity<Project>,
255 pub timestamp: Instant,
256 pub search_prompt: String,
257}
258
259#[derive(Debug)]
260pub struct ZetaContextRetrievalDebugInfo {
261 pub project: Entity<Project>,
262 pub timestamp: Instant,
263}
264
265#[derive(Debug)]
266pub struct ZetaEditPredictionDebugInfo {
267 pub inputs: EditPredictionInputs,
268 pub retrieval_time: Duration,
269 pub buffer: WeakEntity<Buffer>,
270 pub position: language::Anchor,
271 pub local_prompt: Result<String, String>,
272 pub response_rx: oneshot::Receiver<(Result<open_ai::Response, String>, Duration)>,
273}
274
275#[derive(Debug)]
276pub struct ZetaSearchQueryDebugInfo {
277 pub project: Entity<Project>,
278 pub timestamp: Instant,
279 pub search_queries: Vec<SearchToolQuery>,
280}
281
282pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
283
284struct ZetaProject {
285 syntax_index: Option<Entity<SyntaxIndex>>,
286 events: VecDeque<Arc<cloud_llm_client::predict_edits_v3::Event>>,
287 last_event: Option<LastEvent>,
288 recent_paths: VecDeque<ProjectPath>,
289 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
290 current_prediction: Option<CurrentEditPrediction>,
291 next_pending_prediction_id: usize,
292 pending_predictions: ArrayVec<PendingPrediction, 2>,
293 last_prediction_refresh: Option<(EntityId, Instant)>,
294 context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
295 refresh_context_task: Option<LogErrorFuture<Task<Result<()>>>>,
296 refresh_context_debounce_task: Option<Task<Option<()>>>,
297 refresh_context_timestamp: Option<Instant>,
298 license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
299 _subscription: gpui::Subscription,
300}
301
302impl ZetaProject {
303 pub fn events(&self, cx: &App) -> Vec<Arc<cloud_llm_client::predict_edits_v3::Event>> {
304 self.events
305 .iter()
306 .cloned()
307 .chain(
308 self.last_event
309 .as_ref()
310 .and_then(|event| event.finalize(&self.license_detection_watchers, cx)),
311 )
312 .collect()
313 }
314}
315
316#[derive(Debug, Clone)]
317struct CurrentEditPrediction {
318 pub requested_by: PredictionRequestedBy,
319 pub prediction: EditPrediction,
320 pub was_shown: bool,
321}
322
323impl CurrentEditPrediction {
324 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
325 let Some(new_edits) = self
326 .prediction
327 .interpolate(&self.prediction.buffer.read(cx))
328 else {
329 return false;
330 };
331
332 if self.prediction.buffer != old_prediction.prediction.buffer {
333 return true;
334 }
335
336 let Some(old_edits) = old_prediction
337 .prediction
338 .interpolate(&old_prediction.prediction.buffer.read(cx))
339 else {
340 return true;
341 };
342
343 let requested_by_buffer_id = self.requested_by.buffer_id();
344
345 // This reduces the occurrence of UI thrash from replacing edits
346 //
347 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
348 if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
349 && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
350 && old_edits.len() == 1
351 && new_edits.len() == 1
352 {
353 let (old_range, old_text) = &old_edits[0];
354 let (new_range, new_text) = &new_edits[0];
355 new_range == old_range && new_text.starts_with(old_text.as_ref())
356 } else {
357 true
358 }
359 }
360}
361
362#[derive(Debug, Clone)]
363enum PredictionRequestedBy {
364 DiagnosticsUpdate,
365 Buffer(EntityId),
366}
367
368impl PredictionRequestedBy {
369 pub fn buffer_id(&self) -> Option<EntityId> {
370 match self {
371 PredictionRequestedBy::DiagnosticsUpdate => None,
372 PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
373 }
374 }
375}
376
377struct PendingPrediction {
378 id: usize,
379 task: Task<Option<EditPredictionId>>,
380}
381
382/// A prediction from the perspective of a buffer.
383#[derive(Debug)]
384enum BufferEditPrediction<'a> {
385 Local { prediction: &'a EditPrediction },
386 Jump { prediction: &'a EditPrediction },
387}
388
389struct RegisteredBuffer {
390 snapshot: BufferSnapshot,
391 _subscriptions: [gpui::Subscription; 2],
392}
393
394struct LastEvent {
395 old_snapshot: BufferSnapshot,
396 new_snapshot: BufferSnapshot,
397 end_edit_anchor: Option<Anchor>,
398}
399
400impl LastEvent {
401 pub fn finalize(
402 &self,
403 license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
404 cx: &App,
405 ) -> Option<Arc<predict_edits_v3::Event>> {
406 let path = buffer_path_with_id_fallback(&self.new_snapshot, cx);
407 let old_path = buffer_path_with_id_fallback(&self.old_snapshot, cx);
408
409 let file = self.new_snapshot.file();
410 let old_file = self.old_snapshot.file();
411
412 let in_open_source_repo = [file, old_file].iter().all(|file| {
413 file.is_some_and(|file| {
414 license_detection_watchers
415 .get(&file.worktree_id(cx))
416 .is_some_and(|watcher| watcher.is_project_open_source())
417 })
418 });
419
420 let diff = language::unified_diff(&self.old_snapshot.text(), &self.new_snapshot.text());
421
422 if path == old_path && diff.is_empty() {
423 None
424 } else {
425 Some(Arc::new(predict_edits_v3::Event::BufferChange {
426 old_path,
427 path,
428 diff,
429 in_open_source_repo,
430 // TODO: Actually detect if this edit was predicted or not
431 predicted: false,
432 }))
433 }
434 }
435}
436
437fn buffer_path_with_id_fallback(snapshot: &BufferSnapshot, cx: &App) -> Arc<Path> {
438 if let Some(file) = snapshot.file() {
439 file.full_path(cx).into()
440 } else {
441 Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
442 }
443}
444
445impl Zeta {
446 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
447 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
448 }
449
450 pub fn global(
451 client: &Arc<Client>,
452 user_store: &Entity<UserStore>,
453 cx: &mut App,
454 ) -> Entity<Self> {
455 cx.try_global::<ZetaGlobal>()
456 .map(|global| global.0.clone())
457 .unwrap_or_else(|| {
458 let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
459 cx.set_global(ZetaGlobal(zeta.clone()));
460 zeta
461 })
462 }
463
464 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
465 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
466 let data_collection_choice = Self::load_data_collection_choice();
467
468 let (reject_tx, mut reject_rx) = mpsc::unbounded();
469 cx.spawn(async move |this, cx| {
470 while let Some(()) = reject_rx.next().await {
471 this.update(cx, |this, cx| this.reject_edit_predictions(cx))?
472 .await
473 .log_err();
474 }
475 anyhow::Ok(())
476 })
477 .detach();
478
479 Self {
480 projects: HashMap::default(),
481 client,
482 user_store,
483 options: DEFAULT_OPTIONS,
484 llm_token: LlmApiToken::default(),
485 _llm_token_subscription: cx.subscribe(
486 &refresh_llm_token_listener,
487 |this, _listener, _event, cx| {
488 let client = this.client.clone();
489 let llm_token = this.llm_token.clone();
490 cx.spawn(async move |_this, _cx| {
491 llm_token.refresh(&client).await?;
492 anyhow::Ok(())
493 })
494 .detach_and_log_err(cx);
495 },
496 ),
497 update_required: false,
498 debug_tx: None,
499 #[cfg(feature = "eval-support")]
500 eval_cache: None,
501 edit_prediction_model: ZetaEditPredictionModel::Zeta2,
502 sweep_api_token: std::env::var("SWEEP_AI_TOKEN")
503 .context("No SWEEP_AI_TOKEN environment variable set")
504 .log_err(),
505 data_collection_choice,
506 sweep_ai_debug_info: sweep_ai::debug_info(cx),
507 rejected_predictions: Vec::new(),
508 reject_predictions_debounce_task: None,
509 reject_predictions_tx: reject_tx,
510 rated_predictions: Default::default(),
511 shown_predictions: Default::default(),
512 }
513 }
514
515 pub fn set_edit_prediction_model(&mut self, model: ZetaEditPredictionModel) {
516 self.edit_prediction_model = model;
517 }
518
519 pub fn has_sweep_api_token(&self) -> bool {
520 self.sweep_api_token.is_some()
521 }
522
523 #[cfg(feature = "eval-support")]
524 pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
525 self.eval_cache = Some(cache);
526 }
527
528 pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<ZetaDebugInfo> {
529 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
530 self.debug_tx = Some(debug_watch_tx);
531 debug_watch_rx
532 }
533
534 pub fn options(&self) -> &ZetaOptions {
535 &self.options
536 }
537
538 pub fn set_options(&mut self, options: ZetaOptions) {
539 self.options = options;
540 }
541
542 pub fn clear_history(&mut self) {
543 for zeta_project in self.projects.values_mut() {
544 zeta_project.events.clear();
545 }
546 }
547
548 pub fn context_for_project(
549 &self,
550 project: &Entity<Project>,
551 ) -> impl Iterator<Item = (Entity<Buffer>, &[Range<Anchor>])> {
552 self.projects
553 .get(&project.entity_id())
554 .and_then(|project| {
555 Some(
556 project
557 .context
558 .as_ref()?
559 .iter()
560 .map(|(buffer, ranges)| (buffer.clone(), ranges.as_slice())),
561 )
562 })
563 .into_iter()
564 .flatten()
565 }
566
567 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
568 if self.edit_prediction_model == ZetaEditPredictionModel::Zeta2 {
569 self.user_store.read(cx).edit_prediction_usage()
570 } else {
571 None
572 }
573 }
574
575 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
576 self.get_or_init_zeta_project(project, cx);
577 }
578
579 pub fn register_buffer(
580 &mut self,
581 buffer: &Entity<Buffer>,
582 project: &Entity<Project>,
583 cx: &mut Context<Self>,
584 ) {
585 let zeta_project = self.get_or_init_zeta_project(project, cx);
586 Self::register_buffer_impl(zeta_project, buffer, project, cx);
587 }
588
589 fn get_or_init_zeta_project(
590 &mut self,
591 project: &Entity<Project>,
592 cx: &mut Context<Self>,
593 ) -> &mut ZetaProject {
594 self.projects
595 .entry(project.entity_id())
596 .or_insert_with(|| ZetaProject {
597 syntax_index: if let ContextMode::Syntax(_) = &self.options.context {
598 Some(cx.new(|cx| {
599 SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
600 }))
601 } else {
602 None
603 },
604 events: VecDeque::new(),
605 last_event: None,
606 recent_paths: VecDeque::new(),
607 registered_buffers: HashMap::default(),
608 current_prediction: None,
609 pending_predictions: ArrayVec::new(),
610 next_pending_prediction_id: 0,
611 last_prediction_refresh: None,
612 context: None,
613 refresh_context_task: None,
614 refresh_context_debounce_task: None,
615 refresh_context_timestamp: None,
616 license_detection_watchers: HashMap::default(),
617 _subscription: cx.subscribe(&project, Self::handle_project_event),
618 })
619 }
620
621 fn handle_project_event(
622 &mut self,
623 project: Entity<Project>,
624 event: &project::Event,
625 cx: &mut Context<Self>,
626 ) {
627 // TODO [zeta2] init with recent paths
628 match event {
629 project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
630 let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
631 return;
632 };
633 let path = project.read(cx).path_for_entry(*active_entry_id, cx);
634 if let Some(path) = path {
635 if let Some(ix) = zeta_project
636 .recent_paths
637 .iter()
638 .position(|probe| probe == &path)
639 {
640 zeta_project.recent_paths.remove(ix);
641 }
642 zeta_project.recent_paths.push_front(path);
643 }
644 }
645 project::Event::DiagnosticsUpdated { .. } => {
646 self.refresh_prediction_from_diagnostics(project, cx);
647 }
648 _ => (),
649 }
650 }
651
652 fn register_buffer_impl<'a>(
653 zeta_project: &'a mut ZetaProject,
654 buffer: &Entity<Buffer>,
655 project: &Entity<Project>,
656 cx: &mut Context<Self>,
657 ) -> &'a mut RegisteredBuffer {
658 let buffer_id = buffer.entity_id();
659
660 if let Some(file) = buffer.read(cx).file() {
661 let worktree_id = file.worktree_id(cx);
662 if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
663 zeta_project
664 .license_detection_watchers
665 .entry(worktree_id)
666 .or_insert_with(|| {
667 let project_entity_id = project.entity_id();
668 cx.observe_release(&worktree, move |this, _worktree, _cx| {
669 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
670 else {
671 return;
672 };
673 zeta_project.license_detection_watchers.remove(&worktree_id);
674 })
675 .detach();
676 Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
677 });
678 }
679 }
680
681 match zeta_project.registered_buffers.entry(buffer_id) {
682 hash_map::Entry::Occupied(entry) => entry.into_mut(),
683 hash_map::Entry::Vacant(entry) => {
684 let snapshot = buffer.read(cx).snapshot();
685 let project_entity_id = project.entity_id();
686 entry.insert(RegisteredBuffer {
687 snapshot,
688 _subscriptions: [
689 cx.subscribe(buffer, {
690 let project = project.downgrade();
691 move |this, buffer, event, cx| {
692 if let language::BufferEvent::Edited = event
693 && let Some(project) = project.upgrade()
694 {
695 this.report_changes_for_buffer(&buffer, &project, cx);
696 }
697 }
698 }),
699 cx.observe_release(buffer, move |this, _buffer, _cx| {
700 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
701 else {
702 return;
703 };
704 zeta_project.registered_buffers.remove(&buffer_id);
705 }),
706 ],
707 })
708 }
709 }
710 }
711
712 fn report_changes_for_buffer(
713 &mut self,
714 buffer: &Entity<Buffer>,
715 project: &Entity<Project>,
716 cx: &mut Context<Self>,
717 ) {
718 let project_state = self.get_or_init_zeta_project(project, cx);
719 let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
720
721 let new_snapshot = buffer.read(cx).snapshot();
722 if new_snapshot.version == registered_buffer.snapshot.version {
723 return;
724 }
725
726 let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
727 let end_edit_anchor = new_snapshot
728 .anchored_edits_since::<Point>(&old_snapshot.version)
729 .last()
730 .map(|(_, range)| range.end);
731 let events = &mut project_state.events;
732
733 if let Some(LastEvent {
734 new_snapshot: last_new_snapshot,
735 end_edit_anchor: last_end_edit_anchor,
736 ..
737 }) = project_state.last_event.as_mut()
738 {
739 let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
740 == last_new_snapshot.remote_id()
741 && old_snapshot.version == last_new_snapshot.version;
742
743 let should_coalesce = is_next_snapshot_of_same_buffer
744 && end_edit_anchor
745 .as_ref()
746 .zip(last_end_edit_anchor.as_ref())
747 .is_some_and(|(a, b)| {
748 let a = a.to_point(&new_snapshot);
749 let b = b.to_point(&new_snapshot);
750 a.row.abs_diff(b.row) <= CHANGE_GROUPING_LINE_SPAN
751 });
752
753 if should_coalesce {
754 *last_end_edit_anchor = end_edit_anchor;
755 *last_new_snapshot = new_snapshot;
756 return;
757 }
758 }
759
760 if events.len() + 1 >= EVENT_COUNT_MAX {
761 events.pop_front();
762 }
763
764 if let Some(event) = project_state.last_event.take() {
765 events.extend(event.finalize(&project_state.license_detection_watchers, cx));
766 }
767
768 project_state.last_event = Some(LastEvent {
769 old_snapshot,
770 new_snapshot,
771 end_edit_anchor,
772 });
773 }
774
775 fn current_prediction_for_buffer(
776 &self,
777 buffer: &Entity<Buffer>,
778 project: &Entity<Project>,
779 cx: &App,
780 ) -> Option<BufferEditPrediction<'_>> {
781 let project_state = self.projects.get(&project.entity_id())?;
782
783 let CurrentEditPrediction {
784 requested_by,
785 prediction,
786 ..
787 } = project_state.current_prediction.as_ref()?;
788
789 if prediction.targets_buffer(buffer.read(cx)) {
790 Some(BufferEditPrediction::Local { prediction })
791 } else {
792 let show_jump = match requested_by {
793 PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
794 requested_by_buffer_id == &buffer.entity_id()
795 }
796 PredictionRequestedBy::DiagnosticsUpdate => true,
797 };
798
799 if show_jump {
800 Some(BufferEditPrediction::Jump { prediction })
801 } else {
802 None
803 }
804 }
805 }
806
807 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
808 match self.edit_prediction_model {
809 ZetaEditPredictionModel::Zeta1 | ZetaEditPredictionModel::Zeta2 => {}
810 ZetaEditPredictionModel::Sweep => return,
811 }
812
813 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
814 return;
815 };
816
817 let Some(prediction) = project_state.current_prediction.take() else {
818 return;
819 };
820 let request_id = prediction.prediction.id.to_string();
821 for pending_prediction in mem::take(&mut project_state.pending_predictions) {
822 self.cancel_pending_prediction(pending_prediction, cx);
823 }
824
825 let client = self.client.clone();
826 let llm_token = self.llm_token.clone();
827 let app_version = AppVersion::global(cx);
828 cx.spawn(async move |this, cx| {
829 let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") {
830 http_client::Url::parse(&predict_edits_url)?
831 } else {
832 client
833 .http_client()
834 .build_zed_llm_url("/predict_edits/accept", &[])?
835 };
836
837 let response = cx
838 .background_spawn(Self::send_api_request::<()>(
839 move |builder| {
840 let req = builder.uri(url.as_ref()).body(
841 serde_json::to_string(&AcceptEditPredictionBody {
842 request_id: request_id.clone(),
843 })?
844 .into(),
845 );
846 Ok(req?)
847 },
848 client,
849 llm_token,
850 app_version,
851 ))
852 .await;
853
854 Self::handle_api_response(&this, response, cx)?;
855 anyhow::Ok(())
856 })
857 .detach_and_log_err(cx);
858 }
859
860 fn reject_edit_predictions(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
861 match self.edit_prediction_model {
862 ZetaEditPredictionModel::Zeta1 | ZetaEditPredictionModel::Zeta2 => {}
863 ZetaEditPredictionModel::Sweep => return Task::ready(anyhow::Ok(())),
864 }
865
866 let client = self.client.clone();
867 let llm_token = self.llm_token.clone();
868 let app_version = AppVersion::global(cx);
869 let last_rejection = self.rejected_predictions.last().cloned();
870 let Some(last_rejection) = last_rejection else {
871 return Task::ready(anyhow::Ok(()));
872 };
873
874 let body = serde_json::to_string(&RejectEditPredictionsBody {
875 rejections: self.rejected_predictions.clone(),
876 })
877 .ok();
878
879 cx.spawn(async move |this, cx| {
880 let url = client
881 .http_client()
882 .build_zed_llm_url("/predict_edits/reject", &[])?;
883
884 cx.background_spawn(Self::send_api_request::<()>(
885 move |builder| {
886 let req = builder.uri(url.as_ref()).body(body.clone().into());
887 Ok(req?)
888 },
889 client,
890 llm_token,
891 app_version,
892 ))
893 .await
894 .context("Failed to reject edit predictions")?;
895
896 this.update(cx, |this, _| {
897 if let Some(ix) = this
898 .rejected_predictions
899 .iter()
900 .position(|rejection| rejection.request_id == last_rejection.request_id)
901 {
902 this.rejected_predictions.drain(..ix + 1);
903 }
904 })
905 })
906 }
907
908 fn discard_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
909 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
910 project_state.pending_predictions.clear();
911 if let Some(prediction) = project_state.current_prediction.take() {
912 self.discard_prediction(prediction.prediction.id, prediction.was_shown, cx);
913 }
914 };
915 }
916
917 fn did_show_current_prediction(&mut self, project: &Entity<Project>, _cx: &mut Context<Self>) {
918 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
919 if let Some(current_prediction) = project_state.current_prediction.as_mut() {
920 if !current_prediction.was_shown {
921 current_prediction.was_shown = true;
922 self.shown_predictions
923 .push_front(current_prediction.prediction.clone());
924 if self.shown_predictions.len() > 50 {
925 let completion = self.shown_predictions.pop_back().unwrap();
926 self.rated_predictions.remove(&completion.id);
927 }
928 }
929 }
930 }
931 }
932
933 fn discard_prediction(
934 &mut self,
935 prediction_id: EditPredictionId,
936 was_shown: bool,
937 cx: &mut Context<Self>,
938 ) {
939 self.rejected_predictions.push(EditPredictionRejection {
940 request_id: prediction_id.to_string(),
941 was_shown,
942 });
943
944 let reached_request_limit =
945 self.rejected_predictions.len() >= MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST;
946 let reject_tx = self.reject_predictions_tx.clone();
947 self.reject_predictions_debounce_task = Some(cx.spawn(async move |_this, cx| {
948 const DISCARD_COMPLETIONS_DEBOUNCE: Duration = Duration::from_secs(15);
949 if !reached_request_limit {
950 cx.background_executor()
951 .timer(DISCARD_COMPLETIONS_DEBOUNCE)
952 .await;
953 }
954 reject_tx.unbounded_send(()).log_err();
955 }));
956 }
957
958 fn cancel_pending_prediction(
959 &self,
960 pending_prediction: PendingPrediction,
961 cx: &mut Context<Self>,
962 ) {
963 cx.spawn(async move |this, cx| {
964 let Some(prediction_id) = pending_prediction.task.await else {
965 return;
966 };
967
968 this.update(cx, |this, cx| {
969 this.discard_prediction(prediction_id, false, cx);
970 })
971 .ok();
972 })
973 .detach()
974 }
975
976 fn is_refreshing(&self, project: &Entity<Project>) -> bool {
977 self.projects
978 .get(&project.entity_id())
979 .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
980 }
981
982 pub fn refresh_prediction_from_buffer(
983 &mut self,
984 project: Entity<Project>,
985 buffer: Entity<Buffer>,
986 position: language::Anchor,
987 cx: &mut Context<Self>,
988 ) {
989 self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
990 let Some(request_task) = this
991 .update(cx, |this, cx| {
992 this.request_prediction(&project, &buffer, position, cx)
993 })
994 .log_err()
995 else {
996 return Task::ready(anyhow::Ok(None));
997 };
998
999 let project = project.clone();
1000 cx.spawn(async move |cx| {
1001 if let Some(prediction) = request_task.await? {
1002 let id = prediction.id.clone();
1003 this.update(cx, |this, cx| {
1004 let project_state = this
1005 .projects
1006 .get_mut(&project.entity_id())
1007 .context("Project not found")?;
1008
1009 let new_prediction = CurrentEditPrediction {
1010 requested_by: PredictionRequestedBy::Buffer(buffer.entity_id()),
1011 prediction: prediction,
1012 was_shown: false,
1013 };
1014
1015 if project_state
1016 .current_prediction
1017 .as_ref()
1018 .is_none_or(|old_prediction| {
1019 new_prediction.should_replace_prediction(&old_prediction, cx)
1020 })
1021 {
1022 project_state.current_prediction = Some(new_prediction);
1023 cx.notify();
1024 }
1025 anyhow::Ok(())
1026 })??;
1027 Ok(Some(id))
1028 } else {
1029 Ok(None)
1030 }
1031 })
1032 })
1033 }
1034
1035 pub fn refresh_prediction_from_diagnostics(
1036 &mut self,
1037 project: Entity<Project>,
1038 cx: &mut Context<Self>,
1039 ) {
1040 let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1041 return;
1042 };
1043
1044 // Prefer predictions from buffer
1045 if zeta_project.current_prediction.is_some() {
1046 return;
1047 };
1048
1049 self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
1050 let Some(open_buffer_task) = project
1051 .update(cx, |project, cx| {
1052 project
1053 .active_entry()
1054 .and_then(|entry| project.path_for_entry(entry, cx))
1055 .map(|path| project.open_buffer(path, cx))
1056 })
1057 .log_err()
1058 .flatten()
1059 else {
1060 return Task::ready(anyhow::Ok(None));
1061 };
1062
1063 cx.spawn(async move |cx| {
1064 let active_buffer = open_buffer_task.await?;
1065 let snapshot = active_buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
1066
1067 let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1068 active_buffer,
1069 &snapshot,
1070 Default::default(),
1071 Default::default(),
1072 &project,
1073 cx,
1074 )
1075 .await?
1076 else {
1077 return anyhow::Ok(None);
1078 };
1079
1080 let Some(prediction) = this
1081 .update(cx, |this, cx| {
1082 this.request_prediction(&project, &jump_buffer, jump_position, cx)
1083 })?
1084 .await?
1085 else {
1086 return anyhow::Ok(None);
1087 };
1088
1089 let id = prediction.id.clone();
1090 this.update(cx, |this, cx| {
1091 if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
1092 zeta_project.current_prediction.get_or_insert_with(|| {
1093 cx.notify();
1094 CurrentEditPrediction {
1095 requested_by: PredictionRequestedBy::DiagnosticsUpdate,
1096 prediction,
1097 was_shown: false,
1098 }
1099 });
1100 }
1101 })?;
1102
1103 anyhow::Ok(Some(id))
1104 })
1105 });
1106 }
1107
1108 #[cfg(not(test))]
1109 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
1110 #[cfg(test)]
1111 pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
1112
1113 fn queue_prediction_refresh(
1114 &mut self,
1115 project: Entity<Project>,
1116 throttle_entity: EntityId,
1117 cx: &mut Context<Self>,
1118 do_refresh: impl FnOnce(
1119 WeakEntity<Self>,
1120 &mut AsyncApp,
1121 ) -> Task<Result<Option<EditPredictionId>>>
1122 + 'static,
1123 ) {
1124 let zeta_project = self.get_or_init_zeta_project(&project, cx);
1125 let pending_prediction_id = zeta_project.next_pending_prediction_id;
1126 zeta_project.next_pending_prediction_id += 1;
1127 let last_request = zeta_project.last_prediction_refresh;
1128
1129 // TODO report cancelled requests like in zeta1
1130 let task = cx.spawn(async move |this, cx| {
1131 if let Some((last_entity, last_timestamp)) = last_request
1132 && throttle_entity == last_entity
1133 && let Some(timeout) =
1134 (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
1135 {
1136 cx.background_executor().timer(timeout).await;
1137 }
1138
1139 let edit_prediction_id = do_refresh(this.clone(), cx).await.log_err().flatten();
1140
1141 // When a prediction completes, remove it from the pending list, and cancel
1142 // any pending predictions that were enqueued before it.
1143 this.update(cx, |this, cx| {
1144 let zeta_project = this.get_or_init_zeta_project(&project, cx);
1145 let mut pending_predictions = mem::take(&mut zeta_project.pending_predictions);
1146 for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
1147 if pending_prediction.id == pending_prediction_id {
1148 pending_predictions.remove(ix);
1149 for pending_prediction in pending_predictions.drain(0..ix) {
1150 this.cancel_pending_prediction(pending_prediction, cx)
1151 }
1152 break;
1153 }
1154 }
1155 this.get_or_init_zeta_project(&project, cx)
1156 .pending_predictions = pending_predictions;
1157 cx.notify();
1158 })
1159 .ok();
1160
1161 edit_prediction_id
1162 });
1163
1164 if zeta_project.pending_predictions.len() <= 1 {
1165 zeta_project.pending_predictions.push(PendingPrediction {
1166 id: pending_prediction_id,
1167 task,
1168 });
1169 } else if zeta_project.pending_predictions.len() == 2 {
1170 let pending_prediction = zeta_project.pending_predictions.pop().unwrap();
1171 zeta_project.pending_predictions.push(PendingPrediction {
1172 id: pending_prediction_id,
1173 task,
1174 });
1175 self.cancel_pending_prediction(pending_prediction, cx);
1176 }
1177 }
1178
1179 pub fn request_prediction(
1180 &mut self,
1181 project: &Entity<Project>,
1182 active_buffer: &Entity<Buffer>,
1183 position: language::Anchor,
1184 cx: &mut Context<Self>,
1185 ) -> Task<Result<Option<EditPrediction>>> {
1186 match self.edit_prediction_model {
1187 ZetaEditPredictionModel::Zeta1 => {
1188 request_prediction_with_zeta1(self, project, active_buffer, position, cx)
1189 }
1190 ZetaEditPredictionModel::Zeta2 => {
1191 self.request_prediction_with_zeta2(project, active_buffer, position, cx)
1192 }
1193 ZetaEditPredictionModel::Sweep => {
1194 self.request_prediction_with_sweep(project, active_buffer, position, true, cx)
1195 }
1196 }
1197 }
1198
1199 fn request_prediction_with_sweep(
1200 &mut self,
1201 project: &Entity<Project>,
1202 active_buffer: &Entity<Buffer>,
1203 position: language::Anchor,
1204 allow_jump: bool,
1205 cx: &mut Context<Self>,
1206 ) -> Task<Result<Option<EditPrediction>>> {
1207 let snapshot = active_buffer.read(cx).snapshot();
1208 let debug_info = self.sweep_ai_debug_info.clone();
1209 let Some(api_token) = self.sweep_api_token.clone() else {
1210 return Task::ready(Ok(None));
1211 };
1212 let full_path: Arc<Path> = snapshot
1213 .file()
1214 .map(|file| file.full_path(cx))
1215 .unwrap_or_else(|| "untitled".into())
1216 .into();
1217
1218 let project_file = project::File::from_dyn(snapshot.file());
1219 let repo_name = project_file
1220 .map(|file| file.worktree.read(cx).root_name_str())
1221 .unwrap_or("untitled")
1222 .into();
1223 let offset = position.to_offset(&snapshot);
1224
1225 let project_state = self.get_or_init_zeta_project(project, cx);
1226 let events = project_state.events(cx);
1227 let has_events = !events.is_empty();
1228 let recent_buffers = project_state.recent_paths.iter().cloned();
1229 let http_client = cx.http_client();
1230
1231 let recent_buffer_snapshots = recent_buffers
1232 .filter_map(|project_path| {
1233 let buffer = project.read(cx).get_open_buffer(&project_path, cx)?;
1234 if active_buffer == &buffer {
1235 None
1236 } else {
1237 Some(buffer.read(cx).snapshot())
1238 }
1239 })
1240 .take(3)
1241 .collect::<Vec<_>>();
1242
1243 const DIAGNOSTIC_LINES_RANGE: u32 = 20;
1244
1245 let cursor_point = position.to_point(&snapshot);
1246 let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1247 let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1248 let diagnostic_search_range =
1249 Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1250 let buffer_snapshotted_at = Instant::now();
1251
1252 let result = cx.background_spawn({
1253 let snapshot = snapshot.clone();
1254 let diagnostic_search_range = diagnostic_search_range.clone();
1255 async move {
1256 let text = snapshot.text();
1257
1258 let mut recent_changes = String::new();
1259 for event in &events {
1260 sweep_ai::write_event(event.as_ref(), &mut recent_changes).unwrap();
1261 }
1262
1263 let mut file_chunks = recent_buffer_snapshots
1264 .into_iter()
1265 .map(|snapshot| {
1266 let end_point = Point::new(30, 0).min(snapshot.max_point());
1267 sweep_ai::FileChunk {
1268 content: snapshot.text_for_range(Point::zero()..end_point).collect(),
1269 file_path: snapshot
1270 .file()
1271 .map(|f| f.path().as_unix_str())
1272 .unwrap_or("untitled")
1273 .to_string(),
1274 start_line: 0,
1275 end_line: end_point.row as usize,
1276 timestamp: snapshot.file().and_then(|file| {
1277 Some(
1278 file.disk_state()
1279 .mtime()?
1280 .to_seconds_and_nanos_for_persistence()?
1281 .0,
1282 )
1283 }),
1284 }
1285 })
1286 .collect::<Vec<_>>();
1287
1288 let diagnostic_entries =
1289 snapshot.diagnostics_in_range(diagnostic_search_range, false);
1290 let mut diagnostic_content = String::new();
1291 let mut diagnostic_count = 0;
1292
1293 for entry in diagnostic_entries {
1294 let start_point: Point = entry.range.start;
1295
1296 let severity = match entry.diagnostic.severity {
1297 DiagnosticSeverity::ERROR => "error",
1298 DiagnosticSeverity::WARNING => "warning",
1299 DiagnosticSeverity::INFORMATION => "info",
1300 DiagnosticSeverity::HINT => "hint",
1301 _ => continue,
1302 };
1303
1304 diagnostic_count += 1;
1305
1306 writeln!(
1307 &mut diagnostic_content,
1308 "{} at line {}: {}",
1309 severity,
1310 start_point.row + 1,
1311 entry.diagnostic.message
1312 )?;
1313 }
1314
1315 if !diagnostic_content.is_empty() {
1316 file_chunks.push(sweep_ai::FileChunk {
1317 file_path: format!("Diagnostics for {}", full_path.display()),
1318 start_line: 0,
1319 end_line: diagnostic_count,
1320 content: diagnostic_content,
1321 timestamp: None,
1322 });
1323 }
1324
1325 let request_body = sweep_ai::AutocompleteRequest {
1326 debug_info,
1327 repo_name,
1328 file_path: full_path.clone(),
1329 file_contents: text.clone(),
1330 original_file_contents: text,
1331 cursor_position: offset,
1332 recent_changes: recent_changes.clone(),
1333 changes_above_cursor: true,
1334 multiple_suggestions: false,
1335 branch: None,
1336 file_chunks,
1337 retrieval_chunks: vec![],
1338 recent_user_actions: vec![],
1339 // TODO
1340 privacy_mode_enabled: false,
1341 };
1342
1343 let mut buf: Vec<u8> = Vec::new();
1344 let writer = brotli::CompressorWriter::new(&mut buf, 4096, 11, 22);
1345 serde_json::to_writer(writer, &request_body)?;
1346 let body: AsyncBody = buf.into();
1347
1348 let inputs = EditPredictionInputs {
1349 events,
1350 included_files: vec![cloud_llm_client::predict_edits_v3::IncludedFile {
1351 path: full_path.clone(),
1352 max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row),
1353 excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt {
1354 start_line: cloud_llm_client::predict_edits_v3::Line(0),
1355 text: request_body.file_contents.into(),
1356 }],
1357 }],
1358 cursor_point: cloud_llm_client::predict_edits_v3::Point {
1359 column: cursor_point.column,
1360 line: cloud_llm_client::predict_edits_v3::Line(cursor_point.row),
1361 },
1362 cursor_path: full_path.clone(),
1363 };
1364
1365 const SWEEP_API_URL: &str =
1366 "https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
1367
1368 let request = http_client::Request::builder()
1369 .uri(SWEEP_API_URL)
1370 .header("Content-Type", "application/json")
1371 .header("Authorization", format!("Bearer {}", api_token))
1372 .header("Connection", "keep-alive")
1373 .header("Content-Encoding", "br")
1374 .method(Method::POST)
1375 .body(body)?;
1376
1377 let mut response = http_client.send(request).await?;
1378
1379 let mut body: Vec<u8> = Vec::new();
1380 response.body_mut().read_to_end(&mut body).await?;
1381
1382 let response_received_at = Instant::now();
1383 if !response.status().is_success() {
1384 anyhow::bail!(
1385 "Request failed with status: {:?}\nBody: {}",
1386 response.status(),
1387 String::from_utf8_lossy(&body),
1388 );
1389 };
1390
1391 let response: sweep_ai::AutocompleteResponse = serde_json::from_slice(&body)?;
1392
1393 let old_text = snapshot
1394 .text_for_range(response.start_index..response.end_index)
1395 .collect::<String>();
1396 let edits = language::text_diff(&old_text, &response.completion)
1397 .into_iter()
1398 .map(|(range, text)| {
1399 (
1400 snapshot.anchor_after(response.start_index + range.start)
1401 ..snapshot.anchor_before(response.start_index + range.end),
1402 text,
1403 )
1404 })
1405 .collect::<Vec<_>>();
1406
1407 anyhow::Ok((
1408 response.autocomplete_id,
1409 edits,
1410 snapshot,
1411 response_received_at,
1412 inputs,
1413 ))
1414 }
1415 });
1416
1417 let buffer = active_buffer.clone();
1418 let project = project.clone();
1419 let active_buffer = active_buffer.clone();
1420
1421 cx.spawn(async move |this, cx| {
1422 let (id, edits, old_snapshot, response_received_at, inputs) = result.await?;
1423
1424 if edits.is_empty() {
1425 if has_events
1426 && allow_jump
1427 && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1428 active_buffer,
1429 &snapshot,
1430 diagnostic_search_range,
1431 cursor_point,
1432 &project,
1433 cx,
1434 )
1435 .await?
1436 {
1437 return this
1438 .update(cx, |this, cx| {
1439 this.request_prediction_with_sweep(
1440 &project,
1441 &jump_buffer,
1442 jump_position,
1443 false,
1444 cx,
1445 )
1446 })?
1447 .await;
1448 }
1449
1450 return anyhow::Ok(None);
1451 }
1452
1453 anyhow::Ok(
1454 EditPrediction::new(
1455 EditPredictionId(id.into()),
1456 &buffer,
1457 &old_snapshot,
1458 edits.into(),
1459 buffer_snapshotted_at,
1460 response_received_at,
1461 inputs,
1462 cx,
1463 )
1464 .await,
1465 )
1466 })
1467 }
1468
1469 async fn next_diagnostic_location(
1470 active_buffer: Entity<Buffer>,
1471 active_buffer_snapshot: &BufferSnapshot,
1472 active_buffer_diagnostic_search_range: Range<Point>,
1473 active_buffer_cursor_point: Point,
1474 project: &Entity<Project>,
1475 cx: &mut AsyncApp,
1476 ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1477 // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
1478 let mut jump_location = active_buffer_snapshot
1479 .diagnostic_groups(None)
1480 .into_iter()
1481 .filter_map(|(_, group)| {
1482 let range = &group.entries[group.primary_ix]
1483 .range
1484 .to_point(&active_buffer_snapshot);
1485 if range.overlaps(&active_buffer_diagnostic_search_range) {
1486 None
1487 } else {
1488 Some(range.start)
1489 }
1490 })
1491 .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
1492 .map(|position| {
1493 (
1494 active_buffer.clone(),
1495 active_buffer_snapshot.anchor_before(position),
1496 )
1497 });
1498
1499 if jump_location.is_none() {
1500 let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
1501 let file = buffer.file()?;
1502
1503 Some(ProjectPath {
1504 worktree_id: file.worktree_id(cx),
1505 path: file.path().clone(),
1506 })
1507 })?;
1508
1509 let buffer_task = project.update(cx, |project, cx| {
1510 let (path, _, _) = project
1511 .diagnostic_summaries(false, cx)
1512 .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
1513 .max_by_key(|(path, _, _)| {
1514 // find the buffer with errors that shares most parent directories
1515 path.path
1516 .components()
1517 .zip(
1518 active_buffer_path
1519 .as_ref()
1520 .map(|p| p.path.components())
1521 .unwrap_or_default(),
1522 )
1523 .take_while(|(a, b)| a == b)
1524 .count()
1525 })?;
1526
1527 Some(project.open_buffer(path, cx))
1528 })?;
1529
1530 if let Some(buffer_task) = buffer_task {
1531 let closest_buffer = buffer_task.await?;
1532
1533 jump_location = closest_buffer
1534 .read_with(cx, |buffer, _cx| {
1535 buffer
1536 .buffer_diagnostics(None)
1537 .into_iter()
1538 .min_by_key(|entry| entry.diagnostic.severity)
1539 .map(|entry| entry.range.start)
1540 })?
1541 .map(|position| (closest_buffer, position));
1542 }
1543 }
1544
1545 anyhow::Ok(jump_location)
1546 }
1547
1548 fn request_prediction_with_zeta2(
1549 &mut self,
1550 project: &Entity<Project>,
1551 active_buffer: &Entity<Buffer>,
1552 position: language::Anchor,
1553 cx: &mut Context<Self>,
1554 ) -> Task<Result<Option<EditPrediction>>> {
1555 let project_state = self.projects.get(&project.entity_id());
1556
1557 let index_state = project_state.and_then(|state| {
1558 state
1559 .syntax_index
1560 .as_ref()
1561 .map(|syntax_index| syntax_index.read_with(cx, |index, _cx| index.state().clone()))
1562 });
1563 let options = self.options.clone();
1564 let active_snapshot = active_buffer.read(cx).snapshot();
1565 let buffer_snapshotted_at = Instant::now();
1566 let Some(excerpt_path) = active_snapshot
1567 .file()
1568 .map(|path| -> Arc<Path> { path.full_path(cx).into() })
1569 else {
1570 return Task::ready(Err(anyhow!("No file path for excerpt")));
1571 };
1572 let client = self.client.clone();
1573 let llm_token = self.llm_token.clone();
1574 let app_version = AppVersion::global(cx);
1575 let worktree_snapshots = project
1576 .read(cx)
1577 .worktrees(cx)
1578 .map(|worktree| worktree.read(cx).snapshot())
1579 .collect::<Vec<_>>();
1580 let debug_tx = self.debug_tx.clone();
1581
1582 let events = project_state
1583 .map(|state| state.events(cx))
1584 .unwrap_or_default();
1585
1586 let diagnostics = active_snapshot.diagnostic_sets().clone();
1587
1588 let file = active_buffer.read(cx).file();
1589 let parent_abs_path = project::File::from_dyn(file).and_then(|f| {
1590 let mut path = f.worktree.read(cx).absolutize(&f.path);
1591 if path.pop() { Some(path) } else { None }
1592 });
1593
1594 // TODO data collection
1595 let can_collect_data = file
1596 .as_ref()
1597 .map_or(false, |file| self.can_collect_file(project, file, cx));
1598
1599 let empty_context_files = HashMap::default();
1600 let context_files = project_state
1601 .and_then(|project_state| project_state.context.as_ref())
1602 .unwrap_or(&empty_context_files);
1603
1604 #[cfg(feature = "eval-support")]
1605 let parsed_fut = futures::future::join_all(
1606 context_files
1607 .keys()
1608 .map(|buffer| buffer.read(cx).parsing_idle()),
1609 );
1610
1611 let mut included_files = context_files
1612 .iter()
1613 .filter_map(|(buffer_entity, ranges)| {
1614 let buffer = buffer_entity.read(cx);
1615 Some((
1616 buffer_entity.clone(),
1617 buffer.snapshot(),
1618 buffer.file()?.full_path(cx).into(),
1619 ranges.clone(),
1620 ))
1621 })
1622 .collect::<Vec<_>>();
1623
1624 included_files.sort_by(|(_, _, path_a, ranges_a), (_, _, path_b, ranges_b)| {
1625 (path_a, ranges_a.len()).cmp(&(path_b, ranges_b.len()))
1626 });
1627
1628 #[cfg(feature = "eval-support")]
1629 let eval_cache = self.eval_cache.clone();
1630
1631 let request_task = cx.background_spawn({
1632 let active_buffer = active_buffer.clone();
1633 async move {
1634 #[cfg(feature = "eval-support")]
1635 parsed_fut.await;
1636
1637 let index_state = if let Some(index_state) = index_state {
1638 Some(index_state.lock_owned().await)
1639 } else {
1640 None
1641 };
1642
1643 let cursor_offset = position.to_offset(&active_snapshot);
1644 let cursor_point = cursor_offset.to_point(&active_snapshot);
1645
1646 let before_retrieval = Instant::now();
1647
1648 let (diagnostic_groups, diagnostic_groups_truncated) =
1649 Self::gather_nearby_diagnostics(
1650 cursor_offset,
1651 &diagnostics,
1652 &active_snapshot,
1653 options.max_diagnostic_bytes,
1654 );
1655
1656 let cloud_request = match options.context {
1657 ContextMode::Agentic(context_options) => {
1658 let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
1659 cursor_point,
1660 &active_snapshot,
1661 &context_options.excerpt,
1662 index_state.as_deref(),
1663 ) else {
1664 return Ok((None, None));
1665 };
1666
1667 let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start)
1668 ..active_snapshot.anchor_before(excerpt.range.end);
1669
1670 if let Some(buffer_ix) =
1671 included_files.iter().position(|(_, snapshot, _, _)| {
1672 snapshot.remote_id() == active_snapshot.remote_id()
1673 })
1674 {
1675 let (_, buffer, _, ranges) = &mut included_files[buffer_ix];
1676 ranges.push(excerpt_anchor_range);
1677 retrieval_search::merge_anchor_ranges(ranges, buffer);
1678 let last_ix = included_files.len() - 1;
1679 included_files.swap(buffer_ix, last_ix);
1680 } else {
1681 included_files.push((
1682 active_buffer.clone(),
1683 active_snapshot.clone(),
1684 excerpt_path.clone(),
1685 vec![excerpt_anchor_range],
1686 ));
1687 }
1688
1689 let included_files = included_files
1690 .iter()
1691 .map(|(_, snapshot, path, ranges)| {
1692 let ranges = ranges
1693 .iter()
1694 .map(|range| {
1695 let point_range = range.to_point(&snapshot);
1696 Line(point_range.start.row)..Line(point_range.end.row)
1697 })
1698 .collect::<Vec<_>>();
1699 let excerpts = assemble_excerpts(&snapshot, ranges);
1700 predict_edits_v3::IncludedFile {
1701 path: path.clone(),
1702 max_row: Line(snapshot.max_point().row),
1703 excerpts,
1704 }
1705 })
1706 .collect::<Vec<_>>();
1707
1708 predict_edits_v3::PredictEditsRequest {
1709 excerpt_path,
1710 excerpt: String::new(),
1711 excerpt_line_range: Line(0)..Line(0),
1712 excerpt_range: 0..0,
1713 cursor_point: predict_edits_v3::Point {
1714 line: predict_edits_v3::Line(cursor_point.row),
1715 column: cursor_point.column,
1716 },
1717 included_files,
1718 referenced_declarations: vec![],
1719 events,
1720 can_collect_data,
1721 diagnostic_groups,
1722 diagnostic_groups_truncated,
1723 debug_info: debug_tx.is_some(),
1724 prompt_max_bytes: Some(options.max_prompt_bytes),
1725 prompt_format: options.prompt_format,
1726 // TODO [zeta2]
1727 signatures: vec![],
1728 excerpt_parent: None,
1729 git_info: None,
1730 }
1731 }
1732 ContextMode::Syntax(context_options) => {
1733 let Some(context) = EditPredictionContext::gather_context(
1734 cursor_point,
1735 &active_snapshot,
1736 parent_abs_path.as_deref(),
1737 &context_options,
1738 index_state.as_deref(),
1739 ) else {
1740 return Ok((None, None));
1741 };
1742
1743 make_syntax_context_cloud_request(
1744 excerpt_path,
1745 context,
1746 events,
1747 can_collect_data,
1748 diagnostic_groups,
1749 diagnostic_groups_truncated,
1750 None,
1751 debug_tx.is_some(),
1752 &worktree_snapshots,
1753 index_state.as_deref(),
1754 Some(options.max_prompt_bytes),
1755 options.prompt_format,
1756 )
1757 }
1758 };
1759
1760 let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request);
1761
1762 let inputs = EditPredictionInputs {
1763 included_files: cloud_request.included_files,
1764 events: cloud_request.events,
1765 cursor_point: cloud_request.cursor_point,
1766 cursor_path: cloud_request.excerpt_path,
1767 };
1768
1769 let retrieval_time = Instant::now() - before_retrieval;
1770
1771 let debug_response_tx = if let Some(debug_tx) = &debug_tx {
1772 let (response_tx, response_rx) = oneshot::channel();
1773
1774 debug_tx
1775 .unbounded_send(ZetaDebugInfo::EditPredictionRequested(
1776 ZetaEditPredictionDebugInfo {
1777 inputs: inputs.clone(),
1778 retrieval_time,
1779 buffer: active_buffer.downgrade(),
1780 local_prompt: match prompt_result.as_ref() {
1781 Ok((prompt, _)) => Ok(prompt.clone()),
1782 Err(err) => Err(err.to_string()),
1783 },
1784 position,
1785 response_rx,
1786 },
1787 ))
1788 .ok();
1789 Some(response_tx)
1790 } else {
1791 None
1792 };
1793
1794 if cfg!(debug_assertions) && env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
1795 if let Some(debug_response_tx) = debug_response_tx {
1796 debug_response_tx
1797 .send((Err("Request skipped".to_string()), Duration::ZERO))
1798 .ok();
1799 }
1800 anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
1801 }
1802
1803 let (prompt, _) = prompt_result?;
1804 let generation_params =
1805 cloud_zeta2_prompt::generation_params(cloud_request.prompt_format);
1806 let request = open_ai::Request {
1807 model: EDIT_PREDICTIONS_MODEL_ID.clone(),
1808 messages: vec![open_ai::RequestMessage::User {
1809 content: open_ai::MessageContent::Plain(prompt),
1810 }],
1811 stream: false,
1812 max_completion_tokens: None,
1813 stop: generation_params.stop.unwrap_or_default(),
1814 temperature: generation_params.temperature.unwrap_or(0.7),
1815 tool_choice: None,
1816 parallel_tool_calls: None,
1817 tools: vec![],
1818 prompt_cache_key: None,
1819 reasoning_effort: None,
1820 };
1821
1822 log::trace!("Sending edit prediction request");
1823
1824 let before_request = Instant::now();
1825 let response = Self::send_raw_llm_request(
1826 request,
1827 client,
1828 llm_token,
1829 app_version,
1830 #[cfg(feature = "eval-support")]
1831 eval_cache,
1832 #[cfg(feature = "eval-support")]
1833 EvalCacheEntryKind::Prediction,
1834 )
1835 .await;
1836 let received_response_at = Instant::now();
1837 let request_time = received_response_at - before_request;
1838
1839 log::trace!("Got edit prediction response");
1840
1841 if let Some(debug_response_tx) = debug_response_tx {
1842 debug_response_tx
1843 .send((
1844 response
1845 .as_ref()
1846 .map_err(|err| err.to_string())
1847 .map(|response| response.0.clone()),
1848 request_time,
1849 ))
1850 .ok();
1851 }
1852
1853 let (res, usage) = response?;
1854 let request_id = EditPredictionId(res.id.clone().into());
1855 let Some(mut output_text) = text_from_response(res) else {
1856 return Ok((None, usage));
1857 };
1858
1859 if output_text.contains(CURSOR_MARKER) {
1860 log::trace!("Stripping out {CURSOR_MARKER} from response");
1861 output_text = output_text.replace(CURSOR_MARKER, "");
1862 }
1863
1864 let get_buffer_from_context = |path: &Path| {
1865 included_files
1866 .iter()
1867 .find_map(|(_, buffer, probe_path, ranges)| {
1868 if probe_path.as_ref() == path {
1869 Some((buffer, ranges.as_slice()))
1870 } else {
1871 None
1872 }
1873 })
1874 };
1875
1876 let (edited_buffer_snapshot, edits) = match options.prompt_format {
1877 PromptFormat::NumLinesUniDiff => {
1878 // TODO: Implement parsing of multi-file diffs
1879 crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
1880 }
1881 PromptFormat::Minimal
1882 | PromptFormat::MinimalQwen
1883 | PromptFormat::SeedCoder1120 => {
1884 if output_text.contains("--- a/\n+++ b/\nNo edits") {
1885 let edits = vec![];
1886 (&active_snapshot, edits)
1887 } else {
1888 crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
1889 }
1890 }
1891 PromptFormat::OldTextNewText => {
1892 crate::xml_edits::parse_xml_edits(&output_text, get_buffer_from_context)
1893 .await?
1894 }
1895 _ => {
1896 bail!("unsupported prompt format {}", options.prompt_format)
1897 }
1898 };
1899
1900 let edited_buffer = included_files
1901 .iter()
1902 .find_map(|(buffer, snapshot, _, _)| {
1903 if snapshot.remote_id() == edited_buffer_snapshot.remote_id() {
1904 Some(buffer.clone())
1905 } else {
1906 None
1907 }
1908 })
1909 .context("Failed to find buffer in included_buffers")?;
1910
1911 anyhow::Ok((
1912 Some((
1913 request_id,
1914 inputs,
1915 edited_buffer,
1916 edited_buffer_snapshot.clone(),
1917 edits,
1918 received_response_at,
1919 )),
1920 usage,
1921 ))
1922 }
1923 });
1924
1925 cx.spawn({
1926 async move |this, cx| {
1927 let Some((
1928 id,
1929 inputs,
1930 edited_buffer,
1931 edited_buffer_snapshot,
1932 edits,
1933 received_response_at,
1934 )) = Self::handle_api_response(&this, request_task.await, cx)?
1935 else {
1936 return Ok(None);
1937 };
1938
1939 // TODO telemetry: duration, etc
1940 Ok(EditPrediction::new(
1941 id,
1942 &edited_buffer,
1943 &edited_buffer_snapshot,
1944 edits.into(),
1945 buffer_snapshotted_at,
1946 received_response_at,
1947 inputs,
1948 cx,
1949 )
1950 .await)
1951 }
1952 })
1953 }
1954
1955 async fn send_raw_llm_request(
1956 request: open_ai::Request,
1957 client: Arc<Client>,
1958 llm_token: LlmApiToken,
1959 app_version: Version,
1960 #[cfg(feature = "eval-support")] eval_cache: Option<Arc<dyn EvalCache>>,
1961 #[cfg(feature = "eval-support")] eval_cache_kind: EvalCacheEntryKind,
1962 ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
1963 let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
1964 http_client::Url::parse(&predict_edits_url)?
1965 } else {
1966 client
1967 .http_client()
1968 .build_zed_llm_url("/predict_edits/raw", &[])?
1969 };
1970
1971 #[cfg(feature = "eval-support")]
1972 let cache_key = if let Some(cache) = eval_cache {
1973 use collections::FxHasher;
1974 use std::hash::{Hash, Hasher};
1975
1976 let mut hasher = FxHasher::default();
1977 url.hash(&mut hasher);
1978 let request_str = serde_json::to_string_pretty(&request)?;
1979 request_str.hash(&mut hasher);
1980 let hash = hasher.finish();
1981
1982 let key = (eval_cache_kind, hash);
1983 if let Some(response_str) = cache.read(key) {
1984 return Ok((serde_json::from_str(&response_str)?, None));
1985 }
1986
1987 Some((cache, request_str, key))
1988 } else {
1989 None
1990 };
1991
1992 let (response, usage) = Self::send_api_request(
1993 |builder| {
1994 let req = builder
1995 .uri(url.as_ref())
1996 .body(serde_json::to_string(&request)?.into());
1997 Ok(req?)
1998 },
1999 client,
2000 llm_token,
2001 app_version,
2002 )
2003 .await?;
2004
2005 #[cfg(feature = "eval-support")]
2006 if let Some((cache, request, key)) = cache_key {
2007 cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
2008 }
2009
2010 Ok((response, usage))
2011 }
2012
2013 fn handle_api_response<T>(
2014 this: &WeakEntity<Self>,
2015 response: Result<(T, Option<EditPredictionUsage>)>,
2016 cx: &mut gpui::AsyncApp,
2017 ) -> Result<T> {
2018 match response {
2019 Ok((data, usage)) => {
2020 if let Some(usage) = usage {
2021 this.update(cx, |this, cx| {
2022 this.user_store.update(cx, |user_store, cx| {
2023 user_store.update_edit_prediction_usage(usage, cx);
2024 });
2025 })
2026 .ok();
2027 }
2028 Ok(data)
2029 }
2030 Err(err) => {
2031 if err.is::<ZedUpdateRequiredError>() {
2032 cx.update(|cx| {
2033 this.update(cx, |this, _cx| {
2034 this.update_required = true;
2035 })
2036 .ok();
2037
2038 let error_message: SharedString = err.to_string().into();
2039 show_app_notification(
2040 NotificationId::unique::<ZedUpdateRequiredError>(),
2041 cx,
2042 move |cx| {
2043 cx.new(|cx| {
2044 ErrorMessagePrompt::new(error_message.clone(), cx)
2045 .with_link_button("Update Zed", "https://zed.dev/releases")
2046 })
2047 },
2048 );
2049 })
2050 .ok();
2051 }
2052 Err(err)
2053 }
2054 }
2055 }
2056
2057 async fn send_api_request<Res>(
2058 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
2059 client: Arc<Client>,
2060 llm_token: LlmApiToken,
2061 app_version: Version,
2062 ) -> Result<(Res, Option<EditPredictionUsage>)>
2063 where
2064 Res: DeserializeOwned,
2065 {
2066 let http_client = client.http_client();
2067 let mut token = llm_token.acquire(&client).await?;
2068 let mut did_retry = false;
2069
2070 loop {
2071 let request_builder = http_client::Request::builder().method(Method::POST);
2072
2073 let request = build(
2074 request_builder
2075 .header("Content-Type", "application/json")
2076 .header("Authorization", format!("Bearer {}", token))
2077 .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
2078 )?;
2079
2080 let mut response = http_client.send(request).await?;
2081
2082 if let Some(minimum_required_version) = response
2083 .headers()
2084 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
2085 .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
2086 {
2087 anyhow::ensure!(
2088 app_version >= minimum_required_version,
2089 ZedUpdateRequiredError {
2090 minimum_version: minimum_required_version
2091 }
2092 );
2093 }
2094
2095 if response.status().is_success() {
2096 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
2097
2098 let mut body = Vec::new();
2099 response.body_mut().read_to_end(&mut body).await?;
2100 return Ok((serde_json::from_slice(&body)?, usage));
2101 } else if !did_retry
2102 && response
2103 .headers()
2104 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
2105 .is_some()
2106 {
2107 did_retry = true;
2108 token = llm_token.refresh(&client).await?;
2109 } else {
2110 let mut body = String::new();
2111 response.body_mut().read_to_string(&mut body).await?;
2112 anyhow::bail!(
2113 "Request failed with status: {:?}\nBody: {}",
2114 response.status(),
2115 body
2116 );
2117 }
2118 }
2119 }
2120
2121 pub const CONTEXT_RETRIEVAL_IDLE_DURATION: Duration = Duration::from_secs(10);
2122 pub const CONTEXT_RETRIEVAL_DEBOUNCE_DURATION: Duration = Duration::from_secs(3);
2123
2124 // Refresh the related excerpts when the user just beguns editing after
2125 // an idle period, and after they pause editing.
2126 fn refresh_context_if_needed(
2127 &mut self,
2128 project: &Entity<Project>,
2129 buffer: &Entity<language::Buffer>,
2130 cursor_position: language::Anchor,
2131 cx: &mut Context<Self>,
2132 ) {
2133 if !matches!(&self.options().context, ContextMode::Agentic { .. }) {
2134 return;
2135 }
2136
2137 let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
2138 return;
2139 };
2140
2141 let now = Instant::now();
2142 let was_idle = zeta_project
2143 .refresh_context_timestamp
2144 .map_or(true, |timestamp| {
2145 now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION
2146 });
2147 zeta_project.refresh_context_timestamp = Some(now);
2148 zeta_project.refresh_context_debounce_task = Some(cx.spawn({
2149 let buffer = buffer.clone();
2150 let project = project.clone();
2151 async move |this, cx| {
2152 if was_idle {
2153 log::debug!("refetching edit prediction context after idle");
2154 } else {
2155 cx.background_executor()
2156 .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION)
2157 .await;
2158 log::debug!("refetching edit prediction context after pause");
2159 }
2160 this.update(cx, |this, cx| {
2161 let task = this.refresh_context(project.clone(), buffer, cursor_position, cx);
2162
2163 if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
2164 zeta_project.refresh_context_task = Some(task.log_err());
2165 };
2166 })
2167 .ok()
2168 }
2169 }));
2170 }
2171
2172 // Refresh the related excerpts asynchronously. Ensure the task runs to completion,
2173 // and avoid spawning more than one concurrent task.
2174 pub fn refresh_context(
2175 &mut self,
2176 project: Entity<Project>,
2177 buffer: Entity<language::Buffer>,
2178 cursor_position: language::Anchor,
2179 cx: &mut Context<Self>,
2180 ) -> Task<Result<()>> {
2181 let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
2182 return Task::ready(anyhow::Ok(()));
2183 };
2184
2185 let ContextMode::Agentic(options) = &self.options().context else {
2186 return Task::ready(anyhow::Ok(()));
2187 };
2188
2189 let snapshot = buffer.read(cx).snapshot();
2190 let cursor_point = cursor_position.to_point(&snapshot);
2191 let Some(cursor_excerpt) = EditPredictionExcerpt::select_from_buffer(
2192 cursor_point,
2193 &snapshot,
2194 &options.excerpt,
2195 None,
2196 ) else {
2197 return Task::ready(Ok(()));
2198 };
2199
2200 let app_version = AppVersion::global(cx);
2201 let client = self.client.clone();
2202 let llm_token = self.llm_token.clone();
2203 let debug_tx = self.debug_tx.clone();
2204 let current_file_path: Arc<Path> = snapshot
2205 .file()
2206 .map(|f| f.full_path(cx).into())
2207 .unwrap_or_else(|| Path::new("untitled").into());
2208
2209 let prompt = match cloud_zeta2_prompt::retrieval_prompt::build_prompt(
2210 predict_edits_v3::PlanContextRetrievalRequest {
2211 excerpt: cursor_excerpt.text(&snapshot).body,
2212 excerpt_path: current_file_path,
2213 excerpt_line_range: cursor_excerpt.line_range,
2214 cursor_file_max_row: Line(snapshot.max_point().row),
2215 events: zeta_project.events(cx),
2216 },
2217 ) {
2218 Ok(prompt) => prompt,
2219 Err(err) => {
2220 return Task::ready(Err(err));
2221 }
2222 };
2223
2224 if let Some(debug_tx) = &debug_tx {
2225 debug_tx
2226 .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted(
2227 ZetaContextRetrievalStartedDebugInfo {
2228 project: project.clone(),
2229 timestamp: Instant::now(),
2230 search_prompt: prompt.clone(),
2231 },
2232 ))
2233 .ok();
2234 }
2235
2236 pub static TOOL_SCHEMA: LazyLock<(serde_json::Value, String)> = LazyLock::new(|| {
2237 let schema = language_model::tool_schema::root_schema_for::<SearchToolInput>(
2238 language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset,
2239 );
2240
2241 let description = schema
2242 .get("description")
2243 .and_then(|description| description.as_str())
2244 .unwrap()
2245 .to_string();
2246
2247 (schema.into(), description)
2248 });
2249
2250 let (tool_schema, tool_description) = TOOL_SCHEMA.clone();
2251
2252 let request = open_ai::Request {
2253 model: CONTEXT_RETRIEVAL_MODEL_ID.clone(),
2254 messages: vec![open_ai::RequestMessage::User {
2255 content: open_ai::MessageContent::Plain(prompt),
2256 }],
2257 stream: false,
2258 max_completion_tokens: None,
2259 stop: Default::default(),
2260 temperature: 0.7,
2261 tool_choice: None,
2262 parallel_tool_calls: None,
2263 tools: vec![open_ai::ToolDefinition::Function {
2264 function: FunctionDefinition {
2265 name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME.to_string(),
2266 description: Some(tool_description),
2267 parameters: Some(tool_schema),
2268 },
2269 }],
2270 prompt_cache_key: None,
2271 reasoning_effort: None,
2272 };
2273
2274 #[cfg(feature = "eval-support")]
2275 let eval_cache = self.eval_cache.clone();
2276
2277 cx.spawn(async move |this, cx| {
2278 log::trace!("Sending search planning request");
2279 let response = Self::send_raw_llm_request(
2280 request,
2281 client,
2282 llm_token,
2283 app_version,
2284 #[cfg(feature = "eval-support")]
2285 eval_cache.clone(),
2286 #[cfg(feature = "eval-support")]
2287 EvalCacheEntryKind::Context,
2288 )
2289 .await;
2290 let mut response = Self::handle_api_response(&this, response, cx)?;
2291 log::trace!("Got search planning response");
2292
2293 let choice = response
2294 .choices
2295 .pop()
2296 .context("No choices in retrieval response")?;
2297 let open_ai::RequestMessage::Assistant {
2298 content: _,
2299 tool_calls,
2300 } = choice.message
2301 else {
2302 anyhow::bail!("Retrieval response didn't include an assistant message");
2303 };
2304
2305 let mut queries: Vec<SearchToolQuery> = Vec::new();
2306 for tool_call in tool_calls {
2307 let open_ai::ToolCallContent::Function { function } = tool_call.content;
2308 if function.name != cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME {
2309 log::warn!(
2310 "Context retrieval response tried to call an unknown tool: {}",
2311 function.name
2312 );
2313
2314 continue;
2315 }
2316
2317 let input: SearchToolInput = serde_json::from_str(&function.arguments)
2318 .with_context(|| format!("invalid search json {}", &function.arguments))?;
2319 queries.extend(input.queries);
2320 }
2321
2322 if let Some(debug_tx) = &debug_tx {
2323 debug_tx
2324 .unbounded_send(ZetaDebugInfo::SearchQueriesGenerated(
2325 ZetaSearchQueryDebugInfo {
2326 project: project.clone(),
2327 timestamp: Instant::now(),
2328 search_queries: queries.clone(),
2329 },
2330 ))
2331 .ok();
2332 }
2333
2334 log::trace!("Running retrieval search: {queries:#?}");
2335
2336 let related_excerpts_result = retrieval_search::run_retrieval_searches(
2337 queries,
2338 project.clone(),
2339 #[cfg(feature = "eval-support")]
2340 eval_cache,
2341 cx,
2342 )
2343 .await;
2344
2345 log::trace!("Search queries executed");
2346
2347 if let Some(debug_tx) = &debug_tx {
2348 debug_tx
2349 .unbounded_send(ZetaDebugInfo::SearchQueriesExecuted(
2350 ZetaContextRetrievalDebugInfo {
2351 project: project.clone(),
2352 timestamp: Instant::now(),
2353 },
2354 ))
2355 .ok();
2356 }
2357
2358 this.update(cx, |this, _cx| {
2359 let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
2360 return Ok(());
2361 };
2362 zeta_project.refresh_context_task.take();
2363 if let Some(debug_tx) = &this.debug_tx {
2364 debug_tx
2365 .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished(
2366 ZetaContextRetrievalDebugInfo {
2367 project,
2368 timestamp: Instant::now(),
2369 },
2370 ))
2371 .ok();
2372 }
2373 match related_excerpts_result {
2374 Ok(excerpts) => {
2375 zeta_project.context = Some(excerpts);
2376 Ok(())
2377 }
2378 Err(error) => Err(error),
2379 }
2380 })?
2381 })
2382 }
2383
2384 pub fn set_context(
2385 &mut self,
2386 project: Entity<Project>,
2387 context: HashMap<Entity<Buffer>, Vec<Range<Anchor>>>,
2388 ) {
2389 if let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) {
2390 zeta_project.context = Some(context);
2391 }
2392 }
2393
2394 fn gather_nearby_diagnostics(
2395 cursor_offset: usize,
2396 diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
2397 snapshot: &BufferSnapshot,
2398 max_diagnostics_bytes: usize,
2399 ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
2400 // TODO: Could make this more efficient
2401 let mut diagnostic_groups = Vec::new();
2402 for (language_server_id, diagnostics) in diagnostic_sets {
2403 let mut groups = Vec::new();
2404 diagnostics.groups(*language_server_id, &mut groups, &snapshot);
2405 diagnostic_groups.extend(
2406 groups
2407 .into_iter()
2408 .map(|(_, group)| group.resolve::<usize>(&snapshot)),
2409 );
2410 }
2411
2412 // sort by proximity to cursor
2413 diagnostic_groups.sort_by_key(|group| {
2414 let range = &group.entries[group.primary_ix].range;
2415 if range.start >= cursor_offset {
2416 range.start - cursor_offset
2417 } else if cursor_offset >= range.end {
2418 cursor_offset - range.end
2419 } else {
2420 (cursor_offset - range.start).min(range.end - cursor_offset)
2421 }
2422 });
2423
2424 let mut results = Vec::new();
2425 let mut diagnostic_groups_truncated = false;
2426 let mut diagnostics_byte_count = 0;
2427 for group in diagnostic_groups {
2428 let raw_value = serde_json::value::to_raw_value(&group).unwrap();
2429 diagnostics_byte_count += raw_value.get().len();
2430 if diagnostics_byte_count > max_diagnostics_bytes {
2431 diagnostic_groups_truncated = true;
2432 break;
2433 }
2434 results.push(predict_edits_v3::DiagnosticGroup(raw_value));
2435 }
2436
2437 (results, diagnostic_groups_truncated)
2438 }
2439
2440 // TODO: Dedupe with similar code in request_prediction?
2441 pub fn cloud_request_for_zeta_cli(
2442 &mut self,
2443 project: &Entity<Project>,
2444 buffer: &Entity<Buffer>,
2445 position: language::Anchor,
2446 cx: &mut Context<Self>,
2447 ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
2448 let project_state = self.projects.get(&project.entity_id());
2449
2450 let index_state = project_state.and_then(|state| {
2451 state
2452 .syntax_index
2453 .as_ref()
2454 .map(|index| index.read_with(cx, |index, _cx| index.state().clone()))
2455 });
2456 let options = self.options.clone();
2457 let snapshot = buffer.read(cx).snapshot();
2458 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
2459 return Task::ready(Err(anyhow!("No file path for excerpt")));
2460 };
2461 let worktree_snapshots = project
2462 .read(cx)
2463 .worktrees(cx)
2464 .map(|worktree| worktree.read(cx).snapshot())
2465 .collect::<Vec<_>>();
2466
2467 let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
2468 let mut path = f.worktree.read(cx).absolutize(&f.path);
2469 if path.pop() { Some(path) } else { None }
2470 });
2471
2472 cx.background_spawn(async move {
2473 let index_state = if let Some(index_state) = index_state {
2474 Some(index_state.lock_owned().await)
2475 } else {
2476 None
2477 };
2478
2479 let cursor_point = position.to_point(&snapshot);
2480
2481 let debug_info = true;
2482 EditPredictionContext::gather_context(
2483 cursor_point,
2484 &snapshot,
2485 parent_abs_path.as_deref(),
2486 match &options.context {
2487 ContextMode::Agentic(_) => {
2488 // TODO
2489 panic!("Llm mode not supported in zeta cli yet");
2490 }
2491 ContextMode::Syntax(edit_prediction_context_options) => {
2492 edit_prediction_context_options
2493 }
2494 },
2495 index_state.as_deref(),
2496 )
2497 .context("Failed to select excerpt")
2498 .map(|context| {
2499 make_syntax_context_cloud_request(
2500 excerpt_path.into(),
2501 context,
2502 // TODO pass everything
2503 Vec::new(),
2504 false,
2505 Vec::new(),
2506 false,
2507 None,
2508 debug_info,
2509 &worktree_snapshots,
2510 index_state.as_deref(),
2511 Some(options.max_prompt_bytes),
2512 options.prompt_format,
2513 )
2514 })
2515 })
2516 }
2517
2518 pub fn wait_for_initial_indexing(
2519 &mut self,
2520 project: &Entity<Project>,
2521 cx: &mut Context<Self>,
2522 ) -> Task<Result<()>> {
2523 let zeta_project = self.get_or_init_zeta_project(project, cx);
2524 if let Some(syntax_index) = &zeta_project.syntax_index {
2525 syntax_index.read(cx).wait_for_initial_file_indexing(cx)
2526 } else {
2527 Task::ready(Ok(()))
2528 }
2529 }
2530
2531 fn is_file_open_source(
2532 &self,
2533 project: &Entity<Project>,
2534 file: &Arc<dyn File>,
2535 cx: &App,
2536 ) -> bool {
2537 if !file.is_local() || file.is_private() {
2538 return false;
2539 }
2540 let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
2541 return false;
2542 };
2543 zeta_project
2544 .license_detection_watchers
2545 .get(&file.worktree_id(cx))
2546 .as_ref()
2547 .is_some_and(|watcher| watcher.is_project_open_source())
2548 }
2549
2550 fn can_collect_file(&self, project: &Entity<Project>, file: &Arc<dyn File>, cx: &App) -> bool {
2551 self.data_collection_choice.is_enabled() && self.is_file_open_source(project, file, cx)
2552 }
2553
2554 fn can_collect_events(&self, events: &[Arc<Event>]) -> bool {
2555 if !self.data_collection_choice.is_enabled() {
2556 return false;
2557 }
2558 events.iter().all(|event| {
2559 matches!(
2560 event.as_ref(),
2561 Event::BufferChange {
2562 in_open_source_repo: true,
2563 ..
2564 }
2565 )
2566 })
2567 }
2568
2569 fn load_data_collection_choice() -> DataCollectionChoice {
2570 let choice = KEY_VALUE_STORE
2571 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2572 .log_err()
2573 .flatten();
2574
2575 match choice.as_deref() {
2576 Some("true") => DataCollectionChoice::Enabled,
2577 Some("false") => DataCollectionChoice::Disabled,
2578 Some(_) => {
2579 log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
2580 DataCollectionChoice::NotAnswered
2581 }
2582 None => DataCollectionChoice::NotAnswered,
2583 }
2584 }
2585
2586 pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
2587 self.shown_predictions.iter()
2588 }
2589
2590 pub fn shown_completions_len(&self) -> usize {
2591 self.shown_predictions.len()
2592 }
2593
2594 pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
2595 self.rated_predictions.contains(id)
2596 }
2597
2598 pub fn rate_prediction(
2599 &mut self,
2600 prediction: &EditPrediction,
2601 rating: EditPredictionRating,
2602 feedback: String,
2603 cx: &mut Context<Self>,
2604 ) {
2605 self.rated_predictions.insert(prediction.id.clone());
2606 telemetry::event!(
2607 "Edit Prediction Rated",
2608 rating,
2609 inputs = prediction.inputs,
2610 output = prediction.edit_preview.as_unified_diff(&prediction.edits),
2611 feedback
2612 );
2613 self.client.telemetry().flush_events().detach();
2614 cx.notify();
2615 }
2616}
2617
2618pub fn text_from_response(mut res: open_ai::Response) -> Option<String> {
2619 let choice = res.choices.pop()?;
2620 let output_text = match choice.message {
2621 open_ai::RequestMessage::Assistant {
2622 content: Some(open_ai::MessageContent::Plain(content)),
2623 ..
2624 } => content,
2625 open_ai::RequestMessage::Assistant {
2626 content: Some(open_ai::MessageContent::Multipart(mut content)),
2627 ..
2628 } => {
2629 if content.is_empty() {
2630 log::error!("No output from Baseten completion response");
2631 return None;
2632 }
2633
2634 match content.remove(0) {
2635 open_ai::MessagePart::Text { text } => text,
2636 open_ai::MessagePart::Image { .. } => {
2637 log::error!("Expected text, got an image");
2638 return None;
2639 }
2640 }
2641 }
2642 _ => {
2643 log::error!("Invalid response message: {:?}", choice.message);
2644 return None;
2645 }
2646 };
2647 Some(output_text)
2648}
2649
2650#[derive(Error, Debug)]
2651#[error(
2652 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2653)]
2654pub struct ZedUpdateRequiredError {
2655 minimum_version: Version,
2656}
2657
2658fn make_syntax_context_cloud_request(
2659 excerpt_path: Arc<Path>,
2660 context: EditPredictionContext,
2661 events: Vec<Arc<predict_edits_v3::Event>>,
2662 can_collect_data: bool,
2663 diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
2664 diagnostic_groups_truncated: bool,
2665 git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
2666 debug_info: bool,
2667 worktrees: &Vec<worktree::Snapshot>,
2668 index_state: Option<&SyntaxIndexState>,
2669 prompt_max_bytes: Option<usize>,
2670 prompt_format: PromptFormat,
2671) -> predict_edits_v3::PredictEditsRequest {
2672 let mut signatures = Vec::new();
2673 let mut declaration_to_signature_index = HashMap::default();
2674 let mut referenced_declarations = Vec::new();
2675
2676 for snippet in context.declarations {
2677 let project_entry_id = snippet.declaration.project_entry_id();
2678 let Some(path) = worktrees.iter().find_map(|worktree| {
2679 worktree.entry_for_id(project_entry_id).map(|entry| {
2680 let mut full_path = RelPathBuf::new();
2681 full_path.push(worktree.root_name());
2682 full_path.push(&entry.path);
2683 full_path
2684 })
2685 }) else {
2686 continue;
2687 };
2688
2689 let parent_index = index_state.and_then(|index_state| {
2690 snippet.declaration.parent().and_then(|parent| {
2691 add_signature(
2692 parent,
2693 &mut declaration_to_signature_index,
2694 &mut signatures,
2695 index_state,
2696 )
2697 })
2698 });
2699
2700 let (text, text_is_truncated) = snippet.declaration.item_text();
2701 referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
2702 path: path.as_std_path().into(),
2703 text: text.into(),
2704 range: snippet.declaration.item_line_range(),
2705 text_is_truncated,
2706 signature_range: snippet.declaration.signature_range_in_item_text(),
2707 parent_index,
2708 signature_score: snippet.score(DeclarationStyle::Signature),
2709 declaration_score: snippet.score(DeclarationStyle::Declaration),
2710 score_components: snippet.components,
2711 });
2712 }
2713
2714 let excerpt_parent = index_state.and_then(|index_state| {
2715 context
2716 .excerpt
2717 .parent_declarations
2718 .last()
2719 .and_then(|(parent, _)| {
2720 add_signature(
2721 *parent,
2722 &mut declaration_to_signature_index,
2723 &mut signatures,
2724 index_state,
2725 )
2726 })
2727 });
2728
2729 predict_edits_v3::PredictEditsRequest {
2730 excerpt_path,
2731 excerpt: context.excerpt_text.body,
2732 excerpt_line_range: context.excerpt.line_range,
2733 excerpt_range: context.excerpt.range,
2734 cursor_point: predict_edits_v3::Point {
2735 line: predict_edits_v3::Line(context.cursor_point.row),
2736 column: context.cursor_point.column,
2737 },
2738 referenced_declarations,
2739 included_files: vec![],
2740 signatures,
2741 excerpt_parent,
2742 events,
2743 can_collect_data,
2744 diagnostic_groups,
2745 diagnostic_groups_truncated,
2746 git_info,
2747 debug_info,
2748 prompt_max_bytes,
2749 prompt_format,
2750 }
2751}
2752
2753fn add_signature(
2754 declaration_id: DeclarationId,
2755 declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
2756 signatures: &mut Vec<Signature>,
2757 index: &SyntaxIndexState,
2758) -> Option<usize> {
2759 if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
2760 return Some(*signature_index);
2761 }
2762 let Some(parent_declaration) = index.declaration(declaration_id) else {
2763 log::error!("bug: missing parent declaration");
2764 return None;
2765 };
2766 let parent_index = parent_declaration.parent().and_then(|parent| {
2767 add_signature(parent, declaration_to_signature_index, signatures, index)
2768 });
2769 let (text, text_is_truncated) = parent_declaration.signature_text();
2770 let signature_index = signatures.len();
2771 signatures.push(Signature {
2772 text: text.into(),
2773 text_is_truncated,
2774 parent_index,
2775 range: parent_declaration.signature_line_range(),
2776 });
2777 declaration_to_signature_index.insert(declaration_id, signature_index);
2778 Some(signature_index)
2779}
2780
2781#[cfg(feature = "eval-support")]
2782pub type EvalCacheKey = (EvalCacheEntryKind, u64);
2783
2784#[cfg(feature = "eval-support")]
2785#[derive(Debug, Clone, Copy, PartialEq)]
2786pub enum EvalCacheEntryKind {
2787 Context,
2788 Search,
2789 Prediction,
2790}
2791
2792#[cfg(feature = "eval-support")]
2793impl std::fmt::Display for EvalCacheEntryKind {
2794 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2795 match self {
2796 EvalCacheEntryKind::Search => write!(f, "search"),
2797 EvalCacheEntryKind::Context => write!(f, "context"),
2798 EvalCacheEntryKind::Prediction => write!(f, "prediction"),
2799 }
2800 }
2801}
2802
2803#[cfg(feature = "eval-support")]
2804pub trait EvalCache: Send + Sync {
2805 fn read(&self, key: EvalCacheKey) -> Option<String>;
2806 fn write(&self, key: EvalCacheKey, input: &str, value: &str);
2807}
2808
2809#[derive(Debug, Clone, Copy)]
2810pub enum DataCollectionChoice {
2811 NotAnswered,
2812 Enabled,
2813 Disabled,
2814}
2815
2816impl DataCollectionChoice {
2817 pub fn is_enabled(self) -> bool {
2818 match self {
2819 Self::Enabled => true,
2820 Self::NotAnswered | Self::Disabled => false,
2821 }
2822 }
2823
2824 pub fn is_answered(self) -> bool {
2825 match self {
2826 Self::Enabled | Self::Disabled => true,
2827 Self::NotAnswered => false,
2828 }
2829 }
2830
2831 #[must_use]
2832 pub fn toggle(&self) -> DataCollectionChoice {
2833 match self {
2834 Self::Enabled => Self::Disabled,
2835 Self::Disabled => Self::Enabled,
2836 Self::NotAnswered => Self::Enabled,
2837 }
2838 }
2839}
2840
2841impl From<bool> for DataCollectionChoice {
2842 fn from(value: bool) -> Self {
2843 match value {
2844 true => DataCollectionChoice::Enabled,
2845 false => DataCollectionChoice::Disabled,
2846 }
2847 }
2848}
2849
2850struct ZedPredictUpsell;
2851
2852impl Dismissable for ZedPredictUpsell {
2853 const KEY: &'static str = "dismissed-edit-predict-upsell";
2854
2855 fn dismissed() -> bool {
2856 // To make this backwards compatible with older versions of Zed, we
2857 // check if the user has seen the previous Edit Prediction Onboarding
2858 // before, by checking the data collection choice which was written to
2859 // the database once the user clicked on "Accept and Enable"
2860 if KEY_VALUE_STORE
2861 .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
2862 .log_err()
2863 .is_some_and(|s| s.is_some())
2864 {
2865 return true;
2866 }
2867
2868 KEY_VALUE_STORE
2869 .read_kvp(Self::KEY)
2870 .log_err()
2871 .is_some_and(|s| s.is_some())
2872 }
2873}
2874
2875pub fn should_show_upsell_modal() -> bool {
2876 !ZedPredictUpsell::dismissed()
2877}
2878
2879pub fn init(cx: &mut App) {
2880 feature_gate_predict_edits_actions(cx);
2881
2882 cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
2883 workspace.register_action(|workspace, _: &RateCompletions, window, cx| {
2884 if cx.has_flag::<PredictEditsRateCompletionsFeatureFlag>() {
2885 RatePredictionsModal::toggle(workspace, window, cx);
2886 }
2887 });
2888
2889 workspace.register_action(
2890 move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
2891 ZedPredictModal::toggle(
2892 workspace,
2893 workspace.user_store().clone(),
2894 workspace.client().clone(),
2895 window,
2896 cx,
2897 )
2898 },
2899 );
2900
2901 workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
2902 update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
2903 settings
2904 .project
2905 .all_languages
2906 .features
2907 .get_or_insert_default()
2908 .edit_prediction_provider = Some(EditPredictionProvider::None)
2909 });
2910 });
2911 })
2912 .detach();
2913}
2914
2915fn feature_gate_predict_edits_actions(cx: &mut App) {
2916 let rate_completion_action_types = [TypeId::of::<RateCompletions>()];
2917 let reset_onboarding_action_types = [TypeId::of::<ResetOnboarding>()];
2918 let zeta_all_action_types = [
2919 TypeId::of::<RateCompletions>(),
2920 TypeId::of::<ResetOnboarding>(),
2921 zed_actions::OpenZedPredictOnboarding.type_id(),
2922 TypeId::of::<ClearHistory>(),
2923 TypeId::of::<ThumbsUpActivePrediction>(),
2924 TypeId::of::<ThumbsDownActivePrediction>(),
2925 TypeId::of::<NextEdit>(),
2926 TypeId::of::<PreviousEdit>(),
2927 ];
2928
2929 CommandPaletteFilter::update_global(cx, |filter, _cx| {
2930 filter.hide_action_types(&rate_completion_action_types);
2931 filter.hide_action_types(&reset_onboarding_action_types);
2932 filter.hide_action_types(&[zed_actions::OpenZedPredictOnboarding.type_id()]);
2933 });
2934
2935 cx.observe_global::<SettingsStore>(move |cx| {
2936 let is_ai_disabled = DisableAiSettings::get_global(cx).disable_ai;
2937 let has_feature_flag = cx.has_flag::<PredictEditsRateCompletionsFeatureFlag>();
2938
2939 CommandPaletteFilter::update_global(cx, |filter, _cx| {
2940 if is_ai_disabled {
2941 filter.hide_action_types(&zeta_all_action_types);
2942 } else if has_feature_flag {
2943 filter.show_action_types(&rate_completion_action_types);
2944 } else {
2945 filter.hide_action_types(&rate_completion_action_types);
2946 }
2947 });
2948 })
2949 .detach();
2950
2951 cx.observe_flag::<PredictEditsRateCompletionsFeatureFlag, _>(move |is_enabled, cx| {
2952 if !DisableAiSettings::get_global(cx).disable_ai {
2953 if is_enabled {
2954 CommandPaletteFilter::update_global(cx, |filter, _cx| {
2955 filter.show_action_types(&rate_completion_action_types);
2956 });
2957 } else {
2958 CommandPaletteFilter::update_global(cx, |filter, _cx| {
2959 filter.hide_action_types(&rate_completion_action_types);
2960 });
2961 }
2962 }
2963 })
2964 .detach();
2965}
2966
2967#[cfg(test)]
2968mod tests {
2969 use std::{path::Path, sync::Arc};
2970
2971 use client::UserStore;
2972 use clock::FakeSystemClock;
2973 use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
2974 use futures::{
2975 AsyncReadExt, StreamExt,
2976 channel::{mpsc, oneshot},
2977 };
2978 use gpui::{
2979 Entity, TestAppContext,
2980 http_client::{FakeHttpClient, Response},
2981 prelude::*,
2982 };
2983 use indoc::indoc;
2984 use language::OffsetRangeExt as _;
2985 use open_ai::Usage;
2986 use pretty_assertions::{assert_eq, assert_matches};
2987 use project::{FakeFs, Project};
2988 use serde_json::json;
2989 use settings::SettingsStore;
2990 use util::path;
2991 use uuid::Uuid;
2992
2993 use crate::{BufferEditPrediction, Zeta};
2994
2995 #[gpui::test]
2996 async fn test_current_state(cx: &mut TestAppContext) {
2997 let (zeta, mut req_rx) = init_test(cx);
2998 let fs = FakeFs::new(cx.executor());
2999 fs.insert_tree(
3000 "/root",
3001 json!({
3002 "1.txt": "Hello!\nHow\nBye\n",
3003 "2.txt": "Hola!\nComo\nAdios\n"
3004 }),
3005 )
3006 .await;
3007 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3008
3009 zeta.update(cx, |zeta, cx| {
3010 zeta.register_project(&project, cx);
3011 });
3012
3013 let buffer1 = project
3014 .update(cx, |project, cx| {
3015 let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
3016 project.open_buffer(path, cx)
3017 })
3018 .await
3019 .unwrap();
3020 let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
3021 let position = snapshot1.anchor_before(language::Point::new(1, 3));
3022
3023 // Prediction for current file
3024
3025 zeta.update(cx, |zeta, cx| {
3026 zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
3027 });
3028 let (_request, respond_tx) = req_rx.next().await.unwrap();
3029
3030 respond_tx
3031 .send(model_response(indoc! {r"
3032 --- a/root/1.txt
3033 +++ b/root/1.txt
3034 @@ ... @@
3035 Hello!
3036 -How
3037 +How are you?
3038 Bye
3039 "}))
3040 .unwrap();
3041
3042 cx.run_until_parked();
3043
3044 zeta.read_with(cx, |zeta, cx| {
3045 let prediction = zeta
3046 .current_prediction_for_buffer(&buffer1, &project, cx)
3047 .unwrap();
3048 assert_matches!(prediction, BufferEditPrediction::Local { .. });
3049 });
3050
3051 // Context refresh
3052 let refresh_task = zeta.update(cx, |zeta, cx| {
3053 zeta.refresh_context(project.clone(), buffer1.clone(), position, cx)
3054 });
3055 let (_request, respond_tx) = req_rx.next().await.unwrap();
3056 respond_tx
3057 .send(open_ai::Response {
3058 id: Uuid::new_v4().to_string(),
3059 object: "response".into(),
3060 created: 0,
3061 model: "model".into(),
3062 choices: vec![open_ai::Choice {
3063 index: 0,
3064 message: open_ai::RequestMessage::Assistant {
3065 content: None,
3066 tool_calls: vec![open_ai::ToolCall {
3067 id: "search".into(),
3068 content: open_ai::ToolCallContent::Function {
3069 function: open_ai::FunctionContent {
3070 name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME
3071 .to_string(),
3072 arguments: serde_json::to_string(&SearchToolInput {
3073 queries: Box::new([SearchToolQuery {
3074 glob: "root/2.txt".to_string(),
3075 syntax_node: vec![],
3076 content: Some(".".into()),
3077 }]),
3078 })
3079 .unwrap(),
3080 },
3081 },
3082 }],
3083 },
3084 finish_reason: None,
3085 }],
3086 usage: Usage {
3087 prompt_tokens: 0,
3088 completion_tokens: 0,
3089 total_tokens: 0,
3090 },
3091 })
3092 .unwrap();
3093 refresh_task.await.unwrap();
3094
3095 zeta.update(cx, |zeta, cx| {
3096 zeta.discard_current_prediction(&project, cx);
3097 });
3098
3099 // Prediction for another file
3100 zeta.update(cx, |zeta, cx| {
3101 zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
3102 });
3103 let (_request, respond_tx) = req_rx.next().await.unwrap();
3104 respond_tx
3105 .send(model_response(indoc! {r#"
3106 --- a/root/2.txt
3107 +++ b/root/2.txt
3108 Hola!
3109 -Como
3110 +Como estas?
3111 Adios
3112 "#}))
3113 .unwrap();
3114 cx.run_until_parked();
3115
3116 zeta.read_with(cx, |zeta, cx| {
3117 let prediction = zeta
3118 .current_prediction_for_buffer(&buffer1, &project, cx)
3119 .unwrap();
3120 assert_matches!(
3121 prediction,
3122 BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
3123 );
3124 });
3125
3126 let buffer2 = project
3127 .update(cx, |project, cx| {
3128 let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
3129 project.open_buffer(path, cx)
3130 })
3131 .await
3132 .unwrap();
3133
3134 zeta.read_with(cx, |zeta, cx| {
3135 let prediction = zeta
3136 .current_prediction_for_buffer(&buffer2, &project, cx)
3137 .unwrap();
3138 assert_matches!(prediction, BufferEditPrediction::Local { .. });
3139 });
3140 }
3141
3142 #[gpui::test]
3143 async fn test_simple_request(cx: &mut TestAppContext) {
3144 let (zeta, mut req_rx) = init_test(cx);
3145 let fs = FakeFs::new(cx.executor());
3146 fs.insert_tree(
3147 "/root",
3148 json!({
3149 "foo.md": "Hello!\nHow\nBye\n"
3150 }),
3151 )
3152 .await;
3153 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3154
3155 let buffer = project
3156 .update(cx, |project, cx| {
3157 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3158 project.open_buffer(path, cx)
3159 })
3160 .await
3161 .unwrap();
3162 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3163 let position = snapshot.anchor_before(language::Point::new(1, 3));
3164
3165 let prediction_task = zeta.update(cx, |zeta, cx| {
3166 zeta.request_prediction(&project, &buffer, position, cx)
3167 });
3168
3169 let (_, respond_tx) = req_rx.next().await.unwrap();
3170
3171 // TODO Put back when we have a structured request again
3172 // assert_eq!(
3173 // request.excerpt_path.as_ref(),
3174 // Path::new(path!("root/foo.md"))
3175 // );
3176 // assert_eq!(
3177 // request.cursor_point,
3178 // Point {
3179 // line: Line(1),
3180 // column: 3
3181 // }
3182 // );
3183
3184 respond_tx
3185 .send(model_response(indoc! { r"
3186 --- a/root/foo.md
3187 +++ b/root/foo.md
3188 @@ ... @@
3189 Hello!
3190 -How
3191 +How are you?
3192 Bye
3193 "}))
3194 .unwrap();
3195
3196 let prediction = prediction_task.await.unwrap().unwrap();
3197
3198 assert_eq!(prediction.edits.len(), 1);
3199 assert_eq!(
3200 prediction.edits[0].0.to_point(&snapshot).start,
3201 language::Point::new(1, 3)
3202 );
3203 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
3204 }
3205
3206 #[gpui::test]
3207 async fn test_request_events(cx: &mut TestAppContext) {
3208 let (zeta, mut req_rx) = init_test(cx);
3209 let fs = FakeFs::new(cx.executor());
3210 fs.insert_tree(
3211 "/root",
3212 json!({
3213 "foo.md": "Hello!\n\nBye\n"
3214 }),
3215 )
3216 .await;
3217 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3218
3219 let buffer = project
3220 .update(cx, |project, cx| {
3221 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3222 project.open_buffer(path, cx)
3223 })
3224 .await
3225 .unwrap();
3226
3227 zeta.update(cx, |zeta, cx| {
3228 zeta.register_buffer(&buffer, &project, cx);
3229 });
3230
3231 buffer.update(cx, |buffer, cx| {
3232 buffer.edit(vec![(7..7, "How")], None, cx);
3233 });
3234
3235 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3236 let position = snapshot.anchor_before(language::Point::new(1, 3));
3237
3238 let prediction_task = zeta.update(cx, |zeta, cx| {
3239 zeta.request_prediction(&project, &buffer, position, cx)
3240 });
3241
3242 let (request, respond_tx) = req_rx.next().await.unwrap();
3243
3244 let prompt = prompt_from_request(&request);
3245 assert!(
3246 prompt.contains(indoc! {"
3247 --- a/root/foo.md
3248 +++ b/root/foo.md
3249 @@ -1,3 +1,3 @@
3250 Hello!
3251 -
3252 +How
3253 Bye
3254 "}),
3255 "{prompt}"
3256 );
3257
3258 respond_tx
3259 .send(model_response(indoc! {r#"
3260 --- a/root/foo.md
3261 +++ b/root/foo.md
3262 @@ ... @@
3263 Hello!
3264 -How
3265 +How are you?
3266 Bye
3267 "#}))
3268 .unwrap();
3269
3270 let prediction = prediction_task.await.unwrap().unwrap();
3271
3272 assert_eq!(prediction.edits.len(), 1);
3273 assert_eq!(
3274 prediction.edits[0].0.to_point(&snapshot).start,
3275 language::Point::new(1, 3)
3276 );
3277 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
3278 }
3279
3280 // Skipped until we start including diagnostics in prompt
3281 // #[gpui::test]
3282 // async fn test_request_diagnostics(cx: &mut TestAppContext) {
3283 // let (zeta, mut req_rx) = init_test(cx);
3284 // let fs = FakeFs::new(cx.executor());
3285 // fs.insert_tree(
3286 // "/root",
3287 // json!({
3288 // "foo.md": "Hello!\nBye"
3289 // }),
3290 // )
3291 // .await;
3292 // let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
3293
3294 // let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
3295 // let diagnostic = lsp::Diagnostic {
3296 // range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
3297 // severity: Some(lsp::DiagnosticSeverity::ERROR),
3298 // message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
3299 // ..Default::default()
3300 // };
3301
3302 // project.update(cx, |project, cx| {
3303 // project.lsp_store().update(cx, |lsp_store, cx| {
3304 // // Create some diagnostics
3305 // lsp_store
3306 // .update_diagnostics(
3307 // LanguageServerId(0),
3308 // lsp::PublishDiagnosticsParams {
3309 // uri: path_to_buffer_uri.clone(),
3310 // diagnostics: vec![diagnostic],
3311 // version: None,
3312 // },
3313 // None,
3314 // language::DiagnosticSourceKind::Pushed,
3315 // &[],
3316 // cx,
3317 // )
3318 // .unwrap();
3319 // });
3320 // });
3321
3322 // let buffer = project
3323 // .update(cx, |project, cx| {
3324 // let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
3325 // project.open_buffer(path, cx)
3326 // })
3327 // .await
3328 // .unwrap();
3329
3330 // let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
3331 // let position = snapshot.anchor_before(language::Point::new(0, 0));
3332
3333 // let _prediction_task = zeta.update(cx, |zeta, cx| {
3334 // zeta.request_prediction(&project, &buffer, position, cx)
3335 // });
3336
3337 // let (request, _respond_tx) = req_rx.next().await.unwrap();
3338
3339 // assert_eq!(request.diagnostic_groups.len(), 1);
3340 // let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
3341 // .unwrap();
3342 // // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
3343 // assert_eq!(
3344 // value,
3345 // json!({
3346 // "entries": [{
3347 // "range": {
3348 // "start": 8,
3349 // "end": 10
3350 // },
3351 // "diagnostic": {
3352 // "source": null,
3353 // "code": null,
3354 // "code_description": null,
3355 // "severity": 1,
3356 // "message": "\"Hello\" deprecated. Use \"Hi\" instead",
3357 // "markdown": null,
3358 // "group_id": 0,
3359 // "is_primary": true,
3360 // "is_disk_based": false,
3361 // "is_unnecessary": false,
3362 // "source_kind": "Pushed",
3363 // "data": null,
3364 // "underline": true
3365 // }
3366 // }],
3367 // "primary_ix": 0
3368 // })
3369 // );
3370 // }
3371
3372 fn model_response(text: &str) -> open_ai::Response {
3373 open_ai::Response {
3374 id: Uuid::new_v4().to_string(),
3375 object: "response".into(),
3376 created: 0,
3377 model: "model".into(),
3378 choices: vec![open_ai::Choice {
3379 index: 0,
3380 message: open_ai::RequestMessage::Assistant {
3381 content: Some(open_ai::MessageContent::Plain(text.to_string())),
3382 tool_calls: vec![],
3383 },
3384 finish_reason: None,
3385 }],
3386 usage: Usage {
3387 prompt_tokens: 0,
3388 completion_tokens: 0,
3389 total_tokens: 0,
3390 },
3391 }
3392 }
3393
3394 fn prompt_from_request(request: &open_ai::Request) -> &str {
3395 assert_eq!(request.messages.len(), 1);
3396 let open_ai::RequestMessage::User {
3397 content: open_ai::MessageContent::Plain(content),
3398 ..
3399 } = &request.messages[0]
3400 else {
3401 panic!(
3402 "Request does not have single user message of type Plain. {:#?}",
3403 request
3404 );
3405 };
3406 content
3407 }
3408
3409 fn init_test(
3410 cx: &mut TestAppContext,
3411 ) -> (
3412 Entity<Zeta>,
3413 mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender<open_ai::Response>)>,
3414 ) {
3415 cx.update(move |cx| {
3416 let settings_store = SettingsStore::test(cx);
3417 cx.set_global(settings_store);
3418 zlog::init_test();
3419
3420 let (req_tx, req_rx) = mpsc::unbounded();
3421
3422 let http_client = FakeHttpClient::create({
3423 move |req| {
3424 let uri = req.uri().path().to_string();
3425 let mut body = req.into_body();
3426 let req_tx = req_tx.clone();
3427 async move {
3428 let resp = match uri.as_str() {
3429 "/client/llm_tokens" => serde_json::to_string(&json!({
3430 "token": "test"
3431 }))
3432 .unwrap(),
3433 "/predict_edits/raw" => {
3434 let mut buf = Vec::new();
3435 body.read_to_end(&mut buf).await.ok();
3436 let req = serde_json::from_slice(&buf).unwrap();
3437
3438 let (res_tx, res_rx) = oneshot::channel();
3439 req_tx.unbounded_send((req, res_tx)).unwrap();
3440 serde_json::to_string(&res_rx.await?).unwrap()
3441 }
3442 _ => {
3443 panic!("Unexpected path: {}", uri)
3444 }
3445 };
3446
3447 Ok(Response::builder().body(resp.into()).unwrap())
3448 }
3449 }
3450 });
3451
3452 let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
3453 client.cloud_client().set_credentials(1, "test".into());
3454
3455 language_model::init(client.clone(), cx);
3456
3457 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
3458 let zeta = Zeta::global(&client, &user_store, cx);
3459
3460 (zeta, req_rx)
3461 })
3462 }
3463}