1use anyhow::{Context as _, Result, anyhow};
2use chrono::TimeDelta;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
5use cloud_llm_client::{
6 AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
7 ZED_VERSION_HEADER_NAME,
8};
9use cloud_zeta2_prompt::{DEFAULT_MAX_PROMPT_BYTES, PlannedPrompt};
10use edit_prediction_context::{
11 DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
12 EditPredictionExcerptOptions, EditPredictionScoreOptions, SyntaxIndex, SyntaxIndexState,
13};
14use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
15use futures::channel::{mpsc, oneshot};
16use futures::future::{BoxFuture, Shared};
17use futures::{AsyncReadExt as _, FutureExt};
18use gpui::http_client::{AsyncBody, Method};
19use gpui::{
20 App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
21 http_client, prelude::*,
22};
23use language::BufferSnapshot;
24use language::{Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
25use language_model::{LlmApiToken, RefreshLlmTokenListener};
26use project::Project;
27use release_channel::AppVersion;
28use serde::de::DeserializeOwned;
29use std::collections::{HashMap, VecDeque, hash_map};
30use std::path::Path;
31use std::str::FromStr as _;
32use std::sync::Arc;
33use std::time::{Duration, Instant};
34use thiserror::Error;
35use util::rel_path::RelPathBuf;
36use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
37
38mod prediction;
39mod provider;
40
41use crate::prediction::EditPrediction;
42pub use provider::ZetaEditPredictionProvider;
43
44const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
45
46/// Maximum number of events to track.
47const MAX_EVENT_COUNT: usize = 16;
48
49pub const DEFAULT_CONTEXT_OPTIONS: EditPredictionContextOptions = EditPredictionContextOptions {
50 use_imports: true,
51 max_retrieved_declarations: 0,
52 excerpt: EditPredictionExcerptOptions {
53 max_bytes: 512,
54 min_bytes: 128,
55 target_before_cursor_over_total_bytes: 0.5,
56 },
57 score: EditPredictionScoreOptions {
58 omit_excerpt_overlaps: true,
59 },
60};
61
62pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
63 context: DEFAULT_CONTEXT_OPTIONS,
64 max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
65 max_diagnostic_bytes: 2048,
66 prompt_format: PromptFormat::DEFAULT,
67 file_indexing_parallelism: 1,
68};
69
70pub struct Zeta2FeatureFlag;
71
72impl FeatureFlag for Zeta2FeatureFlag {
73 const NAME: &'static str = "zeta2";
74
75 fn enabled_for_staff() -> bool {
76 false
77 }
78}
79
80#[derive(Clone)]
81struct ZetaGlobal(Entity<Zeta>);
82
83impl Global for ZetaGlobal {}
84
85pub struct Zeta {
86 client: Arc<Client>,
87 user_store: Entity<UserStore>,
88 llm_token: LlmApiToken,
89 _llm_token_subscription: Subscription,
90 projects: HashMap<EntityId, ZetaProject>,
91 options: ZetaOptions,
92 update_required: bool,
93 observe_tx: Option<mpsc::UnboundedSender<ObservedPredictionRequest>>,
94}
95
96#[derive(Debug, Clone, PartialEq)]
97pub struct ZetaOptions {
98 pub context: EditPredictionContextOptions,
99 pub max_prompt_bytes: usize,
100 pub max_diagnostic_bytes: usize,
101 pub prompt_format: predict_edits_v3::PromptFormat,
102 pub file_indexing_parallelism: usize,
103}
104
105#[derive(Clone)]
106pub struct ObservedPredictionRequest {
107 pub full_request: predict_edits_v3::PredictEditsRequest,
108 pub retrieval_time: TimeDelta,
109 pub buffer: WeakEntity<Buffer>,
110 pub position: language::Anchor,
111 pub local_prompt: Result<String, String>,
112 pub response:
113 Shared<BoxFuture<'static, Result<predict_edits_v3::PredictEditsResponse, String>>>,
114}
115
116pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
117
118struct ZetaProject {
119 syntax_index: Entity<SyntaxIndex>,
120 events: VecDeque<Event>,
121 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
122 current_prediction: Option<CurrentEditPrediction>,
123}
124
125#[derive(Debug, Clone)]
126struct CurrentEditPrediction {
127 pub requested_by_buffer_id: EntityId,
128 pub prediction: EditPrediction,
129}
130
131impl CurrentEditPrediction {
132 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
133 let Some(new_edits) = self
134 .prediction
135 .interpolate(&self.prediction.buffer.read(cx))
136 else {
137 return false;
138 };
139
140 if self.prediction.buffer != old_prediction.prediction.buffer {
141 return true;
142 }
143
144 let Some(old_edits) = old_prediction
145 .prediction
146 .interpolate(&old_prediction.prediction.buffer.read(cx))
147 else {
148 return true;
149 };
150
151 // This reduces the occurrence of UI thrash from replacing edits
152 //
153 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
154 if self.requested_by_buffer_id == self.prediction.buffer.entity_id()
155 && self.requested_by_buffer_id == old_prediction.prediction.buffer.entity_id()
156 && old_edits.len() == 1
157 && new_edits.len() == 1
158 {
159 let (old_range, old_text) = &old_edits[0];
160 let (new_range, new_text) = &new_edits[0];
161 new_range == old_range && new_text.starts_with(old_text)
162 } else {
163 true
164 }
165 }
166}
167
168/// A prediction from the perspective of a buffer.
169#[derive(Debug)]
170enum BufferEditPrediction<'a> {
171 Local { prediction: &'a EditPrediction },
172 Jump { prediction: &'a EditPrediction },
173}
174
175struct RegisteredBuffer {
176 snapshot: BufferSnapshot,
177 _subscriptions: [gpui::Subscription; 2],
178}
179
180#[derive(Clone)]
181pub enum Event {
182 BufferChange {
183 old_snapshot: BufferSnapshot,
184 new_snapshot: BufferSnapshot,
185 timestamp: Instant,
186 },
187}
188
189impl Zeta {
190 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
191 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
192 }
193
194 pub fn global(
195 client: &Arc<Client>,
196 user_store: &Entity<UserStore>,
197 cx: &mut App,
198 ) -> Entity<Self> {
199 cx.try_global::<ZetaGlobal>()
200 .map(|global| global.0.clone())
201 .unwrap_or_else(|| {
202 let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
203 cx.set_global(ZetaGlobal(zeta.clone()));
204 zeta
205 })
206 }
207
208 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
209 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
210
211 Self {
212 projects: HashMap::new(),
213 client,
214 user_store,
215 options: DEFAULT_OPTIONS,
216 llm_token: LlmApiToken::default(),
217 _llm_token_subscription: cx.subscribe(
218 &refresh_llm_token_listener,
219 |this, _listener, _event, cx| {
220 let client = this.client.clone();
221 let llm_token = this.llm_token.clone();
222 cx.spawn(async move |_this, _cx| {
223 llm_token.refresh(&client).await?;
224 anyhow::Ok(())
225 })
226 .detach_and_log_err(cx);
227 },
228 ),
229 update_required: false,
230 observe_tx: None,
231 }
232 }
233
234 pub fn observe_requests(&mut self) -> mpsc::UnboundedReceiver<ObservedPredictionRequest> {
235 let (tx, rx) = mpsc::unbounded();
236 self.observe_tx = Some(tx);
237 rx
238 }
239
240 pub fn options(&self) -> &ZetaOptions {
241 &self.options
242 }
243
244 pub fn set_options(&mut self, options: ZetaOptions) {
245 self.options = options;
246 }
247
248 pub fn clear_history(&mut self) {
249 for zeta_project in self.projects.values_mut() {
250 zeta_project.events.clear();
251 }
252 }
253
254 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
255 self.user_store.read(cx).edit_prediction_usage()
256 }
257
258 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
259 self.get_or_init_zeta_project(project, cx);
260 }
261
262 pub fn register_buffer(
263 &mut self,
264 buffer: &Entity<Buffer>,
265 project: &Entity<Project>,
266 cx: &mut Context<Self>,
267 ) {
268 let zeta_project = self.get_or_init_zeta_project(project, cx);
269 Self::register_buffer_impl(zeta_project, buffer, project, cx);
270 }
271
272 fn get_or_init_zeta_project(
273 &mut self,
274 project: &Entity<Project>,
275 cx: &mut App,
276 ) -> &mut ZetaProject {
277 self.projects
278 .entry(project.entity_id())
279 .or_insert_with(|| ZetaProject {
280 syntax_index: cx.new(|cx| {
281 SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
282 }),
283 events: VecDeque::new(),
284 registered_buffers: HashMap::new(),
285 current_prediction: None,
286 })
287 }
288
289 fn register_buffer_impl<'a>(
290 zeta_project: &'a mut ZetaProject,
291 buffer: &Entity<Buffer>,
292 project: &Entity<Project>,
293 cx: &mut Context<Self>,
294 ) -> &'a mut RegisteredBuffer {
295 let buffer_id = buffer.entity_id();
296 match zeta_project.registered_buffers.entry(buffer_id) {
297 hash_map::Entry::Occupied(entry) => entry.into_mut(),
298 hash_map::Entry::Vacant(entry) => {
299 let snapshot = buffer.read(cx).snapshot();
300 let project_entity_id = project.entity_id();
301 entry.insert(RegisteredBuffer {
302 snapshot,
303 _subscriptions: [
304 cx.subscribe(buffer, {
305 let project = project.downgrade();
306 move |this, buffer, event, cx| {
307 if let language::BufferEvent::Edited = event
308 && let Some(project) = project.upgrade()
309 {
310 this.report_changes_for_buffer(&buffer, &project, cx);
311 }
312 }
313 }),
314 cx.observe_release(buffer, move |this, _buffer, _cx| {
315 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
316 else {
317 return;
318 };
319 zeta_project.registered_buffers.remove(&buffer_id);
320 }),
321 ],
322 })
323 }
324 }
325 }
326
327 fn report_changes_for_buffer(
328 &mut self,
329 buffer: &Entity<Buffer>,
330 project: &Entity<Project>,
331 cx: &mut Context<Self>,
332 ) -> BufferSnapshot {
333 let zeta_project = self.get_or_init_zeta_project(project, cx);
334 let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
335
336 let new_snapshot = buffer.read(cx).snapshot();
337 if new_snapshot.version != registered_buffer.snapshot.version {
338 let old_snapshot =
339 std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
340 Self::push_event(
341 zeta_project,
342 Event::BufferChange {
343 old_snapshot,
344 new_snapshot: new_snapshot.clone(),
345 timestamp: Instant::now(),
346 },
347 );
348 }
349
350 new_snapshot
351 }
352
353 fn push_event(zeta_project: &mut ZetaProject, event: Event) {
354 let events = &mut zeta_project.events;
355
356 if let Some(Event::BufferChange {
357 new_snapshot: last_new_snapshot,
358 timestamp: last_timestamp,
359 ..
360 }) = events.back_mut()
361 {
362 // Coalesce edits for the same buffer when they happen one after the other.
363 let Event::BufferChange {
364 old_snapshot,
365 new_snapshot,
366 timestamp,
367 } = &event;
368
369 if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
370 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
371 && old_snapshot.version == last_new_snapshot.version
372 {
373 *last_new_snapshot = new_snapshot.clone();
374 *last_timestamp = *timestamp;
375 return;
376 }
377 }
378
379 if events.len() >= MAX_EVENT_COUNT {
380 // These are halved instead of popping to improve prompt caching.
381 events.drain(..MAX_EVENT_COUNT / 2);
382 }
383
384 events.push_back(event);
385 }
386
387 fn current_prediction_for_buffer(
388 &self,
389 buffer: &Entity<Buffer>,
390 project: &Entity<Project>,
391 cx: &App,
392 ) -> Option<BufferEditPrediction<'_>> {
393 let project_state = self.projects.get(&project.entity_id())?;
394
395 let CurrentEditPrediction {
396 requested_by_buffer_id,
397 prediction,
398 } = project_state.current_prediction.as_ref()?;
399
400 if prediction.targets_buffer(buffer.read(cx), cx) {
401 Some(BufferEditPrediction::Local { prediction })
402 } else if *requested_by_buffer_id == buffer.entity_id() {
403 Some(BufferEditPrediction::Jump { prediction })
404 } else {
405 None
406 }
407 }
408
409 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
410 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
411 return;
412 };
413
414 let Some(prediction) = project_state.current_prediction.take() else {
415 return;
416 };
417 let request_id = prediction.prediction.id.into();
418
419 let client = self.client.clone();
420 let llm_token = self.llm_token.clone();
421 let app_version = AppVersion::global(cx);
422 cx.spawn(async move |this, cx| {
423 let url = if let Ok(predict_edits_url) = std::env::var("ZED_ACCEPT_PREDICTION_URL") {
424 http_client::Url::parse(&predict_edits_url)?
425 } else {
426 client
427 .http_client()
428 .build_zed_llm_url("/predict_edits/accept", &[])?
429 };
430
431 let response = cx
432 .background_spawn(Self::send_api_request::<()>(
433 move |builder| {
434 let req = builder.uri(url.as_ref()).body(
435 serde_json::to_string(&AcceptEditPredictionBody { request_id })?.into(),
436 );
437 Ok(req?)
438 },
439 client,
440 llm_token,
441 app_version,
442 ))
443 .await;
444
445 Self::handle_api_response(&this, response, cx)?;
446 anyhow::Ok(())
447 })
448 .detach_and_log_err(cx);
449 }
450
451 fn discard_current_prediction(&mut self, project: &Entity<Project>) {
452 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
453 project_state.current_prediction.take();
454 };
455 }
456
457 pub fn refresh_prediction(
458 &mut self,
459 project: &Entity<Project>,
460 buffer: &Entity<Buffer>,
461 position: language::Anchor,
462 cx: &mut Context<Self>,
463 ) -> Task<Result<()>> {
464 let request_task = self.request_prediction(project, buffer, position, cx);
465 let buffer = buffer.clone();
466 let project = project.clone();
467
468 cx.spawn(async move |this, cx| {
469 if let Some(prediction) = request_task.await? {
470 this.update(cx, |this, cx| {
471 let project_state = this
472 .projects
473 .get_mut(&project.entity_id())
474 .context("Project not found")?;
475
476 let new_prediction = CurrentEditPrediction {
477 requested_by_buffer_id: buffer.entity_id(),
478 prediction: prediction,
479 };
480
481 if project_state
482 .current_prediction
483 .as_ref()
484 .is_none_or(|old_prediction| {
485 new_prediction.should_replace_prediction(&old_prediction, cx)
486 })
487 {
488 project_state.current_prediction = Some(new_prediction);
489 }
490 anyhow::Ok(())
491 })??;
492 }
493 Ok(())
494 })
495 }
496
497 fn request_prediction(
498 &mut self,
499 project: &Entity<Project>,
500 buffer: &Entity<Buffer>,
501 position: language::Anchor,
502 cx: &mut Context<Self>,
503 ) -> Task<Result<Option<EditPrediction>>> {
504 let project_state = self.projects.get(&project.entity_id());
505
506 let index_state = project_state.map(|state| {
507 state
508 .syntax_index
509 .read_with(cx, |index, _cx| index.state().clone())
510 });
511 let options = self.options.clone();
512 let snapshot = buffer.read(cx).snapshot();
513 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx).into()) else {
514 return Task::ready(Err(anyhow!("No file path for excerpt")));
515 };
516 let client = self.client.clone();
517 let llm_token = self.llm_token.clone();
518 let app_version = AppVersion::global(cx);
519 let worktree_snapshots = project
520 .read(cx)
521 .worktrees(cx)
522 .map(|worktree| worktree.read(cx).snapshot())
523 .collect::<Vec<_>>();
524 let observe_tx = self.observe_tx.clone();
525
526 let events = project_state
527 .map(|state| {
528 state
529 .events
530 .iter()
531 .filter_map(|event| match event {
532 Event::BufferChange {
533 old_snapshot,
534 new_snapshot,
535 ..
536 } => {
537 let path = new_snapshot.file().map(|f| f.full_path(cx));
538
539 let old_path = old_snapshot.file().and_then(|f| {
540 let old_path = f.full_path(cx);
541 if Some(&old_path) != path.as_ref() {
542 Some(old_path)
543 } else {
544 None
545 }
546 });
547
548 // TODO [zeta2] move to bg?
549 let diff =
550 language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
551
552 if path == old_path && diff.is_empty() {
553 None
554 } else {
555 Some(predict_edits_v3::Event::BufferChange {
556 old_path,
557 path,
558 diff,
559 //todo: Actually detect if this edit was predicted or not
560 predicted: false,
561 })
562 }
563 }
564 })
565 .collect::<Vec<_>>()
566 })
567 .unwrap_or_default();
568
569 let diagnostics = snapshot.diagnostic_sets().clone();
570
571 let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
572 let mut path = f.worktree.read(cx).absolutize(&f.path);
573 if path.pop() { Some(path) } else { None }
574 });
575
576 // TODO data collection
577 let can_collect_data = cx.is_staff();
578
579 let request_task = cx.background_spawn({
580 let snapshot = snapshot.clone();
581 let buffer = buffer.clone();
582 async move {
583 let index_state = if let Some(index_state) = index_state {
584 Some(index_state.lock_owned().await)
585 } else {
586 None
587 };
588
589 let cursor_offset = position.to_offset(&snapshot);
590 let cursor_point = cursor_offset.to_point(&snapshot);
591
592 let before_retrieval = chrono::Utc::now();
593
594 let Some(context) = EditPredictionContext::gather_context(
595 cursor_point,
596 &snapshot,
597 parent_abs_path.as_deref(),
598 &options.context,
599 index_state.as_deref(),
600 ) else {
601 return Ok((None, None));
602 };
603
604 let retrieval_time = chrono::Utc::now() - before_retrieval;
605
606 let (diagnostic_groups, diagnostic_groups_truncated) =
607 Self::gather_nearby_diagnostics(
608 cursor_offset,
609 &diagnostics,
610 &snapshot,
611 options.max_diagnostic_bytes,
612 );
613
614 let request = make_cloud_request(
615 excerpt_path,
616 context,
617 events,
618 can_collect_data,
619 diagnostic_groups,
620 diagnostic_groups_truncated,
621 None,
622 observe_tx.is_some(),
623 &worktree_snapshots,
624 index_state.as_deref(),
625 Some(options.max_prompt_bytes),
626 options.prompt_format,
627 );
628
629 let observe_response_tx = if let Some(observe_tx) = &observe_tx {
630 let (response_tx, response_rx) = oneshot::channel();
631
632 let local_prompt = PlannedPrompt::populate(&request)
633 .and_then(|p| p.to_prompt_string().map(|p| p.0))
634 .map_err(|err| err.to_string());
635
636 observe_tx
637 .unbounded_send(ObservedPredictionRequest {
638 full_request: request.clone(),
639 retrieval_time,
640 buffer: buffer.downgrade(),
641 local_prompt,
642 position,
643 response: async move {
644 let resp = response_rx
645 .await
646 .map_err(|_: oneshot::Canceled| "Canceled".to_string());
647 Ok(resp??)
648 }
649 .boxed()
650 .shared(),
651 })
652 .ok();
653 Some(response_tx)
654 } else {
655 None
656 };
657
658 if cfg!(debug_assertions) && std::env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
659 if let Some(observe_response_tx) = observe_response_tx {
660 observe_response_tx
661 .send(Err("Request skipped".to_string()))
662 .ok();
663 }
664 anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
665 }
666
667 let response =
668 Self::send_prediction_request(client, llm_token, app_version, request).await;
669
670 if let Some(debug_response_tx) = observe_response_tx {
671 debug_response_tx
672 .send(
673 response
674 .as_ref()
675 .map_err(|err| err.to_string())
676 .map(|response| response.0.clone()),
677 )
678 .ok();
679 }
680
681 response.map(|(res, usage)| (Some(res), usage))
682 }
683 });
684
685 let buffer = buffer.clone();
686
687 cx.spawn({
688 let project = project.clone();
689 async move |this, cx| {
690 let Some(response) = Self::handle_api_response(&this, request_task.await, cx)?
691 else {
692 return Ok(None);
693 };
694
695 // TODO telemetry: duration, etc
696 Ok(EditPrediction::from_response(response, &snapshot, &buffer, &project, cx).await)
697 }
698 })
699 }
700
701 async fn send_prediction_request(
702 client: Arc<Client>,
703 llm_token: LlmApiToken,
704 app_version: SemanticVersion,
705 request: predict_edits_v3::PredictEditsRequest,
706 ) -> Result<(
707 predict_edits_v3::PredictEditsResponse,
708 Option<EditPredictionUsage>,
709 )> {
710 let url = if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
711 http_client::Url::parse(&predict_edits_url)?
712 } else {
713 client
714 .http_client()
715 .build_zed_llm_url("/predict_edits/v3", &[])?
716 };
717
718 Self::send_api_request(
719 |builder| {
720 let req = builder
721 .uri(url.as_ref())
722 .body(serde_json::to_string(&request)?.into());
723 Ok(req?)
724 },
725 client,
726 llm_token,
727 app_version,
728 )
729 .await
730 }
731
732 fn handle_api_response<T>(
733 this: &WeakEntity<Self>,
734 response: Result<(T, Option<EditPredictionUsage>)>,
735 cx: &mut gpui::AsyncApp,
736 ) -> Result<T> {
737 match response {
738 Ok((data, usage)) => {
739 if let Some(usage) = usage {
740 this.update(cx, |this, cx| {
741 this.user_store.update(cx, |user_store, cx| {
742 user_store.update_edit_prediction_usage(usage, cx);
743 });
744 })
745 .ok();
746 }
747 Ok(data)
748 }
749 Err(err) => {
750 if err.is::<ZedUpdateRequiredError>() {
751 cx.update(|cx| {
752 this.update(cx, |this, _cx| {
753 this.update_required = true;
754 })
755 .ok();
756
757 let error_message: SharedString = err.to_string().into();
758 show_app_notification(
759 NotificationId::unique::<ZedUpdateRequiredError>(),
760 cx,
761 move |cx| {
762 cx.new(|cx| {
763 ErrorMessagePrompt::new(error_message.clone(), cx)
764 .with_link_button("Update Zed", "https://zed.dev/releases")
765 })
766 },
767 );
768 })
769 .ok();
770 }
771 Err(err)
772 }
773 }
774 }
775
776 async fn send_api_request<Res>(
777 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
778 client: Arc<Client>,
779 llm_token: LlmApiToken,
780 app_version: SemanticVersion,
781 ) -> Result<(Res, Option<EditPredictionUsage>)>
782 where
783 Res: DeserializeOwned,
784 {
785 let http_client = client.http_client();
786 let mut token = llm_token.acquire(&client).await?;
787 let mut did_retry = false;
788
789 loop {
790 let request_builder = http_client::Request::builder().method(Method::POST);
791
792 let request = build(
793 request_builder
794 .header("Content-Type", "application/json")
795 .header("Authorization", format!("Bearer {}", token))
796 .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
797 )?;
798
799 let mut response = http_client.send(request).await?;
800
801 if let Some(minimum_required_version) = response
802 .headers()
803 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
804 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
805 {
806 anyhow::ensure!(
807 app_version >= minimum_required_version,
808 ZedUpdateRequiredError {
809 minimum_version: minimum_required_version
810 }
811 );
812 }
813
814 if response.status().is_success() {
815 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
816
817 let mut body = Vec::new();
818 response.body_mut().read_to_end(&mut body).await?;
819 return Ok((serde_json::from_slice(&body)?, usage));
820 } else if !did_retry
821 && response
822 .headers()
823 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
824 .is_some()
825 {
826 did_retry = true;
827 token = llm_token.refresh(&client).await?;
828 } else {
829 let mut body = String::new();
830 response.body_mut().read_to_string(&mut body).await?;
831 anyhow::bail!(
832 "Request failed with status: {:?}\nBody: {}",
833 response.status(),
834 body
835 );
836 }
837 }
838 }
839
840 fn gather_nearby_diagnostics(
841 cursor_offset: usize,
842 diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
843 snapshot: &BufferSnapshot,
844 max_diagnostics_bytes: usize,
845 ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
846 // TODO: Could make this more efficient
847 let mut diagnostic_groups = Vec::new();
848 for (language_server_id, diagnostics) in diagnostic_sets {
849 let mut groups = Vec::new();
850 diagnostics.groups(*language_server_id, &mut groups, &snapshot);
851 diagnostic_groups.extend(
852 groups
853 .into_iter()
854 .map(|(_, group)| group.resolve::<usize>(&snapshot)),
855 );
856 }
857
858 // sort by proximity to cursor
859 diagnostic_groups.sort_by_key(|group| {
860 let range = &group.entries[group.primary_ix].range;
861 if range.start >= cursor_offset {
862 range.start - cursor_offset
863 } else if cursor_offset >= range.end {
864 cursor_offset - range.end
865 } else {
866 (cursor_offset - range.start).min(range.end - cursor_offset)
867 }
868 });
869
870 let mut results = Vec::new();
871 let mut diagnostic_groups_truncated = false;
872 let mut diagnostics_byte_count = 0;
873 for group in diagnostic_groups {
874 let raw_value = serde_json::value::to_raw_value(&group).unwrap();
875 diagnostics_byte_count += raw_value.get().len();
876 if diagnostics_byte_count > max_diagnostics_bytes {
877 diagnostic_groups_truncated = true;
878 break;
879 }
880 results.push(predict_edits_v3::DiagnosticGroup(raw_value));
881 }
882
883 (results, diagnostic_groups_truncated)
884 }
885
886 // TODO: Dedupe with similar code in request_prediction?
887 pub fn cloud_request_for_zeta_cli(
888 &mut self,
889 project: &Entity<Project>,
890 buffer: &Entity<Buffer>,
891 position: language::Anchor,
892 cx: &mut Context<Self>,
893 ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
894 let project_state = self.projects.get(&project.entity_id());
895
896 let index_state = project_state.map(|state| {
897 state
898 .syntax_index
899 .read_with(cx, |index, _cx| index.state().clone())
900 });
901 let options = self.options.clone();
902 let snapshot = buffer.read(cx).snapshot();
903 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
904 return Task::ready(Err(anyhow!("No file path for excerpt")));
905 };
906 let worktree_snapshots = project
907 .read(cx)
908 .worktrees(cx)
909 .map(|worktree| worktree.read(cx).snapshot())
910 .collect::<Vec<_>>();
911
912 let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
913 let mut path = f.worktree.read(cx).absolutize(&f.path);
914 if path.pop() { Some(path) } else { None }
915 });
916
917 cx.background_spawn(async move {
918 let index_state = if let Some(index_state) = index_state {
919 Some(index_state.lock_owned().await)
920 } else {
921 None
922 };
923
924 let cursor_point = position.to_point(&snapshot);
925
926 let debug_info = true;
927 EditPredictionContext::gather_context(
928 cursor_point,
929 &snapshot,
930 parent_abs_path.as_deref(),
931 &options.context,
932 index_state.as_deref(),
933 )
934 .context("Failed to select excerpt")
935 .map(|context| {
936 make_cloud_request(
937 excerpt_path.into(),
938 context,
939 // TODO pass everything
940 Vec::new(),
941 false,
942 Vec::new(),
943 false,
944 None,
945 debug_info,
946 &worktree_snapshots,
947 index_state.as_deref(),
948 Some(options.max_prompt_bytes),
949 options.prompt_format,
950 )
951 })
952 })
953 }
954
955 pub fn wait_for_initial_indexing(
956 &mut self,
957 project: &Entity<Project>,
958 cx: &mut App,
959 ) -> Task<Result<()>> {
960 let zeta_project = self.get_or_init_zeta_project(project, cx);
961 zeta_project
962 .syntax_index
963 .read(cx)
964 .wait_for_initial_file_indexing(cx)
965 }
966}
967
968#[derive(Error, Debug)]
969#[error(
970 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
971)]
972pub struct ZedUpdateRequiredError {
973 minimum_version: SemanticVersion,
974}
975
976fn make_cloud_request(
977 excerpt_path: Arc<Path>,
978 context: EditPredictionContext,
979 events: Vec<predict_edits_v3::Event>,
980 can_collect_data: bool,
981 diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
982 diagnostic_groups_truncated: bool,
983 git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
984 debug_info: bool,
985 worktrees: &Vec<worktree::Snapshot>,
986 index_state: Option<&SyntaxIndexState>,
987 prompt_max_bytes: Option<usize>,
988 prompt_format: PromptFormat,
989) -> predict_edits_v3::PredictEditsRequest {
990 let mut signatures = Vec::new();
991 let mut declaration_to_signature_index = HashMap::default();
992 let mut referenced_declarations = Vec::new();
993
994 for snippet in context.declarations {
995 let project_entry_id = snippet.declaration.project_entry_id();
996 let Some(path) = worktrees.iter().find_map(|worktree| {
997 worktree.entry_for_id(project_entry_id).map(|entry| {
998 let mut full_path = RelPathBuf::new();
999 full_path.push(worktree.root_name());
1000 full_path.push(&entry.path);
1001 full_path
1002 })
1003 }) else {
1004 continue;
1005 };
1006
1007 let parent_index = index_state.and_then(|index_state| {
1008 snippet.declaration.parent().and_then(|parent| {
1009 add_signature(
1010 parent,
1011 &mut declaration_to_signature_index,
1012 &mut signatures,
1013 index_state,
1014 )
1015 })
1016 });
1017
1018 let (text, text_is_truncated) = snippet.declaration.item_text();
1019 referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
1020 path: path.as_std_path().into(),
1021 text: text.into(),
1022 range: snippet.declaration.item_line_range(),
1023 text_is_truncated,
1024 signature_range: snippet.declaration.signature_range_in_item_text(),
1025 parent_index,
1026 signature_score: snippet.score(DeclarationStyle::Signature),
1027 declaration_score: snippet.score(DeclarationStyle::Declaration),
1028 score_components: snippet.components,
1029 });
1030 }
1031
1032 let excerpt_parent = index_state.and_then(|index_state| {
1033 context
1034 .excerpt
1035 .parent_declarations
1036 .last()
1037 .and_then(|(parent, _)| {
1038 add_signature(
1039 *parent,
1040 &mut declaration_to_signature_index,
1041 &mut signatures,
1042 index_state,
1043 )
1044 })
1045 });
1046
1047 predict_edits_v3::PredictEditsRequest {
1048 excerpt_path,
1049 excerpt: context.excerpt_text.body,
1050 excerpt_line_range: context.excerpt.line_range,
1051 excerpt_range: context.excerpt.range,
1052 cursor_point: predict_edits_v3::Point {
1053 line: predict_edits_v3::Line(context.cursor_point.row),
1054 column: context.cursor_point.column,
1055 },
1056 referenced_declarations,
1057 signatures,
1058 excerpt_parent,
1059 events,
1060 can_collect_data,
1061 diagnostic_groups,
1062 diagnostic_groups_truncated,
1063 git_info,
1064 debug_info,
1065 prompt_max_bytes,
1066 prompt_format,
1067 }
1068}
1069
1070fn add_signature(
1071 declaration_id: DeclarationId,
1072 declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
1073 signatures: &mut Vec<Signature>,
1074 index: &SyntaxIndexState,
1075) -> Option<usize> {
1076 if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
1077 return Some(*signature_index);
1078 }
1079 let Some(parent_declaration) = index.declaration(declaration_id) else {
1080 log::error!("bug: missing parent declaration");
1081 return None;
1082 };
1083 let parent_index = parent_declaration.parent().and_then(|parent| {
1084 add_signature(parent, declaration_to_signature_index, signatures, index)
1085 });
1086 let (text, text_is_truncated) = parent_declaration.signature_text();
1087 let signature_index = signatures.len();
1088 signatures.push(Signature {
1089 text: text.into(),
1090 text_is_truncated,
1091 parent_index,
1092 range: parent_declaration.signature_line_range(),
1093 });
1094 declaration_to_signature_index.insert(declaration_id, signature_index);
1095 Some(signature_index)
1096}
1097
1098#[cfg(test)]
1099mod tests {
1100 use std::{
1101 path::{Path, PathBuf},
1102 sync::Arc,
1103 };
1104
1105 use client::UserStore;
1106 use clock::FakeSystemClock;
1107 use cloud_llm_client::predict_edits_v3::{self, Point};
1108 use edit_prediction_context::Line;
1109 use futures::{
1110 AsyncReadExt, StreamExt,
1111 channel::{mpsc, oneshot},
1112 };
1113 use gpui::{
1114 Entity, TestAppContext,
1115 http_client::{FakeHttpClient, Response},
1116 prelude::*,
1117 };
1118 use indoc::indoc;
1119 use language::{LanguageServerId, OffsetRangeExt as _};
1120 use pretty_assertions::{assert_eq, assert_matches};
1121 use project::{FakeFs, Project};
1122 use serde_json::json;
1123 use settings::SettingsStore;
1124 use util::path;
1125 use uuid::Uuid;
1126
1127 use crate::{BufferEditPrediction, Zeta};
1128
1129 #[gpui::test]
1130 async fn test_current_state(cx: &mut TestAppContext) {
1131 let (zeta, mut req_rx) = init_test(cx);
1132 let fs = FakeFs::new(cx.executor());
1133 fs.insert_tree(
1134 "/root",
1135 json!({
1136 "1.txt": "Hello!\nHow\nBye",
1137 "2.txt": "Hola!\nComo\nAdios"
1138 }),
1139 )
1140 .await;
1141 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1142
1143 zeta.update(cx, |zeta, cx| {
1144 zeta.register_project(&project, cx);
1145 });
1146
1147 let buffer1 = project
1148 .update(cx, |project, cx| {
1149 let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
1150 project.open_buffer(path, cx)
1151 })
1152 .await
1153 .unwrap();
1154 let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
1155 let position = snapshot1.anchor_before(language::Point::new(1, 3));
1156
1157 // Prediction for current file
1158
1159 let prediction_task = zeta.update(cx, |zeta, cx| {
1160 zeta.refresh_prediction(&project, &buffer1, position, cx)
1161 });
1162 let (_request, respond_tx) = req_rx.next().await.unwrap();
1163 respond_tx
1164 .send(predict_edits_v3::PredictEditsResponse {
1165 request_id: Uuid::new_v4(),
1166 edits: vec![predict_edits_v3::Edit {
1167 path: Path::new(path!("root/1.txt")).into(),
1168 range: Line(0)..Line(snapshot1.max_point().row + 1),
1169 content: "Hello!\nHow are you?\nBye".into(),
1170 }],
1171 debug_info: None,
1172 })
1173 .unwrap();
1174 prediction_task.await.unwrap();
1175
1176 zeta.read_with(cx, |zeta, cx| {
1177 let prediction = zeta
1178 .current_prediction_for_buffer(&buffer1, &project, cx)
1179 .unwrap();
1180 assert_matches!(prediction, BufferEditPrediction::Local { .. });
1181 });
1182
1183 // Prediction for another file
1184 let prediction_task = zeta.update(cx, |zeta, cx| {
1185 zeta.refresh_prediction(&project, &buffer1, position, cx)
1186 });
1187 let (_request, respond_tx) = req_rx.next().await.unwrap();
1188 respond_tx
1189 .send(predict_edits_v3::PredictEditsResponse {
1190 request_id: Uuid::new_v4(),
1191 edits: vec![predict_edits_v3::Edit {
1192 path: Path::new(path!("root/2.txt")).into(),
1193 range: Line(0)..Line(snapshot1.max_point().row + 1),
1194 content: "Hola!\nComo estas?\nAdios".into(),
1195 }],
1196 debug_info: None,
1197 })
1198 .unwrap();
1199 prediction_task.await.unwrap();
1200 zeta.read_with(cx, |zeta, cx| {
1201 let prediction = zeta
1202 .current_prediction_for_buffer(&buffer1, &project, cx)
1203 .unwrap();
1204 assert_matches!(
1205 prediction,
1206 BufferEditPrediction::Jump { prediction } if prediction.path.as_ref() == Path::new(path!("root/2.txt"))
1207 );
1208 });
1209
1210 let buffer2 = project
1211 .update(cx, |project, cx| {
1212 let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
1213 project.open_buffer(path, cx)
1214 })
1215 .await
1216 .unwrap();
1217
1218 zeta.read_with(cx, |zeta, cx| {
1219 let prediction = zeta
1220 .current_prediction_for_buffer(&buffer2, &project, cx)
1221 .unwrap();
1222 assert_matches!(prediction, BufferEditPrediction::Local { .. });
1223 });
1224 }
1225
1226 #[gpui::test]
1227 async fn test_simple_request(cx: &mut TestAppContext) {
1228 let (zeta, mut req_rx) = init_test(cx);
1229 let fs = FakeFs::new(cx.executor());
1230 fs.insert_tree(
1231 "/root",
1232 json!({
1233 "foo.md": "Hello!\nHow\nBye"
1234 }),
1235 )
1236 .await;
1237 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1238
1239 let buffer = project
1240 .update(cx, |project, cx| {
1241 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1242 project.open_buffer(path, cx)
1243 })
1244 .await
1245 .unwrap();
1246 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1247 let position = snapshot.anchor_before(language::Point::new(1, 3));
1248
1249 let prediction_task = zeta.update(cx, |zeta, cx| {
1250 zeta.request_prediction(&project, &buffer, position, cx)
1251 });
1252
1253 let (request, respond_tx) = req_rx.next().await.unwrap();
1254 assert_eq!(
1255 request.excerpt_path.as_ref(),
1256 Path::new(path!("root/foo.md"))
1257 );
1258 assert_eq!(
1259 request.cursor_point,
1260 Point {
1261 line: Line(1),
1262 column: 3
1263 }
1264 );
1265
1266 respond_tx
1267 .send(predict_edits_v3::PredictEditsResponse {
1268 request_id: Uuid::new_v4(),
1269 edits: vec![predict_edits_v3::Edit {
1270 path: Path::new(path!("root/foo.md")).into(),
1271 range: Line(0)..Line(snapshot.max_point().row + 1),
1272 content: "Hello!\nHow are you?\nBye".into(),
1273 }],
1274 debug_info: None,
1275 })
1276 .unwrap();
1277
1278 let prediction = prediction_task.await.unwrap().unwrap();
1279
1280 assert_eq!(prediction.edits.len(), 1);
1281 assert_eq!(
1282 prediction.edits[0].0.to_point(&snapshot).start,
1283 language::Point::new(1, 3)
1284 );
1285 assert_eq!(prediction.edits[0].1, " are you?");
1286 }
1287
1288 #[gpui::test]
1289 async fn test_request_events(cx: &mut TestAppContext) {
1290 let (zeta, mut req_rx) = init_test(cx);
1291 let fs = FakeFs::new(cx.executor());
1292 fs.insert_tree(
1293 "/root",
1294 json!({
1295 "foo.md": "Hello!\n\nBye"
1296 }),
1297 )
1298 .await;
1299 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1300
1301 let buffer = project
1302 .update(cx, |project, cx| {
1303 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1304 project.open_buffer(path, cx)
1305 })
1306 .await
1307 .unwrap();
1308
1309 zeta.update(cx, |zeta, cx| {
1310 zeta.register_buffer(&buffer, &project, cx);
1311 });
1312
1313 buffer.update(cx, |buffer, cx| {
1314 buffer.edit(vec![(7..7, "How")], None, cx);
1315 });
1316
1317 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1318 let position = snapshot.anchor_before(language::Point::new(1, 3));
1319
1320 let prediction_task = zeta.update(cx, |zeta, cx| {
1321 zeta.request_prediction(&project, &buffer, position, cx)
1322 });
1323
1324 let (request, respond_tx) = req_rx.next().await.unwrap();
1325
1326 assert_eq!(request.events.len(), 1);
1327 assert_eq!(
1328 request.events[0],
1329 predict_edits_v3::Event::BufferChange {
1330 path: Some(PathBuf::from(path!("root/foo.md"))),
1331 old_path: None,
1332 diff: indoc! {"
1333 @@ -1,3 +1,3 @@
1334 Hello!
1335 -
1336 +How
1337 Bye
1338 "}
1339 .to_string(),
1340 predicted: false
1341 }
1342 );
1343
1344 respond_tx
1345 .send(predict_edits_v3::PredictEditsResponse {
1346 request_id: Uuid::new_v4(),
1347 edits: vec![predict_edits_v3::Edit {
1348 path: Path::new(path!("root/foo.md")).into(),
1349 range: Line(0)..Line(snapshot.max_point().row + 1),
1350 content: "Hello!\nHow are you?\nBye".into(),
1351 }],
1352 debug_info: None,
1353 })
1354 .unwrap();
1355
1356 let prediction = prediction_task.await.unwrap().unwrap();
1357
1358 assert_eq!(prediction.edits.len(), 1);
1359 assert_eq!(
1360 prediction.edits[0].0.to_point(&snapshot).start,
1361 language::Point::new(1, 3)
1362 );
1363 assert_eq!(prediction.edits[0].1, " are you?");
1364 }
1365
1366 #[gpui::test]
1367 async fn test_request_diagnostics(cx: &mut TestAppContext) {
1368 let (zeta, mut req_rx) = init_test(cx);
1369 let fs = FakeFs::new(cx.executor());
1370 fs.insert_tree(
1371 "/root",
1372 json!({
1373 "foo.md": "Hello!\nBye"
1374 }),
1375 )
1376 .await;
1377 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1378
1379 let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1380 let diagnostic = lsp::Diagnostic {
1381 range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1382 severity: Some(lsp::DiagnosticSeverity::ERROR),
1383 message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1384 ..Default::default()
1385 };
1386
1387 project.update(cx, |project, cx| {
1388 project.lsp_store().update(cx, |lsp_store, cx| {
1389 // Create some diagnostics
1390 lsp_store
1391 .update_diagnostics(
1392 LanguageServerId(0),
1393 lsp::PublishDiagnosticsParams {
1394 uri: path_to_buffer_uri.clone(),
1395 diagnostics: vec![diagnostic],
1396 version: None,
1397 },
1398 None,
1399 language::DiagnosticSourceKind::Pushed,
1400 &[],
1401 cx,
1402 )
1403 .unwrap();
1404 });
1405 });
1406
1407 let buffer = project
1408 .update(cx, |project, cx| {
1409 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1410 project.open_buffer(path, cx)
1411 })
1412 .await
1413 .unwrap();
1414
1415 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1416 let position = snapshot.anchor_before(language::Point::new(0, 0));
1417
1418 let _prediction_task = zeta.update(cx, |zeta, cx| {
1419 zeta.request_prediction(&project, &buffer, position, cx)
1420 });
1421
1422 let (request, _respond_tx) = req_rx.next().await.unwrap();
1423
1424 assert_eq!(request.diagnostic_groups.len(), 1);
1425 let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1426 .unwrap();
1427 // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1428 assert_eq!(
1429 value,
1430 json!({
1431 "entries": [{
1432 "range": {
1433 "start": 8,
1434 "end": 10
1435 },
1436 "diagnostic": {
1437 "source": null,
1438 "code": null,
1439 "code_description": null,
1440 "severity": 1,
1441 "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1442 "markdown": null,
1443 "group_id": 0,
1444 "is_primary": true,
1445 "is_disk_based": false,
1446 "is_unnecessary": false,
1447 "source_kind": "Pushed",
1448 "data": null,
1449 "underline": true
1450 }
1451 }],
1452 "primary_ix": 0
1453 })
1454 );
1455 }
1456
1457 fn init_test(
1458 cx: &mut TestAppContext,
1459 ) -> (
1460 Entity<Zeta>,
1461 mpsc::UnboundedReceiver<(
1462 predict_edits_v3::PredictEditsRequest,
1463 oneshot::Sender<predict_edits_v3::PredictEditsResponse>,
1464 )>,
1465 ) {
1466 cx.update(move |cx| {
1467 let settings_store = SettingsStore::test(cx);
1468 cx.set_global(settings_store);
1469 language::init(cx);
1470 Project::init_settings(cx);
1471
1472 let (req_tx, req_rx) = mpsc::unbounded();
1473
1474 let http_client = FakeHttpClient::create({
1475 move |req| {
1476 let uri = req.uri().path().to_string();
1477 let mut body = req.into_body();
1478 let req_tx = req_tx.clone();
1479 async move {
1480 let resp = match uri.as_str() {
1481 "/client/llm_tokens" => serde_json::to_string(&json!({
1482 "token": "test"
1483 }))
1484 .unwrap(),
1485 "/predict_edits/v3" => {
1486 let mut buf = Vec::new();
1487 body.read_to_end(&mut buf).await.ok();
1488 let req = serde_json::from_slice(&buf).unwrap();
1489
1490 let (res_tx, res_rx) = oneshot::channel();
1491 req_tx.unbounded_send((req, res_tx)).unwrap();
1492 serde_json::to_string(&res_rx.await?).unwrap()
1493 }
1494 _ => {
1495 panic!("Unexpected path: {}", uri)
1496 }
1497 };
1498
1499 Ok(Response::builder().body(resp.into()).unwrap())
1500 }
1501 }
1502 });
1503
1504 let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1505 client.cloud_client().set_credentials(1, "test".into());
1506
1507 language_model::init(client.clone(), cx);
1508
1509 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1510 let zeta = Zeta::global(&client, &user_store, cx);
1511
1512 (zeta, req_rx)
1513 })
1514 }
1515}