1use anyhow::{Context as _, Result, anyhow, bail};
2use chrono::TimeDelta;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
5use cloud_llm_client::{
6 AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
7 ZED_VERSION_HEADER_NAME,
8};
9use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
10use cloud_zeta2_prompt::{CURSOR_MARKER, DEFAULT_MAX_PROMPT_BYTES};
11use collections::HashMap;
12use edit_prediction_context::{
13 DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
14 EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionScoreOptions, Line,
15 SyntaxIndex, SyntaxIndexState,
16};
17use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
18use futures::AsyncReadExt as _;
19use futures::channel::{mpsc, oneshot};
20use gpui::http_client::{AsyncBody, Method};
21use gpui::{
22 App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
23 http_client, prelude::*,
24};
25use language::{Anchor, Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
26use language::{BufferSnapshot, OffsetRangeExt};
27use language_model::{LlmApiToken, RefreshLlmTokenListener};
28use open_ai::FunctionDefinition;
29use project::Project;
30use release_channel::AppVersion;
31use serde::de::DeserializeOwned;
32use std::collections::{VecDeque, hash_map};
33
34use std::env;
35use std::ops::Range;
36use std::path::Path;
37use std::str::FromStr as _;
38use std::sync::{Arc, LazyLock};
39use std::time::{Duration, Instant};
40use thiserror::Error;
41use util::rel_path::RelPathBuf;
42use util::{LogErrorFuture, TryFutureExt};
43use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
44
45pub mod merge_excerpts;
46mod prediction;
47mod provider;
48pub mod retrieval_search;
49pub mod udiff;
50
51use crate::merge_excerpts::merge_excerpts;
52use crate::prediction::EditPrediction;
53pub use crate::prediction::EditPredictionId;
54pub use provider::ZetaEditPredictionProvider;
55
56/// Maximum number of events to track.
57const MAX_EVENT_COUNT: usize = 16;
58
59pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
60 max_bytes: 512,
61 min_bytes: 128,
62 target_before_cursor_over_total_bytes: 0.5,
63};
64
65pub const DEFAULT_CONTEXT_OPTIONS: ContextMode =
66 ContextMode::Agentic(DEFAULT_AGENTIC_CONTEXT_OPTIONS);
67
68pub const DEFAULT_AGENTIC_CONTEXT_OPTIONS: AgenticContextOptions = AgenticContextOptions {
69 excerpt: DEFAULT_EXCERPT_OPTIONS,
70};
71
72pub const DEFAULT_SYNTAX_CONTEXT_OPTIONS: EditPredictionContextOptions =
73 EditPredictionContextOptions {
74 use_imports: true,
75 max_retrieved_declarations: 0,
76 excerpt: DEFAULT_EXCERPT_OPTIONS,
77 score: EditPredictionScoreOptions {
78 omit_excerpt_overlaps: true,
79 },
80 };
81
82pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
83 context: DEFAULT_CONTEXT_OPTIONS,
84 max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
85 max_diagnostic_bytes: 2048,
86 prompt_format: PromptFormat::DEFAULT,
87 file_indexing_parallelism: 1,
88 buffer_change_grouping_interval: Duration::from_secs(1),
89};
90
91static USE_OLLAMA: LazyLock<bool> =
92 LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
93static MODEL_ID: LazyLock<String> = LazyLock::new(|| {
94 env::var("ZED_ZETA2_MODEL").unwrap_or(if *USE_OLLAMA {
95 "qwen3-coder:30b".to_string()
96 } else {
97 "yqvev8r3".to_string()
98 })
99});
100static PREDICT_EDITS_URL: LazyLock<Option<String>> = LazyLock::new(|| {
101 env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| {
102 if *USE_OLLAMA {
103 Some("http://localhost:11434/v1/chat/completions".into())
104 } else {
105 None
106 }
107 })
108});
109
110pub struct Zeta2FeatureFlag;
111
112impl FeatureFlag for Zeta2FeatureFlag {
113 const NAME: &'static str = "zeta2";
114
115 fn enabled_for_staff() -> bool {
116 false
117 }
118}
119
120#[derive(Clone)]
121struct ZetaGlobal(Entity<Zeta>);
122
123impl Global for ZetaGlobal {}
124
125pub struct Zeta {
126 client: Arc<Client>,
127 user_store: Entity<UserStore>,
128 llm_token: LlmApiToken,
129 _llm_token_subscription: Subscription,
130 projects: HashMap<EntityId, ZetaProject>,
131 options: ZetaOptions,
132 update_required: bool,
133 debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
134 #[cfg(feature = "llm-response-cache")]
135 llm_response_cache: Option<Arc<dyn LlmResponseCache>>,
136}
137
138#[cfg(feature = "llm-response-cache")]
139pub trait LlmResponseCache: Send + Sync {
140 fn get_key(&self, url: &gpui::http_client::Url, body: &str) -> u64;
141 fn read_response(&self, key: u64) -> Option<String>;
142 fn write_response(&self, key: u64, value: &str);
143}
144
145#[derive(Debug, Clone, PartialEq)]
146pub struct ZetaOptions {
147 pub context: ContextMode,
148 pub max_prompt_bytes: usize,
149 pub max_diagnostic_bytes: usize,
150 pub prompt_format: predict_edits_v3::PromptFormat,
151 pub file_indexing_parallelism: usize,
152 pub buffer_change_grouping_interval: Duration,
153}
154
155#[derive(Debug, Clone, PartialEq)]
156pub enum ContextMode {
157 Agentic(AgenticContextOptions),
158 Syntax(EditPredictionContextOptions),
159}
160
161#[derive(Debug, Clone, PartialEq)]
162pub struct AgenticContextOptions {
163 pub excerpt: EditPredictionExcerptOptions,
164}
165
166impl ContextMode {
167 pub fn excerpt(&self) -> &EditPredictionExcerptOptions {
168 match self {
169 ContextMode::Agentic(options) => &options.excerpt,
170 ContextMode::Syntax(options) => &options.excerpt,
171 }
172 }
173}
174
175#[derive(Debug)]
176pub enum ZetaDebugInfo {
177 ContextRetrievalStarted(ZetaContextRetrievalStartedDebugInfo),
178 SearchQueriesGenerated(ZetaSearchQueryDebugInfo),
179 SearchQueriesExecuted(ZetaContextRetrievalDebugInfo),
180 ContextRetrievalFinished(ZetaContextRetrievalDebugInfo),
181 EditPredictionRequested(ZetaEditPredictionDebugInfo),
182}
183
184#[derive(Debug)]
185pub struct ZetaContextRetrievalStartedDebugInfo {
186 pub project: Entity<Project>,
187 pub timestamp: Instant,
188 pub search_prompt: String,
189}
190
191#[derive(Debug)]
192pub struct ZetaContextRetrievalDebugInfo {
193 pub project: Entity<Project>,
194 pub timestamp: Instant,
195}
196
197#[derive(Debug)]
198pub struct ZetaEditPredictionDebugInfo {
199 pub request: predict_edits_v3::PredictEditsRequest,
200 pub retrieval_time: TimeDelta,
201 pub buffer: WeakEntity<Buffer>,
202 pub position: language::Anchor,
203 pub local_prompt: Result<String, String>,
204 pub response_rx: oneshot::Receiver<(Result<open_ai::Response, String>, TimeDelta)>,
205}
206
207#[derive(Debug)]
208pub struct ZetaSearchQueryDebugInfo {
209 pub project: Entity<Project>,
210 pub timestamp: Instant,
211 pub search_queries: Vec<SearchToolQuery>,
212}
213
214pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
215
216struct ZetaProject {
217 syntax_index: Entity<SyntaxIndex>,
218 events: VecDeque<Event>,
219 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
220 current_prediction: Option<CurrentEditPrediction>,
221 context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
222 refresh_context_task: Option<LogErrorFuture<Task<Result<()>>>>,
223 refresh_context_debounce_task: Option<Task<Option<()>>>,
224 refresh_context_timestamp: Option<Instant>,
225}
226
227#[derive(Debug, Clone)]
228struct CurrentEditPrediction {
229 pub requested_by_buffer_id: EntityId,
230 pub prediction: EditPrediction,
231}
232
233impl CurrentEditPrediction {
234 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
235 let Some(new_edits) = self
236 .prediction
237 .interpolate(&self.prediction.buffer.read(cx))
238 else {
239 return false;
240 };
241
242 if self.prediction.buffer != old_prediction.prediction.buffer {
243 return true;
244 }
245
246 let Some(old_edits) = old_prediction
247 .prediction
248 .interpolate(&old_prediction.prediction.buffer.read(cx))
249 else {
250 return true;
251 };
252
253 // This reduces the occurrence of UI thrash from replacing edits
254 //
255 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
256 if self.requested_by_buffer_id == self.prediction.buffer.entity_id()
257 && self.requested_by_buffer_id == old_prediction.prediction.buffer.entity_id()
258 && old_edits.len() == 1
259 && new_edits.len() == 1
260 {
261 let (old_range, old_text) = &old_edits[0];
262 let (new_range, new_text) = &new_edits[0];
263 new_range == old_range && new_text.starts_with(old_text.as_ref())
264 } else {
265 true
266 }
267 }
268}
269
270/// A prediction from the perspective of a buffer.
271#[derive(Debug)]
272enum BufferEditPrediction<'a> {
273 Local { prediction: &'a EditPrediction },
274 Jump { prediction: &'a EditPrediction },
275}
276
277struct RegisteredBuffer {
278 snapshot: BufferSnapshot,
279 _subscriptions: [gpui::Subscription; 2],
280}
281
282#[derive(Clone)]
283pub enum Event {
284 BufferChange {
285 old_snapshot: BufferSnapshot,
286 new_snapshot: BufferSnapshot,
287 timestamp: Instant,
288 },
289}
290
291impl Event {
292 pub fn to_request_event(&self, cx: &App) -> Option<predict_edits_v3::Event> {
293 match self {
294 Event::BufferChange {
295 old_snapshot,
296 new_snapshot,
297 ..
298 } => {
299 let path = new_snapshot.file().map(|f| f.full_path(cx));
300
301 let old_path = old_snapshot.file().and_then(|f| {
302 let old_path = f.full_path(cx);
303 if Some(&old_path) != path.as_ref() {
304 Some(old_path)
305 } else {
306 None
307 }
308 });
309
310 // TODO [zeta2] move to bg?
311 let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
312
313 if path == old_path && diff.is_empty() {
314 None
315 } else {
316 Some(predict_edits_v3::Event::BufferChange {
317 old_path,
318 path,
319 diff,
320 //todo: Actually detect if this edit was predicted or not
321 predicted: false,
322 })
323 }
324 }
325 }
326 }
327}
328
329impl Zeta {
330 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
331 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
332 }
333
334 pub fn global(
335 client: &Arc<Client>,
336 user_store: &Entity<UserStore>,
337 cx: &mut App,
338 ) -> Entity<Self> {
339 cx.try_global::<ZetaGlobal>()
340 .map(|global| global.0.clone())
341 .unwrap_or_else(|| {
342 let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
343 cx.set_global(ZetaGlobal(zeta.clone()));
344 zeta
345 })
346 }
347
348 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
349 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
350
351 Self {
352 projects: HashMap::default(),
353 client,
354 user_store,
355 options: DEFAULT_OPTIONS,
356 llm_token: LlmApiToken::default(),
357 _llm_token_subscription: cx.subscribe(
358 &refresh_llm_token_listener,
359 |this, _listener, _event, cx| {
360 let client = this.client.clone();
361 let llm_token = this.llm_token.clone();
362 cx.spawn(async move |_this, _cx| {
363 llm_token.refresh(&client).await?;
364 anyhow::Ok(())
365 })
366 .detach_and_log_err(cx);
367 },
368 ),
369 update_required: false,
370 debug_tx: None,
371 #[cfg(feature = "llm-response-cache")]
372 llm_response_cache: None,
373 }
374 }
375
376 #[cfg(feature = "llm-response-cache")]
377 pub fn with_llm_response_cache(&mut self, cache: Arc<dyn LlmResponseCache>) {
378 self.llm_response_cache = Some(cache);
379 }
380
381 pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<ZetaDebugInfo> {
382 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
383 self.debug_tx = Some(debug_watch_tx);
384 debug_watch_rx
385 }
386
387 pub fn options(&self) -> &ZetaOptions {
388 &self.options
389 }
390
391 pub fn set_options(&mut self, options: ZetaOptions) {
392 self.options = options;
393 }
394
395 pub fn clear_history(&mut self) {
396 for zeta_project in self.projects.values_mut() {
397 zeta_project.events.clear();
398 }
399 }
400
401 pub fn history_for_project(&self, project: &Entity<Project>) -> impl Iterator<Item = &Event> {
402 self.projects
403 .get(&project.entity_id())
404 .map(|project| project.events.iter())
405 .into_iter()
406 .flatten()
407 }
408
409 pub fn context_for_project(
410 &self,
411 project: &Entity<Project>,
412 ) -> impl Iterator<Item = (Entity<Buffer>, &[Range<Anchor>])> {
413 self.projects
414 .get(&project.entity_id())
415 .and_then(|project| {
416 Some(
417 project
418 .context
419 .as_ref()?
420 .iter()
421 .map(|(buffer, ranges)| (buffer.clone(), ranges.as_slice())),
422 )
423 })
424 .into_iter()
425 .flatten()
426 }
427
428 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
429 self.user_store.read(cx).edit_prediction_usage()
430 }
431
432 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
433 self.get_or_init_zeta_project(project, cx);
434 }
435
436 pub fn register_buffer(
437 &mut self,
438 buffer: &Entity<Buffer>,
439 project: &Entity<Project>,
440 cx: &mut Context<Self>,
441 ) {
442 let zeta_project = self.get_or_init_zeta_project(project, cx);
443 Self::register_buffer_impl(zeta_project, buffer, project, cx);
444 }
445
446 fn get_or_init_zeta_project(
447 &mut self,
448 project: &Entity<Project>,
449 cx: &mut App,
450 ) -> &mut ZetaProject {
451 self.projects
452 .entry(project.entity_id())
453 .or_insert_with(|| ZetaProject {
454 syntax_index: cx.new(|cx| {
455 SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
456 }),
457 events: VecDeque::new(),
458 registered_buffers: HashMap::default(),
459 current_prediction: None,
460 context: None,
461 refresh_context_task: None,
462 refresh_context_debounce_task: None,
463 refresh_context_timestamp: None,
464 })
465 }
466
467 fn register_buffer_impl<'a>(
468 zeta_project: &'a mut ZetaProject,
469 buffer: &Entity<Buffer>,
470 project: &Entity<Project>,
471 cx: &mut Context<Self>,
472 ) -> &'a mut RegisteredBuffer {
473 let buffer_id = buffer.entity_id();
474 match zeta_project.registered_buffers.entry(buffer_id) {
475 hash_map::Entry::Occupied(entry) => entry.into_mut(),
476 hash_map::Entry::Vacant(entry) => {
477 let snapshot = buffer.read(cx).snapshot();
478 let project_entity_id = project.entity_id();
479 entry.insert(RegisteredBuffer {
480 snapshot,
481 _subscriptions: [
482 cx.subscribe(buffer, {
483 let project = project.downgrade();
484 move |this, buffer, event, cx| {
485 if let language::BufferEvent::Edited = event
486 && let Some(project) = project.upgrade()
487 {
488 this.report_changes_for_buffer(&buffer, &project, cx);
489 }
490 }
491 }),
492 cx.observe_release(buffer, move |this, _buffer, _cx| {
493 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
494 else {
495 return;
496 };
497 zeta_project.registered_buffers.remove(&buffer_id);
498 }),
499 ],
500 })
501 }
502 }
503 }
504
505 fn report_changes_for_buffer(
506 &mut self,
507 buffer: &Entity<Buffer>,
508 project: &Entity<Project>,
509 cx: &mut Context<Self>,
510 ) -> BufferSnapshot {
511 let buffer_change_grouping_interval = self.options.buffer_change_grouping_interval;
512 let zeta_project = self.get_or_init_zeta_project(project, cx);
513 let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
514
515 let new_snapshot = buffer.read(cx).snapshot();
516 if new_snapshot.version != registered_buffer.snapshot.version {
517 let old_snapshot =
518 std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
519 Self::push_event(
520 zeta_project,
521 buffer_change_grouping_interval,
522 Event::BufferChange {
523 old_snapshot,
524 new_snapshot: new_snapshot.clone(),
525 timestamp: Instant::now(),
526 },
527 );
528 }
529
530 new_snapshot
531 }
532
533 fn push_event(
534 zeta_project: &mut ZetaProject,
535 buffer_change_grouping_interval: Duration,
536 event: Event,
537 ) {
538 let events = &mut zeta_project.events;
539
540 if buffer_change_grouping_interval > Duration::ZERO
541 && let Some(Event::BufferChange {
542 new_snapshot: last_new_snapshot,
543 timestamp: last_timestamp,
544 ..
545 }) = events.back_mut()
546 {
547 // Coalesce edits for the same buffer when they happen one after the other.
548 let Event::BufferChange {
549 old_snapshot,
550 new_snapshot,
551 timestamp,
552 } = &event;
553
554 if timestamp.duration_since(*last_timestamp) <= buffer_change_grouping_interval
555 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
556 && old_snapshot.version == last_new_snapshot.version
557 {
558 *last_new_snapshot = new_snapshot.clone();
559 *last_timestamp = *timestamp;
560 return;
561 }
562 }
563
564 if events.len() >= MAX_EVENT_COUNT {
565 // These are halved instead of popping to improve prompt caching.
566 events.drain(..MAX_EVENT_COUNT / 2);
567 }
568
569 events.push_back(event);
570 }
571
572 fn current_prediction_for_buffer(
573 &self,
574 buffer: &Entity<Buffer>,
575 project: &Entity<Project>,
576 cx: &App,
577 ) -> Option<BufferEditPrediction<'_>> {
578 let project_state = self.projects.get(&project.entity_id())?;
579
580 let CurrentEditPrediction {
581 requested_by_buffer_id,
582 prediction,
583 } = project_state.current_prediction.as_ref()?;
584
585 if prediction.targets_buffer(buffer.read(cx)) {
586 Some(BufferEditPrediction::Local { prediction })
587 } else if *requested_by_buffer_id == buffer.entity_id() {
588 Some(BufferEditPrediction::Jump { prediction })
589 } else {
590 None
591 }
592 }
593
594 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
595 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
596 return;
597 };
598
599 let Some(prediction) = project_state.current_prediction.take() else {
600 return;
601 };
602 let request_id = prediction.prediction.id.to_string();
603
604 let client = self.client.clone();
605 let llm_token = self.llm_token.clone();
606 let app_version = AppVersion::global(cx);
607 cx.spawn(async move |this, cx| {
608 let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") {
609 http_client::Url::parse(&predict_edits_url)?
610 } else {
611 client
612 .http_client()
613 .build_zed_llm_url("/predict_edits/accept", &[])?
614 };
615
616 let response = cx
617 .background_spawn(Self::send_api_request::<()>(
618 move |builder| {
619 let req = builder.uri(url.as_ref()).body(
620 serde_json::to_string(&AcceptEditPredictionBody {
621 request_id: request_id.clone(),
622 })?
623 .into(),
624 );
625 Ok(req?)
626 },
627 client,
628 llm_token,
629 app_version,
630 ))
631 .await;
632
633 Self::handle_api_response(&this, response, cx)?;
634 anyhow::Ok(())
635 })
636 .detach_and_log_err(cx);
637 }
638
639 fn discard_current_prediction(&mut self, project: &Entity<Project>) {
640 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
641 project_state.current_prediction.take();
642 };
643 }
644
645 pub fn refresh_prediction(
646 &mut self,
647 project: &Entity<Project>,
648 buffer: &Entity<Buffer>,
649 position: language::Anchor,
650 cx: &mut Context<Self>,
651 ) -> Task<Result<()>> {
652 let request_task = self.request_prediction(project, buffer, position, cx);
653 let buffer = buffer.clone();
654 let project = project.clone();
655
656 cx.spawn(async move |this, cx| {
657 if let Some(prediction) = request_task.await? {
658 this.update(cx, |this, cx| {
659 let project_state = this
660 .projects
661 .get_mut(&project.entity_id())
662 .context("Project not found")?;
663
664 let new_prediction = CurrentEditPrediction {
665 requested_by_buffer_id: buffer.entity_id(),
666 prediction: prediction,
667 };
668
669 if project_state
670 .current_prediction
671 .as_ref()
672 .is_none_or(|old_prediction| {
673 new_prediction.should_replace_prediction(&old_prediction, cx)
674 })
675 {
676 project_state.current_prediction = Some(new_prediction);
677 }
678 anyhow::Ok(())
679 })??;
680 }
681 Ok(())
682 })
683 }
684
685 pub fn request_prediction(
686 &mut self,
687 project: &Entity<Project>,
688 active_buffer: &Entity<Buffer>,
689 position: language::Anchor,
690 cx: &mut Context<Self>,
691 ) -> Task<Result<Option<EditPrediction>>> {
692 let project_state = self.projects.get(&project.entity_id());
693
694 let index_state = project_state.map(|state| {
695 state
696 .syntax_index
697 .read_with(cx, |index, _cx| index.state().clone())
698 });
699 let options = self.options.clone();
700 let active_snapshot = active_buffer.read(cx).snapshot();
701 let Some(excerpt_path) = active_snapshot
702 .file()
703 .map(|path| -> Arc<Path> { path.full_path(cx).into() })
704 else {
705 return Task::ready(Err(anyhow!("No file path for excerpt")));
706 };
707 let client = self.client.clone();
708 let llm_token = self.llm_token.clone();
709 let app_version = AppVersion::global(cx);
710 let worktree_snapshots = project
711 .read(cx)
712 .worktrees(cx)
713 .map(|worktree| worktree.read(cx).snapshot())
714 .collect::<Vec<_>>();
715 let debug_tx = self.debug_tx.clone();
716
717 let events = project_state
718 .map(|state| {
719 state
720 .events
721 .iter()
722 .filter_map(|event| event.to_request_event(cx))
723 .collect::<Vec<_>>()
724 })
725 .unwrap_or_default();
726
727 let diagnostics = active_snapshot.diagnostic_sets().clone();
728
729 let parent_abs_path =
730 project::File::from_dyn(active_buffer.read(cx).file()).and_then(|f| {
731 let mut path = f.worktree.read(cx).absolutize(&f.path);
732 if path.pop() { Some(path) } else { None }
733 });
734
735 // TODO data collection
736 let can_collect_data = cx.is_staff();
737
738 let mut included_files = project_state
739 .and_then(|project_state| project_state.context.as_ref())
740 .unwrap_or(&HashMap::default())
741 .iter()
742 .filter_map(|(buffer_entity, ranges)| {
743 let buffer = buffer_entity.read(cx);
744 Some((
745 buffer_entity.clone(),
746 buffer.snapshot(),
747 buffer.file()?.full_path(cx).into(),
748 ranges.clone(),
749 ))
750 })
751 .collect::<Vec<_>>();
752
753 #[cfg(feature = "llm-response-cache")]
754 let llm_response_cache = self.llm_response_cache.clone();
755
756 let request_task = cx.background_spawn({
757 let active_buffer = active_buffer.clone();
758 async move {
759 let index_state = if let Some(index_state) = index_state {
760 Some(index_state.lock_owned().await)
761 } else {
762 None
763 };
764
765 let cursor_offset = position.to_offset(&active_snapshot);
766 let cursor_point = cursor_offset.to_point(&active_snapshot);
767
768 let before_retrieval = chrono::Utc::now();
769
770 let (diagnostic_groups, diagnostic_groups_truncated) =
771 Self::gather_nearby_diagnostics(
772 cursor_offset,
773 &diagnostics,
774 &active_snapshot,
775 options.max_diagnostic_bytes,
776 );
777
778 let cloud_request = match options.context {
779 ContextMode::Agentic(context_options) => {
780 let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
781 cursor_point,
782 &active_snapshot,
783 &context_options.excerpt,
784 index_state.as_deref(),
785 ) else {
786 return Ok((None, None));
787 };
788
789 let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start)
790 ..active_snapshot.anchor_before(excerpt.range.end);
791
792 if let Some(buffer_ix) =
793 included_files.iter().position(|(_, snapshot, _, _)| {
794 snapshot.remote_id() == active_snapshot.remote_id()
795 })
796 {
797 let (_, buffer, _, ranges) = &mut included_files[buffer_ix];
798 let range_ix = ranges
799 .binary_search_by(|probe| {
800 probe
801 .start
802 .cmp(&excerpt_anchor_range.start, buffer)
803 .then(excerpt_anchor_range.end.cmp(&probe.end, buffer))
804 })
805 .unwrap_or_else(|ix| ix);
806
807 ranges.insert(range_ix, excerpt_anchor_range);
808 let last_ix = included_files.len() - 1;
809 included_files.swap(buffer_ix, last_ix);
810 } else {
811 included_files.push((
812 active_buffer.clone(),
813 active_snapshot,
814 excerpt_path.clone(),
815 vec![excerpt_anchor_range],
816 ));
817 }
818
819 let included_files = included_files
820 .iter()
821 .map(|(_, buffer, path, ranges)| {
822 let excerpts = merge_excerpts(
823 &buffer,
824 ranges.iter().map(|range| {
825 let point_range = range.to_point(&buffer);
826 Line(point_range.start.row)..Line(point_range.end.row)
827 }),
828 );
829 predict_edits_v3::IncludedFile {
830 path: path.clone(),
831 max_row: Line(buffer.max_point().row),
832 excerpts,
833 }
834 })
835 .collect::<Vec<_>>();
836
837 predict_edits_v3::PredictEditsRequest {
838 excerpt_path,
839 excerpt: String::new(),
840 excerpt_line_range: Line(0)..Line(0),
841 excerpt_range: 0..0,
842 cursor_point: predict_edits_v3::Point {
843 line: predict_edits_v3::Line(cursor_point.row),
844 column: cursor_point.column,
845 },
846 included_files,
847 referenced_declarations: vec![],
848 events,
849 can_collect_data,
850 diagnostic_groups,
851 diagnostic_groups_truncated,
852 debug_info: debug_tx.is_some(),
853 prompt_max_bytes: Some(options.max_prompt_bytes),
854 prompt_format: options.prompt_format,
855 // TODO [zeta2]
856 signatures: vec![],
857 excerpt_parent: None,
858 git_info: None,
859 }
860 }
861 ContextMode::Syntax(context_options) => {
862 let Some(context) = EditPredictionContext::gather_context(
863 cursor_point,
864 &active_snapshot,
865 parent_abs_path.as_deref(),
866 &context_options,
867 index_state.as_deref(),
868 ) else {
869 return Ok((None, None));
870 };
871
872 make_syntax_context_cloud_request(
873 excerpt_path,
874 context,
875 events,
876 can_collect_data,
877 diagnostic_groups,
878 diagnostic_groups_truncated,
879 None,
880 debug_tx.is_some(),
881 &worktree_snapshots,
882 index_state.as_deref(),
883 Some(options.max_prompt_bytes),
884 options.prompt_format,
885 )
886 }
887 };
888
889 let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request);
890
891 let retrieval_time = chrono::Utc::now() - before_retrieval;
892
893 let debug_response_tx = if let Some(debug_tx) = &debug_tx {
894 let (response_tx, response_rx) = oneshot::channel();
895
896 debug_tx
897 .unbounded_send(ZetaDebugInfo::EditPredictionRequested(
898 ZetaEditPredictionDebugInfo {
899 request: cloud_request.clone(),
900 retrieval_time,
901 buffer: active_buffer.downgrade(),
902 local_prompt: match prompt_result.as_ref() {
903 Ok((prompt, _)) => Ok(prompt.clone()),
904 Err(err) => Err(err.to_string()),
905 },
906 position,
907 response_rx,
908 },
909 ))
910 .ok();
911 Some(response_tx)
912 } else {
913 None
914 };
915
916 if cfg!(debug_assertions) && env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
917 if let Some(debug_response_tx) = debug_response_tx {
918 debug_response_tx
919 .send((Err("Request skipped".to_string()), TimeDelta::zero()))
920 .ok();
921 }
922 anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
923 }
924
925 let (prompt, _) = prompt_result?;
926 let request = open_ai::Request {
927 model: MODEL_ID.clone(),
928 messages: vec![open_ai::RequestMessage::User {
929 content: open_ai::MessageContent::Plain(prompt),
930 }],
931 stream: false,
932 max_completion_tokens: None,
933 stop: Default::default(),
934 temperature: 0.7,
935 tool_choice: None,
936 parallel_tool_calls: None,
937 tools: vec![],
938 prompt_cache_key: None,
939 reasoning_effort: None,
940 };
941
942 log::trace!("Sending edit prediction request");
943
944 let before_request = chrono::Utc::now();
945 let response = Self::send_raw_llm_request(
946 request,
947 client,
948 llm_token,
949 app_version,
950 #[cfg(feature = "llm-response-cache")]
951 llm_response_cache
952 ).await;
953 let request_time = chrono::Utc::now() - before_request;
954
955 log::trace!("Got edit prediction response");
956
957 if let Some(debug_response_tx) = debug_response_tx {
958 debug_response_tx
959 .send((
960 response
961 .as_ref()
962 .map_err(|err| err.to_string())
963 .map(|response| response.0.clone()),
964 request_time,
965 ))
966 .ok();
967 }
968
969 let (res, usage) = response?;
970 let request_id = EditPredictionId(res.id.clone().into());
971 let Some(mut output_text) = text_from_response(res) else {
972 return Ok((None, usage))
973 };
974
975 if output_text.contains(CURSOR_MARKER) {
976 log::trace!("Stripping out {CURSOR_MARKER} from response");
977 output_text = output_text.replace(CURSOR_MARKER, "");
978 }
979
980 let (edited_buffer_snapshot, edits) = match options.prompt_format {
981 PromptFormat::NumLinesUniDiff => {
982 crate::udiff::parse_diff(&output_text, |path| {
983 included_files
984 .iter()
985 .find_map(|(_, buffer, probe_path, ranges)| {
986 if probe_path.as_ref() == path {
987 Some((buffer, ranges.as_slice()))
988 } else {
989 None
990 }
991 })
992 })
993 .await?
994 }
995 _ => {
996 bail!("unsupported prompt format {}", options.prompt_format)
997 }
998 };
999
1000 let edited_buffer = included_files
1001 .iter()
1002 .find_map(|(buffer, snapshot, _, _)| {
1003 if snapshot.remote_id() == edited_buffer_snapshot.remote_id() {
1004 Some(buffer.clone())
1005 } else {
1006 None
1007 }
1008 })
1009 .context("Failed to find buffer in included_buffers, even though we just found the snapshot")?;
1010
1011 anyhow::Ok((Some((request_id, edited_buffer, edited_buffer_snapshot.clone(), edits)), usage))
1012 }
1013 });
1014
1015 cx.spawn({
1016 async move |this, cx| {
1017 let Some((id, edited_buffer, edited_buffer_snapshot, edits)) =
1018 Self::handle_api_response(&this, request_task.await, cx)?
1019 else {
1020 return Ok(None);
1021 };
1022
1023 // TODO telemetry: duration, etc
1024 Ok(
1025 EditPrediction::new(id, &edited_buffer, &edited_buffer_snapshot, edits, cx)
1026 .await,
1027 )
1028 }
1029 })
1030 }
1031
1032 async fn send_raw_llm_request(
1033 request: open_ai::Request,
1034 client: Arc<Client>,
1035 llm_token: LlmApiToken,
1036 app_version: SemanticVersion,
1037 #[cfg(feature = "llm-response-cache")] llm_response_cache: Option<
1038 Arc<dyn LlmResponseCache>,
1039 >,
1040 ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
1041 let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
1042 http_client::Url::parse(&predict_edits_url)?
1043 } else {
1044 client
1045 .http_client()
1046 .build_zed_llm_url("/predict_edits/raw", &[])?
1047 };
1048
1049 #[cfg(feature = "llm-response-cache")]
1050 let cache_key = if let Some(cache) = llm_response_cache {
1051 let request_json = serde_json::to_string(&request)?;
1052 let key = cache.get_key(&url, &request_json);
1053
1054 if let Some(response_str) = cache.read_response(key) {
1055 return Ok((serde_json::from_str(&response_str)?, None));
1056 }
1057
1058 Some((cache, key))
1059 } else {
1060 None
1061 };
1062
1063 let (response, usage) = Self::send_api_request(
1064 |builder| {
1065 let req = builder
1066 .uri(url.as_ref())
1067 .body(serde_json::to_string(&request)?.into());
1068 Ok(req?)
1069 },
1070 client,
1071 llm_token,
1072 app_version,
1073 )
1074 .await?;
1075
1076 #[cfg(feature = "llm-response-cache")]
1077 if let Some((cache, key)) = cache_key {
1078 cache.write_response(key, &serde_json::to_string(&response)?);
1079 }
1080
1081 Ok((response, usage))
1082 }
1083
1084 fn handle_api_response<T>(
1085 this: &WeakEntity<Self>,
1086 response: Result<(T, Option<EditPredictionUsage>)>,
1087 cx: &mut gpui::AsyncApp,
1088 ) -> Result<T> {
1089 match response {
1090 Ok((data, usage)) => {
1091 if let Some(usage) = usage {
1092 this.update(cx, |this, cx| {
1093 this.user_store.update(cx, |user_store, cx| {
1094 user_store.update_edit_prediction_usage(usage, cx);
1095 });
1096 })
1097 .ok();
1098 }
1099 Ok(data)
1100 }
1101 Err(err) => {
1102 if err.is::<ZedUpdateRequiredError>() {
1103 cx.update(|cx| {
1104 this.update(cx, |this, _cx| {
1105 this.update_required = true;
1106 })
1107 .ok();
1108
1109 let error_message: SharedString = err.to_string().into();
1110 show_app_notification(
1111 NotificationId::unique::<ZedUpdateRequiredError>(),
1112 cx,
1113 move |cx| {
1114 cx.new(|cx| {
1115 ErrorMessagePrompt::new(error_message.clone(), cx)
1116 .with_link_button("Update Zed", "https://zed.dev/releases")
1117 })
1118 },
1119 );
1120 })
1121 .ok();
1122 }
1123 Err(err)
1124 }
1125 }
1126 }
1127
1128 async fn send_api_request<Res>(
1129 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1130 client: Arc<Client>,
1131 llm_token: LlmApiToken,
1132 app_version: SemanticVersion,
1133 ) -> Result<(Res, Option<EditPredictionUsage>)>
1134 where
1135 Res: DeserializeOwned,
1136 {
1137 let http_client = client.http_client();
1138 let mut token = llm_token.acquire(&client).await?;
1139 let mut did_retry = false;
1140
1141 loop {
1142 let request_builder = http_client::Request::builder().method(Method::POST);
1143
1144 let request = build(
1145 request_builder
1146 .header("Content-Type", "application/json")
1147 .header("Authorization", format!("Bearer {}", token))
1148 .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
1149 )?;
1150
1151 let mut response = http_client.send(request).await?;
1152
1153 if let Some(minimum_required_version) = response
1154 .headers()
1155 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
1156 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
1157 {
1158 anyhow::ensure!(
1159 app_version >= minimum_required_version,
1160 ZedUpdateRequiredError {
1161 minimum_version: minimum_required_version
1162 }
1163 );
1164 }
1165
1166 if response.status().is_success() {
1167 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
1168
1169 let mut body = Vec::new();
1170 response.body_mut().read_to_end(&mut body).await?;
1171 return Ok((serde_json::from_slice(&body)?, usage));
1172 } else if !did_retry
1173 && response
1174 .headers()
1175 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1176 .is_some()
1177 {
1178 did_retry = true;
1179 token = llm_token.refresh(&client).await?;
1180 } else {
1181 let mut body = String::new();
1182 response.body_mut().read_to_string(&mut body).await?;
1183 anyhow::bail!(
1184 "Request failed with status: {:?}\nBody: {}",
1185 response.status(),
1186 body
1187 );
1188 }
1189 }
1190 }
1191
1192 pub const CONTEXT_RETRIEVAL_IDLE_DURATION: Duration = Duration::from_secs(10);
1193 pub const CONTEXT_RETRIEVAL_DEBOUNCE_DURATION: Duration = Duration::from_secs(3);
1194
1195 // Refresh the related excerpts when the user just beguns editing after
1196 // an idle period, and after they pause editing.
1197 fn refresh_context_if_needed(
1198 &mut self,
1199 project: &Entity<Project>,
1200 buffer: &Entity<language::Buffer>,
1201 cursor_position: language::Anchor,
1202 cx: &mut Context<Self>,
1203 ) {
1204 if !matches!(&self.options().context, ContextMode::Agentic { .. }) {
1205 return;
1206 }
1207
1208 let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1209 return;
1210 };
1211
1212 let now = Instant::now();
1213 let was_idle = zeta_project
1214 .refresh_context_timestamp
1215 .map_or(true, |timestamp| {
1216 now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION
1217 });
1218 zeta_project.refresh_context_timestamp = Some(now);
1219 zeta_project.refresh_context_debounce_task = Some(cx.spawn({
1220 let buffer = buffer.clone();
1221 let project = project.clone();
1222 async move |this, cx| {
1223 if was_idle {
1224 log::debug!("refetching edit prediction context after idle");
1225 } else {
1226 cx.background_executor()
1227 .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION)
1228 .await;
1229 log::debug!("refetching edit prediction context after pause");
1230 }
1231 this.update(cx, |this, cx| {
1232 let task = this.refresh_context(project.clone(), buffer, cursor_position, cx);
1233
1234 if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
1235 zeta_project.refresh_context_task = Some(task.log_err());
1236 };
1237 })
1238 .ok()
1239 }
1240 }));
1241 }
1242
1243 // Refresh the related excerpts asynchronously. Ensure the task runs to completion,
1244 // and avoid spawning more than one concurrent task.
1245 pub fn refresh_context(
1246 &mut self,
1247 project: Entity<Project>,
1248 buffer: Entity<language::Buffer>,
1249 cursor_position: language::Anchor,
1250 cx: &mut Context<Self>,
1251 ) -> Task<Result<()>> {
1252 let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
1253 return Task::ready(anyhow::Ok(()));
1254 };
1255
1256 let ContextMode::Agentic(options) = &self.options().context else {
1257 return Task::ready(anyhow::Ok(()));
1258 };
1259
1260 let snapshot = buffer.read(cx).snapshot();
1261 let cursor_point = cursor_position.to_point(&snapshot);
1262 let Some(cursor_excerpt) = EditPredictionExcerpt::select_from_buffer(
1263 cursor_point,
1264 &snapshot,
1265 &options.excerpt,
1266 None,
1267 ) else {
1268 return Task::ready(Ok(()));
1269 };
1270
1271 let app_version = AppVersion::global(cx);
1272 let client = self.client.clone();
1273 let llm_token = self.llm_token.clone();
1274 let debug_tx = self.debug_tx.clone();
1275 let current_file_path: Arc<Path> = snapshot
1276 .file()
1277 .map(|f| f.full_path(cx).into())
1278 .unwrap_or_else(|| Path::new("untitled").into());
1279
1280 let prompt = match cloud_zeta2_prompt::retrieval_prompt::build_prompt(
1281 predict_edits_v3::PlanContextRetrievalRequest {
1282 excerpt: cursor_excerpt.text(&snapshot).body,
1283 excerpt_path: current_file_path,
1284 excerpt_line_range: cursor_excerpt.line_range,
1285 cursor_file_max_row: Line(snapshot.max_point().row),
1286 events: zeta_project
1287 .events
1288 .iter()
1289 .filter_map(|ev| ev.to_request_event(cx))
1290 .collect(),
1291 },
1292 ) {
1293 Ok(prompt) => prompt,
1294 Err(err) => {
1295 return Task::ready(Err(err));
1296 }
1297 };
1298
1299 if let Some(debug_tx) = &debug_tx {
1300 debug_tx
1301 .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted(
1302 ZetaContextRetrievalStartedDebugInfo {
1303 project: project.clone(),
1304 timestamp: Instant::now(),
1305 search_prompt: prompt.clone(),
1306 },
1307 ))
1308 .ok();
1309 }
1310
1311 pub static TOOL_SCHEMA: LazyLock<(serde_json::Value, String)> = LazyLock::new(|| {
1312 let schema = language_model::tool_schema::root_schema_for::<SearchToolInput>(
1313 language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset,
1314 );
1315
1316 let description = schema
1317 .get("description")
1318 .and_then(|description| description.as_str())
1319 .unwrap()
1320 .to_string();
1321
1322 (schema.into(), description)
1323 });
1324
1325 let (tool_schema, tool_description) = TOOL_SCHEMA.clone();
1326
1327 let request = open_ai::Request {
1328 model: MODEL_ID.clone(),
1329 messages: vec![open_ai::RequestMessage::User {
1330 content: open_ai::MessageContent::Plain(prompt),
1331 }],
1332 stream: false,
1333 max_completion_tokens: None,
1334 stop: Default::default(),
1335 temperature: 0.7,
1336 tool_choice: None,
1337 parallel_tool_calls: None,
1338 tools: vec![open_ai::ToolDefinition::Function {
1339 function: FunctionDefinition {
1340 name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME.to_string(),
1341 description: Some(tool_description),
1342 parameters: Some(tool_schema),
1343 },
1344 }],
1345 prompt_cache_key: None,
1346 reasoning_effort: None,
1347 };
1348
1349 #[cfg(feature = "llm-response-cache")]
1350 let llm_response_cache = self.llm_response_cache.clone();
1351
1352 cx.spawn(async move |this, cx| {
1353 log::trace!("Sending search planning request");
1354 let response = Self::send_raw_llm_request(
1355 request,
1356 client,
1357 llm_token,
1358 app_version,
1359 #[cfg(feature = "llm-response-cache")]
1360 llm_response_cache,
1361 )
1362 .await;
1363 let mut response = Self::handle_api_response(&this, response, cx)?;
1364 log::trace!("Got search planning response");
1365
1366 let choice = response
1367 .choices
1368 .pop()
1369 .context("No choices in retrieval response")?;
1370 let open_ai::RequestMessage::Assistant {
1371 content: _,
1372 tool_calls,
1373 } = choice.message
1374 else {
1375 anyhow::bail!("Retrieval response didn't include an assistant message");
1376 };
1377
1378 let mut queries: Vec<SearchToolQuery> = Vec::new();
1379 for tool_call in tool_calls {
1380 let open_ai::ToolCallContent::Function { function } = tool_call.content;
1381 if function.name != cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME {
1382 log::warn!(
1383 "Context retrieval response tried to call an unknown tool: {}",
1384 function.name
1385 );
1386
1387 continue;
1388 }
1389
1390 let input: SearchToolInput = serde_json::from_str(&function.arguments)?;
1391 queries.extend(input.queries);
1392 }
1393
1394 if let Some(debug_tx) = &debug_tx {
1395 debug_tx
1396 .unbounded_send(ZetaDebugInfo::SearchQueriesGenerated(
1397 ZetaSearchQueryDebugInfo {
1398 project: project.clone(),
1399 timestamp: Instant::now(),
1400 search_queries: queries.clone(),
1401 },
1402 ))
1403 .ok();
1404 }
1405
1406 log::trace!("Running retrieval search: {queries:#?}");
1407
1408 let related_excerpts_result =
1409 retrieval_search::run_retrieval_searches(project.clone(), queries, cx).await;
1410
1411 log::trace!("Search queries executed");
1412
1413 if let Some(debug_tx) = &debug_tx {
1414 debug_tx
1415 .unbounded_send(ZetaDebugInfo::SearchQueriesExecuted(
1416 ZetaContextRetrievalDebugInfo {
1417 project: project.clone(),
1418 timestamp: Instant::now(),
1419 },
1420 ))
1421 .ok();
1422 }
1423
1424 this.update(cx, |this, _cx| {
1425 let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
1426 return Ok(());
1427 };
1428 zeta_project.refresh_context_task.take();
1429 if let Some(debug_tx) = &this.debug_tx {
1430 debug_tx
1431 .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished(
1432 ZetaContextRetrievalDebugInfo {
1433 project,
1434 timestamp: Instant::now(),
1435 },
1436 ))
1437 .ok();
1438 }
1439 match related_excerpts_result {
1440 Ok(excerpts) => {
1441 zeta_project.context = Some(excerpts);
1442 Ok(())
1443 }
1444 Err(error) => Err(error),
1445 }
1446 })?
1447 })
1448 }
1449
1450 fn gather_nearby_diagnostics(
1451 cursor_offset: usize,
1452 diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
1453 snapshot: &BufferSnapshot,
1454 max_diagnostics_bytes: usize,
1455 ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
1456 // TODO: Could make this more efficient
1457 let mut diagnostic_groups = Vec::new();
1458 for (language_server_id, diagnostics) in diagnostic_sets {
1459 let mut groups = Vec::new();
1460 diagnostics.groups(*language_server_id, &mut groups, &snapshot);
1461 diagnostic_groups.extend(
1462 groups
1463 .into_iter()
1464 .map(|(_, group)| group.resolve::<usize>(&snapshot)),
1465 );
1466 }
1467
1468 // sort by proximity to cursor
1469 diagnostic_groups.sort_by_key(|group| {
1470 let range = &group.entries[group.primary_ix].range;
1471 if range.start >= cursor_offset {
1472 range.start - cursor_offset
1473 } else if cursor_offset >= range.end {
1474 cursor_offset - range.end
1475 } else {
1476 (cursor_offset - range.start).min(range.end - cursor_offset)
1477 }
1478 });
1479
1480 let mut results = Vec::new();
1481 let mut diagnostic_groups_truncated = false;
1482 let mut diagnostics_byte_count = 0;
1483 for group in diagnostic_groups {
1484 let raw_value = serde_json::value::to_raw_value(&group).unwrap();
1485 diagnostics_byte_count += raw_value.get().len();
1486 if diagnostics_byte_count > max_diagnostics_bytes {
1487 diagnostic_groups_truncated = true;
1488 break;
1489 }
1490 results.push(predict_edits_v3::DiagnosticGroup(raw_value));
1491 }
1492
1493 (results, diagnostic_groups_truncated)
1494 }
1495
1496 // TODO: Dedupe with similar code in request_prediction?
1497 pub fn cloud_request_for_zeta_cli(
1498 &mut self,
1499 project: &Entity<Project>,
1500 buffer: &Entity<Buffer>,
1501 position: language::Anchor,
1502 cx: &mut Context<Self>,
1503 ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
1504 let project_state = self.projects.get(&project.entity_id());
1505
1506 let index_state = project_state.map(|state| {
1507 state
1508 .syntax_index
1509 .read_with(cx, |index, _cx| index.state().clone())
1510 });
1511 let options = self.options.clone();
1512 let snapshot = buffer.read(cx).snapshot();
1513 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
1514 return Task::ready(Err(anyhow!("No file path for excerpt")));
1515 };
1516 let worktree_snapshots = project
1517 .read(cx)
1518 .worktrees(cx)
1519 .map(|worktree| worktree.read(cx).snapshot())
1520 .collect::<Vec<_>>();
1521
1522 let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
1523 let mut path = f.worktree.read(cx).absolutize(&f.path);
1524 if path.pop() { Some(path) } else { None }
1525 });
1526
1527 cx.background_spawn(async move {
1528 let index_state = if let Some(index_state) = index_state {
1529 Some(index_state.lock_owned().await)
1530 } else {
1531 None
1532 };
1533
1534 let cursor_point = position.to_point(&snapshot);
1535
1536 let debug_info = true;
1537 EditPredictionContext::gather_context(
1538 cursor_point,
1539 &snapshot,
1540 parent_abs_path.as_deref(),
1541 match &options.context {
1542 ContextMode::Agentic(_) => {
1543 // TODO
1544 panic!("Llm mode not supported in zeta cli yet");
1545 }
1546 ContextMode::Syntax(edit_prediction_context_options) => {
1547 edit_prediction_context_options
1548 }
1549 },
1550 index_state.as_deref(),
1551 )
1552 .context("Failed to select excerpt")
1553 .map(|context| {
1554 make_syntax_context_cloud_request(
1555 excerpt_path.into(),
1556 context,
1557 // TODO pass everything
1558 Vec::new(),
1559 false,
1560 Vec::new(),
1561 false,
1562 None,
1563 debug_info,
1564 &worktree_snapshots,
1565 index_state.as_deref(),
1566 Some(options.max_prompt_bytes),
1567 options.prompt_format,
1568 )
1569 })
1570 })
1571 }
1572
1573 pub fn wait_for_initial_indexing(
1574 &mut self,
1575 project: &Entity<Project>,
1576 cx: &mut App,
1577 ) -> Task<Result<()>> {
1578 let zeta_project = self.get_or_init_zeta_project(project, cx);
1579 zeta_project
1580 .syntax_index
1581 .read(cx)
1582 .wait_for_initial_file_indexing(cx)
1583 }
1584}
1585
1586pub fn text_from_response(mut res: open_ai::Response) -> Option<String> {
1587 let choice = res.choices.pop()?;
1588 let output_text = match choice.message {
1589 open_ai::RequestMessage::Assistant {
1590 content: Some(open_ai::MessageContent::Plain(content)),
1591 ..
1592 } => content,
1593 open_ai::RequestMessage::Assistant {
1594 content: Some(open_ai::MessageContent::Multipart(mut content)),
1595 ..
1596 } => {
1597 if content.is_empty() {
1598 log::error!("No output from Baseten completion response");
1599 return None;
1600 }
1601
1602 match content.remove(0) {
1603 open_ai::MessagePart::Text { text } => text,
1604 open_ai::MessagePart::Image { .. } => {
1605 log::error!("Expected text, got an image");
1606 return None;
1607 }
1608 }
1609 }
1610 _ => {
1611 log::error!("Invalid response message: {:?}", choice.message);
1612 return None;
1613 }
1614 };
1615 Some(output_text)
1616}
1617
1618#[derive(Error, Debug)]
1619#[error(
1620 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1621)]
1622pub struct ZedUpdateRequiredError {
1623 minimum_version: SemanticVersion,
1624}
1625
1626fn make_syntax_context_cloud_request(
1627 excerpt_path: Arc<Path>,
1628 context: EditPredictionContext,
1629 events: Vec<predict_edits_v3::Event>,
1630 can_collect_data: bool,
1631 diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
1632 diagnostic_groups_truncated: bool,
1633 git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
1634 debug_info: bool,
1635 worktrees: &Vec<worktree::Snapshot>,
1636 index_state: Option<&SyntaxIndexState>,
1637 prompt_max_bytes: Option<usize>,
1638 prompt_format: PromptFormat,
1639) -> predict_edits_v3::PredictEditsRequest {
1640 let mut signatures = Vec::new();
1641 let mut declaration_to_signature_index = HashMap::default();
1642 let mut referenced_declarations = Vec::new();
1643
1644 for snippet in context.declarations {
1645 let project_entry_id = snippet.declaration.project_entry_id();
1646 let Some(path) = worktrees.iter().find_map(|worktree| {
1647 worktree.entry_for_id(project_entry_id).map(|entry| {
1648 let mut full_path = RelPathBuf::new();
1649 full_path.push(worktree.root_name());
1650 full_path.push(&entry.path);
1651 full_path
1652 })
1653 }) else {
1654 continue;
1655 };
1656
1657 let parent_index = index_state.and_then(|index_state| {
1658 snippet.declaration.parent().and_then(|parent| {
1659 add_signature(
1660 parent,
1661 &mut declaration_to_signature_index,
1662 &mut signatures,
1663 index_state,
1664 )
1665 })
1666 });
1667
1668 let (text, text_is_truncated) = snippet.declaration.item_text();
1669 referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
1670 path: path.as_std_path().into(),
1671 text: text.into(),
1672 range: snippet.declaration.item_line_range(),
1673 text_is_truncated,
1674 signature_range: snippet.declaration.signature_range_in_item_text(),
1675 parent_index,
1676 signature_score: snippet.score(DeclarationStyle::Signature),
1677 declaration_score: snippet.score(DeclarationStyle::Declaration),
1678 score_components: snippet.components,
1679 });
1680 }
1681
1682 let excerpt_parent = index_state.and_then(|index_state| {
1683 context
1684 .excerpt
1685 .parent_declarations
1686 .last()
1687 .and_then(|(parent, _)| {
1688 add_signature(
1689 *parent,
1690 &mut declaration_to_signature_index,
1691 &mut signatures,
1692 index_state,
1693 )
1694 })
1695 });
1696
1697 predict_edits_v3::PredictEditsRequest {
1698 excerpt_path,
1699 excerpt: context.excerpt_text.body,
1700 excerpt_line_range: context.excerpt.line_range,
1701 excerpt_range: context.excerpt.range,
1702 cursor_point: predict_edits_v3::Point {
1703 line: predict_edits_v3::Line(context.cursor_point.row),
1704 column: context.cursor_point.column,
1705 },
1706 referenced_declarations,
1707 included_files: vec![],
1708 signatures,
1709 excerpt_parent,
1710 events,
1711 can_collect_data,
1712 diagnostic_groups,
1713 diagnostic_groups_truncated,
1714 git_info,
1715 debug_info,
1716 prompt_max_bytes,
1717 prompt_format,
1718 }
1719}
1720
1721fn add_signature(
1722 declaration_id: DeclarationId,
1723 declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
1724 signatures: &mut Vec<Signature>,
1725 index: &SyntaxIndexState,
1726) -> Option<usize> {
1727 if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
1728 return Some(*signature_index);
1729 }
1730 let Some(parent_declaration) = index.declaration(declaration_id) else {
1731 log::error!("bug: missing parent declaration");
1732 return None;
1733 };
1734 let parent_index = parent_declaration.parent().and_then(|parent| {
1735 add_signature(parent, declaration_to_signature_index, signatures, index)
1736 });
1737 let (text, text_is_truncated) = parent_declaration.signature_text();
1738 let signature_index = signatures.len();
1739 signatures.push(Signature {
1740 text: text.into(),
1741 text_is_truncated,
1742 parent_index,
1743 range: parent_declaration.signature_line_range(),
1744 });
1745 declaration_to_signature_index.insert(declaration_id, signature_index);
1746 Some(signature_index)
1747}
1748
1749#[cfg(test)]
1750mod tests {
1751 use std::{path::Path, sync::Arc};
1752
1753 use client::UserStore;
1754 use clock::FakeSystemClock;
1755 use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
1756 use futures::{
1757 AsyncReadExt, StreamExt,
1758 channel::{mpsc, oneshot},
1759 };
1760 use gpui::{
1761 Entity, TestAppContext,
1762 http_client::{FakeHttpClient, Response},
1763 prelude::*,
1764 };
1765 use indoc::indoc;
1766 use language::OffsetRangeExt as _;
1767 use open_ai::Usage;
1768 use pretty_assertions::{assert_eq, assert_matches};
1769 use project::{FakeFs, Project};
1770 use serde_json::json;
1771 use settings::SettingsStore;
1772 use util::path;
1773 use uuid::Uuid;
1774
1775 use crate::{BufferEditPrediction, Zeta};
1776
1777 #[gpui::test]
1778 async fn test_current_state(cx: &mut TestAppContext) {
1779 let (zeta, mut req_rx) = init_test(cx);
1780 let fs = FakeFs::new(cx.executor());
1781 fs.insert_tree(
1782 "/root",
1783 json!({
1784 "1.txt": "Hello!\nHow\nBye\n",
1785 "2.txt": "Hola!\nComo\nAdios\n"
1786 }),
1787 )
1788 .await;
1789 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1790
1791 zeta.update(cx, |zeta, cx| {
1792 zeta.register_project(&project, cx);
1793 });
1794
1795 let buffer1 = project
1796 .update(cx, |project, cx| {
1797 let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
1798 project.open_buffer(path, cx)
1799 })
1800 .await
1801 .unwrap();
1802 let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
1803 let position = snapshot1.anchor_before(language::Point::new(1, 3));
1804
1805 // Prediction for current file
1806
1807 let prediction_task = zeta.update(cx, |zeta, cx| {
1808 zeta.refresh_prediction(&project, &buffer1, position, cx)
1809 });
1810 let (_request, respond_tx) = req_rx.next().await.unwrap();
1811
1812 respond_tx
1813 .send(model_response(indoc! {r"
1814 --- a/root/1.txt
1815 +++ b/root/1.txt
1816 @@ ... @@
1817 Hello!
1818 -How
1819 +How are you?
1820 Bye
1821 "}))
1822 .unwrap();
1823 prediction_task.await.unwrap();
1824
1825 zeta.read_with(cx, |zeta, cx| {
1826 let prediction = zeta
1827 .current_prediction_for_buffer(&buffer1, &project, cx)
1828 .unwrap();
1829 assert_matches!(prediction, BufferEditPrediction::Local { .. });
1830 });
1831
1832 // Context refresh
1833 let refresh_task = zeta.update(cx, |zeta, cx| {
1834 zeta.refresh_context(project.clone(), buffer1.clone(), position, cx)
1835 });
1836 let (_request, respond_tx) = req_rx.next().await.unwrap();
1837 respond_tx
1838 .send(open_ai::Response {
1839 id: Uuid::new_v4().to_string(),
1840 object: "response".into(),
1841 created: 0,
1842 model: "model".into(),
1843 choices: vec![open_ai::Choice {
1844 index: 0,
1845 message: open_ai::RequestMessage::Assistant {
1846 content: None,
1847 tool_calls: vec![open_ai::ToolCall {
1848 id: "search".into(),
1849 content: open_ai::ToolCallContent::Function {
1850 function: open_ai::FunctionContent {
1851 name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME
1852 .to_string(),
1853 arguments: serde_json::to_string(&SearchToolInput {
1854 queries: Box::new([SearchToolQuery {
1855 glob: "root/2.txt".to_string(),
1856 syntax_node: vec![],
1857 content: Some(".".into()),
1858 }]),
1859 })
1860 .unwrap(),
1861 },
1862 },
1863 }],
1864 },
1865 finish_reason: None,
1866 }],
1867 usage: Usage {
1868 prompt_tokens: 0,
1869 completion_tokens: 0,
1870 total_tokens: 0,
1871 },
1872 })
1873 .unwrap();
1874 refresh_task.await.unwrap();
1875
1876 zeta.update(cx, |zeta, _cx| {
1877 zeta.discard_current_prediction(&project);
1878 });
1879
1880 // Prediction for another file
1881 let prediction_task = zeta.update(cx, |zeta, cx| {
1882 zeta.refresh_prediction(&project, &buffer1, position, cx)
1883 });
1884 let (_request, respond_tx) = req_rx.next().await.unwrap();
1885 respond_tx
1886 .send(model_response(indoc! {r#"
1887 --- a/root/2.txt
1888 +++ b/root/2.txt
1889 Hola!
1890 -Como
1891 +Como estas?
1892 Adios
1893 "#}))
1894 .unwrap();
1895 prediction_task.await.unwrap();
1896 zeta.read_with(cx, |zeta, cx| {
1897 let prediction = zeta
1898 .current_prediction_for_buffer(&buffer1, &project, cx)
1899 .unwrap();
1900 assert_matches!(
1901 prediction,
1902 BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
1903 );
1904 });
1905
1906 let buffer2 = project
1907 .update(cx, |project, cx| {
1908 let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
1909 project.open_buffer(path, cx)
1910 })
1911 .await
1912 .unwrap();
1913
1914 zeta.read_with(cx, |zeta, cx| {
1915 let prediction = zeta
1916 .current_prediction_for_buffer(&buffer2, &project, cx)
1917 .unwrap();
1918 assert_matches!(prediction, BufferEditPrediction::Local { .. });
1919 });
1920 }
1921
1922 #[gpui::test]
1923 async fn test_simple_request(cx: &mut TestAppContext) {
1924 let (zeta, mut req_rx) = init_test(cx);
1925 let fs = FakeFs::new(cx.executor());
1926 fs.insert_tree(
1927 "/root",
1928 json!({
1929 "foo.md": "Hello!\nHow\nBye\n"
1930 }),
1931 )
1932 .await;
1933 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1934
1935 let buffer = project
1936 .update(cx, |project, cx| {
1937 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1938 project.open_buffer(path, cx)
1939 })
1940 .await
1941 .unwrap();
1942 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1943 let position = snapshot.anchor_before(language::Point::new(1, 3));
1944
1945 let prediction_task = zeta.update(cx, |zeta, cx| {
1946 zeta.request_prediction(&project, &buffer, position, cx)
1947 });
1948
1949 let (_, respond_tx) = req_rx.next().await.unwrap();
1950
1951 // TODO Put back when we have a structured request again
1952 // assert_eq!(
1953 // request.excerpt_path.as_ref(),
1954 // Path::new(path!("root/foo.md"))
1955 // );
1956 // assert_eq!(
1957 // request.cursor_point,
1958 // Point {
1959 // line: Line(1),
1960 // column: 3
1961 // }
1962 // );
1963
1964 respond_tx
1965 .send(model_response(indoc! { r"
1966 --- a/root/foo.md
1967 +++ b/root/foo.md
1968 @@ ... @@
1969 Hello!
1970 -How
1971 +How are you?
1972 Bye
1973 "}))
1974 .unwrap();
1975
1976 let prediction = prediction_task.await.unwrap().unwrap();
1977
1978 assert_eq!(prediction.edits.len(), 1);
1979 assert_eq!(
1980 prediction.edits[0].0.to_point(&snapshot).start,
1981 language::Point::new(1, 3)
1982 );
1983 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
1984 }
1985
1986 #[gpui::test]
1987 async fn test_request_events(cx: &mut TestAppContext) {
1988 let (zeta, mut req_rx) = init_test(cx);
1989 let fs = FakeFs::new(cx.executor());
1990 fs.insert_tree(
1991 "/root",
1992 json!({
1993 "foo.md": "Hello!\n\nBye\n"
1994 }),
1995 )
1996 .await;
1997 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1998
1999 let buffer = project
2000 .update(cx, |project, cx| {
2001 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2002 project.open_buffer(path, cx)
2003 })
2004 .await
2005 .unwrap();
2006
2007 zeta.update(cx, |zeta, cx| {
2008 zeta.register_buffer(&buffer, &project, cx);
2009 });
2010
2011 buffer.update(cx, |buffer, cx| {
2012 buffer.edit(vec![(7..7, "How")], None, cx);
2013 });
2014
2015 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2016 let position = snapshot.anchor_before(language::Point::new(1, 3));
2017
2018 let prediction_task = zeta.update(cx, |zeta, cx| {
2019 zeta.request_prediction(&project, &buffer, position, cx)
2020 });
2021
2022 let (request, respond_tx) = req_rx.next().await.unwrap();
2023
2024 let prompt = prompt_from_request(&request);
2025 assert!(
2026 prompt.contains(indoc! {"
2027 --- a/root/foo.md
2028 +++ b/root/foo.md
2029 @@ -1,3 +1,3 @@
2030 Hello!
2031 -
2032 +How
2033 Bye
2034 "}),
2035 "{prompt}"
2036 );
2037
2038 respond_tx
2039 .send(model_response(indoc! {r#"
2040 --- a/root/foo.md
2041 +++ b/root/foo.md
2042 @@ ... @@
2043 Hello!
2044 -How
2045 +How are you?
2046 Bye
2047 "#}))
2048 .unwrap();
2049
2050 let prediction = prediction_task.await.unwrap().unwrap();
2051
2052 assert_eq!(prediction.edits.len(), 1);
2053 assert_eq!(
2054 prediction.edits[0].0.to_point(&snapshot).start,
2055 language::Point::new(1, 3)
2056 );
2057 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
2058 }
2059
2060 // Skipped until we start including diagnostics in prompt
2061 // #[gpui::test]
2062 // async fn test_request_diagnostics(cx: &mut TestAppContext) {
2063 // let (zeta, mut req_rx) = init_test(cx);
2064 // let fs = FakeFs::new(cx.executor());
2065 // fs.insert_tree(
2066 // "/root",
2067 // json!({
2068 // "foo.md": "Hello!\nBye"
2069 // }),
2070 // )
2071 // .await;
2072 // let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2073
2074 // let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
2075 // let diagnostic = lsp::Diagnostic {
2076 // range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
2077 // severity: Some(lsp::DiagnosticSeverity::ERROR),
2078 // message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
2079 // ..Default::default()
2080 // };
2081
2082 // project.update(cx, |project, cx| {
2083 // project.lsp_store().update(cx, |lsp_store, cx| {
2084 // // Create some diagnostics
2085 // lsp_store
2086 // .update_diagnostics(
2087 // LanguageServerId(0),
2088 // lsp::PublishDiagnosticsParams {
2089 // uri: path_to_buffer_uri.clone(),
2090 // diagnostics: vec![diagnostic],
2091 // version: None,
2092 // },
2093 // None,
2094 // language::DiagnosticSourceKind::Pushed,
2095 // &[],
2096 // cx,
2097 // )
2098 // .unwrap();
2099 // });
2100 // });
2101
2102 // let buffer = project
2103 // .update(cx, |project, cx| {
2104 // let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2105 // project.open_buffer(path, cx)
2106 // })
2107 // .await
2108 // .unwrap();
2109
2110 // let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2111 // let position = snapshot.anchor_before(language::Point::new(0, 0));
2112
2113 // let _prediction_task = zeta.update(cx, |zeta, cx| {
2114 // zeta.request_prediction(&project, &buffer, position, cx)
2115 // });
2116
2117 // let (request, _respond_tx) = req_rx.next().await.unwrap();
2118
2119 // assert_eq!(request.diagnostic_groups.len(), 1);
2120 // let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
2121 // .unwrap();
2122 // // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
2123 // assert_eq!(
2124 // value,
2125 // json!({
2126 // "entries": [{
2127 // "range": {
2128 // "start": 8,
2129 // "end": 10
2130 // },
2131 // "diagnostic": {
2132 // "source": null,
2133 // "code": null,
2134 // "code_description": null,
2135 // "severity": 1,
2136 // "message": "\"Hello\" deprecated. Use \"Hi\" instead",
2137 // "markdown": null,
2138 // "group_id": 0,
2139 // "is_primary": true,
2140 // "is_disk_based": false,
2141 // "is_unnecessary": false,
2142 // "source_kind": "Pushed",
2143 // "data": null,
2144 // "underline": true
2145 // }
2146 // }],
2147 // "primary_ix": 0
2148 // })
2149 // );
2150 // }
2151
2152 fn model_response(text: &str) -> open_ai::Response {
2153 open_ai::Response {
2154 id: Uuid::new_v4().to_string(),
2155 object: "response".into(),
2156 created: 0,
2157 model: "model".into(),
2158 choices: vec![open_ai::Choice {
2159 index: 0,
2160 message: open_ai::RequestMessage::Assistant {
2161 content: Some(open_ai::MessageContent::Plain(text.to_string())),
2162 tool_calls: vec![],
2163 },
2164 finish_reason: None,
2165 }],
2166 usage: Usage {
2167 prompt_tokens: 0,
2168 completion_tokens: 0,
2169 total_tokens: 0,
2170 },
2171 }
2172 }
2173
2174 fn prompt_from_request(request: &open_ai::Request) -> &str {
2175 assert_eq!(request.messages.len(), 1);
2176 let open_ai::RequestMessage::User {
2177 content: open_ai::MessageContent::Plain(content),
2178 ..
2179 } = &request.messages[0]
2180 else {
2181 panic!(
2182 "Request does not have single user message of type Plain. {:#?}",
2183 request
2184 );
2185 };
2186 content
2187 }
2188
2189 fn init_test(
2190 cx: &mut TestAppContext,
2191 ) -> (
2192 Entity<Zeta>,
2193 mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender<open_ai::Response>)>,
2194 ) {
2195 cx.update(move |cx| {
2196 let settings_store = SettingsStore::test(cx);
2197 cx.set_global(settings_store);
2198 zlog::init_test();
2199
2200 let (req_tx, req_rx) = mpsc::unbounded();
2201
2202 let http_client = FakeHttpClient::create({
2203 move |req| {
2204 let uri = req.uri().path().to_string();
2205 let mut body = req.into_body();
2206 let req_tx = req_tx.clone();
2207 async move {
2208 let resp = match uri.as_str() {
2209 "/client/llm_tokens" => serde_json::to_string(&json!({
2210 "token": "test"
2211 }))
2212 .unwrap(),
2213 "/predict_edits/raw" => {
2214 let mut buf = Vec::new();
2215 body.read_to_end(&mut buf).await.ok();
2216 let req = serde_json::from_slice(&buf).unwrap();
2217
2218 let (res_tx, res_rx) = oneshot::channel();
2219 req_tx.unbounded_send((req, res_tx)).unwrap();
2220 serde_json::to_string(&res_rx.await?).unwrap()
2221 }
2222 _ => {
2223 panic!("Unexpected path: {}", uri)
2224 }
2225 };
2226
2227 Ok(Response::builder().body(resp.into()).unwrap())
2228 }
2229 }
2230 });
2231
2232 let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
2233 client.cloud_client().set_credentials(1, "test".into());
2234
2235 language_model::init(client.clone(), cx);
2236
2237 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2238 let zeta = Zeta::global(&client, &user_store, cx);
2239
2240 (zeta, req_rx)
2241 })
2242 }
2243}