1use anyhow::{Context as _, Result, anyhow};
2use chrono::TimeDelta;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
5use cloud_llm_client::{
6 AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
7 ZED_VERSION_HEADER_NAME,
8};
9use cloud_zeta2_prompt::{DEFAULT_MAX_PROMPT_BYTES, PlannedPrompt};
10use edit_prediction_context::{
11 DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
12 EditPredictionExcerptOptions, EditPredictionScoreOptions, SyntaxIndex, SyntaxIndexState,
13};
14use futures::AsyncReadExt as _;
15use futures::channel::{mpsc, oneshot};
16use gpui::http_client::{AsyncBody, Method};
17use gpui::{
18 App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
19 http_client, prelude::*,
20};
21use language::BufferSnapshot;
22use language::{Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
23use language_model::{LlmApiToken, RefreshLlmTokenListener};
24use project::Project;
25use release_channel::AppVersion;
26use serde::de::DeserializeOwned;
27use std::collections::{HashMap, VecDeque, hash_map};
28use std::path::Path;
29use std::str::FromStr as _;
30use std::sync::Arc;
31use std::time::{Duration, Instant};
32use thiserror::Error;
33use util::rel_path::RelPathBuf;
34use util::some_or_debug_panic;
35use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
36
37mod prediction;
38mod provider;
39
40use crate::prediction::EditPrediction;
41pub use provider::ZetaEditPredictionProvider;
42
43const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
44
45/// Maximum number of events to track.
46const MAX_EVENT_COUNT: usize = 16;
47
48pub const DEFAULT_CONTEXT_OPTIONS: EditPredictionContextOptions = EditPredictionContextOptions {
49 use_imports: true,
50 use_references: false,
51 excerpt: EditPredictionExcerptOptions {
52 max_bytes: 512,
53 min_bytes: 128,
54 target_before_cursor_over_total_bytes: 0.5,
55 },
56 score: EditPredictionScoreOptions {
57 omit_excerpt_overlaps: true,
58 },
59};
60
61pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
62 context: DEFAULT_CONTEXT_OPTIONS,
63 max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
64 max_diagnostic_bytes: 2048,
65 prompt_format: PromptFormat::DEFAULT,
66 file_indexing_parallelism: 1,
67};
68
69#[derive(Clone)]
70struct ZetaGlobal(Entity<Zeta>);
71
72impl Global for ZetaGlobal {}
73
74pub struct Zeta {
75 client: Arc<Client>,
76 user_store: Entity<UserStore>,
77 llm_token: LlmApiToken,
78 _llm_token_subscription: Subscription,
79 projects: HashMap<EntityId, ZetaProject>,
80 options: ZetaOptions,
81 update_required: bool,
82 debug_tx: Option<mpsc::UnboundedSender<PredictionDebugInfo>>,
83}
84
85#[derive(Debug, Clone, PartialEq)]
86pub struct ZetaOptions {
87 pub context: EditPredictionContextOptions,
88 pub max_prompt_bytes: usize,
89 pub max_diagnostic_bytes: usize,
90 pub prompt_format: predict_edits_v3::PromptFormat,
91 pub file_indexing_parallelism: usize,
92}
93
94pub struct PredictionDebugInfo {
95 pub context: EditPredictionContext,
96 pub retrieval_time: TimeDelta,
97 pub buffer: WeakEntity<Buffer>,
98 pub position: language::Anchor,
99 pub local_prompt: Result<String, String>,
100 pub response_rx: oneshot::Receiver<Result<RequestDebugInfo, String>>,
101}
102
103pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
104
105struct ZetaProject {
106 syntax_index: Entity<SyntaxIndex>,
107 events: VecDeque<Event>,
108 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
109 current_prediction: Option<CurrentEditPrediction>,
110}
111
112#[derive(Debug, Clone)]
113struct CurrentEditPrediction {
114 pub requested_by_buffer_id: EntityId,
115 pub prediction: EditPrediction,
116}
117
118impl CurrentEditPrediction {
119 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
120 let Some(new_edits) = self
121 .prediction
122 .interpolate(&self.prediction.buffer.read(cx))
123 else {
124 return false;
125 };
126
127 if self.prediction.buffer != old_prediction.prediction.buffer {
128 return true;
129 }
130
131 let Some(old_edits) = old_prediction
132 .prediction
133 .interpolate(&old_prediction.prediction.buffer.read(cx))
134 else {
135 return true;
136 };
137
138 // This reduces the occurrence of UI thrash from replacing edits
139 //
140 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
141 if self.requested_by_buffer_id == self.prediction.buffer.entity_id()
142 && self.requested_by_buffer_id == old_prediction.prediction.buffer.entity_id()
143 && old_edits.len() == 1
144 && new_edits.len() == 1
145 {
146 let (old_range, old_text) = &old_edits[0];
147 let (new_range, new_text) = &new_edits[0];
148 new_range == old_range && new_text.starts_with(old_text)
149 } else {
150 true
151 }
152 }
153}
154
155/// A prediction from the perspective of a buffer.
156#[derive(Debug)]
157enum BufferEditPrediction<'a> {
158 Local { prediction: &'a EditPrediction },
159 Jump { prediction: &'a EditPrediction },
160}
161
162struct RegisteredBuffer {
163 snapshot: BufferSnapshot,
164 _subscriptions: [gpui::Subscription; 2],
165}
166
167#[derive(Clone)]
168pub enum Event {
169 BufferChange {
170 old_snapshot: BufferSnapshot,
171 new_snapshot: BufferSnapshot,
172 timestamp: Instant,
173 },
174}
175
176impl Zeta {
177 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
178 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
179 }
180
181 pub fn global(
182 client: &Arc<Client>,
183 user_store: &Entity<UserStore>,
184 cx: &mut App,
185 ) -> Entity<Self> {
186 cx.try_global::<ZetaGlobal>()
187 .map(|global| global.0.clone())
188 .unwrap_or_else(|| {
189 let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
190 cx.set_global(ZetaGlobal(zeta.clone()));
191 zeta
192 })
193 }
194
195 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
196 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
197
198 Self {
199 projects: HashMap::new(),
200 client,
201 user_store,
202 options: DEFAULT_OPTIONS,
203 llm_token: LlmApiToken::default(),
204 _llm_token_subscription: cx.subscribe(
205 &refresh_llm_token_listener,
206 |this, _listener, _event, cx| {
207 let client = this.client.clone();
208 let llm_token = this.llm_token.clone();
209 cx.spawn(async move |_this, _cx| {
210 llm_token.refresh(&client).await?;
211 anyhow::Ok(())
212 })
213 .detach_and_log_err(cx);
214 },
215 ),
216 update_required: false,
217 debug_tx: None,
218 }
219 }
220
221 pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<PredictionDebugInfo> {
222 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
223 self.debug_tx = Some(debug_watch_tx);
224 debug_watch_rx
225 }
226
227 pub fn options(&self) -> &ZetaOptions {
228 &self.options
229 }
230
231 pub fn set_options(&mut self, options: ZetaOptions) {
232 self.options = options;
233 }
234
235 pub fn clear_history(&mut self) {
236 for zeta_project in self.projects.values_mut() {
237 zeta_project.events.clear();
238 }
239 }
240
241 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
242 self.user_store.read(cx).edit_prediction_usage()
243 }
244
245 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
246 self.get_or_init_zeta_project(project, cx);
247 }
248
249 pub fn register_buffer(
250 &mut self,
251 buffer: &Entity<Buffer>,
252 project: &Entity<Project>,
253 cx: &mut Context<Self>,
254 ) {
255 let zeta_project = self.get_or_init_zeta_project(project, cx);
256 Self::register_buffer_impl(zeta_project, buffer, project, cx);
257 }
258
259 fn get_or_init_zeta_project(
260 &mut self,
261 project: &Entity<Project>,
262 cx: &mut App,
263 ) -> &mut ZetaProject {
264 self.projects
265 .entry(project.entity_id())
266 .or_insert_with(|| ZetaProject {
267 syntax_index: cx.new(|cx| {
268 SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
269 }),
270 events: VecDeque::new(),
271 registered_buffers: HashMap::new(),
272 current_prediction: None,
273 })
274 }
275
276 fn register_buffer_impl<'a>(
277 zeta_project: &'a mut ZetaProject,
278 buffer: &Entity<Buffer>,
279 project: &Entity<Project>,
280 cx: &mut Context<Self>,
281 ) -> &'a mut RegisteredBuffer {
282 let buffer_id = buffer.entity_id();
283 match zeta_project.registered_buffers.entry(buffer_id) {
284 hash_map::Entry::Occupied(entry) => entry.into_mut(),
285 hash_map::Entry::Vacant(entry) => {
286 let snapshot = buffer.read(cx).snapshot();
287 let project_entity_id = project.entity_id();
288 entry.insert(RegisteredBuffer {
289 snapshot,
290 _subscriptions: [
291 cx.subscribe(buffer, {
292 let project = project.downgrade();
293 move |this, buffer, event, cx| {
294 if let language::BufferEvent::Edited = event
295 && let Some(project) = project.upgrade()
296 {
297 this.report_changes_for_buffer(&buffer, &project, cx);
298 }
299 }
300 }),
301 cx.observe_release(buffer, move |this, _buffer, _cx| {
302 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
303 else {
304 return;
305 };
306 zeta_project.registered_buffers.remove(&buffer_id);
307 }),
308 ],
309 })
310 }
311 }
312 }
313
314 fn report_changes_for_buffer(
315 &mut self,
316 buffer: &Entity<Buffer>,
317 project: &Entity<Project>,
318 cx: &mut Context<Self>,
319 ) -> BufferSnapshot {
320 let zeta_project = self.get_or_init_zeta_project(project, cx);
321 let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
322
323 let new_snapshot = buffer.read(cx).snapshot();
324 if new_snapshot.version != registered_buffer.snapshot.version {
325 let old_snapshot =
326 std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
327 Self::push_event(
328 zeta_project,
329 Event::BufferChange {
330 old_snapshot,
331 new_snapshot: new_snapshot.clone(),
332 timestamp: Instant::now(),
333 },
334 );
335 }
336
337 new_snapshot
338 }
339
340 fn push_event(zeta_project: &mut ZetaProject, event: Event) {
341 let events = &mut zeta_project.events;
342
343 if let Some(Event::BufferChange {
344 new_snapshot: last_new_snapshot,
345 timestamp: last_timestamp,
346 ..
347 }) = events.back_mut()
348 {
349 // Coalesce edits for the same buffer when they happen one after the other.
350 let Event::BufferChange {
351 old_snapshot,
352 new_snapshot,
353 timestamp,
354 } = &event;
355
356 if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
357 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
358 && old_snapshot.version == last_new_snapshot.version
359 {
360 *last_new_snapshot = new_snapshot.clone();
361 *last_timestamp = *timestamp;
362 return;
363 }
364 }
365
366 if events.len() >= MAX_EVENT_COUNT {
367 // These are halved instead of popping to improve prompt caching.
368 events.drain(..MAX_EVENT_COUNT / 2);
369 }
370
371 events.push_back(event);
372 }
373
374 fn current_prediction_for_buffer(
375 &self,
376 buffer: &Entity<Buffer>,
377 project: &Entity<Project>,
378 cx: &App,
379 ) -> Option<BufferEditPrediction<'_>> {
380 let project_state = self.projects.get(&project.entity_id())?;
381
382 let CurrentEditPrediction {
383 requested_by_buffer_id,
384 prediction,
385 } = project_state.current_prediction.as_ref()?;
386
387 if prediction.targets_buffer(buffer.read(cx), cx) {
388 Some(BufferEditPrediction::Local { prediction })
389 } else if *requested_by_buffer_id == buffer.entity_id() {
390 Some(BufferEditPrediction::Jump { prediction })
391 } else {
392 None
393 }
394 }
395
396 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
397 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
398 return;
399 };
400
401 let Some(prediction) = project_state.current_prediction.take() else {
402 return;
403 };
404 let request_id = prediction.prediction.id.into();
405
406 let client = self.client.clone();
407 let llm_token = self.llm_token.clone();
408 let app_version = AppVersion::global(cx);
409 cx.spawn(async move |this, cx| {
410 let url = if let Ok(predict_edits_url) = std::env::var("ZED_ACCEPT_PREDICTION_URL") {
411 http_client::Url::parse(&predict_edits_url)?
412 } else {
413 client
414 .http_client()
415 .build_zed_llm_url("/predict_edits/accept", &[])?
416 };
417
418 let response = cx
419 .background_spawn(Self::send_api_request::<()>(
420 move |builder| {
421 let req = builder.uri(url.as_ref()).body(
422 serde_json::to_string(&AcceptEditPredictionBody { request_id })?.into(),
423 );
424 Ok(req?)
425 },
426 client,
427 llm_token,
428 app_version,
429 ))
430 .await;
431
432 Self::handle_api_response(&this, response, cx)?;
433 anyhow::Ok(())
434 })
435 .detach_and_log_err(cx);
436 }
437
438 fn discard_current_prediction(&mut self, project: &Entity<Project>) {
439 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
440 project_state.current_prediction.take();
441 };
442 }
443
444 pub fn refresh_prediction(
445 &mut self,
446 project: &Entity<Project>,
447 buffer: &Entity<Buffer>,
448 position: language::Anchor,
449 cx: &mut Context<Self>,
450 ) -> Task<Result<()>> {
451 let request_task = self.request_prediction(project, buffer, position, cx);
452 let buffer = buffer.clone();
453 let project = project.clone();
454
455 cx.spawn(async move |this, cx| {
456 if let Some(prediction) = request_task.await? {
457 this.update(cx, |this, cx| {
458 let project_state = this
459 .projects
460 .get_mut(&project.entity_id())
461 .context("Project not found")?;
462
463 let new_prediction = CurrentEditPrediction {
464 requested_by_buffer_id: buffer.entity_id(),
465 prediction: prediction,
466 };
467
468 if project_state
469 .current_prediction
470 .as_ref()
471 .is_none_or(|old_prediction| {
472 new_prediction.should_replace_prediction(&old_prediction, cx)
473 })
474 {
475 project_state.current_prediction = Some(new_prediction);
476 }
477 anyhow::Ok(())
478 })??;
479 }
480 Ok(())
481 })
482 }
483
484 fn request_prediction(
485 &mut self,
486 project: &Entity<Project>,
487 buffer: &Entity<Buffer>,
488 position: language::Anchor,
489 cx: &mut Context<Self>,
490 ) -> Task<Result<Option<EditPrediction>>> {
491 let project_state = self.projects.get(&project.entity_id());
492
493 let index_state = project_state.map(|state| {
494 state
495 .syntax_index
496 .read_with(cx, |index, _cx| index.state().clone())
497 });
498 let options = self.options.clone();
499 let snapshot = buffer.read(cx).snapshot();
500 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx).into()) else {
501 return Task::ready(Err(anyhow!("No file path for excerpt")));
502 };
503 let client = self.client.clone();
504 let llm_token = self.llm_token.clone();
505 let app_version = AppVersion::global(cx);
506 let worktree_snapshots = project
507 .read(cx)
508 .worktrees(cx)
509 .map(|worktree| worktree.read(cx).snapshot())
510 .collect::<Vec<_>>();
511 let debug_tx = self.debug_tx.clone();
512
513 let events = project_state
514 .map(|state| {
515 state
516 .events
517 .iter()
518 .filter_map(|event| match event {
519 Event::BufferChange {
520 old_snapshot,
521 new_snapshot,
522 ..
523 } => {
524 let path = new_snapshot.file().map(|f| f.full_path(cx));
525
526 let old_path = old_snapshot.file().and_then(|f| {
527 let old_path = f.full_path(cx);
528 if Some(&old_path) != path.as_ref() {
529 Some(old_path)
530 } else {
531 None
532 }
533 });
534
535 // TODO [zeta2] move to bg?
536 let diff =
537 language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
538
539 if path == old_path && diff.is_empty() {
540 None
541 } else {
542 Some(predict_edits_v3::Event::BufferChange {
543 old_path,
544 path,
545 diff,
546 //todo: Actually detect if this edit was predicted or not
547 predicted: false,
548 })
549 }
550 }
551 })
552 .collect::<Vec<_>>()
553 })
554 .unwrap_or_default();
555
556 let diagnostics = snapshot.diagnostic_sets().clone();
557
558 let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
559 let mut path = f.worktree.read(cx).absolutize(&f.path);
560 if path.pop() { Some(path) } else { None }
561 });
562
563 let request_task = cx.background_spawn({
564 let snapshot = snapshot.clone();
565 let buffer = buffer.clone();
566 async move {
567 let index_state = if let Some(index_state) = index_state {
568 Some(index_state.lock_owned().await)
569 } else {
570 None
571 };
572
573 let cursor_offset = position.to_offset(&snapshot);
574 let cursor_point = cursor_offset.to_point(&snapshot);
575
576 let before_retrieval = chrono::Utc::now();
577
578 let Some(context) = EditPredictionContext::gather_context(
579 cursor_point,
580 &snapshot,
581 parent_abs_path.as_deref(),
582 &options.context,
583 index_state.as_deref(),
584 ) else {
585 return Ok((None, None));
586 };
587
588 let retrieval_time = chrono::Utc::now() - before_retrieval;
589
590 let (diagnostic_groups, diagnostic_groups_truncated) =
591 Self::gather_nearby_diagnostics(
592 cursor_offset,
593 &diagnostics,
594 &snapshot,
595 options.max_diagnostic_bytes,
596 );
597
598 let debug_context = debug_tx.map(|tx| (tx, context.clone()));
599
600 let request = make_cloud_request(
601 excerpt_path,
602 context,
603 events,
604 // TODO data collection
605 false,
606 diagnostic_groups,
607 diagnostic_groups_truncated,
608 None,
609 debug_context.is_some(),
610 &worktree_snapshots,
611 index_state.as_deref(),
612 Some(options.max_prompt_bytes),
613 options.prompt_format,
614 );
615
616 let debug_response_tx = if let Some((debug_tx, context)) = debug_context {
617 let (response_tx, response_rx) = oneshot::channel();
618
619 let local_prompt = PlannedPrompt::populate(&request)
620 .and_then(|p| p.to_prompt_string().map(|p| p.0))
621 .map_err(|err| err.to_string());
622
623 debug_tx
624 .unbounded_send(PredictionDebugInfo {
625 context,
626 retrieval_time,
627 buffer: buffer.downgrade(),
628 local_prompt,
629 position,
630 response_rx,
631 })
632 .ok();
633 Some(response_tx)
634 } else {
635 None
636 };
637
638 if cfg!(debug_assertions) && std::env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
639 if let Some(debug_response_tx) = debug_response_tx {
640 debug_response_tx
641 .send(Err("Request skipped".to_string()))
642 .ok();
643 }
644 anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
645 }
646
647 let response =
648 Self::send_prediction_request(client, llm_token, app_version, request).await;
649
650 if let Some(debug_response_tx) = debug_response_tx {
651 debug_response_tx
652 .send(response.as_ref().map_err(|err| err.to_string()).and_then(
653 |response| match some_or_debug_panic(response.0.debug_info.clone()) {
654 Some(debug_info) => Ok(debug_info),
655 None => Err("Missing debug info".to_string()),
656 },
657 ))
658 .ok();
659 }
660
661 response.map(|(res, usage)| (Some(res), usage))
662 }
663 });
664
665 let buffer = buffer.clone();
666
667 cx.spawn({
668 let project = project.clone();
669 async move |this, cx| {
670 let Some(response) = Self::handle_api_response(&this, request_task.await, cx)?
671 else {
672 return Ok(None);
673 };
674
675 // TODO telemetry: duration, etc
676 Ok(EditPrediction::from_response(response, &snapshot, &buffer, &project, cx).await)
677 }
678 })
679 }
680
681 async fn send_prediction_request(
682 client: Arc<Client>,
683 llm_token: LlmApiToken,
684 app_version: SemanticVersion,
685 request: predict_edits_v3::PredictEditsRequest,
686 ) -> Result<(
687 predict_edits_v3::PredictEditsResponse,
688 Option<EditPredictionUsage>,
689 )> {
690 let url = if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
691 http_client::Url::parse(&predict_edits_url)?
692 } else {
693 client
694 .http_client()
695 .build_zed_llm_url("/predict_edits/v3", &[])?
696 };
697
698 Self::send_api_request(
699 |builder| {
700 let req = builder
701 .uri(url.as_ref())
702 .body(serde_json::to_string(&request)?.into());
703 Ok(req?)
704 },
705 client,
706 llm_token,
707 app_version,
708 )
709 .await
710 }
711
712 fn handle_api_response<T>(
713 this: &WeakEntity<Self>,
714 response: Result<(T, Option<EditPredictionUsage>)>,
715 cx: &mut gpui::AsyncApp,
716 ) -> Result<T> {
717 match response {
718 Ok((data, usage)) => {
719 if let Some(usage) = usage {
720 this.update(cx, |this, cx| {
721 this.user_store.update(cx, |user_store, cx| {
722 user_store.update_edit_prediction_usage(usage, cx);
723 });
724 })
725 .ok();
726 }
727 Ok(data)
728 }
729 Err(err) => {
730 if err.is::<ZedUpdateRequiredError>() {
731 cx.update(|cx| {
732 this.update(cx, |this, _cx| {
733 this.update_required = true;
734 })
735 .ok();
736
737 let error_message: SharedString = err.to_string().into();
738 show_app_notification(
739 NotificationId::unique::<ZedUpdateRequiredError>(),
740 cx,
741 move |cx| {
742 cx.new(|cx| {
743 ErrorMessagePrompt::new(error_message.clone(), cx)
744 .with_link_button("Update Zed", "https://zed.dev/releases")
745 })
746 },
747 );
748 })
749 .ok();
750 }
751 Err(err)
752 }
753 }
754 }
755
756 async fn send_api_request<Res>(
757 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
758 client: Arc<Client>,
759 llm_token: LlmApiToken,
760 app_version: SemanticVersion,
761 ) -> Result<(Res, Option<EditPredictionUsage>)>
762 where
763 Res: DeserializeOwned,
764 {
765 let http_client = client.http_client();
766 let mut token = llm_token.acquire(&client).await?;
767 let mut did_retry = false;
768
769 loop {
770 let request_builder = http_client::Request::builder().method(Method::POST);
771
772 let request = build(
773 request_builder
774 .header("Content-Type", "application/json")
775 .header("Authorization", format!("Bearer {}", token))
776 .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
777 )?;
778
779 let mut response = http_client.send(request).await?;
780
781 if let Some(minimum_required_version) = response
782 .headers()
783 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
784 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
785 {
786 anyhow::ensure!(
787 app_version >= minimum_required_version,
788 ZedUpdateRequiredError {
789 minimum_version: minimum_required_version
790 }
791 );
792 }
793
794 if response.status().is_success() {
795 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
796
797 let mut body = Vec::new();
798 response.body_mut().read_to_end(&mut body).await?;
799 return Ok((serde_json::from_slice(&body)?, usage));
800 } else if !did_retry
801 && response
802 .headers()
803 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
804 .is_some()
805 {
806 did_retry = true;
807 token = llm_token.refresh(&client).await?;
808 } else {
809 let mut body = String::new();
810 response.body_mut().read_to_string(&mut body).await?;
811 anyhow::bail!(
812 "Request failed with status: {:?}\nBody: {}",
813 response.status(),
814 body
815 );
816 }
817 }
818 }
819
820 fn gather_nearby_diagnostics(
821 cursor_offset: usize,
822 diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
823 snapshot: &BufferSnapshot,
824 max_diagnostics_bytes: usize,
825 ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
826 // TODO: Could make this more efficient
827 let mut diagnostic_groups = Vec::new();
828 for (language_server_id, diagnostics) in diagnostic_sets {
829 let mut groups = Vec::new();
830 diagnostics.groups(*language_server_id, &mut groups, &snapshot);
831 diagnostic_groups.extend(
832 groups
833 .into_iter()
834 .map(|(_, group)| group.resolve::<usize>(&snapshot)),
835 );
836 }
837
838 // sort by proximity to cursor
839 diagnostic_groups.sort_by_key(|group| {
840 let range = &group.entries[group.primary_ix].range;
841 if range.start >= cursor_offset {
842 range.start - cursor_offset
843 } else if cursor_offset >= range.end {
844 cursor_offset - range.end
845 } else {
846 (cursor_offset - range.start).min(range.end - cursor_offset)
847 }
848 });
849
850 let mut results = Vec::new();
851 let mut diagnostic_groups_truncated = false;
852 let mut diagnostics_byte_count = 0;
853 for group in diagnostic_groups {
854 let raw_value = serde_json::value::to_raw_value(&group).unwrap();
855 diagnostics_byte_count += raw_value.get().len();
856 if diagnostics_byte_count > max_diagnostics_bytes {
857 diagnostic_groups_truncated = true;
858 break;
859 }
860 results.push(predict_edits_v3::DiagnosticGroup(raw_value));
861 }
862
863 (results, diagnostic_groups_truncated)
864 }
865
866 // TODO: Dedupe with similar code in request_prediction?
867 pub fn cloud_request_for_zeta_cli(
868 &mut self,
869 project: &Entity<Project>,
870 buffer: &Entity<Buffer>,
871 position: language::Anchor,
872 cx: &mut Context<Self>,
873 ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
874 let project_state = self.projects.get(&project.entity_id());
875
876 let index_state = project_state.map(|state| {
877 state
878 .syntax_index
879 .read_with(cx, |index, _cx| index.state().clone())
880 });
881 let options = self.options.clone();
882 let snapshot = buffer.read(cx).snapshot();
883 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
884 return Task::ready(Err(anyhow!("No file path for excerpt")));
885 };
886 let worktree_snapshots = project
887 .read(cx)
888 .worktrees(cx)
889 .map(|worktree| worktree.read(cx).snapshot())
890 .collect::<Vec<_>>();
891
892 let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
893 let mut path = f.worktree.read(cx).absolutize(&f.path);
894 if path.pop() { Some(path) } else { None }
895 });
896
897 cx.background_spawn(async move {
898 let index_state = if let Some(index_state) = index_state {
899 Some(index_state.lock_owned().await)
900 } else {
901 None
902 };
903
904 let cursor_point = position.to_point(&snapshot);
905
906 let debug_info = true;
907 EditPredictionContext::gather_context(
908 cursor_point,
909 &snapshot,
910 parent_abs_path.as_deref(),
911 &options.context,
912 index_state.as_deref(),
913 )
914 .context("Failed to select excerpt")
915 .map(|context| {
916 make_cloud_request(
917 excerpt_path.into(),
918 context,
919 // TODO pass everything
920 Vec::new(),
921 false,
922 Vec::new(),
923 false,
924 None,
925 debug_info,
926 &worktree_snapshots,
927 index_state.as_deref(),
928 Some(options.max_prompt_bytes),
929 options.prompt_format,
930 )
931 })
932 })
933 }
934
935 pub fn wait_for_initial_indexing(
936 &mut self,
937 project: &Entity<Project>,
938 cx: &mut App,
939 ) -> Task<Result<()>> {
940 let zeta_project = self.get_or_init_zeta_project(project, cx);
941 zeta_project
942 .syntax_index
943 .read(cx)
944 .wait_for_initial_file_indexing(cx)
945 }
946}
947
948#[derive(Error, Debug)]
949#[error(
950 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
951)]
952pub struct ZedUpdateRequiredError {
953 minimum_version: SemanticVersion,
954}
955
956fn make_cloud_request(
957 excerpt_path: Arc<Path>,
958 context: EditPredictionContext,
959 events: Vec<predict_edits_v3::Event>,
960 can_collect_data: bool,
961 diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
962 diagnostic_groups_truncated: bool,
963 git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
964 debug_info: bool,
965 worktrees: &Vec<worktree::Snapshot>,
966 index_state: Option<&SyntaxIndexState>,
967 prompt_max_bytes: Option<usize>,
968 prompt_format: PromptFormat,
969) -> predict_edits_v3::PredictEditsRequest {
970 let mut signatures = Vec::new();
971 let mut declaration_to_signature_index = HashMap::default();
972 let mut referenced_declarations = Vec::new();
973
974 for snippet in context.declarations {
975 let project_entry_id = snippet.declaration.project_entry_id();
976 let Some(path) = worktrees.iter().find_map(|worktree| {
977 worktree.entry_for_id(project_entry_id).map(|entry| {
978 let mut full_path = RelPathBuf::new();
979 full_path.push(worktree.root_name());
980 full_path.push(&entry.path);
981 full_path
982 })
983 }) else {
984 continue;
985 };
986
987 let parent_index = index_state.and_then(|index_state| {
988 snippet.declaration.parent().and_then(|parent| {
989 add_signature(
990 parent,
991 &mut declaration_to_signature_index,
992 &mut signatures,
993 index_state,
994 )
995 })
996 });
997
998 let (text, text_is_truncated) = snippet.declaration.item_text();
999 referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
1000 path: path.as_std_path().into(),
1001 text: text.into(),
1002 range: snippet.declaration.item_line_range(),
1003 text_is_truncated,
1004 signature_range: snippet.declaration.signature_range_in_item_text(),
1005 parent_index,
1006 signature_score: snippet.score(DeclarationStyle::Signature),
1007 declaration_score: snippet.score(DeclarationStyle::Declaration),
1008 score_components: snippet.components,
1009 });
1010 }
1011
1012 let excerpt_parent = index_state.and_then(|index_state| {
1013 context
1014 .excerpt
1015 .parent_declarations
1016 .last()
1017 .and_then(|(parent, _)| {
1018 add_signature(
1019 *parent,
1020 &mut declaration_to_signature_index,
1021 &mut signatures,
1022 index_state,
1023 )
1024 })
1025 });
1026
1027 predict_edits_v3::PredictEditsRequest {
1028 excerpt_path,
1029 excerpt: context.excerpt_text.body,
1030 excerpt_line_range: context.excerpt.line_range,
1031 excerpt_range: context.excerpt.range,
1032 cursor_point: predict_edits_v3::Point {
1033 line: predict_edits_v3::Line(context.cursor_point.row),
1034 column: context.cursor_point.column,
1035 },
1036 referenced_declarations,
1037 signatures,
1038 excerpt_parent,
1039 events,
1040 can_collect_data,
1041 diagnostic_groups,
1042 diagnostic_groups_truncated,
1043 git_info,
1044 debug_info,
1045 prompt_max_bytes,
1046 prompt_format,
1047 }
1048}
1049
1050fn add_signature(
1051 declaration_id: DeclarationId,
1052 declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
1053 signatures: &mut Vec<Signature>,
1054 index: &SyntaxIndexState,
1055) -> Option<usize> {
1056 if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
1057 return Some(*signature_index);
1058 }
1059 let Some(parent_declaration) = index.declaration(declaration_id) else {
1060 log::error!("bug: missing parent declaration");
1061 return None;
1062 };
1063 let parent_index = parent_declaration.parent().and_then(|parent| {
1064 add_signature(parent, declaration_to_signature_index, signatures, index)
1065 });
1066 let (text, text_is_truncated) = parent_declaration.signature_text();
1067 let signature_index = signatures.len();
1068 signatures.push(Signature {
1069 text: text.into(),
1070 text_is_truncated,
1071 parent_index,
1072 range: parent_declaration.signature_line_range(),
1073 });
1074 declaration_to_signature_index.insert(declaration_id, signature_index);
1075 Some(signature_index)
1076}
1077
1078#[cfg(test)]
1079mod tests {
1080 use std::{
1081 path::{Path, PathBuf},
1082 sync::Arc,
1083 };
1084
1085 use client::UserStore;
1086 use clock::FakeSystemClock;
1087 use cloud_llm_client::predict_edits_v3::{self, Point};
1088 use edit_prediction_context::Line;
1089 use futures::{
1090 AsyncReadExt, StreamExt,
1091 channel::{mpsc, oneshot},
1092 };
1093 use gpui::{
1094 Entity, TestAppContext,
1095 http_client::{FakeHttpClient, Response},
1096 prelude::*,
1097 };
1098 use indoc::indoc;
1099 use language::{LanguageServerId, OffsetRangeExt as _};
1100 use pretty_assertions::{assert_eq, assert_matches};
1101 use project::{FakeFs, Project};
1102 use serde_json::json;
1103 use settings::SettingsStore;
1104 use util::path;
1105 use uuid::Uuid;
1106
1107 use crate::{BufferEditPrediction, Zeta};
1108
1109 #[gpui::test]
1110 async fn test_current_state(cx: &mut TestAppContext) {
1111 let (zeta, mut req_rx) = init_test(cx);
1112 let fs = FakeFs::new(cx.executor());
1113 fs.insert_tree(
1114 "/root",
1115 json!({
1116 "1.txt": "Hello!\nHow\nBye",
1117 "2.txt": "Hola!\nComo\nAdios"
1118 }),
1119 )
1120 .await;
1121 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1122
1123 zeta.update(cx, |zeta, cx| {
1124 zeta.register_project(&project, cx);
1125 });
1126
1127 let buffer1 = project
1128 .update(cx, |project, cx| {
1129 let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
1130 project.open_buffer(path, cx)
1131 })
1132 .await
1133 .unwrap();
1134 let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
1135 let position = snapshot1.anchor_before(language::Point::new(1, 3));
1136
1137 // Prediction for current file
1138
1139 let prediction_task = zeta.update(cx, |zeta, cx| {
1140 zeta.refresh_prediction(&project, &buffer1, position, cx)
1141 });
1142 let (_request, respond_tx) = req_rx.next().await.unwrap();
1143 respond_tx
1144 .send(predict_edits_v3::PredictEditsResponse {
1145 request_id: Uuid::new_v4(),
1146 edits: vec![predict_edits_v3::Edit {
1147 path: Path::new(path!("root/1.txt")).into(),
1148 range: Line(0)..Line(snapshot1.max_point().row + 1),
1149 content: "Hello!\nHow are you?\nBye".into(),
1150 }],
1151 debug_info: None,
1152 })
1153 .unwrap();
1154 prediction_task.await.unwrap();
1155
1156 zeta.read_with(cx, |zeta, cx| {
1157 let prediction = zeta
1158 .current_prediction_for_buffer(&buffer1, &project, cx)
1159 .unwrap();
1160 assert_matches!(prediction, BufferEditPrediction::Local { .. });
1161 });
1162
1163 // Prediction for another file
1164 let prediction_task = zeta.update(cx, |zeta, cx| {
1165 zeta.refresh_prediction(&project, &buffer1, position, cx)
1166 });
1167 let (_request, respond_tx) = req_rx.next().await.unwrap();
1168 respond_tx
1169 .send(predict_edits_v3::PredictEditsResponse {
1170 request_id: Uuid::new_v4(),
1171 edits: vec![predict_edits_v3::Edit {
1172 path: Path::new(path!("root/2.txt")).into(),
1173 range: Line(0)..Line(snapshot1.max_point().row + 1),
1174 content: "Hola!\nComo estas?\nAdios".into(),
1175 }],
1176 debug_info: None,
1177 })
1178 .unwrap();
1179 prediction_task.await.unwrap();
1180 zeta.read_with(cx, |zeta, cx| {
1181 let prediction = zeta
1182 .current_prediction_for_buffer(&buffer1, &project, cx)
1183 .unwrap();
1184 assert_matches!(
1185 prediction,
1186 BufferEditPrediction::Jump { prediction } if prediction.path.as_ref() == Path::new(path!("root/2.txt"))
1187 );
1188 });
1189
1190 let buffer2 = project
1191 .update(cx, |project, cx| {
1192 let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
1193 project.open_buffer(path, cx)
1194 })
1195 .await
1196 .unwrap();
1197
1198 zeta.read_with(cx, |zeta, cx| {
1199 let prediction = zeta
1200 .current_prediction_for_buffer(&buffer2, &project, cx)
1201 .unwrap();
1202 assert_matches!(prediction, BufferEditPrediction::Local { .. });
1203 });
1204 }
1205
1206 #[gpui::test]
1207 async fn test_simple_request(cx: &mut TestAppContext) {
1208 let (zeta, mut req_rx) = init_test(cx);
1209 let fs = FakeFs::new(cx.executor());
1210 fs.insert_tree(
1211 "/root",
1212 json!({
1213 "foo.md": "Hello!\nHow\nBye"
1214 }),
1215 )
1216 .await;
1217 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1218
1219 let buffer = project
1220 .update(cx, |project, cx| {
1221 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1222 project.open_buffer(path, cx)
1223 })
1224 .await
1225 .unwrap();
1226 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1227 let position = snapshot.anchor_before(language::Point::new(1, 3));
1228
1229 let prediction_task = zeta.update(cx, |zeta, cx| {
1230 zeta.request_prediction(&project, &buffer, position, cx)
1231 });
1232
1233 let (request, respond_tx) = req_rx.next().await.unwrap();
1234 assert_eq!(
1235 request.excerpt_path.as_ref(),
1236 Path::new(path!("root/foo.md"))
1237 );
1238 assert_eq!(
1239 request.cursor_point,
1240 Point {
1241 line: Line(1),
1242 column: 3
1243 }
1244 );
1245
1246 respond_tx
1247 .send(predict_edits_v3::PredictEditsResponse {
1248 request_id: Uuid::new_v4(),
1249 edits: vec![predict_edits_v3::Edit {
1250 path: Path::new(path!("root/foo.md")).into(),
1251 range: Line(0)..Line(snapshot.max_point().row + 1),
1252 content: "Hello!\nHow are you?\nBye".into(),
1253 }],
1254 debug_info: None,
1255 })
1256 .unwrap();
1257
1258 let prediction = prediction_task.await.unwrap().unwrap();
1259
1260 assert_eq!(prediction.edits.len(), 1);
1261 assert_eq!(
1262 prediction.edits[0].0.to_point(&snapshot).start,
1263 language::Point::new(1, 3)
1264 );
1265 assert_eq!(prediction.edits[0].1, " are you?");
1266 }
1267
1268 #[gpui::test]
1269 async fn test_request_events(cx: &mut TestAppContext) {
1270 let (zeta, mut req_rx) = init_test(cx);
1271 let fs = FakeFs::new(cx.executor());
1272 fs.insert_tree(
1273 "/root",
1274 json!({
1275 "foo.md": "Hello!\n\nBye"
1276 }),
1277 )
1278 .await;
1279 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1280
1281 let buffer = project
1282 .update(cx, |project, cx| {
1283 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1284 project.open_buffer(path, cx)
1285 })
1286 .await
1287 .unwrap();
1288
1289 zeta.update(cx, |zeta, cx| {
1290 zeta.register_buffer(&buffer, &project, cx);
1291 });
1292
1293 buffer.update(cx, |buffer, cx| {
1294 buffer.edit(vec![(7..7, "How")], None, cx);
1295 });
1296
1297 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1298 let position = snapshot.anchor_before(language::Point::new(1, 3));
1299
1300 let prediction_task = zeta.update(cx, |zeta, cx| {
1301 zeta.request_prediction(&project, &buffer, position, cx)
1302 });
1303
1304 let (request, respond_tx) = req_rx.next().await.unwrap();
1305
1306 assert_eq!(request.events.len(), 1);
1307 assert_eq!(
1308 request.events[0],
1309 predict_edits_v3::Event::BufferChange {
1310 path: Some(PathBuf::from(path!("root/foo.md"))),
1311 old_path: None,
1312 diff: indoc! {"
1313 @@ -1,3 +1,3 @@
1314 Hello!
1315 -
1316 +How
1317 Bye
1318 "}
1319 .to_string(),
1320 predicted: false
1321 }
1322 );
1323
1324 respond_tx
1325 .send(predict_edits_v3::PredictEditsResponse {
1326 request_id: Uuid::new_v4(),
1327 edits: vec![predict_edits_v3::Edit {
1328 path: Path::new(path!("root/foo.md")).into(),
1329 range: Line(0)..Line(snapshot.max_point().row + 1),
1330 content: "Hello!\nHow are you?\nBye".into(),
1331 }],
1332 debug_info: None,
1333 })
1334 .unwrap();
1335
1336 let prediction = prediction_task.await.unwrap().unwrap();
1337
1338 assert_eq!(prediction.edits.len(), 1);
1339 assert_eq!(
1340 prediction.edits[0].0.to_point(&snapshot).start,
1341 language::Point::new(1, 3)
1342 );
1343 assert_eq!(prediction.edits[0].1, " are you?");
1344 }
1345
1346 #[gpui::test]
1347 async fn test_request_diagnostics(cx: &mut TestAppContext) {
1348 let (zeta, mut req_rx) = init_test(cx);
1349 let fs = FakeFs::new(cx.executor());
1350 fs.insert_tree(
1351 "/root",
1352 json!({
1353 "foo.md": "Hello!\nBye"
1354 }),
1355 )
1356 .await;
1357 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1358
1359 let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1360 let diagnostic = lsp::Diagnostic {
1361 range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1362 severity: Some(lsp::DiagnosticSeverity::ERROR),
1363 message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1364 ..Default::default()
1365 };
1366
1367 project.update(cx, |project, cx| {
1368 project.lsp_store().update(cx, |lsp_store, cx| {
1369 // Create some diagnostics
1370 lsp_store
1371 .update_diagnostics(
1372 LanguageServerId(0),
1373 lsp::PublishDiagnosticsParams {
1374 uri: path_to_buffer_uri.clone(),
1375 diagnostics: vec![diagnostic],
1376 version: None,
1377 },
1378 None,
1379 language::DiagnosticSourceKind::Pushed,
1380 &[],
1381 cx,
1382 )
1383 .unwrap();
1384 });
1385 });
1386
1387 let buffer = project
1388 .update(cx, |project, cx| {
1389 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1390 project.open_buffer(path, cx)
1391 })
1392 .await
1393 .unwrap();
1394
1395 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1396 let position = snapshot.anchor_before(language::Point::new(0, 0));
1397
1398 let _prediction_task = zeta.update(cx, |zeta, cx| {
1399 zeta.request_prediction(&project, &buffer, position, cx)
1400 });
1401
1402 let (request, _respond_tx) = req_rx.next().await.unwrap();
1403
1404 assert_eq!(request.diagnostic_groups.len(), 1);
1405 let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1406 .unwrap();
1407 // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1408 assert_eq!(
1409 value,
1410 json!({
1411 "entries": [{
1412 "range": {
1413 "start": 8,
1414 "end": 10
1415 },
1416 "diagnostic": {
1417 "source": null,
1418 "code": null,
1419 "code_description": null,
1420 "severity": 1,
1421 "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1422 "markdown": null,
1423 "group_id": 0,
1424 "is_primary": true,
1425 "is_disk_based": false,
1426 "is_unnecessary": false,
1427 "source_kind": "Pushed",
1428 "data": null,
1429 "underline": true
1430 }
1431 }],
1432 "primary_ix": 0
1433 })
1434 );
1435 }
1436
1437 fn init_test(
1438 cx: &mut TestAppContext,
1439 ) -> (
1440 Entity<Zeta>,
1441 mpsc::UnboundedReceiver<(
1442 predict_edits_v3::PredictEditsRequest,
1443 oneshot::Sender<predict_edits_v3::PredictEditsResponse>,
1444 )>,
1445 ) {
1446 cx.update(move |cx| {
1447 let settings_store = SettingsStore::test(cx);
1448 cx.set_global(settings_store);
1449 language::init(cx);
1450 Project::init_settings(cx);
1451
1452 let (req_tx, req_rx) = mpsc::unbounded();
1453
1454 let http_client = FakeHttpClient::create({
1455 move |req| {
1456 let uri = req.uri().path().to_string();
1457 let mut body = req.into_body();
1458 let req_tx = req_tx.clone();
1459 async move {
1460 let resp = match uri.as_str() {
1461 "/client/llm_tokens" => serde_json::to_string(&json!({
1462 "token": "test"
1463 }))
1464 .unwrap(),
1465 "/predict_edits/v3" => {
1466 let mut buf = Vec::new();
1467 body.read_to_end(&mut buf).await.ok();
1468 let req = serde_json::from_slice(&buf).unwrap();
1469
1470 let (res_tx, res_rx) = oneshot::channel();
1471 req_tx.unbounded_send((req, res_tx)).unwrap();
1472 serde_json::to_string(&res_rx.await?).unwrap()
1473 }
1474 _ => {
1475 panic!("Unexpected path: {}", uri)
1476 }
1477 };
1478
1479 Ok(Response::builder().body(resp.into()).unwrap())
1480 }
1481 }
1482 });
1483
1484 let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1485 client.cloud_client().set_credentials(1, "test".into());
1486
1487 language_model::init(client.clone(), cx);
1488
1489 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1490 let zeta = Zeta::global(&client, &user_store, cx);
1491
1492 (zeta, req_rx)
1493 })
1494 }
1495}