1use anyhow::{Context as _, Result, anyhow, bail};
2use chrono::TimeDelta;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
5use cloud_llm_client::{
6 AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
7 ZED_VERSION_HEADER_NAME,
8};
9use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
10use cloud_zeta2_prompt::{CURSOR_MARKER, DEFAULT_MAX_PROMPT_BYTES};
11use collections::HashMap;
12use edit_prediction_context::{
13 DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
14 EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionScoreOptions, Line,
15 SyntaxIndex, SyntaxIndexState,
16};
17use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
18use futures::AsyncReadExt as _;
19use futures::channel::{mpsc, oneshot};
20use gpui::http_client::{AsyncBody, Method};
21use gpui::{
22 App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
23 http_client, prelude::*,
24};
25use language::{Anchor, Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
26use language::{BufferSnapshot, OffsetRangeExt};
27use language_model::{LlmApiToken, RefreshLlmTokenListener};
28use open_ai::FunctionDefinition;
29use project::Project;
30use release_channel::AppVersion;
31use serde::de::DeserializeOwned;
32use std::collections::{VecDeque, hash_map};
33
34use std::env;
35use std::ops::Range;
36use std::path::Path;
37use std::str::FromStr as _;
38use std::sync::{Arc, LazyLock};
39use std::time::{Duration, Instant};
40use thiserror::Error;
41use util::rel_path::RelPathBuf;
42use util::{LogErrorFuture, TryFutureExt};
43use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
44
45pub mod assemble_excerpts;
46mod prediction;
47mod provider;
48pub mod retrieval_search;
49pub mod udiff;
50mod xml_edits;
51
52use crate::assemble_excerpts::assemble_excerpts;
53use crate::prediction::EditPrediction;
54pub use crate::prediction::EditPredictionId;
55pub use provider::ZetaEditPredictionProvider;
56
57/// Maximum number of events to track.
58const MAX_EVENT_COUNT: usize = 16;
59
60pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
61 max_bytes: 512,
62 min_bytes: 128,
63 target_before_cursor_over_total_bytes: 0.5,
64};
65
66pub const DEFAULT_CONTEXT_OPTIONS: ContextMode =
67 ContextMode::Agentic(DEFAULT_AGENTIC_CONTEXT_OPTIONS);
68
69pub const DEFAULT_AGENTIC_CONTEXT_OPTIONS: AgenticContextOptions = AgenticContextOptions {
70 excerpt: DEFAULT_EXCERPT_OPTIONS,
71};
72
73pub const DEFAULT_SYNTAX_CONTEXT_OPTIONS: EditPredictionContextOptions =
74 EditPredictionContextOptions {
75 use_imports: true,
76 max_retrieved_declarations: 0,
77 excerpt: DEFAULT_EXCERPT_OPTIONS,
78 score: EditPredictionScoreOptions {
79 omit_excerpt_overlaps: true,
80 },
81 };
82
83pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
84 context: DEFAULT_CONTEXT_OPTIONS,
85 max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
86 max_diagnostic_bytes: 2048,
87 prompt_format: PromptFormat::DEFAULT,
88 file_indexing_parallelism: 1,
89 buffer_change_grouping_interval: Duration::from_secs(1),
90};
91
92static USE_OLLAMA: LazyLock<bool> =
93 LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
94static CONTEXT_RETRIEVAL_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
95 env::var("ZED_ZETA2_CONTEXT_MODEL").unwrap_or(if *USE_OLLAMA {
96 "qwen3-coder:30b".to_string()
97 } else {
98 "yqvev8r3".to_string()
99 })
100});
101static EDIT_PREDICTIONS_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
102 match env::var("ZED_ZETA2_MODEL").as_deref() {
103 Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten
104 Ok(model) => model,
105 Err(_) if *USE_OLLAMA => "qwen3-coder:30b",
106 Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten
107 }
108 .to_string()
109});
110static PREDICT_EDITS_URL: LazyLock<Option<String>> = LazyLock::new(|| {
111 env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| {
112 if *USE_OLLAMA {
113 Some("http://localhost:11434/v1/chat/completions".into())
114 } else {
115 None
116 }
117 })
118});
119
120pub struct Zeta2FeatureFlag;
121
122impl FeatureFlag for Zeta2FeatureFlag {
123 const NAME: &'static str = "zeta2";
124
125 fn enabled_for_staff() -> bool {
126 false
127 }
128}
129
130#[derive(Clone)]
131struct ZetaGlobal(Entity<Zeta>);
132
133impl Global for ZetaGlobal {}
134
135pub struct Zeta {
136 client: Arc<Client>,
137 user_store: Entity<UserStore>,
138 llm_token: LlmApiToken,
139 _llm_token_subscription: Subscription,
140 projects: HashMap<EntityId, ZetaProject>,
141 options: ZetaOptions,
142 update_required: bool,
143 debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
144 #[cfg(feature = "eval-support")]
145 eval_cache: Option<Arc<dyn EvalCache>>,
146}
147
148#[derive(Debug, Clone, PartialEq)]
149pub struct ZetaOptions {
150 pub context: ContextMode,
151 pub max_prompt_bytes: usize,
152 pub max_diagnostic_bytes: usize,
153 pub prompt_format: predict_edits_v3::PromptFormat,
154 pub file_indexing_parallelism: usize,
155 pub buffer_change_grouping_interval: Duration,
156}
157
158#[derive(Debug, Clone, PartialEq)]
159pub enum ContextMode {
160 Agentic(AgenticContextOptions),
161 Syntax(EditPredictionContextOptions),
162}
163
164#[derive(Debug, Clone, PartialEq)]
165pub struct AgenticContextOptions {
166 pub excerpt: EditPredictionExcerptOptions,
167}
168
169impl ContextMode {
170 pub fn excerpt(&self) -> &EditPredictionExcerptOptions {
171 match self {
172 ContextMode::Agentic(options) => &options.excerpt,
173 ContextMode::Syntax(options) => &options.excerpt,
174 }
175 }
176}
177
178#[derive(Debug)]
179pub enum ZetaDebugInfo {
180 ContextRetrievalStarted(ZetaContextRetrievalStartedDebugInfo),
181 SearchQueriesGenerated(ZetaSearchQueryDebugInfo),
182 SearchQueriesExecuted(ZetaContextRetrievalDebugInfo),
183 ContextRetrievalFinished(ZetaContextRetrievalDebugInfo),
184 EditPredictionRequested(ZetaEditPredictionDebugInfo),
185}
186
187#[derive(Debug)]
188pub struct ZetaContextRetrievalStartedDebugInfo {
189 pub project: Entity<Project>,
190 pub timestamp: Instant,
191 pub search_prompt: String,
192}
193
194#[derive(Debug)]
195pub struct ZetaContextRetrievalDebugInfo {
196 pub project: Entity<Project>,
197 pub timestamp: Instant,
198}
199
200#[derive(Debug)]
201pub struct ZetaEditPredictionDebugInfo {
202 pub request: predict_edits_v3::PredictEditsRequest,
203 pub retrieval_time: TimeDelta,
204 pub buffer: WeakEntity<Buffer>,
205 pub position: language::Anchor,
206 pub local_prompt: Result<String, String>,
207 pub response_rx: oneshot::Receiver<(Result<open_ai::Response, String>, TimeDelta)>,
208}
209
210#[derive(Debug)]
211pub struct ZetaSearchQueryDebugInfo {
212 pub project: Entity<Project>,
213 pub timestamp: Instant,
214 pub search_queries: Vec<SearchToolQuery>,
215}
216
217pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
218
219struct ZetaProject {
220 syntax_index: Option<Entity<SyntaxIndex>>,
221 events: VecDeque<Event>,
222 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
223 current_prediction: Option<CurrentEditPrediction>,
224 context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
225 refresh_context_task: Option<LogErrorFuture<Task<Result<()>>>>,
226 refresh_context_debounce_task: Option<Task<Option<()>>>,
227 refresh_context_timestamp: Option<Instant>,
228}
229
230#[derive(Debug, Clone)]
231struct CurrentEditPrediction {
232 pub requested_by_buffer_id: EntityId,
233 pub prediction: EditPrediction,
234}
235
236impl CurrentEditPrediction {
237 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
238 let Some(new_edits) = self
239 .prediction
240 .interpolate(&self.prediction.buffer.read(cx))
241 else {
242 return false;
243 };
244
245 if self.prediction.buffer != old_prediction.prediction.buffer {
246 return true;
247 }
248
249 let Some(old_edits) = old_prediction
250 .prediction
251 .interpolate(&old_prediction.prediction.buffer.read(cx))
252 else {
253 return true;
254 };
255
256 // This reduces the occurrence of UI thrash from replacing edits
257 //
258 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
259 if self.requested_by_buffer_id == self.prediction.buffer.entity_id()
260 && self.requested_by_buffer_id == old_prediction.prediction.buffer.entity_id()
261 && old_edits.len() == 1
262 && new_edits.len() == 1
263 {
264 let (old_range, old_text) = &old_edits[0];
265 let (new_range, new_text) = &new_edits[0];
266 new_range == old_range && new_text.starts_with(old_text.as_ref())
267 } else {
268 true
269 }
270 }
271}
272
273/// A prediction from the perspective of a buffer.
274#[derive(Debug)]
275enum BufferEditPrediction<'a> {
276 Local { prediction: &'a EditPrediction },
277 Jump { prediction: &'a EditPrediction },
278}
279
280struct RegisteredBuffer {
281 snapshot: BufferSnapshot,
282 _subscriptions: [gpui::Subscription; 2],
283}
284
285#[derive(Clone)]
286pub enum Event {
287 BufferChange {
288 old_snapshot: BufferSnapshot,
289 new_snapshot: BufferSnapshot,
290 timestamp: Instant,
291 },
292}
293
294impl Event {
295 pub fn to_request_event(&self, cx: &App) -> Option<predict_edits_v3::Event> {
296 match self {
297 Event::BufferChange {
298 old_snapshot,
299 new_snapshot,
300 ..
301 } => {
302 let path = new_snapshot.file().map(|f| f.full_path(cx));
303
304 let old_path = old_snapshot.file().and_then(|f| {
305 let old_path = f.full_path(cx);
306 if Some(&old_path) != path.as_ref() {
307 Some(old_path)
308 } else {
309 None
310 }
311 });
312
313 // TODO [zeta2] move to bg?
314 let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
315
316 if path == old_path && diff.is_empty() {
317 None
318 } else {
319 Some(predict_edits_v3::Event::BufferChange {
320 old_path,
321 path,
322 diff,
323 //todo: Actually detect if this edit was predicted or not
324 predicted: false,
325 })
326 }
327 }
328 }
329 }
330}
331
332impl Zeta {
333 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
334 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
335 }
336
337 pub fn global(
338 client: &Arc<Client>,
339 user_store: &Entity<UserStore>,
340 cx: &mut App,
341 ) -> Entity<Self> {
342 cx.try_global::<ZetaGlobal>()
343 .map(|global| global.0.clone())
344 .unwrap_or_else(|| {
345 let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
346 cx.set_global(ZetaGlobal(zeta.clone()));
347 zeta
348 })
349 }
350
351 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
352 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
353
354 Self {
355 projects: HashMap::default(),
356 client,
357 user_store,
358 options: DEFAULT_OPTIONS,
359 llm_token: LlmApiToken::default(),
360 _llm_token_subscription: cx.subscribe(
361 &refresh_llm_token_listener,
362 |this, _listener, _event, cx| {
363 let client = this.client.clone();
364 let llm_token = this.llm_token.clone();
365 cx.spawn(async move |_this, _cx| {
366 llm_token.refresh(&client).await?;
367 anyhow::Ok(())
368 })
369 .detach_and_log_err(cx);
370 },
371 ),
372 update_required: false,
373 debug_tx: None,
374 #[cfg(feature = "eval-support")]
375 eval_cache: None,
376 }
377 }
378
379 #[cfg(feature = "eval-support")]
380 pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
381 self.eval_cache = Some(cache);
382 }
383
384 pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<ZetaDebugInfo> {
385 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
386 self.debug_tx = Some(debug_watch_tx);
387 debug_watch_rx
388 }
389
390 pub fn options(&self) -> &ZetaOptions {
391 &self.options
392 }
393
394 pub fn set_options(&mut self, options: ZetaOptions) {
395 self.options = options;
396 }
397
398 pub fn clear_history(&mut self) {
399 for zeta_project in self.projects.values_mut() {
400 zeta_project.events.clear();
401 }
402 }
403
404 pub fn history_for_project(&self, project: &Entity<Project>) -> impl Iterator<Item = &Event> {
405 self.projects
406 .get(&project.entity_id())
407 .map(|project| project.events.iter())
408 .into_iter()
409 .flatten()
410 }
411
412 pub fn context_for_project(
413 &self,
414 project: &Entity<Project>,
415 ) -> impl Iterator<Item = (Entity<Buffer>, &[Range<Anchor>])> {
416 self.projects
417 .get(&project.entity_id())
418 .and_then(|project| {
419 Some(
420 project
421 .context
422 .as_ref()?
423 .iter()
424 .map(|(buffer, ranges)| (buffer.clone(), ranges.as_slice())),
425 )
426 })
427 .into_iter()
428 .flatten()
429 }
430
431 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
432 self.user_store.read(cx).edit_prediction_usage()
433 }
434
435 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
436 self.get_or_init_zeta_project(project, cx);
437 }
438
439 pub fn register_buffer(
440 &mut self,
441 buffer: &Entity<Buffer>,
442 project: &Entity<Project>,
443 cx: &mut Context<Self>,
444 ) {
445 let zeta_project = self.get_or_init_zeta_project(project, cx);
446 Self::register_buffer_impl(zeta_project, buffer, project, cx);
447 }
448
449 fn get_or_init_zeta_project(
450 &mut self,
451 project: &Entity<Project>,
452 cx: &mut App,
453 ) -> &mut ZetaProject {
454 self.projects
455 .entry(project.entity_id())
456 .or_insert_with(|| ZetaProject {
457 syntax_index: if let ContextMode::Syntax(_) = &self.options.context {
458 Some(cx.new(|cx| {
459 SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
460 }))
461 } else {
462 None
463 },
464 events: VecDeque::new(),
465 registered_buffers: HashMap::default(),
466 current_prediction: None,
467 context: None,
468 refresh_context_task: None,
469 refresh_context_debounce_task: None,
470 refresh_context_timestamp: None,
471 })
472 }
473
474 fn register_buffer_impl<'a>(
475 zeta_project: &'a mut ZetaProject,
476 buffer: &Entity<Buffer>,
477 project: &Entity<Project>,
478 cx: &mut Context<Self>,
479 ) -> &'a mut RegisteredBuffer {
480 let buffer_id = buffer.entity_id();
481 match zeta_project.registered_buffers.entry(buffer_id) {
482 hash_map::Entry::Occupied(entry) => entry.into_mut(),
483 hash_map::Entry::Vacant(entry) => {
484 let snapshot = buffer.read(cx).snapshot();
485 let project_entity_id = project.entity_id();
486 entry.insert(RegisteredBuffer {
487 snapshot,
488 _subscriptions: [
489 cx.subscribe(buffer, {
490 let project = project.downgrade();
491 move |this, buffer, event, cx| {
492 if let language::BufferEvent::Edited = event
493 && let Some(project) = project.upgrade()
494 {
495 this.report_changes_for_buffer(&buffer, &project, cx);
496 }
497 }
498 }),
499 cx.observe_release(buffer, move |this, _buffer, _cx| {
500 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
501 else {
502 return;
503 };
504 zeta_project.registered_buffers.remove(&buffer_id);
505 }),
506 ],
507 })
508 }
509 }
510 }
511
512 fn report_changes_for_buffer(
513 &mut self,
514 buffer: &Entity<Buffer>,
515 project: &Entity<Project>,
516 cx: &mut Context<Self>,
517 ) -> BufferSnapshot {
518 let buffer_change_grouping_interval = self.options.buffer_change_grouping_interval;
519 let zeta_project = self.get_or_init_zeta_project(project, cx);
520 let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
521
522 let new_snapshot = buffer.read(cx).snapshot();
523 if new_snapshot.version != registered_buffer.snapshot.version {
524 let old_snapshot =
525 std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
526 Self::push_event(
527 zeta_project,
528 buffer_change_grouping_interval,
529 Event::BufferChange {
530 old_snapshot,
531 new_snapshot: new_snapshot.clone(),
532 timestamp: Instant::now(),
533 },
534 );
535 }
536
537 new_snapshot
538 }
539
540 fn push_event(
541 zeta_project: &mut ZetaProject,
542 buffer_change_grouping_interval: Duration,
543 event: Event,
544 ) {
545 let events = &mut zeta_project.events;
546
547 if buffer_change_grouping_interval > Duration::ZERO
548 && let Some(Event::BufferChange {
549 new_snapshot: last_new_snapshot,
550 timestamp: last_timestamp,
551 ..
552 }) = events.back_mut()
553 {
554 // Coalesce edits for the same buffer when they happen one after the other.
555 let Event::BufferChange {
556 old_snapshot,
557 new_snapshot,
558 timestamp,
559 } = &event;
560
561 if timestamp.duration_since(*last_timestamp) <= buffer_change_grouping_interval
562 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
563 && old_snapshot.version == last_new_snapshot.version
564 {
565 *last_new_snapshot = new_snapshot.clone();
566 *last_timestamp = *timestamp;
567 return;
568 }
569 }
570
571 if events.len() >= MAX_EVENT_COUNT {
572 // These are halved instead of popping to improve prompt caching.
573 events.drain(..MAX_EVENT_COUNT / 2);
574 }
575
576 events.push_back(event);
577 }
578
579 fn current_prediction_for_buffer(
580 &self,
581 buffer: &Entity<Buffer>,
582 project: &Entity<Project>,
583 cx: &App,
584 ) -> Option<BufferEditPrediction<'_>> {
585 let project_state = self.projects.get(&project.entity_id())?;
586
587 let CurrentEditPrediction {
588 requested_by_buffer_id,
589 prediction,
590 } = project_state.current_prediction.as_ref()?;
591
592 if prediction.targets_buffer(buffer.read(cx)) {
593 Some(BufferEditPrediction::Local { prediction })
594 } else if *requested_by_buffer_id == buffer.entity_id() {
595 Some(BufferEditPrediction::Jump { prediction })
596 } else {
597 None
598 }
599 }
600
601 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
602 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
603 return;
604 };
605
606 let Some(prediction) = project_state.current_prediction.take() else {
607 return;
608 };
609 let request_id = prediction.prediction.id.to_string();
610
611 let client = self.client.clone();
612 let llm_token = self.llm_token.clone();
613 let app_version = AppVersion::global(cx);
614 cx.spawn(async move |this, cx| {
615 let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") {
616 http_client::Url::parse(&predict_edits_url)?
617 } else {
618 client
619 .http_client()
620 .build_zed_llm_url("/predict_edits/accept", &[])?
621 };
622
623 let response = cx
624 .background_spawn(Self::send_api_request::<()>(
625 move |builder| {
626 let req = builder.uri(url.as_ref()).body(
627 serde_json::to_string(&AcceptEditPredictionBody {
628 request_id: request_id.clone(),
629 })?
630 .into(),
631 );
632 Ok(req?)
633 },
634 client,
635 llm_token,
636 app_version,
637 ))
638 .await;
639
640 Self::handle_api_response(&this, response, cx)?;
641 anyhow::Ok(())
642 })
643 .detach_and_log_err(cx);
644 }
645
646 fn discard_current_prediction(&mut self, project: &Entity<Project>) {
647 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
648 project_state.current_prediction.take();
649 };
650 }
651
652 pub fn refresh_prediction(
653 &mut self,
654 project: &Entity<Project>,
655 buffer: &Entity<Buffer>,
656 position: language::Anchor,
657 cx: &mut Context<Self>,
658 ) -> Task<Result<()>> {
659 let request_task = self.request_prediction(project, buffer, position, cx);
660 let buffer = buffer.clone();
661 let project = project.clone();
662
663 cx.spawn(async move |this, cx| {
664 if let Some(prediction) = request_task.await? {
665 this.update(cx, |this, cx| {
666 let project_state = this
667 .projects
668 .get_mut(&project.entity_id())
669 .context("Project not found")?;
670
671 let new_prediction = CurrentEditPrediction {
672 requested_by_buffer_id: buffer.entity_id(),
673 prediction: prediction,
674 };
675
676 if project_state
677 .current_prediction
678 .as_ref()
679 .is_none_or(|old_prediction| {
680 new_prediction.should_replace_prediction(&old_prediction, cx)
681 })
682 {
683 project_state.current_prediction = Some(new_prediction);
684 }
685 anyhow::Ok(())
686 })??;
687 }
688 Ok(())
689 })
690 }
691
692 pub fn request_prediction(
693 &mut self,
694 project: &Entity<Project>,
695 active_buffer: &Entity<Buffer>,
696 position: language::Anchor,
697 cx: &mut Context<Self>,
698 ) -> Task<Result<Option<EditPrediction>>> {
699 let project_state = self.projects.get(&project.entity_id());
700
701 let index_state = project_state.and_then(|state| {
702 state
703 .syntax_index
704 .as_ref()
705 .map(|syntax_index| syntax_index.read_with(cx, |index, _cx| index.state().clone()))
706 });
707 let options = self.options.clone();
708 let active_snapshot = active_buffer.read(cx).snapshot();
709 let Some(excerpt_path) = active_snapshot
710 .file()
711 .map(|path| -> Arc<Path> { path.full_path(cx).into() })
712 else {
713 return Task::ready(Err(anyhow!("No file path for excerpt")));
714 };
715 let client = self.client.clone();
716 let llm_token = self.llm_token.clone();
717 let app_version = AppVersion::global(cx);
718 let worktree_snapshots = project
719 .read(cx)
720 .worktrees(cx)
721 .map(|worktree| worktree.read(cx).snapshot())
722 .collect::<Vec<_>>();
723 let debug_tx = self.debug_tx.clone();
724
725 let events = project_state
726 .map(|state| {
727 state
728 .events
729 .iter()
730 .filter_map(|event| event.to_request_event(cx))
731 .collect::<Vec<_>>()
732 })
733 .unwrap_or_default();
734
735 let diagnostics = active_snapshot.diagnostic_sets().clone();
736
737 let parent_abs_path =
738 project::File::from_dyn(active_buffer.read(cx).file()).and_then(|f| {
739 let mut path = f.worktree.read(cx).absolutize(&f.path);
740 if path.pop() { Some(path) } else { None }
741 });
742
743 // TODO data collection
744 let can_collect_data = cx.is_staff();
745
746 let empty_context_files = HashMap::default();
747 let context_files = project_state
748 .and_then(|project_state| project_state.context.as_ref())
749 .unwrap_or(&empty_context_files);
750
751 #[cfg(feature = "eval-support")]
752 let parsed_fut = futures::future::join_all(
753 context_files
754 .keys()
755 .map(|buffer| buffer.read(cx).parsing_idle()),
756 );
757
758 let mut included_files = context_files
759 .iter()
760 .filter_map(|(buffer_entity, ranges)| {
761 let buffer = buffer_entity.read(cx);
762 Some((
763 buffer_entity.clone(),
764 buffer.snapshot(),
765 buffer.file()?.full_path(cx).into(),
766 ranges.clone(),
767 ))
768 })
769 .collect::<Vec<_>>();
770
771 included_files.sort_by(|(_, _, path_a, ranges_a), (_, _, path_b, ranges_b)| {
772 (path_a, ranges_a.len()).cmp(&(path_b, ranges_b.len()))
773 });
774
775 #[cfg(feature = "eval-support")]
776 let eval_cache = self.eval_cache.clone();
777
778 let request_task = cx.background_spawn({
779 let active_buffer = active_buffer.clone();
780 async move {
781 #[cfg(feature = "eval-support")]
782 parsed_fut.await;
783
784 let index_state = if let Some(index_state) = index_state {
785 Some(index_state.lock_owned().await)
786 } else {
787 None
788 };
789
790 let cursor_offset = position.to_offset(&active_snapshot);
791 let cursor_point = cursor_offset.to_point(&active_snapshot);
792
793 let before_retrieval = chrono::Utc::now();
794
795 let (diagnostic_groups, diagnostic_groups_truncated) =
796 Self::gather_nearby_diagnostics(
797 cursor_offset,
798 &diagnostics,
799 &active_snapshot,
800 options.max_diagnostic_bytes,
801 );
802
803 let cloud_request = match options.context {
804 ContextMode::Agentic(context_options) => {
805 let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
806 cursor_point,
807 &active_snapshot,
808 &context_options.excerpt,
809 index_state.as_deref(),
810 ) else {
811 return Ok((None, None));
812 };
813
814 let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start)
815 ..active_snapshot.anchor_before(excerpt.range.end);
816
817 if let Some(buffer_ix) =
818 included_files.iter().position(|(_, snapshot, _, _)| {
819 snapshot.remote_id() == active_snapshot.remote_id()
820 })
821 {
822 let (_, buffer, _, ranges) = &mut included_files[buffer_ix];
823 ranges.push(excerpt_anchor_range);
824 retrieval_search::merge_anchor_ranges(ranges, buffer);
825 let last_ix = included_files.len() - 1;
826 included_files.swap(buffer_ix, last_ix);
827 } else {
828 included_files.push((
829 active_buffer.clone(),
830 active_snapshot.clone(),
831 excerpt_path.clone(),
832 vec![excerpt_anchor_range],
833 ));
834 }
835
836 let included_files = included_files
837 .iter()
838 .map(|(_, snapshot, path, ranges)| {
839 let ranges = ranges
840 .iter()
841 .map(|range| {
842 let point_range = range.to_point(&snapshot);
843 Line(point_range.start.row)..Line(point_range.end.row)
844 })
845 .collect::<Vec<_>>();
846 let excerpts = assemble_excerpts(&snapshot, ranges);
847 predict_edits_v3::IncludedFile {
848 path: path.clone(),
849 max_row: Line(snapshot.max_point().row),
850 excerpts,
851 }
852 })
853 .collect::<Vec<_>>();
854
855 predict_edits_v3::PredictEditsRequest {
856 excerpt_path,
857 excerpt: String::new(),
858 excerpt_line_range: Line(0)..Line(0),
859 excerpt_range: 0..0,
860 cursor_point: predict_edits_v3::Point {
861 line: predict_edits_v3::Line(cursor_point.row),
862 column: cursor_point.column,
863 },
864 included_files,
865 referenced_declarations: vec![],
866 events,
867 can_collect_data,
868 diagnostic_groups,
869 diagnostic_groups_truncated,
870 debug_info: debug_tx.is_some(),
871 prompt_max_bytes: Some(options.max_prompt_bytes),
872 prompt_format: options.prompt_format,
873 // TODO [zeta2]
874 signatures: vec![],
875 excerpt_parent: None,
876 git_info: None,
877 }
878 }
879 ContextMode::Syntax(context_options) => {
880 let Some(context) = EditPredictionContext::gather_context(
881 cursor_point,
882 &active_snapshot,
883 parent_abs_path.as_deref(),
884 &context_options,
885 index_state.as_deref(),
886 ) else {
887 return Ok((None, None));
888 };
889
890 make_syntax_context_cloud_request(
891 excerpt_path,
892 context,
893 events,
894 can_collect_data,
895 diagnostic_groups,
896 diagnostic_groups_truncated,
897 None,
898 debug_tx.is_some(),
899 &worktree_snapshots,
900 index_state.as_deref(),
901 Some(options.max_prompt_bytes),
902 options.prompt_format,
903 )
904 }
905 };
906
907 let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request);
908
909 let retrieval_time = chrono::Utc::now() - before_retrieval;
910
911 let debug_response_tx = if let Some(debug_tx) = &debug_tx {
912 let (response_tx, response_rx) = oneshot::channel();
913
914 debug_tx
915 .unbounded_send(ZetaDebugInfo::EditPredictionRequested(
916 ZetaEditPredictionDebugInfo {
917 request: cloud_request.clone(),
918 retrieval_time,
919 buffer: active_buffer.downgrade(),
920 local_prompt: match prompt_result.as_ref() {
921 Ok((prompt, _)) => Ok(prompt.clone()),
922 Err(err) => Err(err.to_string()),
923 },
924 position,
925 response_rx,
926 },
927 ))
928 .ok();
929 Some(response_tx)
930 } else {
931 None
932 };
933
934 if cfg!(debug_assertions) && env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
935 if let Some(debug_response_tx) = debug_response_tx {
936 debug_response_tx
937 .send((Err("Request skipped".to_string()), TimeDelta::zero()))
938 .ok();
939 }
940 anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
941 }
942
943 let (prompt, _) = prompt_result?;
944 let request = open_ai::Request {
945 model: EDIT_PREDICTIONS_MODEL_ID.clone(),
946 messages: vec![open_ai::RequestMessage::User {
947 content: open_ai::MessageContent::Plain(prompt),
948 }],
949 stream: false,
950 max_completion_tokens: None,
951 stop: Default::default(),
952 temperature: 0.7,
953 tool_choice: None,
954 parallel_tool_calls: None,
955 tools: vec![],
956 prompt_cache_key: None,
957 reasoning_effort: None,
958 };
959
960 log::trace!("Sending edit prediction request");
961
962 let before_request = chrono::Utc::now();
963 let response = Self::send_raw_llm_request(
964 request,
965 client,
966 llm_token,
967 app_version,
968 #[cfg(feature = "eval-support")]
969 eval_cache,
970 #[cfg(feature = "eval-support")]
971 EvalCacheEntryKind::Prediction,
972 )
973 .await;
974 let request_time = chrono::Utc::now() - before_request;
975
976 log::trace!("Got edit prediction response");
977
978 if let Some(debug_response_tx) = debug_response_tx {
979 debug_response_tx
980 .send((
981 response
982 .as_ref()
983 .map_err(|err| err.to_string())
984 .map(|response| response.0.clone()),
985 request_time,
986 ))
987 .ok();
988 }
989
990 let (res, usage) = response?;
991 let request_id = EditPredictionId(res.id.clone().into());
992 let Some(mut output_text) = text_from_response(res) else {
993 return Ok((None, usage));
994 };
995
996 if output_text.contains(CURSOR_MARKER) {
997 log::trace!("Stripping out {CURSOR_MARKER} from response");
998 output_text = output_text.replace(CURSOR_MARKER, "");
999 }
1000
1001 let get_buffer_from_context = |path: &Path| {
1002 included_files
1003 .iter()
1004 .find_map(|(_, buffer, probe_path, ranges)| {
1005 if probe_path.as_ref() == path {
1006 Some((buffer, ranges.as_slice()))
1007 } else {
1008 None
1009 }
1010 })
1011 };
1012
1013 let (edited_buffer_snapshot, edits) = match options.prompt_format {
1014 PromptFormat::NumLinesUniDiff => {
1015 // TODO: Implement parsing of multi-file diffs
1016 crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
1017 }
1018 PromptFormat::Minimal => {
1019 if output_text.contains("--- a/\n+++ b/\nNo edits") {
1020 let edits = vec![];
1021 (&active_snapshot, edits)
1022 } else {
1023 crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
1024 }
1025 }
1026 PromptFormat::OldTextNewText => {
1027 crate::xml_edits::parse_xml_edits(&output_text, get_buffer_from_context)
1028 .await?
1029 }
1030 _ => {
1031 bail!("unsupported prompt format {}", options.prompt_format)
1032 }
1033 };
1034
1035 let edited_buffer = included_files
1036 .iter()
1037 .find_map(|(buffer, snapshot, _, _)| {
1038 if snapshot.remote_id() == edited_buffer_snapshot.remote_id() {
1039 Some(buffer.clone())
1040 } else {
1041 None
1042 }
1043 })
1044 .context("Failed to find buffer in included_buffers")?;
1045
1046 anyhow::Ok((
1047 Some((
1048 request_id,
1049 edited_buffer,
1050 edited_buffer_snapshot.clone(),
1051 edits,
1052 )),
1053 usage,
1054 ))
1055 }
1056 });
1057
1058 cx.spawn({
1059 async move |this, cx| {
1060 let Some((id, edited_buffer, edited_buffer_snapshot, edits)) =
1061 Self::handle_api_response(&this, request_task.await, cx)?
1062 else {
1063 return Ok(None);
1064 };
1065
1066 // TODO telemetry: duration, etc
1067 Ok(
1068 EditPrediction::new(id, &edited_buffer, &edited_buffer_snapshot, edits, cx)
1069 .await,
1070 )
1071 }
1072 })
1073 }
1074
1075 async fn send_raw_llm_request(
1076 request: open_ai::Request,
1077 client: Arc<Client>,
1078 llm_token: LlmApiToken,
1079 app_version: SemanticVersion,
1080 #[cfg(feature = "eval-support")] eval_cache: Option<Arc<dyn EvalCache>>,
1081 #[cfg(feature = "eval-support")] eval_cache_kind: EvalCacheEntryKind,
1082 ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
1083 let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
1084 http_client::Url::parse(&predict_edits_url)?
1085 } else {
1086 client
1087 .http_client()
1088 .build_zed_llm_url("/predict_edits/raw", &[])?
1089 };
1090
1091 #[cfg(feature = "eval-support")]
1092 let cache_key = if let Some(cache) = eval_cache {
1093 use collections::FxHasher;
1094 use std::hash::{Hash, Hasher};
1095
1096 let mut hasher = FxHasher::default();
1097 url.hash(&mut hasher);
1098 let request_str = serde_json::to_string_pretty(&request)?;
1099 request_str.hash(&mut hasher);
1100 let hash = hasher.finish();
1101
1102 let key = (eval_cache_kind, hash);
1103 if let Some(response_str) = cache.read(key) {
1104 return Ok((serde_json::from_str(&response_str)?, None));
1105 }
1106
1107 Some((cache, request_str, key))
1108 } else {
1109 None
1110 };
1111
1112 let (response, usage) = Self::send_api_request(
1113 |builder| {
1114 let req = builder
1115 .uri(url.as_ref())
1116 .body(serde_json::to_string(&request)?.into());
1117 Ok(req?)
1118 },
1119 client,
1120 llm_token,
1121 app_version,
1122 )
1123 .await?;
1124
1125 #[cfg(feature = "eval-support")]
1126 if let Some((cache, request, key)) = cache_key {
1127 cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
1128 }
1129
1130 Ok((response, usage))
1131 }
1132
1133 fn handle_api_response<T>(
1134 this: &WeakEntity<Self>,
1135 response: Result<(T, Option<EditPredictionUsage>)>,
1136 cx: &mut gpui::AsyncApp,
1137 ) -> Result<T> {
1138 match response {
1139 Ok((data, usage)) => {
1140 if let Some(usage) = usage {
1141 this.update(cx, |this, cx| {
1142 this.user_store.update(cx, |user_store, cx| {
1143 user_store.update_edit_prediction_usage(usage, cx);
1144 });
1145 })
1146 .ok();
1147 }
1148 Ok(data)
1149 }
1150 Err(err) => {
1151 if err.is::<ZedUpdateRequiredError>() {
1152 cx.update(|cx| {
1153 this.update(cx, |this, _cx| {
1154 this.update_required = true;
1155 })
1156 .ok();
1157
1158 let error_message: SharedString = err.to_string().into();
1159 show_app_notification(
1160 NotificationId::unique::<ZedUpdateRequiredError>(),
1161 cx,
1162 move |cx| {
1163 cx.new(|cx| {
1164 ErrorMessagePrompt::new(error_message.clone(), cx)
1165 .with_link_button("Update Zed", "https://zed.dev/releases")
1166 })
1167 },
1168 );
1169 })
1170 .ok();
1171 }
1172 Err(err)
1173 }
1174 }
1175 }
1176
1177 async fn send_api_request<Res>(
1178 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1179 client: Arc<Client>,
1180 llm_token: LlmApiToken,
1181 app_version: SemanticVersion,
1182 ) -> Result<(Res, Option<EditPredictionUsage>)>
1183 where
1184 Res: DeserializeOwned,
1185 {
1186 let http_client = client.http_client();
1187 let mut token = llm_token.acquire(&client).await?;
1188 let mut did_retry = false;
1189
1190 loop {
1191 let request_builder = http_client::Request::builder().method(Method::POST);
1192
1193 let request = build(
1194 request_builder
1195 .header("Content-Type", "application/json")
1196 .header("Authorization", format!("Bearer {}", token))
1197 .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
1198 )?;
1199
1200 let mut response = http_client.send(request).await?;
1201
1202 if let Some(minimum_required_version) = response
1203 .headers()
1204 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
1205 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
1206 {
1207 anyhow::ensure!(
1208 app_version >= minimum_required_version,
1209 ZedUpdateRequiredError {
1210 minimum_version: minimum_required_version
1211 }
1212 );
1213 }
1214
1215 if response.status().is_success() {
1216 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
1217
1218 let mut body = Vec::new();
1219 response.body_mut().read_to_end(&mut body).await?;
1220 return Ok((serde_json::from_slice(&body)?, usage));
1221 } else if !did_retry
1222 && response
1223 .headers()
1224 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1225 .is_some()
1226 {
1227 did_retry = true;
1228 token = llm_token.refresh(&client).await?;
1229 } else {
1230 let mut body = String::new();
1231 response.body_mut().read_to_string(&mut body).await?;
1232 anyhow::bail!(
1233 "Request failed with status: {:?}\nBody: {}",
1234 response.status(),
1235 body
1236 );
1237 }
1238 }
1239 }
1240
1241 pub const CONTEXT_RETRIEVAL_IDLE_DURATION: Duration = Duration::from_secs(10);
1242 pub const CONTEXT_RETRIEVAL_DEBOUNCE_DURATION: Duration = Duration::from_secs(3);
1243
1244 // Refresh the related excerpts when the user just beguns editing after
1245 // an idle period, and after they pause editing.
1246 fn refresh_context_if_needed(
1247 &mut self,
1248 project: &Entity<Project>,
1249 buffer: &Entity<language::Buffer>,
1250 cursor_position: language::Anchor,
1251 cx: &mut Context<Self>,
1252 ) {
1253 if !matches!(&self.options().context, ContextMode::Agentic { .. }) {
1254 return;
1255 }
1256
1257 let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1258 return;
1259 };
1260
1261 let now = Instant::now();
1262 let was_idle = zeta_project
1263 .refresh_context_timestamp
1264 .map_or(true, |timestamp| {
1265 now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION
1266 });
1267 zeta_project.refresh_context_timestamp = Some(now);
1268 zeta_project.refresh_context_debounce_task = Some(cx.spawn({
1269 let buffer = buffer.clone();
1270 let project = project.clone();
1271 async move |this, cx| {
1272 if was_idle {
1273 log::debug!("refetching edit prediction context after idle");
1274 } else {
1275 cx.background_executor()
1276 .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION)
1277 .await;
1278 log::debug!("refetching edit prediction context after pause");
1279 }
1280 this.update(cx, |this, cx| {
1281 let task = this.refresh_context(project.clone(), buffer, cursor_position, cx);
1282
1283 if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
1284 zeta_project.refresh_context_task = Some(task.log_err());
1285 };
1286 })
1287 .ok()
1288 }
1289 }));
1290 }
1291
1292 // Refresh the related excerpts asynchronously. Ensure the task runs to completion,
1293 // and avoid spawning more than one concurrent task.
1294 pub fn refresh_context(
1295 &mut self,
1296 project: Entity<Project>,
1297 buffer: Entity<language::Buffer>,
1298 cursor_position: language::Anchor,
1299 cx: &mut Context<Self>,
1300 ) -> Task<Result<()>> {
1301 let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
1302 return Task::ready(anyhow::Ok(()));
1303 };
1304
1305 let ContextMode::Agentic(options) = &self.options().context else {
1306 return Task::ready(anyhow::Ok(()));
1307 };
1308
1309 let snapshot = buffer.read(cx).snapshot();
1310 let cursor_point = cursor_position.to_point(&snapshot);
1311 let Some(cursor_excerpt) = EditPredictionExcerpt::select_from_buffer(
1312 cursor_point,
1313 &snapshot,
1314 &options.excerpt,
1315 None,
1316 ) else {
1317 return Task::ready(Ok(()));
1318 };
1319
1320 let app_version = AppVersion::global(cx);
1321 let client = self.client.clone();
1322 let llm_token = self.llm_token.clone();
1323 let debug_tx = self.debug_tx.clone();
1324 let current_file_path: Arc<Path> = snapshot
1325 .file()
1326 .map(|f| f.full_path(cx).into())
1327 .unwrap_or_else(|| Path::new("untitled").into());
1328
1329 let prompt = match cloud_zeta2_prompt::retrieval_prompt::build_prompt(
1330 predict_edits_v3::PlanContextRetrievalRequest {
1331 excerpt: cursor_excerpt.text(&snapshot).body,
1332 excerpt_path: current_file_path,
1333 excerpt_line_range: cursor_excerpt.line_range,
1334 cursor_file_max_row: Line(snapshot.max_point().row),
1335 events: zeta_project
1336 .events
1337 .iter()
1338 .filter_map(|ev| ev.to_request_event(cx))
1339 .collect(),
1340 },
1341 ) {
1342 Ok(prompt) => prompt,
1343 Err(err) => {
1344 return Task::ready(Err(err));
1345 }
1346 };
1347
1348 if let Some(debug_tx) = &debug_tx {
1349 debug_tx
1350 .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted(
1351 ZetaContextRetrievalStartedDebugInfo {
1352 project: project.clone(),
1353 timestamp: Instant::now(),
1354 search_prompt: prompt.clone(),
1355 },
1356 ))
1357 .ok();
1358 }
1359
1360 pub static TOOL_SCHEMA: LazyLock<(serde_json::Value, String)> = LazyLock::new(|| {
1361 let schema = language_model::tool_schema::root_schema_for::<SearchToolInput>(
1362 language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset,
1363 );
1364
1365 let description = schema
1366 .get("description")
1367 .and_then(|description| description.as_str())
1368 .unwrap()
1369 .to_string();
1370
1371 (schema.into(), description)
1372 });
1373
1374 let (tool_schema, tool_description) = TOOL_SCHEMA.clone();
1375
1376 let request = open_ai::Request {
1377 model: CONTEXT_RETRIEVAL_MODEL_ID.clone(),
1378 messages: vec![open_ai::RequestMessage::User {
1379 content: open_ai::MessageContent::Plain(prompt),
1380 }],
1381 stream: false,
1382 max_completion_tokens: None,
1383 stop: Default::default(),
1384 temperature: 0.7,
1385 tool_choice: None,
1386 parallel_tool_calls: None,
1387 tools: vec![open_ai::ToolDefinition::Function {
1388 function: FunctionDefinition {
1389 name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME.to_string(),
1390 description: Some(tool_description),
1391 parameters: Some(tool_schema),
1392 },
1393 }],
1394 prompt_cache_key: None,
1395 reasoning_effort: None,
1396 };
1397
1398 #[cfg(feature = "eval-support")]
1399 let eval_cache = self.eval_cache.clone();
1400
1401 cx.spawn(async move |this, cx| {
1402 log::trace!("Sending search planning request");
1403 let response = Self::send_raw_llm_request(
1404 request,
1405 client,
1406 llm_token,
1407 app_version,
1408 #[cfg(feature = "eval-support")]
1409 eval_cache.clone(),
1410 #[cfg(feature = "eval-support")]
1411 EvalCacheEntryKind::Context,
1412 )
1413 .await;
1414 let mut response = Self::handle_api_response(&this, response, cx)?;
1415 log::trace!("Got search planning response");
1416
1417 let choice = response
1418 .choices
1419 .pop()
1420 .context("No choices in retrieval response")?;
1421 let open_ai::RequestMessage::Assistant {
1422 content: _,
1423 tool_calls,
1424 } = choice.message
1425 else {
1426 anyhow::bail!("Retrieval response didn't include an assistant message");
1427 };
1428
1429 let mut queries: Vec<SearchToolQuery> = Vec::new();
1430 for tool_call in tool_calls {
1431 let open_ai::ToolCallContent::Function { function } = tool_call.content;
1432 if function.name != cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME {
1433 log::warn!(
1434 "Context retrieval response tried to call an unknown tool: {}",
1435 function.name
1436 );
1437
1438 continue;
1439 }
1440
1441 let input: SearchToolInput = serde_json::from_str(&function.arguments)
1442 .with_context(|| format!("invalid search json {}", &function.arguments))?;
1443 queries.extend(input.queries);
1444 }
1445
1446 if let Some(debug_tx) = &debug_tx {
1447 debug_tx
1448 .unbounded_send(ZetaDebugInfo::SearchQueriesGenerated(
1449 ZetaSearchQueryDebugInfo {
1450 project: project.clone(),
1451 timestamp: Instant::now(),
1452 search_queries: queries.clone(),
1453 },
1454 ))
1455 .ok();
1456 }
1457
1458 log::trace!("Running retrieval search: {queries:#?}");
1459
1460 let related_excerpts_result = retrieval_search::run_retrieval_searches(
1461 queries,
1462 project.clone(),
1463 #[cfg(feature = "eval-support")]
1464 eval_cache,
1465 cx,
1466 )
1467 .await;
1468
1469 log::trace!("Search queries executed");
1470
1471 if let Some(debug_tx) = &debug_tx {
1472 debug_tx
1473 .unbounded_send(ZetaDebugInfo::SearchQueriesExecuted(
1474 ZetaContextRetrievalDebugInfo {
1475 project: project.clone(),
1476 timestamp: Instant::now(),
1477 },
1478 ))
1479 .ok();
1480 }
1481
1482 this.update(cx, |this, _cx| {
1483 let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
1484 return Ok(());
1485 };
1486 zeta_project.refresh_context_task.take();
1487 if let Some(debug_tx) = &this.debug_tx {
1488 debug_tx
1489 .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished(
1490 ZetaContextRetrievalDebugInfo {
1491 project,
1492 timestamp: Instant::now(),
1493 },
1494 ))
1495 .ok();
1496 }
1497 match related_excerpts_result {
1498 Ok(excerpts) => {
1499 zeta_project.context = Some(excerpts);
1500 Ok(())
1501 }
1502 Err(error) => Err(error),
1503 }
1504 })?
1505 })
1506 }
1507
1508 pub fn set_context(
1509 &mut self,
1510 project: Entity<Project>,
1511 context: HashMap<Entity<Buffer>, Vec<Range<Anchor>>>,
1512 ) {
1513 if let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) {
1514 zeta_project.context = Some(context);
1515 }
1516 }
1517
1518 fn gather_nearby_diagnostics(
1519 cursor_offset: usize,
1520 diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
1521 snapshot: &BufferSnapshot,
1522 max_diagnostics_bytes: usize,
1523 ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
1524 // TODO: Could make this more efficient
1525 let mut diagnostic_groups = Vec::new();
1526 for (language_server_id, diagnostics) in diagnostic_sets {
1527 let mut groups = Vec::new();
1528 diagnostics.groups(*language_server_id, &mut groups, &snapshot);
1529 diagnostic_groups.extend(
1530 groups
1531 .into_iter()
1532 .map(|(_, group)| group.resolve::<usize>(&snapshot)),
1533 );
1534 }
1535
1536 // sort by proximity to cursor
1537 diagnostic_groups.sort_by_key(|group| {
1538 let range = &group.entries[group.primary_ix].range;
1539 if range.start >= cursor_offset {
1540 range.start - cursor_offset
1541 } else if cursor_offset >= range.end {
1542 cursor_offset - range.end
1543 } else {
1544 (cursor_offset - range.start).min(range.end - cursor_offset)
1545 }
1546 });
1547
1548 let mut results = Vec::new();
1549 let mut diagnostic_groups_truncated = false;
1550 let mut diagnostics_byte_count = 0;
1551 for group in diagnostic_groups {
1552 let raw_value = serde_json::value::to_raw_value(&group).unwrap();
1553 diagnostics_byte_count += raw_value.get().len();
1554 if diagnostics_byte_count > max_diagnostics_bytes {
1555 diagnostic_groups_truncated = true;
1556 break;
1557 }
1558 results.push(predict_edits_v3::DiagnosticGroup(raw_value));
1559 }
1560
1561 (results, diagnostic_groups_truncated)
1562 }
1563
1564 // TODO: Dedupe with similar code in request_prediction?
1565 pub fn cloud_request_for_zeta_cli(
1566 &mut self,
1567 project: &Entity<Project>,
1568 buffer: &Entity<Buffer>,
1569 position: language::Anchor,
1570 cx: &mut Context<Self>,
1571 ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
1572 let project_state = self.projects.get(&project.entity_id());
1573
1574 let index_state = project_state.and_then(|state| {
1575 state
1576 .syntax_index
1577 .as_ref()
1578 .map(|index| index.read_with(cx, |index, _cx| index.state().clone()))
1579 });
1580 let options = self.options.clone();
1581 let snapshot = buffer.read(cx).snapshot();
1582 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
1583 return Task::ready(Err(anyhow!("No file path for excerpt")));
1584 };
1585 let worktree_snapshots = project
1586 .read(cx)
1587 .worktrees(cx)
1588 .map(|worktree| worktree.read(cx).snapshot())
1589 .collect::<Vec<_>>();
1590
1591 let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
1592 let mut path = f.worktree.read(cx).absolutize(&f.path);
1593 if path.pop() { Some(path) } else { None }
1594 });
1595
1596 cx.background_spawn(async move {
1597 let index_state = if let Some(index_state) = index_state {
1598 Some(index_state.lock_owned().await)
1599 } else {
1600 None
1601 };
1602
1603 let cursor_point = position.to_point(&snapshot);
1604
1605 let debug_info = true;
1606 EditPredictionContext::gather_context(
1607 cursor_point,
1608 &snapshot,
1609 parent_abs_path.as_deref(),
1610 match &options.context {
1611 ContextMode::Agentic(_) => {
1612 // TODO
1613 panic!("Llm mode not supported in zeta cli yet");
1614 }
1615 ContextMode::Syntax(edit_prediction_context_options) => {
1616 edit_prediction_context_options
1617 }
1618 },
1619 index_state.as_deref(),
1620 )
1621 .context("Failed to select excerpt")
1622 .map(|context| {
1623 make_syntax_context_cloud_request(
1624 excerpt_path.into(),
1625 context,
1626 // TODO pass everything
1627 Vec::new(),
1628 false,
1629 Vec::new(),
1630 false,
1631 None,
1632 debug_info,
1633 &worktree_snapshots,
1634 index_state.as_deref(),
1635 Some(options.max_prompt_bytes),
1636 options.prompt_format,
1637 )
1638 })
1639 })
1640 }
1641
1642 pub fn wait_for_initial_indexing(
1643 &mut self,
1644 project: &Entity<Project>,
1645 cx: &mut App,
1646 ) -> Task<Result<()>> {
1647 let zeta_project = self.get_or_init_zeta_project(project, cx);
1648 if let Some(syntax_index) = &zeta_project.syntax_index {
1649 syntax_index.read(cx).wait_for_initial_file_indexing(cx)
1650 } else {
1651 Task::ready(Ok(()))
1652 }
1653 }
1654}
1655
1656pub fn text_from_response(mut res: open_ai::Response) -> Option<String> {
1657 let choice = res.choices.pop()?;
1658 let output_text = match choice.message {
1659 open_ai::RequestMessage::Assistant {
1660 content: Some(open_ai::MessageContent::Plain(content)),
1661 ..
1662 } => content,
1663 open_ai::RequestMessage::Assistant {
1664 content: Some(open_ai::MessageContent::Multipart(mut content)),
1665 ..
1666 } => {
1667 if content.is_empty() {
1668 log::error!("No output from Baseten completion response");
1669 return None;
1670 }
1671
1672 match content.remove(0) {
1673 open_ai::MessagePart::Text { text } => text,
1674 open_ai::MessagePart::Image { .. } => {
1675 log::error!("Expected text, got an image");
1676 return None;
1677 }
1678 }
1679 }
1680 _ => {
1681 log::error!("Invalid response message: {:?}", choice.message);
1682 return None;
1683 }
1684 };
1685 Some(output_text)
1686}
1687
1688#[derive(Error, Debug)]
1689#[error(
1690 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1691)]
1692pub struct ZedUpdateRequiredError {
1693 minimum_version: SemanticVersion,
1694}
1695
1696fn make_syntax_context_cloud_request(
1697 excerpt_path: Arc<Path>,
1698 context: EditPredictionContext,
1699 events: Vec<predict_edits_v3::Event>,
1700 can_collect_data: bool,
1701 diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
1702 diagnostic_groups_truncated: bool,
1703 git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
1704 debug_info: bool,
1705 worktrees: &Vec<worktree::Snapshot>,
1706 index_state: Option<&SyntaxIndexState>,
1707 prompt_max_bytes: Option<usize>,
1708 prompt_format: PromptFormat,
1709) -> predict_edits_v3::PredictEditsRequest {
1710 let mut signatures = Vec::new();
1711 let mut declaration_to_signature_index = HashMap::default();
1712 let mut referenced_declarations = Vec::new();
1713
1714 for snippet in context.declarations {
1715 let project_entry_id = snippet.declaration.project_entry_id();
1716 let Some(path) = worktrees.iter().find_map(|worktree| {
1717 worktree.entry_for_id(project_entry_id).map(|entry| {
1718 let mut full_path = RelPathBuf::new();
1719 full_path.push(worktree.root_name());
1720 full_path.push(&entry.path);
1721 full_path
1722 })
1723 }) else {
1724 continue;
1725 };
1726
1727 let parent_index = index_state.and_then(|index_state| {
1728 snippet.declaration.parent().and_then(|parent| {
1729 add_signature(
1730 parent,
1731 &mut declaration_to_signature_index,
1732 &mut signatures,
1733 index_state,
1734 )
1735 })
1736 });
1737
1738 let (text, text_is_truncated) = snippet.declaration.item_text();
1739 referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
1740 path: path.as_std_path().into(),
1741 text: text.into(),
1742 range: snippet.declaration.item_line_range(),
1743 text_is_truncated,
1744 signature_range: snippet.declaration.signature_range_in_item_text(),
1745 parent_index,
1746 signature_score: snippet.score(DeclarationStyle::Signature),
1747 declaration_score: snippet.score(DeclarationStyle::Declaration),
1748 score_components: snippet.components,
1749 });
1750 }
1751
1752 let excerpt_parent = index_state.and_then(|index_state| {
1753 context
1754 .excerpt
1755 .parent_declarations
1756 .last()
1757 .and_then(|(parent, _)| {
1758 add_signature(
1759 *parent,
1760 &mut declaration_to_signature_index,
1761 &mut signatures,
1762 index_state,
1763 )
1764 })
1765 });
1766
1767 predict_edits_v3::PredictEditsRequest {
1768 excerpt_path,
1769 excerpt: context.excerpt_text.body,
1770 excerpt_line_range: context.excerpt.line_range,
1771 excerpt_range: context.excerpt.range,
1772 cursor_point: predict_edits_v3::Point {
1773 line: predict_edits_v3::Line(context.cursor_point.row),
1774 column: context.cursor_point.column,
1775 },
1776 referenced_declarations,
1777 included_files: vec![],
1778 signatures,
1779 excerpt_parent,
1780 events,
1781 can_collect_data,
1782 diagnostic_groups,
1783 diagnostic_groups_truncated,
1784 git_info,
1785 debug_info,
1786 prompt_max_bytes,
1787 prompt_format,
1788 }
1789}
1790
1791fn add_signature(
1792 declaration_id: DeclarationId,
1793 declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
1794 signatures: &mut Vec<Signature>,
1795 index: &SyntaxIndexState,
1796) -> Option<usize> {
1797 if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
1798 return Some(*signature_index);
1799 }
1800 let Some(parent_declaration) = index.declaration(declaration_id) else {
1801 log::error!("bug: missing parent declaration");
1802 return None;
1803 };
1804 let parent_index = parent_declaration.parent().and_then(|parent| {
1805 add_signature(parent, declaration_to_signature_index, signatures, index)
1806 });
1807 let (text, text_is_truncated) = parent_declaration.signature_text();
1808 let signature_index = signatures.len();
1809 signatures.push(Signature {
1810 text: text.into(),
1811 text_is_truncated,
1812 parent_index,
1813 range: parent_declaration.signature_line_range(),
1814 });
1815 declaration_to_signature_index.insert(declaration_id, signature_index);
1816 Some(signature_index)
1817}
1818
1819#[cfg(feature = "eval-support")]
1820pub type EvalCacheKey = (EvalCacheEntryKind, u64);
1821
1822#[cfg(feature = "eval-support")]
1823#[derive(Debug, Clone, Copy, PartialEq)]
1824pub enum EvalCacheEntryKind {
1825 Context,
1826 Search,
1827 Prediction,
1828}
1829
1830#[cfg(feature = "eval-support")]
1831impl std::fmt::Display for EvalCacheEntryKind {
1832 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1833 match self {
1834 EvalCacheEntryKind::Search => write!(f, "search"),
1835 EvalCacheEntryKind::Context => write!(f, "context"),
1836 EvalCacheEntryKind::Prediction => write!(f, "prediction"),
1837 }
1838 }
1839}
1840
1841#[cfg(feature = "eval-support")]
1842pub trait EvalCache: Send + Sync {
1843 fn read(&self, key: EvalCacheKey) -> Option<String>;
1844 fn write(&self, key: EvalCacheKey, input: &str, value: &str);
1845}
1846
1847#[cfg(test)]
1848mod tests {
1849 use std::{path::Path, sync::Arc};
1850
1851 use client::UserStore;
1852 use clock::FakeSystemClock;
1853 use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
1854 use futures::{
1855 AsyncReadExt, StreamExt,
1856 channel::{mpsc, oneshot},
1857 };
1858 use gpui::{
1859 Entity, TestAppContext,
1860 http_client::{FakeHttpClient, Response},
1861 prelude::*,
1862 };
1863 use indoc::indoc;
1864 use language::OffsetRangeExt as _;
1865 use open_ai::Usage;
1866 use pretty_assertions::{assert_eq, assert_matches};
1867 use project::{FakeFs, Project};
1868 use serde_json::json;
1869 use settings::SettingsStore;
1870 use util::path;
1871 use uuid::Uuid;
1872
1873 use crate::{BufferEditPrediction, Zeta};
1874
1875 #[gpui::test]
1876 async fn test_current_state(cx: &mut TestAppContext) {
1877 let (zeta, mut req_rx) = init_test(cx);
1878 let fs = FakeFs::new(cx.executor());
1879 fs.insert_tree(
1880 "/root",
1881 json!({
1882 "1.txt": "Hello!\nHow\nBye\n",
1883 "2.txt": "Hola!\nComo\nAdios\n"
1884 }),
1885 )
1886 .await;
1887 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1888
1889 zeta.update(cx, |zeta, cx| {
1890 zeta.register_project(&project, cx);
1891 });
1892
1893 let buffer1 = project
1894 .update(cx, |project, cx| {
1895 let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
1896 project.open_buffer(path, cx)
1897 })
1898 .await
1899 .unwrap();
1900 let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
1901 let position = snapshot1.anchor_before(language::Point::new(1, 3));
1902
1903 // Prediction for current file
1904
1905 let prediction_task = zeta.update(cx, |zeta, cx| {
1906 zeta.refresh_prediction(&project, &buffer1, position, cx)
1907 });
1908 let (_request, respond_tx) = req_rx.next().await.unwrap();
1909
1910 respond_tx
1911 .send(model_response(indoc! {r"
1912 --- a/root/1.txt
1913 +++ b/root/1.txt
1914 @@ ... @@
1915 Hello!
1916 -How
1917 +How are you?
1918 Bye
1919 "}))
1920 .unwrap();
1921 prediction_task.await.unwrap();
1922
1923 zeta.read_with(cx, |zeta, cx| {
1924 let prediction = zeta
1925 .current_prediction_for_buffer(&buffer1, &project, cx)
1926 .unwrap();
1927 assert_matches!(prediction, BufferEditPrediction::Local { .. });
1928 });
1929
1930 // Context refresh
1931 let refresh_task = zeta.update(cx, |zeta, cx| {
1932 zeta.refresh_context(project.clone(), buffer1.clone(), position, cx)
1933 });
1934 let (_request, respond_tx) = req_rx.next().await.unwrap();
1935 respond_tx
1936 .send(open_ai::Response {
1937 id: Uuid::new_v4().to_string(),
1938 object: "response".into(),
1939 created: 0,
1940 model: "model".into(),
1941 choices: vec![open_ai::Choice {
1942 index: 0,
1943 message: open_ai::RequestMessage::Assistant {
1944 content: None,
1945 tool_calls: vec![open_ai::ToolCall {
1946 id: "search".into(),
1947 content: open_ai::ToolCallContent::Function {
1948 function: open_ai::FunctionContent {
1949 name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME
1950 .to_string(),
1951 arguments: serde_json::to_string(&SearchToolInput {
1952 queries: Box::new([SearchToolQuery {
1953 glob: "root/2.txt".to_string(),
1954 syntax_node: vec![],
1955 content: Some(".".into()),
1956 }]),
1957 })
1958 .unwrap(),
1959 },
1960 },
1961 }],
1962 },
1963 finish_reason: None,
1964 }],
1965 usage: Usage {
1966 prompt_tokens: 0,
1967 completion_tokens: 0,
1968 total_tokens: 0,
1969 },
1970 })
1971 .unwrap();
1972 refresh_task.await.unwrap();
1973
1974 zeta.update(cx, |zeta, _cx| {
1975 zeta.discard_current_prediction(&project);
1976 });
1977
1978 // Prediction for another file
1979 let prediction_task = zeta.update(cx, |zeta, cx| {
1980 zeta.refresh_prediction(&project, &buffer1, position, cx)
1981 });
1982 let (_request, respond_tx) = req_rx.next().await.unwrap();
1983 respond_tx
1984 .send(model_response(indoc! {r#"
1985 --- a/root/2.txt
1986 +++ b/root/2.txt
1987 Hola!
1988 -Como
1989 +Como estas?
1990 Adios
1991 "#}))
1992 .unwrap();
1993 prediction_task.await.unwrap();
1994 zeta.read_with(cx, |zeta, cx| {
1995 let prediction = zeta
1996 .current_prediction_for_buffer(&buffer1, &project, cx)
1997 .unwrap();
1998 assert_matches!(
1999 prediction,
2000 BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
2001 );
2002 });
2003
2004 let buffer2 = project
2005 .update(cx, |project, cx| {
2006 let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
2007 project.open_buffer(path, cx)
2008 })
2009 .await
2010 .unwrap();
2011
2012 zeta.read_with(cx, |zeta, cx| {
2013 let prediction = zeta
2014 .current_prediction_for_buffer(&buffer2, &project, cx)
2015 .unwrap();
2016 assert_matches!(prediction, BufferEditPrediction::Local { .. });
2017 });
2018 }
2019
2020 #[gpui::test]
2021 async fn test_simple_request(cx: &mut TestAppContext) {
2022 let (zeta, mut req_rx) = init_test(cx);
2023 let fs = FakeFs::new(cx.executor());
2024 fs.insert_tree(
2025 "/root",
2026 json!({
2027 "foo.md": "Hello!\nHow\nBye\n"
2028 }),
2029 )
2030 .await;
2031 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2032
2033 let buffer = project
2034 .update(cx, |project, cx| {
2035 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2036 project.open_buffer(path, cx)
2037 })
2038 .await
2039 .unwrap();
2040 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2041 let position = snapshot.anchor_before(language::Point::new(1, 3));
2042
2043 let prediction_task = zeta.update(cx, |zeta, cx| {
2044 zeta.request_prediction(&project, &buffer, position, cx)
2045 });
2046
2047 let (_, respond_tx) = req_rx.next().await.unwrap();
2048
2049 // TODO Put back when we have a structured request again
2050 // assert_eq!(
2051 // request.excerpt_path.as_ref(),
2052 // Path::new(path!("root/foo.md"))
2053 // );
2054 // assert_eq!(
2055 // request.cursor_point,
2056 // Point {
2057 // line: Line(1),
2058 // column: 3
2059 // }
2060 // );
2061
2062 respond_tx
2063 .send(model_response(indoc! { r"
2064 --- a/root/foo.md
2065 +++ b/root/foo.md
2066 @@ ... @@
2067 Hello!
2068 -How
2069 +How are you?
2070 Bye
2071 "}))
2072 .unwrap();
2073
2074 let prediction = prediction_task.await.unwrap().unwrap();
2075
2076 assert_eq!(prediction.edits.len(), 1);
2077 assert_eq!(
2078 prediction.edits[0].0.to_point(&snapshot).start,
2079 language::Point::new(1, 3)
2080 );
2081 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
2082 }
2083
2084 #[gpui::test]
2085 async fn test_request_events(cx: &mut TestAppContext) {
2086 let (zeta, mut req_rx) = init_test(cx);
2087 let fs = FakeFs::new(cx.executor());
2088 fs.insert_tree(
2089 "/root",
2090 json!({
2091 "foo.md": "Hello!\n\nBye\n"
2092 }),
2093 )
2094 .await;
2095 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2096
2097 let buffer = project
2098 .update(cx, |project, cx| {
2099 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2100 project.open_buffer(path, cx)
2101 })
2102 .await
2103 .unwrap();
2104
2105 zeta.update(cx, |zeta, cx| {
2106 zeta.register_buffer(&buffer, &project, cx);
2107 });
2108
2109 buffer.update(cx, |buffer, cx| {
2110 buffer.edit(vec![(7..7, "How")], None, cx);
2111 });
2112
2113 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2114 let position = snapshot.anchor_before(language::Point::new(1, 3));
2115
2116 let prediction_task = zeta.update(cx, |zeta, cx| {
2117 zeta.request_prediction(&project, &buffer, position, cx)
2118 });
2119
2120 let (request, respond_tx) = req_rx.next().await.unwrap();
2121
2122 let prompt = prompt_from_request(&request);
2123 assert!(
2124 prompt.contains(indoc! {"
2125 --- a/root/foo.md
2126 +++ b/root/foo.md
2127 @@ -1,3 +1,3 @@
2128 Hello!
2129 -
2130 +How
2131 Bye
2132 "}),
2133 "{prompt}"
2134 );
2135
2136 respond_tx
2137 .send(model_response(indoc! {r#"
2138 --- a/root/foo.md
2139 +++ b/root/foo.md
2140 @@ ... @@
2141 Hello!
2142 -How
2143 +How are you?
2144 Bye
2145 "#}))
2146 .unwrap();
2147
2148 let prediction = prediction_task.await.unwrap().unwrap();
2149
2150 assert_eq!(prediction.edits.len(), 1);
2151 assert_eq!(
2152 prediction.edits[0].0.to_point(&snapshot).start,
2153 language::Point::new(1, 3)
2154 );
2155 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
2156 }
2157
2158 // Skipped until we start including diagnostics in prompt
2159 // #[gpui::test]
2160 // async fn test_request_diagnostics(cx: &mut TestAppContext) {
2161 // let (zeta, mut req_rx) = init_test(cx);
2162 // let fs = FakeFs::new(cx.executor());
2163 // fs.insert_tree(
2164 // "/root",
2165 // json!({
2166 // "foo.md": "Hello!\nBye"
2167 // }),
2168 // )
2169 // .await;
2170 // let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2171
2172 // let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
2173 // let diagnostic = lsp::Diagnostic {
2174 // range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
2175 // severity: Some(lsp::DiagnosticSeverity::ERROR),
2176 // message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
2177 // ..Default::default()
2178 // };
2179
2180 // project.update(cx, |project, cx| {
2181 // project.lsp_store().update(cx, |lsp_store, cx| {
2182 // // Create some diagnostics
2183 // lsp_store
2184 // .update_diagnostics(
2185 // LanguageServerId(0),
2186 // lsp::PublishDiagnosticsParams {
2187 // uri: path_to_buffer_uri.clone(),
2188 // diagnostics: vec![diagnostic],
2189 // version: None,
2190 // },
2191 // None,
2192 // language::DiagnosticSourceKind::Pushed,
2193 // &[],
2194 // cx,
2195 // )
2196 // .unwrap();
2197 // });
2198 // });
2199
2200 // let buffer = project
2201 // .update(cx, |project, cx| {
2202 // let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2203 // project.open_buffer(path, cx)
2204 // })
2205 // .await
2206 // .unwrap();
2207
2208 // let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2209 // let position = snapshot.anchor_before(language::Point::new(0, 0));
2210
2211 // let _prediction_task = zeta.update(cx, |zeta, cx| {
2212 // zeta.request_prediction(&project, &buffer, position, cx)
2213 // });
2214
2215 // let (request, _respond_tx) = req_rx.next().await.unwrap();
2216
2217 // assert_eq!(request.diagnostic_groups.len(), 1);
2218 // let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
2219 // .unwrap();
2220 // // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
2221 // assert_eq!(
2222 // value,
2223 // json!({
2224 // "entries": [{
2225 // "range": {
2226 // "start": 8,
2227 // "end": 10
2228 // },
2229 // "diagnostic": {
2230 // "source": null,
2231 // "code": null,
2232 // "code_description": null,
2233 // "severity": 1,
2234 // "message": "\"Hello\" deprecated. Use \"Hi\" instead",
2235 // "markdown": null,
2236 // "group_id": 0,
2237 // "is_primary": true,
2238 // "is_disk_based": false,
2239 // "is_unnecessary": false,
2240 // "source_kind": "Pushed",
2241 // "data": null,
2242 // "underline": true
2243 // }
2244 // }],
2245 // "primary_ix": 0
2246 // })
2247 // );
2248 // }
2249
2250 fn model_response(text: &str) -> open_ai::Response {
2251 open_ai::Response {
2252 id: Uuid::new_v4().to_string(),
2253 object: "response".into(),
2254 created: 0,
2255 model: "model".into(),
2256 choices: vec![open_ai::Choice {
2257 index: 0,
2258 message: open_ai::RequestMessage::Assistant {
2259 content: Some(open_ai::MessageContent::Plain(text.to_string())),
2260 tool_calls: vec![],
2261 },
2262 finish_reason: None,
2263 }],
2264 usage: Usage {
2265 prompt_tokens: 0,
2266 completion_tokens: 0,
2267 total_tokens: 0,
2268 },
2269 }
2270 }
2271
2272 fn prompt_from_request(request: &open_ai::Request) -> &str {
2273 assert_eq!(request.messages.len(), 1);
2274 let open_ai::RequestMessage::User {
2275 content: open_ai::MessageContent::Plain(content),
2276 ..
2277 } = &request.messages[0]
2278 else {
2279 panic!(
2280 "Request does not have single user message of type Plain. {:#?}",
2281 request
2282 );
2283 };
2284 content
2285 }
2286
2287 fn init_test(
2288 cx: &mut TestAppContext,
2289 ) -> (
2290 Entity<Zeta>,
2291 mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender<open_ai::Response>)>,
2292 ) {
2293 cx.update(move |cx| {
2294 let settings_store = SettingsStore::test(cx);
2295 cx.set_global(settings_store);
2296 zlog::init_test();
2297
2298 let (req_tx, req_rx) = mpsc::unbounded();
2299
2300 let http_client = FakeHttpClient::create({
2301 move |req| {
2302 let uri = req.uri().path().to_string();
2303 let mut body = req.into_body();
2304 let req_tx = req_tx.clone();
2305 async move {
2306 let resp = match uri.as_str() {
2307 "/client/llm_tokens" => serde_json::to_string(&json!({
2308 "token": "test"
2309 }))
2310 .unwrap(),
2311 "/predict_edits/raw" => {
2312 let mut buf = Vec::new();
2313 body.read_to_end(&mut buf).await.ok();
2314 let req = serde_json::from_slice(&buf).unwrap();
2315
2316 let (res_tx, res_rx) = oneshot::channel();
2317 req_tx.unbounded_send((req, res_tx)).unwrap();
2318 serde_json::to_string(&res_rx.await?).unwrap()
2319 }
2320 _ => {
2321 panic!("Unexpected path: {}", uri)
2322 }
2323 };
2324
2325 Ok(Response::builder().body(resp.into()).unwrap())
2326 }
2327 }
2328 });
2329
2330 let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
2331 client.cloud_client().set_credentials(1, "test".into());
2332
2333 language_model::init(client.clone(), cx);
2334
2335 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2336 let zeta = Zeta::global(&client, &user_store, cx);
2337
2338 (zeta, req_rx)
2339 })
2340 }
2341}