1use anyhow::{Context as _, Result, anyhow, bail};
2use chrono::TimeDelta;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
5use cloud_llm_client::{
6 AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
7 ZED_VERSION_HEADER_NAME,
8};
9use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
10use cloud_zeta2_prompt::{CURSOR_MARKER, DEFAULT_MAX_PROMPT_BYTES};
11use collections::HashMap;
12use edit_prediction_context::{
13 DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
14 EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionScoreOptions, Line,
15 SyntaxIndex, SyntaxIndexState,
16};
17use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
18use futures::AsyncReadExt as _;
19use futures::channel::{mpsc, oneshot};
20use gpui::http_client::{AsyncBody, Method};
21use gpui::{
22 App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
23 http_client, prelude::*,
24};
25use language::{Anchor, Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
26use language::{BufferSnapshot, OffsetRangeExt};
27use language_model::{LlmApiToken, RefreshLlmTokenListener};
28use open_ai::FunctionDefinition;
29use project::Project;
30use release_channel::AppVersion;
31use serde::de::DeserializeOwned;
32use std::collections::{VecDeque, hash_map};
33
34use std::env;
35use std::ops::Range;
36use std::path::Path;
37use std::str::FromStr as _;
38use std::sync::{Arc, LazyLock};
39use std::time::{Duration, Instant};
40use thiserror::Error;
41use util::rel_path::RelPathBuf;
42use util::{LogErrorFuture, TryFutureExt};
43use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
44
45pub mod merge_excerpts;
46mod prediction;
47mod provider;
48pub mod retrieval_search;
49pub mod udiff;
50mod xml_edits;
51
52use crate::merge_excerpts::merge_excerpts;
53use crate::prediction::EditPrediction;
54pub use crate::prediction::EditPredictionId;
55pub use provider::ZetaEditPredictionProvider;
56
57/// Maximum number of events to track.
58const MAX_EVENT_COUNT: usize = 16;
59
60pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
61 max_bytes: 512,
62 min_bytes: 128,
63 target_before_cursor_over_total_bytes: 0.5,
64};
65
66pub const DEFAULT_CONTEXT_OPTIONS: ContextMode =
67 ContextMode::Agentic(DEFAULT_AGENTIC_CONTEXT_OPTIONS);
68
69pub const DEFAULT_AGENTIC_CONTEXT_OPTIONS: AgenticContextOptions = AgenticContextOptions {
70 excerpt: DEFAULT_EXCERPT_OPTIONS,
71};
72
73pub const DEFAULT_SYNTAX_CONTEXT_OPTIONS: EditPredictionContextOptions =
74 EditPredictionContextOptions {
75 use_imports: true,
76 max_retrieved_declarations: 0,
77 excerpt: DEFAULT_EXCERPT_OPTIONS,
78 score: EditPredictionScoreOptions {
79 omit_excerpt_overlaps: true,
80 },
81 };
82
83pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
84 context: DEFAULT_CONTEXT_OPTIONS,
85 max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
86 max_diagnostic_bytes: 2048,
87 prompt_format: PromptFormat::DEFAULT,
88 file_indexing_parallelism: 1,
89 buffer_change_grouping_interval: Duration::from_secs(1),
90};
91
92static USE_OLLAMA: LazyLock<bool> =
93 LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
94static MODEL_ID: LazyLock<String> = LazyLock::new(|| {
95 env::var("ZED_ZETA2_MODEL").unwrap_or(if *USE_OLLAMA {
96 "qwen3-coder:30b".to_string()
97 } else {
98 "yqvev8r3".to_string()
99 })
100});
101static PREDICT_EDITS_URL: LazyLock<Option<String>> = LazyLock::new(|| {
102 env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| {
103 if *USE_OLLAMA {
104 Some("http://localhost:11434/v1/chat/completions".into())
105 } else {
106 None
107 }
108 })
109});
110
111pub struct Zeta2FeatureFlag;
112
113impl FeatureFlag for Zeta2FeatureFlag {
114 const NAME: &'static str = "zeta2";
115
116 fn enabled_for_staff() -> bool {
117 false
118 }
119}
120
121#[derive(Clone)]
122struct ZetaGlobal(Entity<Zeta>);
123
124impl Global for ZetaGlobal {}
125
126pub struct Zeta {
127 client: Arc<Client>,
128 user_store: Entity<UserStore>,
129 llm_token: LlmApiToken,
130 _llm_token_subscription: Subscription,
131 projects: HashMap<EntityId, ZetaProject>,
132 options: ZetaOptions,
133 update_required: bool,
134 debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
135 #[cfg(feature = "eval-support")]
136 eval_cache: Option<Arc<dyn EvalCache>>,
137}
138
139#[derive(Debug, Clone, PartialEq)]
140pub struct ZetaOptions {
141 pub context: ContextMode,
142 pub max_prompt_bytes: usize,
143 pub max_diagnostic_bytes: usize,
144 pub prompt_format: predict_edits_v3::PromptFormat,
145 pub file_indexing_parallelism: usize,
146 pub buffer_change_grouping_interval: Duration,
147}
148
149#[derive(Debug, Clone, PartialEq)]
150pub enum ContextMode {
151 Agentic(AgenticContextOptions),
152 Syntax(EditPredictionContextOptions),
153}
154
155#[derive(Debug, Clone, PartialEq)]
156pub struct AgenticContextOptions {
157 pub excerpt: EditPredictionExcerptOptions,
158}
159
160impl ContextMode {
161 pub fn excerpt(&self) -> &EditPredictionExcerptOptions {
162 match self {
163 ContextMode::Agentic(options) => &options.excerpt,
164 ContextMode::Syntax(options) => &options.excerpt,
165 }
166 }
167}
168
169#[derive(Debug)]
170pub enum ZetaDebugInfo {
171 ContextRetrievalStarted(ZetaContextRetrievalStartedDebugInfo),
172 SearchQueriesGenerated(ZetaSearchQueryDebugInfo),
173 SearchQueriesExecuted(ZetaContextRetrievalDebugInfo),
174 ContextRetrievalFinished(ZetaContextRetrievalDebugInfo),
175 EditPredictionRequested(ZetaEditPredictionDebugInfo),
176}
177
178#[derive(Debug)]
179pub struct ZetaContextRetrievalStartedDebugInfo {
180 pub project: Entity<Project>,
181 pub timestamp: Instant,
182 pub search_prompt: String,
183}
184
185#[derive(Debug)]
186pub struct ZetaContextRetrievalDebugInfo {
187 pub project: Entity<Project>,
188 pub timestamp: Instant,
189}
190
191#[derive(Debug)]
192pub struct ZetaEditPredictionDebugInfo {
193 pub request: predict_edits_v3::PredictEditsRequest,
194 pub retrieval_time: TimeDelta,
195 pub buffer: WeakEntity<Buffer>,
196 pub position: language::Anchor,
197 pub local_prompt: Result<String, String>,
198 pub response_rx: oneshot::Receiver<(Result<open_ai::Response, String>, TimeDelta)>,
199}
200
201#[derive(Debug)]
202pub struct ZetaSearchQueryDebugInfo {
203 pub project: Entity<Project>,
204 pub timestamp: Instant,
205 pub search_queries: Vec<SearchToolQuery>,
206}
207
208pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
209
210struct ZetaProject {
211 syntax_index: Option<Entity<SyntaxIndex>>,
212 events: VecDeque<Event>,
213 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
214 current_prediction: Option<CurrentEditPrediction>,
215 context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
216 refresh_context_task: Option<LogErrorFuture<Task<Result<()>>>>,
217 refresh_context_debounce_task: Option<Task<Option<()>>>,
218 refresh_context_timestamp: Option<Instant>,
219}
220
221#[derive(Debug, Clone)]
222struct CurrentEditPrediction {
223 pub requested_by_buffer_id: EntityId,
224 pub prediction: EditPrediction,
225}
226
227impl CurrentEditPrediction {
228 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
229 let Some(new_edits) = self
230 .prediction
231 .interpolate(&self.prediction.buffer.read(cx))
232 else {
233 return false;
234 };
235
236 if self.prediction.buffer != old_prediction.prediction.buffer {
237 return true;
238 }
239
240 let Some(old_edits) = old_prediction
241 .prediction
242 .interpolate(&old_prediction.prediction.buffer.read(cx))
243 else {
244 return true;
245 };
246
247 // This reduces the occurrence of UI thrash from replacing edits
248 //
249 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
250 if self.requested_by_buffer_id == self.prediction.buffer.entity_id()
251 && self.requested_by_buffer_id == old_prediction.prediction.buffer.entity_id()
252 && old_edits.len() == 1
253 && new_edits.len() == 1
254 {
255 let (old_range, old_text) = &old_edits[0];
256 let (new_range, new_text) = &new_edits[0];
257 new_range == old_range && new_text.starts_with(old_text.as_ref())
258 } else {
259 true
260 }
261 }
262}
263
264/// A prediction from the perspective of a buffer.
265#[derive(Debug)]
266enum BufferEditPrediction<'a> {
267 Local { prediction: &'a EditPrediction },
268 Jump { prediction: &'a EditPrediction },
269}
270
271struct RegisteredBuffer {
272 snapshot: BufferSnapshot,
273 _subscriptions: [gpui::Subscription; 2],
274}
275
276#[derive(Clone)]
277pub enum Event {
278 BufferChange {
279 old_snapshot: BufferSnapshot,
280 new_snapshot: BufferSnapshot,
281 timestamp: Instant,
282 },
283}
284
285impl Event {
286 pub fn to_request_event(&self, cx: &App) -> Option<predict_edits_v3::Event> {
287 match self {
288 Event::BufferChange {
289 old_snapshot,
290 new_snapshot,
291 ..
292 } => {
293 let path = new_snapshot.file().map(|f| f.full_path(cx));
294
295 let old_path = old_snapshot.file().and_then(|f| {
296 let old_path = f.full_path(cx);
297 if Some(&old_path) != path.as_ref() {
298 Some(old_path)
299 } else {
300 None
301 }
302 });
303
304 // TODO [zeta2] move to bg?
305 let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
306
307 if path == old_path && diff.is_empty() {
308 None
309 } else {
310 Some(predict_edits_v3::Event::BufferChange {
311 old_path,
312 path,
313 diff,
314 //todo: Actually detect if this edit was predicted or not
315 predicted: false,
316 })
317 }
318 }
319 }
320 }
321}
322
323impl Zeta {
324 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
325 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
326 }
327
328 pub fn global(
329 client: &Arc<Client>,
330 user_store: &Entity<UserStore>,
331 cx: &mut App,
332 ) -> Entity<Self> {
333 cx.try_global::<ZetaGlobal>()
334 .map(|global| global.0.clone())
335 .unwrap_or_else(|| {
336 let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
337 cx.set_global(ZetaGlobal(zeta.clone()));
338 zeta
339 })
340 }
341
342 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
343 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
344
345 Self {
346 projects: HashMap::default(),
347 client,
348 user_store,
349 options: DEFAULT_OPTIONS,
350 llm_token: LlmApiToken::default(),
351 _llm_token_subscription: cx.subscribe(
352 &refresh_llm_token_listener,
353 |this, _listener, _event, cx| {
354 let client = this.client.clone();
355 let llm_token = this.llm_token.clone();
356 cx.spawn(async move |_this, _cx| {
357 llm_token.refresh(&client).await?;
358 anyhow::Ok(())
359 })
360 .detach_and_log_err(cx);
361 },
362 ),
363 update_required: false,
364 debug_tx: None,
365 #[cfg(feature = "eval-support")]
366 eval_cache: None,
367 }
368 }
369
370 #[cfg(feature = "eval-support")]
371 pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
372 self.eval_cache = Some(cache);
373 }
374
375 pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<ZetaDebugInfo> {
376 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
377 self.debug_tx = Some(debug_watch_tx);
378 debug_watch_rx
379 }
380
381 pub fn options(&self) -> &ZetaOptions {
382 &self.options
383 }
384
385 pub fn set_options(&mut self, options: ZetaOptions) {
386 self.options = options;
387 }
388
389 pub fn clear_history(&mut self) {
390 for zeta_project in self.projects.values_mut() {
391 zeta_project.events.clear();
392 }
393 }
394
395 pub fn history_for_project(&self, project: &Entity<Project>) -> impl Iterator<Item = &Event> {
396 self.projects
397 .get(&project.entity_id())
398 .map(|project| project.events.iter())
399 .into_iter()
400 .flatten()
401 }
402
403 pub fn context_for_project(
404 &self,
405 project: &Entity<Project>,
406 ) -> impl Iterator<Item = (Entity<Buffer>, &[Range<Anchor>])> {
407 self.projects
408 .get(&project.entity_id())
409 .and_then(|project| {
410 Some(
411 project
412 .context
413 .as_ref()?
414 .iter()
415 .map(|(buffer, ranges)| (buffer.clone(), ranges.as_slice())),
416 )
417 })
418 .into_iter()
419 .flatten()
420 }
421
422 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
423 self.user_store.read(cx).edit_prediction_usage()
424 }
425
426 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
427 self.get_or_init_zeta_project(project, cx);
428 }
429
430 pub fn register_buffer(
431 &mut self,
432 buffer: &Entity<Buffer>,
433 project: &Entity<Project>,
434 cx: &mut Context<Self>,
435 ) {
436 let zeta_project = self.get_or_init_zeta_project(project, cx);
437 Self::register_buffer_impl(zeta_project, buffer, project, cx);
438 }
439
440 fn get_or_init_zeta_project(
441 &mut self,
442 project: &Entity<Project>,
443 cx: &mut App,
444 ) -> &mut ZetaProject {
445 self.projects
446 .entry(project.entity_id())
447 .or_insert_with(|| ZetaProject {
448 syntax_index: if let ContextMode::Syntax(_) = &self.options.context {
449 Some(cx.new(|cx| {
450 SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
451 }))
452 } else {
453 None
454 },
455 events: VecDeque::new(),
456 registered_buffers: HashMap::default(),
457 current_prediction: None,
458 context: None,
459 refresh_context_task: None,
460 refresh_context_debounce_task: None,
461 refresh_context_timestamp: None,
462 })
463 }
464
465 fn register_buffer_impl<'a>(
466 zeta_project: &'a mut ZetaProject,
467 buffer: &Entity<Buffer>,
468 project: &Entity<Project>,
469 cx: &mut Context<Self>,
470 ) -> &'a mut RegisteredBuffer {
471 let buffer_id = buffer.entity_id();
472 match zeta_project.registered_buffers.entry(buffer_id) {
473 hash_map::Entry::Occupied(entry) => entry.into_mut(),
474 hash_map::Entry::Vacant(entry) => {
475 let snapshot = buffer.read(cx).snapshot();
476 let project_entity_id = project.entity_id();
477 entry.insert(RegisteredBuffer {
478 snapshot,
479 _subscriptions: [
480 cx.subscribe(buffer, {
481 let project = project.downgrade();
482 move |this, buffer, event, cx| {
483 if let language::BufferEvent::Edited = event
484 && let Some(project) = project.upgrade()
485 {
486 this.report_changes_for_buffer(&buffer, &project, cx);
487 }
488 }
489 }),
490 cx.observe_release(buffer, move |this, _buffer, _cx| {
491 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
492 else {
493 return;
494 };
495 zeta_project.registered_buffers.remove(&buffer_id);
496 }),
497 ],
498 })
499 }
500 }
501 }
502
503 fn report_changes_for_buffer(
504 &mut self,
505 buffer: &Entity<Buffer>,
506 project: &Entity<Project>,
507 cx: &mut Context<Self>,
508 ) -> BufferSnapshot {
509 let buffer_change_grouping_interval = self.options.buffer_change_grouping_interval;
510 let zeta_project = self.get_or_init_zeta_project(project, cx);
511 let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
512
513 let new_snapshot = buffer.read(cx).snapshot();
514 if new_snapshot.version != registered_buffer.snapshot.version {
515 let old_snapshot =
516 std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
517 Self::push_event(
518 zeta_project,
519 buffer_change_grouping_interval,
520 Event::BufferChange {
521 old_snapshot,
522 new_snapshot: new_snapshot.clone(),
523 timestamp: Instant::now(),
524 },
525 );
526 }
527
528 new_snapshot
529 }
530
531 fn push_event(
532 zeta_project: &mut ZetaProject,
533 buffer_change_grouping_interval: Duration,
534 event: Event,
535 ) {
536 let events = &mut zeta_project.events;
537
538 if buffer_change_grouping_interval > Duration::ZERO
539 && let Some(Event::BufferChange {
540 new_snapshot: last_new_snapshot,
541 timestamp: last_timestamp,
542 ..
543 }) = events.back_mut()
544 {
545 // Coalesce edits for the same buffer when they happen one after the other.
546 let Event::BufferChange {
547 old_snapshot,
548 new_snapshot,
549 timestamp,
550 } = &event;
551
552 if timestamp.duration_since(*last_timestamp) <= buffer_change_grouping_interval
553 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
554 && old_snapshot.version == last_new_snapshot.version
555 {
556 *last_new_snapshot = new_snapshot.clone();
557 *last_timestamp = *timestamp;
558 return;
559 }
560 }
561
562 if events.len() >= MAX_EVENT_COUNT {
563 // These are halved instead of popping to improve prompt caching.
564 events.drain(..MAX_EVENT_COUNT / 2);
565 }
566
567 events.push_back(event);
568 }
569
570 fn current_prediction_for_buffer(
571 &self,
572 buffer: &Entity<Buffer>,
573 project: &Entity<Project>,
574 cx: &App,
575 ) -> Option<BufferEditPrediction<'_>> {
576 let project_state = self.projects.get(&project.entity_id())?;
577
578 let CurrentEditPrediction {
579 requested_by_buffer_id,
580 prediction,
581 } = project_state.current_prediction.as_ref()?;
582
583 if prediction.targets_buffer(buffer.read(cx)) {
584 Some(BufferEditPrediction::Local { prediction })
585 } else if *requested_by_buffer_id == buffer.entity_id() {
586 Some(BufferEditPrediction::Jump { prediction })
587 } else {
588 None
589 }
590 }
591
592 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
593 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
594 return;
595 };
596
597 let Some(prediction) = project_state.current_prediction.take() else {
598 return;
599 };
600 let request_id = prediction.prediction.id.to_string();
601
602 let client = self.client.clone();
603 let llm_token = self.llm_token.clone();
604 let app_version = AppVersion::global(cx);
605 cx.spawn(async move |this, cx| {
606 let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") {
607 http_client::Url::parse(&predict_edits_url)?
608 } else {
609 client
610 .http_client()
611 .build_zed_llm_url("/predict_edits/accept", &[])?
612 };
613
614 let response = cx
615 .background_spawn(Self::send_api_request::<()>(
616 move |builder| {
617 let req = builder.uri(url.as_ref()).body(
618 serde_json::to_string(&AcceptEditPredictionBody {
619 request_id: request_id.clone(),
620 })?
621 .into(),
622 );
623 Ok(req?)
624 },
625 client,
626 llm_token,
627 app_version,
628 ))
629 .await;
630
631 Self::handle_api_response(&this, response, cx)?;
632 anyhow::Ok(())
633 })
634 .detach_and_log_err(cx);
635 }
636
637 fn discard_current_prediction(&mut self, project: &Entity<Project>) {
638 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
639 project_state.current_prediction.take();
640 };
641 }
642
643 pub fn refresh_prediction(
644 &mut self,
645 project: &Entity<Project>,
646 buffer: &Entity<Buffer>,
647 position: language::Anchor,
648 cx: &mut Context<Self>,
649 ) -> Task<Result<()>> {
650 let request_task = self.request_prediction(project, buffer, position, cx);
651 let buffer = buffer.clone();
652 let project = project.clone();
653
654 cx.spawn(async move |this, cx| {
655 if let Some(prediction) = request_task.await? {
656 this.update(cx, |this, cx| {
657 let project_state = this
658 .projects
659 .get_mut(&project.entity_id())
660 .context("Project not found")?;
661
662 let new_prediction = CurrentEditPrediction {
663 requested_by_buffer_id: buffer.entity_id(),
664 prediction: prediction,
665 };
666
667 if project_state
668 .current_prediction
669 .as_ref()
670 .is_none_or(|old_prediction| {
671 new_prediction.should_replace_prediction(&old_prediction, cx)
672 })
673 {
674 project_state.current_prediction = Some(new_prediction);
675 }
676 anyhow::Ok(())
677 })??;
678 }
679 Ok(())
680 })
681 }
682
683 pub fn request_prediction(
684 &mut self,
685 project: &Entity<Project>,
686 active_buffer: &Entity<Buffer>,
687 position: language::Anchor,
688 cx: &mut Context<Self>,
689 ) -> Task<Result<Option<EditPrediction>>> {
690 let project_state = self.projects.get(&project.entity_id());
691
692 let index_state = project_state.and_then(|state| {
693 state
694 .syntax_index
695 .as_ref()
696 .map(|syntax_index| syntax_index.read_with(cx, |index, _cx| index.state().clone()))
697 });
698 let options = self.options.clone();
699 let active_snapshot = active_buffer.read(cx).snapshot();
700 let Some(excerpt_path) = active_snapshot
701 .file()
702 .map(|path| -> Arc<Path> { path.full_path(cx).into() })
703 else {
704 return Task::ready(Err(anyhow!("No file path for excerpt")));
705 };
706 let client = self.client.clone();
707 let llm_token = self.llm_token.clone();
708 let app_version = AppVersion::global(cx);
709 let worktree_snapshots = project
710 .read(cx)
711 .worktrees(cx)
712 .map(|worktree| worktree.read(cx).snapshot())
713 .collect::<Vec<_>>();
714 let debug_tx = self.debug_tx.clone();
715
716 let events = project_state
717 .map(|state| {
718 state
719 .events
720 .iter()
721 .filter_map(|event| event.to_request_event(cx))
722 .collect::<Vec<_>>()
723 })
724 .unwrap_or_default();
725
726 let diagnostics = active_snapshot.diagnostic_sets().clone();
727
728 let parent_abs_path =
729 project::File::from_dyn(active_buffer.read(cx).file()).and_then(|f| {
730 let mut path = f.worktree.read(cx).absolutize(&f.path);
731 if path.pop() { Some(path) } else { None }
732 });
733
734 // TODO data collection
735 let can_collect_data = cx.is_staff();
736
737 let empty_context_files = HashMap::default();
738 let context_files = project_state
739 .and_then(|project_state| project_state.context.as_ref())
740 .unwrap_or(&empty_context_files);
741
742 #[cfg(feature = "eval-support")]
743 let parsed_fut = futures::future::join_all(
744 context_files
745 .keys()
746 .map(|buffer| buffer.read(cx).parsing_idle()),
747 );
748
749 let mut included_files = context_files
750 .iter()
751 .filter_map(|(buffer_entity, ranges)| {
752 let buffer = buffer_entity.read(cx);
753 Some((
754 buffer_entity.clone(),
755 buffer.snapshot(),
756 buffer.file()?.full_path(cx).into(),
757 ranges.clone(),
758 ))
759 })
760 .collect::<Vec<_>>();
761
762 included_files.sort_by(|(_, _, path_a, ranges_a), (_, _, path_b, ranges_b)| {
763 (path_a, ranges_a.len()).cmp(&(path_b, ranges_b.len()))
764 });
765
766 #[cfg(feature = "eval-support")]
767 let eval_cache = self.eval_cache.clone();
768
769 let request_task = cx.background_spawn({
770 let active_buffer = active_buffer.clone();
771 async move {
772 #[cfg(feature = "eval-support")]
773 parsed_fut.await;
774
775 let index_state = if let Some(index_state) = index_state {
776 Some(index_state.lock_owned().await)
777 } else {
778 None
779 };
780
781 let cursor_offset = position.to_offset(&active_snapshot);
782 let cursor_point = cursor_offset.to_point(&active_snapshot);
783
784 let before_retrieval = chrono::Utc::now();
785
786 let (diagnostic_groups, diagnostic_groups_truncated) =
787 Self::gather_nearby_diagnostics(
788 cursor_offset,
789 &diagnostics,
790 &active_snapshot,
791 options.max_diagnostic_bytes,
792 );
793
794 let cloud_request = match options.context {
795 ContextMode::Agentic(context_options) => {
796 let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
797 cursor_point,
798 &active_snapshot,
799 &context_options.excerpt,
800 index_state.as_deref(),
801 ) else {
802 return Ok((None, None));
803 };
804
805 let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start)
806 ..active_snapshot.anchor_before(excerpt.range.end);
807
808 if let Some(buffer_ix) =
809 included_files.iter().position(|(_, snapshot, _, _)| {
810 snapshot.remote_id() == active_snapshot.remote_id()
811 })
812 {
813 let (_, buffer, _, ranges) = &mut included_files[buffer_ix];
814 let range_ix = ranges
815 .binary_search_by(|probe| {
816 probe
817 .start
818 .cmp(&excerpt_anchor_range.start, buffer)
819 .then(excerpt_anchor_range.end.cmp(&probe.end, buffer))
820 })
821 .unwrap_or_else(|ix| ix);
822
823 ranges.insert(range_ix, excerpt_anchor_range);
824 let last_ix = included_files.len() - 1;
825 included_files.swap(buffer_ix, last_ix);
826 } else {
827 included_files.push((
828 active_buffer.clone(),
829 active_snapshot,
830 excerpt_path.clone(),
831 vec![excerpt_anchor_range],
832 ));
833 }
834
835 let included_files = included_files
836 .iter()
837 .map(|(_, snapshot, path, ranges)| {
838 let excerpts = merge_excerpts(
839 &snapshot,
840 ranges.iter().map(|range| {
841 let point_range = range.to_point(&snapshot);
842 Line(point_range.start.row)..Line(point_range.end.row)
843 }),
844 );
845 predict_edits_v3::IncludedFile {
846 path: path.clone(),
847 max_row: Line(snapshot.max_point().row),
848 excerpts,
849 }
850 })
851 .collect::<Vec<_>>();
852
853 predict_edits_v3::PredictEditsRequest {
854 excerpt_path,
855 excerpt: String::new(),
856 excerpt_line_range: Line(0)..Line(0),
857 excerpt_range: 0..0,
858 cursor_point: predict_edits_v3::Point {
859 line: predict_edits_v3::Line(cursor_point.row),
860 column: cursor_point.column,
861 },
862 included_files,
863 referenced_declarations: vec![],
864 events,
865 can_collect_data,
866 diagnostic_groups,
867 diagnostic_groups_truncated,
868 debug_info: debug_tx.is_some(),
869 prompt_max_bytes: Some(options.max_prompt_bytes),
870 prompt_format: options.prompt_format,
871 // TODO [zeta2]
872 signatures: vec![],
873 excerpt_parent: None,
874 git_info: None,
875 }
876 }
877 ContextMode::Syntax(context_options) => {
878 let Some(context) = EditPredictionContext::gather_context(
879 cursor_point,
880 &active_snapshot,
881 parent_abs_path.as_deref(),
882 &context_options,
883 index_state.as_deref(),
884 ) else {
885 return Ok((None, None));
886 };
887
888 make_syntax_context_cloud_request(
889 excerpt_path,
890 context,
891 events,
892 can_collect_data,
893 diagnostic_groups,
894 diagnostic_groups_truncated,
895 None,
896 debug_tx.is_some(),
897 &worktree_snapshots,
898 index_state.as_deref(),
899 Some(options.max_prompt_bytes),
900 options.prompt_format,
901 )
902 }
903 };
904
905 let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request);
906
907 let retrieval_time = chrono::Utc::now() - before_retrieval;
908
909 let debug_response_tx = if let Some(debug_tx) = &debug_tx {
910 let (response_tx, response_rx) = oneshot::channel();
911
912 debug_tx
913 .unbounded_send(ZetaDebugInfo::EditPredictionRequested(
914 ZetaEditPredictionDebugInfo {
915 request: cloud_request.clone(),
916 retrieval_time,
917 buffer: active_buffer.downgrade(),
918 local_prompt: match prompt_result.as_ref() {
919 Ok((prompt, _)) => Ok(prompt.clone()),
920 Err(err) => Err(err.to_string()),
921 },
922 position,
923 response_rx,
924 },
925 ))
926 .ok();
927 Some(response_tx)
928 } else {
929 None
930 };
931
932 if cfg!(debug_assertions) && env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
933 if let Some(debug_response_tx) = debug_response_tx {
934 debug_response_tx
935 .send((Err("Request skipped".to_string()), TimeDelta::zero()))
936 .ok();
937 }
938 anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
939 }
940
941 let (prompt, _) = prompt_result?;
942 let request = open_ai::Request {
943 model: MODEL_ID.clone(),
944 messages: vec![open_ai::RequestMessage::User {
945 content: open_ai::MessageContent::Plain(prompt),
946 }],
947 stream: false,
948 max_completion_tokens: None,
949 stop: Default::default(),
950 temperature: 0.7,
951 tool_choice: None,
952 parallel_tool_calls: None,
953 tools: vec![],
954 prompt_cache_key: None,
955 reasoning_effort: None,
956 };
957
958 log::trace!("Sending edit prediction request");
959
960 let before_request = chrono::Utc::now();
961 let response = Self::send_raw_llm_request(
962 request,
963 client,
964 llm_token,
965 app_version,
966 #[cfg(feature = "eval-support")]
967 eval_cache,
968 #[cfg(feature = "eval-support")]
969 EvalCacheEntryKind::Prediction,
970 )
971 .await;
972 let request_time = chrono::Utc::now() - before_request;
973
974 log::trace!("Got edit prediction response");
975
976 if let Some(debug_response_tx) = debug_response_tx {
977 debug_response_tx
978 .send((
979 response
980 .as_ref()
981 .map_err(|err| err.to_string())
982 .map(|response| response.0.clone()),
983 request_time,
984 ))
985 .ok();
986 }
987
988 let (res, usage) = response?;
989 let request_id = EditPredictionId(res.id.clone().into());
990 let Some(mut output_text) = text_from_response(res) else {
991 return Ok((None, usage));
992 };
993
994 if output_text.contains(CURSOR_MARKER) {
995 log::trace!("Stripping out {CURSOR_MARKER} from response");
996 output_text = output_text.replace(CURSOR_MARKER, "");
997 }
998
999 let get_buffer_from_context = |path: &Path| {
1000 included_files
1001 .iter()
1002 .find_map(|(_, buffer, probe_path, ranges)| {
1003 if probe_path.as_ref() == path {
1004 Some((buffer, ranges.as_slice()))
1005 } else {
1006 None
1007 }
1008 })
1009 };
1010
1011 let (edited_buffer_snapshot, edits) = match options.prompt_format {
1012 PromptFormat::NumLinesUniDiff => {
1013 crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
1014 }
1015 PromptFormat::OldTextNewText => {
1016 crate::xml_edits::parse_xml_edits(&output_text, get_buffer_from_context)
1017 .await?
1018 }
1019 _ => {
1020 bail!("unsupported prompt format {}", options.prompt_format)
1021 }
1022 };
1023
1024 let edited_buffer = included_files
1025 .iter()
1026 .find_map(|(buffer, snapshot, _, _)| {
1027 if snapshot.remote_id() == edited_buffer_snapshot.remote_id() {
1028 Some(buffer.clone())
1029 } else {
1030 None
1031 }
1032 })
1033 .context("Failed to find buffer in included_buffers")?;
1034
1035 anyhow::Ok((
1036 Some((
1037 request_id,
1038 edited_buffer,
1039 edited_buffer_snapshot.clone(),
1040 edits,
1041 )),
1042 usage,
1043 ))
1044 }
1045 });
1046
1047 cx.spawn({
1048 async move |this, cx| {
1049 let Some((id, edited_buffer, edited_buffer_snapshot, edits)) =
1050 Self::handle_api_response(&this, request_task.await, cx)?
1051 else {
1052 return Ok(None);
1053 };
1054
1055 // TODO telemetry: duration, etc
1056 Ok(
1057 EditPrediction::new(id, &edited_buffer, &edited_buffer_snapshot, edits, cx)
1058 .await,
1059 )
1060 }
1061 })
1062 }
1063
1064 async fn send_raw_llm_request(
1065 request: open_ai::Request,
1066 client: Arc<Client>,
1067 llm_token: LlmApiToken,
1068 app_version: SemanticVersion,
1069 #[cfg(feature = "eval-support")] eval_cache: Option<Arc<dyn EvalCache>>,
1070 #[cfg(feature = "eval-support")] eval_cache_kind: EvalCacheEntryKind,
1071 ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
1072 let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
1073 http_client::Url::parse(&predict_edits_url)?
1074 } else {
1075 client
1076 .http_client()
1077 .build_zed_llm_url("/predict_edits/raw", &[])?
1078 };
1079
1080 #[cfg(feature = "eval-support")]
1081 let cache_key = if let Some(cache) = eval_cache {
1082 use collections::FxHasher;
1083 use std::hash::{Hash, Hasher};
1084
1085 let mut hasher = FxHasher::default();
1086 url.hash(&mut hasher);
1087 let request_str = serde_json::to_string_pretty(&request)?;
1088 request_str.hash(&mut hasher);
1089 let hash = hasher.finish();
1090
1091 let key = (eval_cache_kind, hash);
1092 if let Some(response_str) = cache.read(key) {
1093 return Ok((serde_json::from_str(&response_str)?, None));
1094 }
1095
1096 Some((cache, request_str, key))
1097 } else {
1098 None
1099 };
1100
1101 let (response, usage) = Self::send_api_request(
1102 |builder| {
1103 let req = builder
1104 .uri(url.as_ref())
1105 .body(serde_json::to_string(&request)?.into());
1106 Ok(req?)
1107 },
1108 client,
1109 llm_token,
1110 app_version,
1111 )
1112 .await?;
1113
1114 #[cfg(feature = "eval-support")]
1115 if let Some((cache, request, key)) = cache_key {
1116 cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
1117 }
1118
1119 Ok((response, usage))
1120 }
1121
1122 fn handle_api_response<T>(
1123 this: &WeakEntity<Self>,
1124 response: Result<(T, Option<EditPredictionUsage>)>,
1125 cx: &mut gpui::AsyncApp,
1126 ) -> Result<T> {
1127 match response {
1128 Ok((data, usage)) => {
1129 if let Some(usage) = usage {
1130 this.update(cx, |this, cx| {
1131 this.user_store.update(cx, |user_store, cx| {
1132 user_store.update_edit_prediction_usage(usage, cx);
1133 });
1134 })
1135 .ok();
1136 }
1137 Ok(data)
1138 }
1139 Err(err) => {
1140 if err.is::<ZedUpdateRequiredError>() {
1141 cx.update(|cx| {
1142 this.update(cx, |this, _cx| {
1143 this.update_required = true;
1144 })
1145 .ok();
1146
1147 let error_message: SharedString = err.to_string().into();
1148 show_app_notification(
1149 NotificationId::unique::<ZedUpdateRequiredError>(),
1150 cx,
1151 move |cx| {
1152 cx.new(|cx| {
1153 ErrorMessagePrompt::new(error_message.clone(), cx)
1154 .with_link_button("Update Zed", "https://zed.dev/releases")
1155 })
1156 },
1157 );
1158 })
1159 .ok();
1160 }
1161 Err(err)
1162 }
1163 }
1164 }
1165
1166 async fn send_api_request<Res>(
1167 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1168 client: Arc<Client>,
1169 llm_token: LlmApiToken,
1170 app_version: SemanticVersion,
1171 ) -> Result<(Res, Option<EditPredictionUsage>)>
1172 where
1173 Res: DeserializeOwned,
1174 {
1175 let http_client = client.http_client();
1176 let mut token = llm_token.acquire(&client).await?;
1177 let mut did_retry = false;
1178
1179 loop {
1180 let request_builder = http_client::Request::builder().method(Method::POST);
1181
1182 let request = build(
1183 request_builder
1184 .header("Content-Type", "application/json")
1185 .header("Authorization", format!("Bearer {}", token))
1186 .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
1187 )?;
1188
1189 let mut response = http_client.send(request).await?;
1190
1191 if let Some(minimum_required_version) = response
1192 .headers()
1193 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
1194 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
1195 {
1196 anyhow::ensure!(
1197 app_version >= minimum_required_version,
1198 ZedUpdateRequiredError {
1199 minimum_version: minimum_required_version
1200 }
1201 );
1202 }
1203
1204 if response.status().is_success() {
1205 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
1206
1207 let mut body = Vec::new();
1208 response.body_mut().read_to_end(&mut body).await?;
1209 return Ok((serde_json::from_slice(&body)?, usage));
1210 } else if !did_retry
1211 && response
1212 .headers()
1213 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1214 .is_some()
1215 {
1216 did_retry = true;
1217 token = llm_token.refresh(&client).await?;
1218 } else {
1219 let mut body = String::new();
1220 response.body_mut().read_to_string(&mut body).await?;
1221 anyhow::bail!(
1222 "Request failed with status: {:?}\nBody: {}",
1223 response.status(),
1224 body
1225 );
1226 }
1227 }
1228 }
1229
1230 pub const CONTEXT_RETRIEVAL_IDLE_DURATION: Duration = Duration::from_secs(10);
1231 pub const CONTEXT_RETRIEVAL_DEBOUNCE_DURATION: Duration = Duration::from_secs(3);
1232
1233 // Refresh the related excerpts when the user just beguns editing after
1234 // an idle period, and after they pause editing.
1235 fn refresh_context_if_needed(
1236 &mut self,
1237 project: &Entity<Project>,
1238 buffer: &Entity<language::Buffer>,
1239 cursor_position: language::Anchor,
1240 cx: &mut Context<Self>,
1241 ) {
1242 if !matches!(&self.options().context, ContextMode::Agentic { .. }) {
1243 return;
1244 }
1245
1246 let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1247 return;
1248 };
1249
1250 let now = Instant::now();
1251 let was_idle = zeta_project
1252 .refresh_context_timestamp
1253 .map_or(true, |timestamp| {
1254 now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION
1255 });
1256 zeta_project.refresh_context_timestamp = Some(now);
1257 zeta_project.refresh_context_debounce_task = Some(cx.spawn({
1258 let buffer = buffer.clone();
1259 let project = project.clone();
1260 async move |this, cx| {
1261 if was_idle {
1262 log::debug!("refetching edit prediction context after idle");
1263 } else {
1264 cx.background_executor()
1265 .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION)
1266 .await;
1267 log::debug!("refetching edit prediction context after pause");
1268 }
1269 this.update(cx, |this, cx| {
1270 let task = this.refresh_context(project.clone(), buffer, cursor_position, cx);
1271
1272 if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
1273 zeta_project.refresh_context_task = Some(task.log_err());
1274 };
1275 })
1276 .ok()
1277 }
1278 }));
1279 }
1280
1281 // Refresh the related excerpts asynchronously. Ensure the task runs to completion,
1282 // and avoid spawning more than one concurrent task.
1283 pub fn refresh_context(
1284 &mut self,
1285 project: Entity<Project>,
1286 buffer: Entity<language::Buffer>,
1287 cursor_position: language::Anchor,
1288 cx: &mut Context<Self>,
1289 ) -> Task<Result<()>> {
1290 let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
1291 return Task::ready(anyhow::Ok(()));
1292 };
1293
1294 let ContextMode::Agentic(options) = &self.options().context else {
1295 return Task::ready(anyhow::Ok(()));
1296 };
1297
1298 let snapshot = buffer.read(cx).snapshot();
1299 let cursor_point = cursor_position.to_point(&snapshot);
1300 let Some(cursor_excerpt) = EditPredictionExcerpt::select_from_buffer(
1301 cursor_point,
1302 &snapshot,
1303 &options.excerpt,
1304 None,
1305 ) else {
1306 return Task::ready(Ok(()));
1307 };
1308
1309 let app_version = AppVersion::global(cx);
1310 let client = self.client.clone();
1311 let llm_token = self.llm_token.clone();
1312 let debug_tx = self.debug_tx.clone();
1313 let current_file_path: Arc<Path> = snapshot
1314 .file()
1315 .map(|f| f.full_path(cx).into())
1316 .unwrap_or_else(|| Path::new("untitled").into());
1317
1318 let prompt = match cloud_zeta2_prompt::retrieval_prompt::build_prompt(
1319 predict_edits_v3::PlanContextRetrievalRequest {
1320 excerpt: cursor_excerpt.text(&snapshot).body,
1321 excerpt_path: current_file_path,
1322 excerpt_line_range: cursor_excerpt.line_range,
1323 cursor_file_max_row: Line(snapshot.max_point().row),
1324 events: zeta_project
1325 .events
1326 .iter()
1327 .filter_map(|ev| ev.to_request_event(cx))
1328 .collect(),
1329 },
1330 ) {
1331 Ok(prompt) => prompt,
1332 Err(err) => {
1333 return Task::ready(Err(err));
1334 }
1335 };
1336
1337 if let Some(debug_tx) = &debug_tx {
1338 debug_tx
1339 .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted(
1340 ZetaContextRetrievalStartedDebugInfo {
1341 project: project.clone(),
1342 timestamp: Instant::now(),
1343 search_prompt: prompt.clone(),
1344 },
1345 ))
1346 .ok();
1347 }
1348
1349 pub static TOOL_SCHEMA: LazyLock<(serde_json::Value, String)> = LazyLock::new(|| {
1350 let schema = language_model::tool_schema::root_schema_for::<SearchToolInput>(
1351 language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset,
1352 );
1353
1354 let description = schema
1355 .get("description")
1356 .and_then(|description| description.as_str())
1357 .unwrap()
1358 .to_string();
1359
1360 (schema.into(), description)
1361 });
1362
1363 let (tool_schema, tool_description) = TOOL_SCHEMA.clone();
1364
1365 let request = open_ai::Request {
1366 model: MODEL_ID.clone(),
1367 messages: vec![open_ai::RequestMessage::User {
1368 content: open_ai::MessageContent::Plain(prompt),
1369 }],
1370 stream: false,
1371 max_completion_tokens: None,
1372 stop: Default::default(),
1373 temperature: 0.7,
1374 tool_choice: None,
1375 parallel_tool_calls: None,
1376 tools: vec![open_ai::ToolDefinition::Function {
1377 function: FunctionDefinition {
1378 name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME.to_string(),
1379 description: Some(tool_description),
1380 parameters: Some(tool_schema),
1381 },
1382 }],
1383 prompt_cache_key: None,
1384 reasoning_effort: None,
1385 };
1386
1387 #[cfg(feature = "eval-support")]
1388 let eval_cache = self.eval_cache.clone();
1389
1390 cx.spawn(async move |this, cx| {
1391 log::trace!("Sending search planning request");
1392 let response = Self::send_raw_llm_request(
1393 request,
1394 client,
1395 llm_token,
1396 app_version,
1397 #[cfg(feature = "eval-support")]
1398 eval_cache.clone(),
1399 #[cfg(feature = "eval-support")]
1400 EvalCacheEntryKind::Context,
1401 )
1402 .await;
1403 let mut response = Self::handle_api_response(&this, response, cx)?;
1404 log::trace!("Got search planning response");
1405
1406 let choice = response
1407 .choices
1408 .pop()
1409 .context("No choices in retrieval response")?;
1410 let open_ai::RequestMessage::Assistant {
1411 content: _,
1412 tool_calls,
1413 } = choice.message
1414 else {
1415 anyhow::bail!("Retrieval response didn't include an assistant message");
1416 };
1417
1418 let mut queries: Vec<SearchToolQuery> = Vec::new();
1419 for tool_call in tool_calls {
1420 let open_ai::ToolCallContent::Function { function } = tool_call.content;
1421 if function.name != cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME {
1422 log::warn!(
1423 "Context retrieval response tried to call an unknown tool: {}",
1424 function.name
1425 );
1426
1427 continue;
1428 }
1429
1430 let input: SearchToolInput = serde_json::from_str(&function.arguments)
1431 .with_context(|| format!("invalid search json {}", &function.arguments))?;
1432 queries.extend(input.queries);
1433 }
1434
1435 if let Some(debug_tx) = &debug_tx {
1436 debug_tx
1437 .unbounded_send(ZetaDebugInfo::SearchQueriesGenerated(
1438 ZetaSearchQueryDebugInfo {
1439 project: project.clone(),
1440 timestamp: Instant::now(),
1441 search_queries: queries.clone(),
1442 },
1443 ))
1444 .ok();
1445 }
1446
1447 log::trace!("Running retrieval search: {queries:#?}");
1448
1449 let related_excerpts_result = retrieval_search::run_retrieval_searches(
1450 queries,
1451 project.clone(),
1452 #[cfg(feature = "eval-support")]
1453 eval_cache,
1454 cx,
1455 )
1456 .await;
1457
1458 log::trace!("Search queries executed");
1459
1460 if let Some(debug_tx) = &debug_tx {
1461 debug_tx
1462 .unbounded_send(ZetaDebugInfo::SearchQueriesExecuted(
1463 ZetaContextRetrievalDebugInfo {
1464 project: project.clone(),
1465 timestamp: Instant::now(),
1466 },
1467 ))
1468 .ok();
1469 }
1470
1471 this.update(cx, |this, _cx| {
1472 let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
1473 return Ok(());
1474 };
1475 zeta_project.refresh_context_task.take();
1476 if let Some(debug_tx) = &this.debug_tx {
1477 debug_tx
1478 .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished(
1479 ZetaContextRetrievalDebugInfo {
1480 project,
1481 timestamp: Instant::now(),
1482 },
1483 ))
1484 .ok();
1485 }
1486 match related_excerpts_result {
1487 Ok(excerpts) => {
1488 zeta_project.context = Some(excerpts);
1489 Ok(())
1490 }
1491 Err(error) => Err(error),
1492 }
1493 })?
1494 })
1495 }
1496
1497 pub fn set_context(
1498 &mut self,
1499 project: Entity<Project>,
1500 context: HashMap<Entity<Buffer>, Vec<Range<Anchor>>>,
1501 ) {
1502 if let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) {
1503 zeta_project.context = Some(context);
1504 }
1505 }
1506
1507 fn gather_nearby_diagnostics(
1508 cursor_offset: usize,
1509 diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
1510 snapshot: &BufferSnapshot,
1511 max_diagnostics_bytes: usize,
1512 ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
1513 // TODO: Could make this more efficient
1514 let mut diagnostic_groups = Vec::new();
1515 for (language_server_id, diagnostics) in diagnostic_sets {
1516 let mut groups = Vec::new();
1517 diagnostics.groups(*language_server_id, &mut groups, &snapshot);
1518 diagnostic_groups.extend(
1519 groups
1520 .into_iter()
1521 .map(|(_, group)| group.resolve::<usize>(&snapshot)),
1522 );
1523 }
1524
1525 // sort by proximity to cursor
1526 diagnostic_groups.sort_by_key(|group| {
1527 let range = &group.entries[group.primary_ix].range;
1528 if range.start >= cursor_offset {
1529 range.start - cursor_offset
1530 } else if cursor_offset >= range.end {
1531 cursor_offset - range.end
1532 } else {
1533 (cursor_offset - range.start).min(range.end - cursor_offset)
1534 }
1535 });
1536
1537 let mut results = Vec::new();
1538 let mut diagnostic_groups_truncated = false;
1539 let mut diagnostics_byte_count = 0;
1540 for group in diagnostic_groups {
1541 let raw_value = serde_json::value::to_raw_value(&group).unwrap();
1542 diagnostics_byte_count += raw_value.get().len();
1543 if diagnostics_byte_count > max_diagnostics_bytes {
1544 diagnostic_groups_truncated = true;
1545 break;
1546 }
1547 results.push(predict_edits_v3::DiagnosticGroup(raw_value));
1548 }
1549
1550 (results, diagnostic_groups_truncated)
1551 }
1552
1553 // TODO: Dedupe with similar code in request_prediction?
1554 pub fn cloud_request_for_zeta_cli(
1555 &mut self,
1556 project: &Entity<Project>,
1557 buffer: &Entity<Buffer>,
1558 position: language::Anchor,
1559 cx: &mut Context<Self>,
1560 ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
1561 let project_state = self.projects.get(&project.entity_id());
1562
1563 let index_state = project_state.and_then(|state| {
1564 state
1565 .syntax_index
1566 .as_ref()
1567 .map(|index| index.read_with(cx, |index, _cx| index.state().clone()))
1568 });
1569 let options = self.options.clone();
1570 let snapshot = buffer.read(cx).snapshot();
1571 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
1572 return Task::ready(Err(anyhow!("No file path for excerpt")));
1573 };
1574 let worktree_snapshots = project
1575 .read(cx)
1576 .worktrees(cx)
1577 .map(|worktree| worktree.read(cx).snapshot())
1578 .collect::<Vec<_>>();
1579
1580 let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
1581 let mut path = f.worktree.read(cx).absolutize(&f.path);
1582 if path.pop() { Some(path) } else { None }
1583 });
1584
1585 cx.background_spawn(async move {
1586 let index_state = if let Some(index_state) = index_state {
1587 Some(index_state.lock_owned().await)
1588 } else {
1589 None
1590 };
1591
1592 let cursor_point = position.to_point(&snapshot);
1593
1594 let debug_info = true;
1595 EditPredictionContext::gather_context(
1596 cursor_point,
1597 &snapshot,
1598 parent_abs_path.as_deref(),
1599 match &options.context {
1600 ContextMode::Agentic(_) => {
1601 // TODO
1602 panic!("Llm mode not supported in zeta cli yet");
1603 }
1604 ContextMode::Syntax(edit_prediction_context_options) => {
1605 edit_prediction_context_options
1606 }
1607 },
1608 index_state.as_deref(),
1609 )
1610 .context("Failed to select excerpt")
1611 .map(|context| {
1612 make_syntax_context_cloud_request(
1613 excerpt_path.into(),
1614 context,
1615 // TODO pass everything
1616 Vec::new(),
1617 false,
1618 Vec::new(),
1619 false,
1620 None,
1621 debug_info,
1622 &worktree_snapshots,
1623 index_state.as_deref(),
1624 Some(options.max_prompt_bytes),
1625 options.prompt_format,
1626 )
1627 })
1628 })
1629 }
1630
1631 pub fn wait_for_initial_indexing(
1632 &mut self,
1633 project: &Entity<Project>,
1634 cx: &mut App,
1635 ) -> Task<Result<()>> {
1636 let zeta_project = self.get_or_init_zeta_project(project, cx);
1637 if let Some(syntax_index) = &zeta_project.syntax_index {
1638 syntax_index.read(cx).wait_for_initial_file_indexing(cx)
1639 } else {
1640 Task::ready(Ok(()))
1641 }
1642 }
1643}
1644
1645pub fn text_from_response(mut res: open_ai::Response) -> Option<String> {
1646 let choice = res.choices.pop()?;
1647 let output_text = match choice.message {
1648 open_ai::RequestMessage::Assistant {
1649 content: Some(open_ai::MessageContent::Plain(content)),
1650 ..
1651 } => content,
1652 open_ai::RequestMessage::Assistant {
1653 content: Some(open_ai::MessageContent::Multipart(mut content)),
1654 ..
1655 } => {
1656 if content.is_empty() {
1657 log::error!("No output from Baseten completion response");
1658 return None;
1659 }
1660
1661 match content.remove(0) {
1662 open_ai::MessagePart::Text { text } => text,
1663 open_ai::MessagePart::Image { .. } => {
1664 log::error!("Expected text, got an image");
1665 return None;
1666 }
1667 }
1668 }
1669 _ => {
1670 log::error!("Invalid response message: {:?}", choice.message);
1671 return None;
1672 }
1673 };
1674 Some(output_text)
1675}
1676
1677#[derive(Error, Debug)]
1678#[error(
1679 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1680)]
1681pub struct ZedUpdateRequiredError {
1682 minimum_version: SemanticVersion,
1683}
1684
1685fn make_syntax_context_cloud_request(
1686 excerpt_path: Arc<Path>,
1687 context: EditPredictionContext,
1688 events: Vec<predict_edits_v3::Event>,
1689 can_collect_data: bool,
1690 diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
1691 diagnostic_groups_truncated: bool,
1692 git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
1693 debug_info: bool,
1694 worktrees: &Vec<worktree::Snapshot>,
1695 index_state: Option<&SyntaxIndexState>,
1696 prompt_max_bytes: Option<usize>,
1697 prompt_format: PromptFormat,
1698) -> predict_edits_v3::PredictEditsRequest {
1699 let mut signatures = Vec::new();
1700 let mut declaration_to_signature_index = HashMap::default();
1701 let mut referenced_declarations = Vec::new();
1702
1703 for snippet in context.declarations {
1704 let project_entry_id = snippet.declaration.project_entry_id();
1705 let Some(path) = worktrees.iter().find_map(|worktree| {
1706 worktree.entry_for_id(project_entry_id).map(|entry| {
1707 let mut full_path = RelPathBuf::new();
1708 full_path.push(worktree.root_name());
1709 full_path.push(&entry.path);
1710 full_path
1711 })
1712 }) else {
1713 continue;
1714 };
1715
1716 let parent_index = index_state.and_then(|index_state| {
1717 snippet.declaration.parent().and_then(|parent| {
1718 add_signature(
1719 parent,
1720 &mut declaration_to_signature_index,
1721 &mut signatures,
1722 index_state,
1723 )
1724 })
1725 });
1726
1727 let (text, text_is_truncated) = snippet.declaration.item_text();
1728 referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
1729 path: path.as_std_path().into(),
1730 text: text.into(),
1731 range: snippet.declaration.item_line_range(),
1732 text_is_truncated,
1733 signature_range: snippet.declaration.signature_range_in_item_text(),
1734 parent_index,
1735 signature_score: snippet.score(DeclarationStyle::Signature),
1736 declaration_score: snippet.score(DeclarationStyle::Declaration),
1737 score_components: snippet.components,
1738 });
1739 }
1740
1741 let excerpt_parent = index_state.and_then(|index_state| {
1742 context
1743 .excerpt
1744 .parent_declarations
1745 .last()
1746 .and_then(|(parent, _)| {
1747 add_signature(
1748 *parent,
1749 &mut declaration_to_signature_index,
1750 &mut signatures,
1751 index_state,
1752 )
1753 })
1754 });
1755
1756 predict_edits_v3::PredictEditsRequest {
1757 excerpt_path,
1758 excerpt: context.excerpt_text.body,
1759 excerpt_line_range: context.excerpt.line_range,
1760 excerpt_range: context.excerpt.range,
1761 cursor_point: predict_edits_v3::Point {
1762 line: predict_edits_v3::Line(context.cursor_point.row),
1763 column: context.cursor_point.column,
1764 },
1765 referenced_declarations,
1766 included_files: vec![],
1767 signatures,
1768 excerpt_parent,
1769 events,
1770 can_collect_data,
1771 diagnostic_groups,
1772 diagnostic_groups_truncated,
1773 git_info,
1774 debug_info,
1775 prompt_max_bytes,
1776 prompt_format,
1777 }
1778}
1779
1780fn add_signature(
1781 declaration_id: DeclarationId,
1782 declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
1783 signatures: &mut Vec<Signature>,
1784 index: &SyntaxIndexState,
1785) -> Option<usize> {
1786 if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
1787 return Some(*signature_index);
1788 }
1789 let Some(parent_declaration) = index.declaration(declaration_id) else {
1790 log::error!("bug: missing parent declaration");
1791 return None;
1792 };
1793 let parent_index = parent_declaration.parent().and_then(|parent| {
1794 add_signature(parent, declaration_to_signature_index, signatures, index)
1795 });
1796 let (text, text_is_truncated) = parent_declaration.signature_text();
1797 let signature_index = signatures.len();
1798 signatures.push(Signature {
1799 text: text.into(),
1800 text_is_truncated,
1801 parent_index,
1802 range: parent_declaration.signature_line_range(),
1803 });
1804 declaration_to_signature_index.insert(declaration_id, signature_index);
1805 Some(signature_index)
1806}
1807
1808#[cfg(feature = "eval-support")]
1809pub type EvalCacheKey = (EvalCacheEntryKind, u64);
1810
1811#[cfg(feature = "eval-support")]
1812#[derive(Debug, Clone, Copy, PartialEq)]
1813pub enum EvalCacheEntryKind {
1814 Context,
1815 Search,
1816 Prediction,
1817}
1818
1819#[cfg(feature = "eval-support")]
1820impl std::fmt::Display for EvalCacheEntryKind {
1821 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1822 match self {
1823 EvalCacheEntryKind::Search => write!(f, "search"),
1824 EvalCacheEntryKind::Context => write!(f, "context"),
1825 EvalCacheEntryKind::Prediction => write!(f, "prediction"),
1826 }
1827 }
1828}
1829
1830#[cfg(feature = "eval-support")]
1831pub trait EvalCache: Send + Sync {
1832 fn read(&self, key: EvalCacheKey) -> Option<String>;
1833 fn write(&self, key: EvalCacheKey, input: &str, value: &str);
1834}
1835
1836#[cfg(test)]
1837mod tests {
1838 use std::{path::Path, sync::Arc};
1839
1840 use client::UserStore;
1841 use clock::FakeSystemClock;
1842 use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
1843 use futures::{
1844 AsyncReadExt, StreamExt,
1845 channel::{mpsc, oneshot},
1846 };
1847 use gpui::{
1848 Entity, TestAppContext,
1849 http_client::{FakeHttpClient, Response},
1850 prelude::*,
1851 };
1852 use indoc::indoc;
1853 use language::OffsetRangeExt as _;
1854 use open_ai::Usage;
1855 use pretty_assertions::{assert_eq, assert_matches};
1856 use project::{FakeFs, Project};
1857 use serde_json::json;
1858 use settings::SettingsStore;
1859 use util::path;
1860 use uuid::Uuid;
1861
1862 use crate::{BufferEditPrediction, Zeta};
1863
1864 #[gpui::test]
1865 async fn test_current_state(cx: &mut TestAppContext) {
1866 let (zeta, mut req_rx) = init_test(cx);
1867 let fs = FakeFs::new(cx.executor());
1868 fs.insert_tree(
1869 "/root",
1870 json!({
1871 "1.txt": "Hello!\nHow\nBye\n",
1872 "2.txt": "Hola!\nComo\nAdios\n"
1873 }),
1874 )
1875 .await;
1876 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1877
1878 zeta.update(cx, |zeta, cx| {
1879 zeta.register_project(&project, cx);
1880 });
1881
1882 let buffer1 = project
1883 .update(cx, |project, cx| {
1884 let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
1885 project.open_buffer(path, cx)
1886 })
1887 .await
1888 .unwrap();
1889 let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
1890 let position = snapshot1.anchor_before(language::Point::new(1, 3));
1891
1892 // Prediction for current file
1893
1894 let prediction_task = zeta.update(cx, |zeta, cx| {
1895 zeta.refresh_prediction(&project, &buffer1, position, cx)
1896 });
1897 let (_request, respond_tx) = req_rx.next().await.unwrap();
1898
1899 respond_tx
1900 .send(model_response(indoc! {r"
1901 --- a/root/1.txt
1902 +++ b/root/1.txt
1903 @@ ... @@
1904 Hello!
1905 -How
1906 +How are you?
1907 Bye
1908 "}))
1909 .unwrap();
1910 prediction_task.await.unwrap();
1911
1912 zeta.read_with(cx, |zeta, cx| {
1913 let prediction = zeta
1914 .current_prediction_for_buffer(&buffer1, &project, cx)
1915 .unwrap();
1916 assert_matches!(prediction, BufferEditPrediction::Local { .. });
1917 });
1918
1919 // Context refresh
1920 let refresh_task = zeta.update(cx, |zeta, cx| {
1921 zeta.refresh_context(project.clone(), buffer1.clone(), position, cx)
1922 });
1923 let (_request, respond_tx) = req_rx.next().await.unwrap();
1924 respond_tx
1925 .send(open_ai::Response {
1926 id: Uuid::new_v4().to_string(),
1927 object: "response".into(),
1928 created: 0,
1929 model: "model".into(),
1930 choices: vec![open_ai::Choice {
1931 index: 0,
1932 message: open_ai::RequestMessage::Assistant {
1933 content: None,
1934 tool_calls: vec![open_ai::ToolCall {
1935 id: "search".into(),
1936 content: open_ai::ToolCallContent::Function {
1937 function: open_ai::FunctionContent {
1938 name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME
1939 .to_string(),
1940 arguments: serde_json::to_string(&SearchToolInput {
1941 queries: Box::new([SearchToolQuery {
1942 glob: "root/2.txt".to_string(),
1943 syntax_node: vec![],
1944 content: Some(".".into()),
1945 }]),
1946 })
1947 .unwrap(),
1948 },
1949 },
1950 }],
1951 },
1952 finish_reason: None,
1953 }],
1954 usage: Usage {
1955 prompt_tokens: 0,
1956 completion_tokens: 0,
1957 total_tokens: 0,
1958 },
1959 })
1960 .unwrap();
1961 refresh_task.await.unwrap();
1962
1963 zeta.update(cx, |zeta, _cx| {
1964 zeta.discard_current_prediction(&project);
1965 });
1966
1967 // Prediction for another file
1968 let prediction_task = zeta.update(cx, |zeta, cx| {
1969 zeta.refresh_prediction(&project, &buffer1, position, cx)
1970 });
1971 let (_request, respond_tx) = req_rx.next().await.unwrap();
1972 respond_tx
1973 .send(model_response(indoc! {r#"
1974 --- a/root/2.txt
1975 +++ b/root/2.txt
1976 Hola!
1977 -Como
1978 +Como estas?
1979 Adios
1980 "#}))
1981 .unwrap();
1982 prediction_task.await.unwrap();
1983 zeta.read_with(cx, |zeta, cx| {
1984 let prediction = zeta
1985 .current_prediction_for_buffer(&buffer1, &project, cx)
1986 .unwrap();
1987 assert_matches!(
1988 prediction,
1989 BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
1990 );
1991 });
1992
1993 let buffer2 = project
1994 .update(cx, |project, cx| {
1995 let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
1996 project.open_buffer(path, cx)
1997 })
1998 .await
1999 .unwrap();
2000
2001 zeta.read_with(cx, |zeta, cx| {
2002 let prediction = zeta
2003 .current_prediction_for_buffer(&buffer2, &project, cx)
2004 .unwrap();
2005 assert_matches!(prediction, BufferEditPrediction::Local { .. });
2006 });
2007 }
2008
2009 #[gpui::test]
2010 async fn test_simple_request(cx: &mut TestAppContext) {
2011 let (zeta, mut req_rx) = init_test(cx);
2012 let fs = FakeFs::new(cx.executor());
2013 fs.insert_tree(
2014 "/root",
2015 json!({
2016 "foo.md": "Hello!\nHow\nBye\n"
2017 }),
2018 )
2019 .await;
2020 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2021
2022 let buffer = project
2023 .update(cx, |project, cx| {
2024 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2025 project.open_buffer(path, cx)
2026 })
2027 .await
2028 .unwrap();
2029 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2030 let position = snapshot.anchor_before(language::Point::new(1, 3));
2031
2032 let prediction_task = zeta.update(cx, |zeta, cx| {
2033 zeta.request_prediction(&project, &buffer, position, cx)
2034 });
2035
2036 let (_, respond_tx) = req_rx.next().await.unwrap();
2037
2038 // TODO Put back when we have a structured request again
2039 // assert_eq!(
2040 // request.excerpt_path.as_ref(),
2041 // Path::new(path!("root/foo.md"))
2042 // );
2043 // assert_eq!(
2044 // request.cursor_point,
2045 // Point {
2046 // line: Line(1),
2047 // column: 3
2048 // }
2049 // );
2050
2051 respond_tx
2052 .send(model_response(indoc! { r"
2053 --- a/root/foo.md
2054 +++ b/root/foo.md
2055 @@ ... @@
2056 Hello!
2057 -How
2058 +How are you?
2059 Bye
2060 "}))
2061 .unwrap();
2062
2063 let prediction = prediction_task.await.unwrap().unwrap();
2064
2065 assert_eq!(prediction.edits.len(), 1);
2066 assert_eq!(
2067 prediction.edits[0].0.to_point(&snapshot).start,
2068 language::Point::new(1, 3)
2069 );
2070 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
2071 }
2072
2073 #[gpui::test]
2074 async fn test_request_events(cx: &mut TestAppContext) {
2075 let (zeta, mut req_rx) = init_test(cx);
2076 let fs = FakeFs::new(cx.executor());
2077 fs.insert_tree(
2078 "/root",
2079 json!({
2080 "foo.md": "Hello!\n\nBye\n"
2081 }),
2082 )
2083 .await;
2084 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2085
2086 let buffer = project
2087 .update(cx, |project, cx| {
2088 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2089 project.open_buffer(path, cx)
2090 })
2091 .await
2092 .unwrap();
2093
2094 zeta.update(cx, |zeta, cx| {
2095 zeta.register_buffer(&buffer, &project, cx);
2096 });
2097
2098 buffer.update(cx, |buffer, cx| {
2099 buffer.edit(vec![(7..7, "How")], None, cx);
2100 });
2101
2102 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2103 let position = snapshot.anchor_before(language::Point::new(1, 3));
2104
2105 let prediction_task = zeta.update(cx, |zeta, cx| {
2106 zeta.request_prediction(&project, &buffer, position, cx)
2107 });
2108
2109 let (request, respond_tx) = req_rx.next().await.unwrap();
2110
2111 let prompt = prompt_from_request(&request);
2112 assert!(
2113 prompt.contains(indoc! {"
2114 --- a/root/foo.md
2115 +++ b/root/foo.md
2116 @@ -1,3 +1,3 @@
2117 Hello!
2118 -
2119 +How
2120 Bye
2121 "}),
2122 "{prompt}"
2123 );
2124
2125 respond_tx
2126 .send(model_response(indoc! {r#"
2127 --- a/root/foo.md
2128 +++ b/root/foo.md
2129 @@ ... @@
2130 Hello!
2131 -How
2132 +How are you?
2133 Bye
2134 "#}))
2135 .unwrap();
2136
2137 let prediction = prediction_task.await.unwrap().unwrap();
2138
2139 assert_eq!(prediction.edits.len(), 1);
2140 assert_eq!(
2141 prediction.edits[0].0.to_point(&snapshot).start,
2142 language::Point::new(1, 3)
2143 );
2144 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
2145 }
2146
2147 // Skipped until we start including diagnostics in prompt
2148 // #[gpui::test]
2149 // async fn test_request_diagnostics(cx: &mut TestAppContext) {
2150 // let (zeta, mut req_rx) = init_test(cx);
2151 // let fs = FakeFs::new(cx.executor());
2152 // fs.insert_tree(
2153 // "/root",
2154 // json!({
2155 // "foo.md": "Hello!\nBye"
2156 // }),
2157 // )
2158 // .await;
2159 // let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2160
2161 // let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
2162 // let diagnostic = lsp::Diagnostic {
2163 // range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
2164 // severity: Some(lsp::DiagnosticSeverity::ERROR),
2165 // message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
2166 // ..Default::default()
2167 // };
2168
2169 // project.update(cx, |project, cx| {
2170 // project.lsp_store().update(cx, |lsp_store, cx| {
2171 // // Create some diagnostics
2172 // lsp_store
2173 // .update_diagnostics(
2174 // LanguageServerId(0),
2175 // lsp::PublishDiagnosticsParams {
2176 // uri: path_to_buffer_uri.clone(),
2177 // diagnostics: vec![diagnostic],
2178 // version: None,
2179 // },
2180 // None,
2181 // language::DiagnosticSourceKind::Pushed,
2182 // &[],
2183 // cx,
2184 // )
2185 // .unwrap();
2186 // });
2187 // });
2188
2189 // let buffer = project
2190 // .update(cx, |project, cx| {
2191 // let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2192 // project.open_buffer(path, cx)
2193 // })
2194 // .await
2195 // .unwrap();
2196
2197 // let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2198 // let position = snapshot.anchor_before(language::Point::new(0, 0));
2199
2200 // let _prediction_task = zeta.update(cx, |zeta, cx| {
2201 // zeta.request_prediction(&project, &buffer, position, cx)
2202 // });
2203
2204 // let (request, _respond_tx) = req_rx.next().await.unwrap();
2205
2206 // assert_eq!(request.diagnostic_groups.len(), 1);
2207 // let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
2208 // .unwrap();
2209 // // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
2210 // assert_eq!(
2211 // value,
2212 // json!({
2213 // "entries": [{
2214 // "range": {
2215 // "start": 8,
2216 // "end": 10
2217 // },
2218 // "diagnostic": {
2219 // "source": null,
2220 // "code": null,
2221 // "code_description": null,
2222 // "severity": 1,
2223 // "message": "\"Hello\" deprecated. Use \"Hi\" instead",
2224 // "markdown": null,
2225 // "group_id": 0,
2226 // "is_primary": true,
2227 // "is_disk_based": false,
2228 // "is_unnecessary": false,
2229 // "source_kind": "Pushed",
2230 // "data": null,
2231 // "underline": true
2232 // }
2233 // }],
2234 // "primary_ix": 0
2235 // })
2236 // );
2237 // }
2238
2239 fn model_response(text: &str) -> open_ai::Response {
2240 open_ai::Response {
2241 id: Uuid::new_v4().to_string(),
2242 object: "response".into(),
2243 created: 0,
2244 model: "model".into(),
2245 choices: vec![open_ai::Choice {
2246 index: 0,
2247 message: open_ai::RequestMessage::Assistant {
2248 content: Some(open_ai::MessageContent::Plain(text.to_string())),
2249 tool_calls: vec![],
2250 },
2251 finish_reason: None,
2252 }],
2253 usage: Usage {
2254 prompt_tokens: 0,
2255 completion_tokens: 0,
2256 total_tokens: 0,
2257 },
2258 }
2259 }
2260
2261 fn prompt_from_request(request: &open_ai::Request) -> &str {
2262 assert_eq!(request.messages.len(), 1);
2263 let open_ai::RequestMessage::User {
2264 content: open_ai::MessageContent::Plain(content),
2265 ..
2266 } = &request.messages[0]
2267 else {
2268 panic!(
2269 "Request does not have single user message of type Plain. {:#?}",
2270 request
2271 );
2272 };
2273 content
2274 }
2275
2276 fn init_test(
2277 cx: &mut TestAppContext,
2278 ) -> (
2279 Entity<Zeta>,
2280 mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender<open_ai::Response>)>,
2281 ) {
2282 cx.update(move |cx| {
2283 let settings_store = SettingsStore::test(cx);
2284 cx.set_global(settings_store);
2285 zlog::init_test();
2286
2287 let (req_tx, req_rx) = mpsc::unbounded();
2288
2289 let http_client = FakeHttpClient::create({
2290 move |req| {
2291 let uri = req.uri().path().to_string();
2292 let mut body = req.into_body();
2293 let req_tx = req_tx.clone();
2294 async move {
2295 let resp = match uri.as_str() {
2296 "/client/llm_tokens" => serde_json::to_string(&json!({
2297 "token": "test"
2298 }))
2299 .unwrap(),
2300 "/predict_edits/raw" => {
2301 let mut buf = Vec::new();
2302 body.read_to_end(&mut buf).await.ok();
2303 let req = serde_json::from_slice(&buf).unwrap();
2304
2305 let (res_tx, res_rx) = oneshot::channel();
2306 req_tx.unbounded_send((req, res_tx)).unwrap();
2307 serde_json::to_string(&res_rx.await?).unwrap()
2308 }
2309 _ => {
2310 panic!("Unexpected path: {}", uri)
2311 }
2312 };
2313
2314 Ok(Response::builder().body(resp.into()).unwrap())
2315 }
2316 }
2317 });
2318
2319 let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
2320 client.cloud_client().set_credentials(1, "test".into());
2321
2322 language_model::init(client.clone(), cx);
2323
2324 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2325 let zeta = Zeta::global(&client, &user_store, cx);
2326
2327 (zeta, req_rx)
2328 })
2329 }
2330}