1use anyhow::{Context as _, Result, anyhow};
2use chrono::TimeDelta;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
5use cloud_llm_client::{
6 AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
7 ZED_VERSION_HEADER_NAME,
8};
9use cloud_zeta2_prompt::{DEFAULT_MAX_PROMPT_BYTES, PlannedPrompt};
10use edit_prediction_context::{
11 DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
12 EditPredictionExcerptOptions, EditPredictionScoreOptions, SyntaxIndex, SyntaxIndexState,
13};
14use feature_flags::FeatureFlag;
15use futures::AsyncReadExt as _;
16use futures::channel::{mpsc, oneshot};
17use gpui::http_client::{AsyncBody, Method};
18use gpui::{
19 App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
20 http_client, prelude::*,
21};
22use language::BufferSnapshot;
23use language::{Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
24use language_model::{LlmApiToken, RefreshLlmTokenListener};
25use project::Project;
26use release_channel::AppVersion;
27use serde::de::DeserializeOwned;
28use std::collections::{HashMap, VecDeque, hash_map};
29use std::path::Path;
30use std::str::FromStr as _;
31use std::sync::Arc;
32use std::time::{Duration, Instant};
33use thiserror::Error;
34use util::rel_path::RelPathBuf;
35use util::some_or_debug_panic;
36use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
37
38mod prediction;
39mod provider;
40
41use crate::prediction::EditPrediction;
42pub use provider::ZetaEditPredictionProvider;
43
44const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
45
46/// Maximum number of events to track.
47const MAX_EVENT_COUNT: usize = 16;
48
49pub const DEFAULT_CONTEXT_OPTIONS: EditPredictionContextOptions = EditPredictionContextOptions {
50 use_imports: true,
51 max_retrieved_declarations: 0,
52 excerpt: EditPredictionExcerptOptions {
53 max_bytes: 512,
54 min_bytes: 128,
55 target_before_cursor_over_total_bytes: 0.5,
56 },
57 score: EditPredictionScoreOptions {
58 omit_excerpt_overlaps: true,
59 },
60};
61
62pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
63 context: DEFAULT_CONTEXT_OPTIONS,
64 max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
65 max_diagnostic_bytes: 2048,
66 prompt_format: PromptFormat::DEFAULT,
67 file_indexing_parallelism: 1,
68};
69
70pub struct Zeta2FeatureFlag;
71
72impl FeatureFlag for Zeta2FeatureFlag {
73 const NAME: &'static str = "zeta2";
74
75 fn enabled_for_staff() -> bool {
76 false
77 }
78}
79
80#[derive(Clone)]
81struct ZetaGlobal(Entity<Zeta>);
82
83impl Global for ZetaGlobal {}
84
85pub struct Zeta {
86 client: Arc<Client>,
87 user_store: Entity<UserStore>,
88 llm_token: LlmApiToken,
89 _llm_token_subscription: Subscription,
90 projects: HashMap<EntityId, ZetaProject>,
91 options: ZetaOptions,
92 update_required: bool,
93 debug_tx: Option<mpsc::UnboundedSender<PredictionDebugInfo>>,
94}
95
96#[derive(Debug, Clone, PartialEq)]
97pub struct ZetaOptions {
98 pub context: EditPredictionContextOptions,
99 pub max_prompt_bytes: usize,
100 pub max_diagnostic_bytes: usize,
101 pub prompt_format: predict_edits_v3::PromptFormat,
102 pub file_indexing_parallelism: usize,
103}
104
105pub struct PredictionDebugInfo {
106 pub context: EditPredictionContext,
107 pub retrieval_time: TimeDelta,
108 pub buffer: WeakEntity<Buffer>,
109 pub position: language::Anchor,
110 pub local_prompt: Result<String, String>,
111 pub response_rx: oneshot::Receiver<Result<RequestDebugInfo, String>>,
112}
113
114pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
115
116struct ZetaProject {
117 syntax_index: Entity<SyntaxIndex>,
118 events: VecDeque<Event>,
119 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
120 current_prediction: Option<CurrentEditPrediction>,
121}
122
123#[derive(Debug, Clone)]
124struct CurrentEditPrediction {
125 pub requested_by_buffer_id: EntityId,
126 pub prediction: EditPrediction,
127}
128
129impl CurrentEditPrediction {
130 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
131 let Some(new_edits) = self
132 .prediction
133 .interpolate(&self.prediction.buffer.read(cx))
134 else {
135 return false;
136 };
137
138 if self.prediction.buffer != old_prediction.prediction.buffer {
139 return true;
140 }
141
142 let Some(old_edits) = old_prediction
143 .prediction
144 .interpolate(&old_prediction.prediction.buffer.read(cx))
145 else {
146 return true;
147 };
148
149 // This reduces the occurrence of UI thrash from replacing edits
150 //
151 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
152 if self.requested_by_buffer_id == self.prediction.buffer.entity_id()
153 && self.requested_by_buffer_id == old_prediction.prediction.buffer.entity_id()
154 && old_edits.len() == 1
155 && new_edits.len() == 1
156 {
157 let (old_range, old_text) = &old_edits[0];
158 let (new_range, new_text) = &new_edits[0];
159 new_range == old_range && new_text.starts_with(old_text)
160 } else {
161 true
162 }
163 }
164}
165
166/// A prediction from the perspective of a buffer.
167#[derive(Debug)]
168enum BufferEditPrediction<'a> {
169 Local { prediction: &'a EditPrediction },
170 Jump { prediction: &'a EditPrediction },
171}
172
173struct RegisteredBuffer {
174 snapshot: BufferSnapshot,
175 _subscriptions: [gpui::Subscription; 2],
176}
177
178#[derive(Clone)]
179pub enum Event {
180 BufferChange {
181 old_snapshot: BufferSnapshot,
182 new_snapshot: BufferSnapshot,
183 timestamp: Instant,
184 },
185}
186
187impl Zeta {
188 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
189 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
190 }
191
192 pub fn global(
193 client: &Arc<Client>,
194 user_store: &Entity<UserStore>,
195 cx: &mut App,
196 ) -> Entity<Self> {
197 cx.try_global::<ZetaGlobal>()
198 .map(|global| global.0.clone())
199 .unwrap_or_else(|| {
200 let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
201 cx.set_global(ZetaGlobal(zeta.clone()));
202 zeta
203 })
204 }
205
206 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
207 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
208
209 Self {
210 projects: HashMap::new(),
211 client,
212 user_store,
213 options: DEFAULT_OPTIONS,
214 llm_token: LlmApiToken::default(),
215 _llm_token_subscription: cx.subscribe(
216 &refresh_llm_token_listener,
217 |this, _listener, _event, cx| {
218 let client = this.client.clone();
219 let llm_token = this.llm_token.clone();
220 cx.spawn(async move |_this, _cx| {
221 llm_token.refresh(&client).await?;
222 anyhow::Ok(())
223 })
224 .detach_and_log_err(cx);
225 },
226 ),
227 update_required: false,
228 debug_tx: None,
229 }
230 }
231
232 pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<PredictionDebugInfo> {
233 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
234 self.debug_tx = Some(debug_watch_tx);
235 debug_watch_rx
236 }
237
238 pub fn options(&self) -> &ZetaOptions {
239 &self.options
240 }
241
242 pub fn set_options(&mut self, options: ZetaOptions) {
243 self.options = options;
244 }
245
246 pub fn clear_history(&mut self) {
247 for zeta_project in self.projects.values_mut() {
248 zeta_project.events.clear();
249 }
250 }
251
252 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
253 self.user_store.read(cx).edit_prediction_usage()
254 }
255
256 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
257 self.get_or_init_zeta_project(project, cx);
258 }
259
260 pub fn register_buffer(
261 &mut self,
262 buffer: &Entity<Buffer>,
263 project: &Entity<Project>,
264 cx: &mut Context<Self>,
265 ) {
266 let zeta_project = self.get_or_init_zeta_project(project, cx);
267 Self::register_buffer_impl(zeta_project, buffer, project, cx);
268 }
269
270 fn get_or_init_zeta_project(
271 &mut self,
272 project: &Entity<Project>,
273 cx: &mut App,
274 ) -> &mut ZetaProject {
275 self.projects
276 .entry(project.entity_id())
277 .or_insert_with(|| ZetaProject {
278 syntax_index: cx.new(|cx| {
279 SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
280 }),
281 events: VecDeque::new(),
282 registered_buffers: HashMap::new(),
283 current_prediction: None,
284 })
285 }
286
287 fn register_buffer_impl<'a>(
288 zeta_project: &'a mut ZetaProject,
289 buffer: &Entity<Buffer>,
290 project: &Entity<Project>,
291 cx: &mut Context<Self>,
292 ) -> &'a mut RegisteredBuffer {
293 let buffer_id = buffer.entity_id();
294 match zeta_project.registered_buffers.entry(buffer_id) {
295 hash_map::Entry::Occupied(entry) => entry.into_mut(),
296 hash_map::Entry::Vacant(entry) => {
297 let snapshot = buffer.read(cx).snapshot();
298 let project_entity_id = project.entity_id();
299 entry.insert(RegisteredBuffer {
300 snapshot,
301 _subscriptions: [
302 cx.subscribe(buffer, {
303 let project = project.downgrade();
304 move |this, buffer, event, cx| {
305 if let language::BufferEvent::Edited = event
306 && let Some(project) = project.upgrade()
307 {
308 this.report_changes_for_buffer(&buffer, &project, cx);
309 }
310 }
311 }),
312 cx.observe_release(buffer, move |this, _buffer, _cx| {
313 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
314 else {
315 return;
316 };
317 zeta_project.registered_buffers.remove(&buffer_id);
318 }),
319 ],
320 })
321 }
322 }
323 }
324
325 fn report_changes_for_buffer(
326 &mut self,
327 buffer: &Entity<Buffer>,
328 project: &Entity<Project>,
329 cx: &mut Context<Self>,
330 ) -> BufferSnapshot {
331 let zeta_project = self.get_or_init_zeta_project(project, cx);
332 let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
333
334 let new_snapshot = buffer.read(cx).snapshot();
335 if new_snapshot.version != registered_buffer.snapshot.version {
336 let old_snapshot =
337 std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
338 Self::push_event(
339 zeta_project,
340 Event::BufferChange {
341 old_snapshot,
342 new_snapshot: new_snapshot.clone(),
343 timestamp: Instant::now(),
344 },
345 );
346 }
347
348 new_snapshot
349 }
350
351 fn push_event(zeta_project: &mut ZetaProject, event: Event) {
352 let events = &mut zeta_project.events;
353
354 if let Some(Event::BufferChange {
355 new_snapshot: last_new_snapshot,
356 timestamp: last_timestamp,
357 ..
358 }) = events.back_mut()
359 {
360 // Coalesce edits for the same buffer when they happen one after the other.
361 let Event::BufferChange {
362 old_snapshot,
363 new_snapshot,
364 timestamp,
365 } = &event;
366
367 if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
368 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
369 && old_snapshot.version == last_new_snapshot.version
370 {
371 *last_new_snapshot = new_snapshot.clone();
372 *last_timestamp = *timestamp;
373 return;
374 }
375 }
376
377 if events.len() >= MAX_EVENT_COUNT {
378 // These are halved instead of popping to improve prompt caching.
379 events.drain(..MAX_EVENT_COUNT / 2);
380 }
381
382 events.push_back(event);
383 }
384
385 fn current_prediction_for_buffer(
386 &self,
387 buffer: &Entity<Buffer>,
388 project: &Entity<Project>,
389 cx: &App,
390 ) -> Option<BufferEditPrediction<'_>> {
391 let project_state = self.projects.get(&project.entity_id())?;
392
393 let CurrentEditPrediction {
394 requested_by_buffer_id,
395 prediction,
396 } = project_state.current_prediction.as_ref()?;
397
398 if prediction.targets_buffer(buffer.read(cx), cx) {
399 Some(BufferEditPrediction::Local { prediction })
400 } else if *requested_by_buffer_id == buffer.entity_id() {
401 Some(BufferEditPrediction::Jump { prediction })
402 } else {
403 None
404 }
405 }
406
407 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
408 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
409 return;
410 };
411
412 let Some(prediction) = project_state.current_prediction.take() else {
413 return;
414 };
415 let request_id = prediction.prediction.id.into();
416
417 let client = self.client.clone();
418 let llm_token = self.llm_token.clone();
419 let app_version = AppVersion::global(cx);
420 cx.spawn(async move |this, cx| {
421 let url = if let Ok(predict_edits_url) = std::env::var("ZED_ACCEPT_PREDICTION_URL") {
422 http_client::Url::parse(&predict_edits_url)?
423 } else {
424 client
425 .http_client()
426 .build_zed_llm_url("/predict_edits/accept", &[])?
427 };
428
429 let response = cx
430 .background_spawn(Self::send_api_request::<()>(
431 move |builder| {
432 let req = builder.uri(url.as_ref()).body(
433 serde_json::to_string(&AcceptEditPredictionBody { request_id })?.into(),
434 );
435 Ok(req?)
436 },
437 client,
438 llm_token,
439 app_version,
440 ))
441 .await;
442
443 Self::handle_api_response(&this, response, cx)?;
444 anyhow::Ok(())
445 })
446 .detach_and_log_err(cx);
447 }
448
449 fn discard_current_prediction(&mut self, project: &Entity<Project>) {
450 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
451 project_state.current_prediction.take();
452 };
453 }
454
455 pub fn refresh_prediction(
456 &mut self,
457 project: &Entity<Project>,
458 buffer: &Entity<Buffer>,
459 position: language::Anchor,
460 cx: &mut Context<Self>,
461 ) -> Task<Result<()>> {
462 let request_task = self.request_prediction(project, buffer, position, cx);
463 let buffer = buffer.clone();
464 let project = project.clone();
465
466 cx.spawn(async move |this, cx| {
467 if let Some(prediction) = request_task.await? {
468 this.update(cx, |this, cx| {
469 let project_state = this
470 .projects
471 .get_mut(&project.entity_id())
472 .context("Project not found")?;
473
474 let new_prediction = CurrentEditPrediction {
475 requested_by_buffer_id: buffer.entity_id(),
476 prediction: prediction,
477 };
478
479 if project_state
480 .current_prediction
481 .as_ref()
482 .is_none_or(|old_prediction| {
483 new_prediction.should_replace_prediction(&old_prediction, cx)
484 })
485 {
486 project_state.current_prediction = Some(new_prediction);
487 }
488 anyhow::Ok(())
489 })??;
490 }
491 Ok(())
492 })
493 }
494
495 fn request_prediction(
496 &mut self,
497 project: &Entity<Project>,
498 buffer: &Entity<Buffer>,
499 position: language::Anchor,
500 cx: &mut Context<Self>,
501 ) -> Task<Result<Option<EditPrediction>>> {
502 let project_state = self.projects.get(&project.entity_id());
503
504 let index_state = project_state.map(|state| {
505 state
506 .syntax_index
507 .read_with(cx, |index, _cx| index.state().clone())
508 });
509 let options = self.options.clone();
510 let snapshot = buffer.read(cx).snapshot();
511 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx).into()) else {
512 return Task::ready(Err(anyhow!("No file path for excerpt")));
513 };
514 let client = self.client.clone();
515 let llm_token = self.llm_token.clone();
516 let app_version = AppVersion::global(cx);
517 let worktree_snapshots = project
518 .read(cx)
519 .worktrees(cx)
520 .map(|worktree| worktree.read(cx).snapshot())
521 .collect::<Vec<_>>();
522 let debug_tx = self.debug_tx.clone();
523
524 let events = project_state
525 .map(|state| {
526 state
527 .events
528 .iter()
529 .filter_map(|event| match event {
530 Event::BufferChange {
531 old_snapshot,
532 new_snapshot,
533 ..
534 } => {
535 let path = new_snapshot.file().map(|f| f.full_path(cx));
536
537 let old_path = old_snapshot.file().and_then(|f| {
538 let old_path = f.full_path(cx);
539 if Some(&old_path) != path.as_ref() {
540 Some(old_path)
541 } else {
542 None
543 }
544 });
545
546 // TODO [zeta2] move to bg?
547 let diff =
548 language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
549
550 if path == old_path && diff.is_empty() {
551 None
552 } else {
553 Some(predict_edits_v3::Event::BufferChange {
554 old_path,
555 path,
556 diff,
557 //todo: Actually detect if this edit was predicted or not
558 predicted: false,
559 })
560 }
561 }
562 })
563 .collect::<Vec<_>>()
564 })
565 .unwrap_or_default();
566
567 let diagnostics = snapshot.diagnostic_sets().clone();
568
569 let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
570 let mut path = f.worktree.read(cx).absolutize(&f.path);
571 if path.pop() { Some(path) } else { None }
572 });
573
574 let request_task = cx.background_spawn({
575 let snapshot = snapshot.clone();
576 let buffer = buffer.clone();
577 async move {
578 let index_state = if let Some(index_state) = index_state {
579 Some(index_state.lock_owned().await)
580 } else {
581 None
582 };
583
584 let cursor_offset = position.to_offset(&snapshot);
585 let cursor_point = cursor_offset.to_point(&snapshot);
586
587 let before_retrieval = chrono::Utc::now();
588
589 let Some(context) = EditPredictionContext::gather_context(
590 cursor_point,
591 &snapshot,
592 parent_abs_path.as_deref(),
593 &options.context,
594 index_state.as_deref(),
595 ) else {
596 return Ok((None, None));
597 };
598
599 let retrieval_time = chrono::Utc::now() - before_retrieval;
600
601 let (diagnostic_groups, diagnostic_groups_truncated) =
602 Self::gather_nearby_diagnostics(
603 cursor_offset,
604 &diagnostics,
605 &snapshot,
606 options.max_diagnostic_bytes,
607 );
608
609 let debug_context = debug_tx.map(|tx| (tx, context.clone()));
610
611 let request = make_cloud_request(
612 excerpt_path,
613 context,
614 events,
615 // TODO data collection
616 false,
617 diagnostic_groups,
618 diagnostic_groups_truncated,
619 None,
620 debug_context.is_some(),
621 &worktree_snapshots,
622 index_state.as_deref(),
623 Some(options.max_prompt_bytes),
624 options.prompt_format,
625 );
626
627 let debug_response_tx = if let Some((debug_tx, context)) = debug_context {
628 let (response_tx, response_rx) = oneshot::channel();
629
630 let local_prompt = PlannedPrompt::populate(&request)
631 .and_then(|p| p.to_prompt_string().map(|p| p.0))
632 .map_err(|err| err.to_string());
633
634 debug_tx
635 .unbounded_send(PredictionDebugInfo {
636 context,
637 retrieval_time,
638 buffer: buffer.downgrade(),
639 local_prompt,
640 position,
641 response_rx,
642 })
643 .ok();
644 Some(response_tx)
645 } else {
646 None
647 };
648
649 if cfg!(debug_assertions) && std::env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
650 if let Some(debug_response_tx) = debug_response_tx {
651 debug_response_tx
652 .send(Err("Request skipped".to_string()))
653 .ok();
654 }
655 anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
656 }
657
658 let response =
659 Self::send_prediction_request(client, llm_token, app_version, request).await;
660
661 if let Some(debug_response_tx) = debug_response_tx {
662 debug_response_tx
663 .send(response.as_ref().map_err(|err| err.to_string()).and_then(
664 |response| match some_or_debug_panic(response.0.debug_info.clone()) {
665 Some(debug_info) => Ok(debug_info),
666 None => Err("Missing debug info".to_string()),
667 },
668 ))
669 .ok();
670 }
671
672 response.map(|(res, usage)| (Some(res), usage))
673 }
674 });
675
676 let buffer = buffer.clone();
677
678 cx.spawn({
679 let project = project.clone();
680 async move |this, cx| {
681 let Some(response) = Self::handle_api_response(&this, request_task.await, cx)?
682 else {
683 return Ok(None);
684 };
685
686 // TODO telemetry: duration, etc
687 Ok(EditPrediction::from_response(response, &snapshot, &buffer, &project, cx).await)
688 }
689 })
690 }
691
692 async fn send_prediction_request(
693 client: Arc<Client>,
694 llm_token: LlmApiToken,
695 app_version: SemanticVersion,
696 request: predict_edits_v3::PredictEditsRequest,
697 ) -> Result<(
698 predict_edits_v3::PredictEditsResponse,
699 Option<EditPredictionUsage>,
700 )> {
701 let url = if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
702 http_client::Url::parse(&predict_edits_url)?
703 } else {
704 client
705 .http_client()
706 .build_zed_llm_url("/predict_edits/v3", &[])?
707 };
708
709 Self::send_api_request(
710 |builder| {
711 let req = builder
712 .uri(url.as_ref())
713 .body(serde_json::to_string(&request)?.into());
714 Ok(req?)
715 },
716 client,
717 llm_token,
718 app_version,
719 )
720 .await
721 }
722
723 fn handle_api_response<T>(
724 this: &WeakEntity<Self>,
725 response: Result<(T, Option<EditPredictionUsage>)>,
726 cx: &mut gpui::AsyncApp,
727 ) -> Result<T> {
728 match response {
729 Ok((data, usage)) => {
730 if let Some(usage) = usage {
731 this.update(cx, |this, cx| {
732 this.user_store.update(cx, |user_store, cx| {
733 user_store.update_edit_prediction_usage(usage, cx);
734 });
735 })
736 .ok();
737 }
738 Ok(data)
739 }
740 Err(err) => {
741 if err.is::<ZedUpdateRequiredError>() {
742 cx.update(|cx| {
743 this.update(cx, |this, _cx| {
744 this.update_required = true;
745 })
746 .ok();
747
748 let error_message: SharedString = err.to_string().into();
749 show_app_notification(
750 NotificationId::unique::<ZedUpdateRequiredError>(),
751 cx,
752 move |cx| {
753 cx.new(|cx| {
754 ErrorMessagePrompt::new(error_message.clone(), cx)
755 .with_link_button("Update Zed", "https://zed.dev/releases")
756 })
757 },
758 );
759 })
760 .ok();
761 }
762 Err(err)
763 }
764 }
765 }
766
767 async fn send_api_request<Res>(
768 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
769 client: Arc<Client>,
770 llm_token: LlmApiToken,
771 app_version: SemanticVersion,
772 ) -> Result<(Res, Option<EditPredictionUsage>)>
773 where
774 Res: DeserializeOwned,
775 {
776 let http_client = client.http_client();
777 let mut token = llm_token.acquire(&client).await?;
778 let mut did_retry = false;
779
780 loop {
781 let request_builder = http_client::Request::builder().method(Method::POST);
782
783 let request = build(
784 request_builder
785 .header("Content-Type", "application/json")
786 .header("Authorization", format!("Bearer {}", token))
787 .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
788 )?;
789
790 let mut response = http_client.send(request).await?;
791
792 if let Some(minimum_required_version) = response
793 .headers()
794 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
795 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
796 {
797 anyhow::ensure!(
798 app_version >= minimum_required_version,
799 ZedUpdateRequiredError {
800 minimum_version: minimum_required_version
801 }
802 );
803 }
804
805 if response.status().is_success() {
806 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
807
808 let mut body = Vec::new();
809 response.body_mut().read_to_end(&mut body).await?;
810 return Ok((serde_json::from_slice(&body)?, usage));
811 } else if !did_retry
812 && response
813 .headers()
814 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
815 .is_some()
816 {
817 did_retry = true;
818 token = llm_token.refresh(&client).await?;
819 } else {
820 let mut body = String::new();
821 response.body_mut().read_to_string(&mut body).await?;
822 anyhow::bail!(
823 "Request failed with status: {:?}\nBody: {}",
824 response.status(),
825 body
826 );
827 }
828 }
829 }
830
831 fn gather_nearby_diagnostics(
832 cursor_offset: usize,
833 diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
834 snapshot: &BufferSnapshot,
835 max_diagnostics_bytes: usize,
836 ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
837 // TODO: Could make this more efficient
838 let mut diagnostic_groups = Vec::new();
839 for (language_server_id, diagnostics) in diagnostic_sets {
840 let mut groups = Vec::new();
841 diagnostics.groups(*language_server_id, &mut groups, &snapshot);
842 diagnostic_groups.extend(
843 groups
844 .into_iter()
845 .map(|(_, group)| group.resolve::<usize>(&snapshot)),
846 );
847 }
848
849 // sort by proximity to cursor
850 diagnostic_groups.sort_by_key(|group| {
851 let range = &group.entries[group.primary_ix].range;
852 if range.start >= cursor_offset {
853 range.start - cursor_offset
854 } else if cursor_offset >= range.end {
855 cursor_offset - range.end
856 } else {
857 (cursor_offset - range.start).min(range.end - cursor_offset)
858 }
859 });
860
861 let mut results = Vec::new();
862 let mut diagnostic_groups_truncated = false;
863 let mut diagnostics_byte_count = 0;
864 for group in diagnostic_groups {
865 let raw_value = serde_json::value::to_raw_value(&group).unwrap();
866 diagnostics_byte_count += raw_value.get().len();
867 if diagnostics_byte_count > max_diagnostics_bytes {
868 diagnostic_groups_truncated = true;
869 break;
870 }
871 results.push(predict_edits_v3::DiagnosticGroup(raw_value));
872 }
873
874 (results, diagnostic_groups_truncated)
875 }
876
877 // TODO: Dedupe with similar code in request_prediction?
878 pub fn cloud_request_for_zeta_cli(
879 &mut self,
880 project: &Entity<Project>,
881 buffer: &Entity<Buffer>,
882 position: language::Anchor,
883 cx: &mut Context<Self>,
884 ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
885 let project_state = self.projects.get(&project.entity_id());
886
887 let index_state = project_state.map(|state| {
888 state
889 .syntax_index
890 .read_with(cx, |index, _cx| index.state().clone())
891 });
892 let options = self.options.clone();
893 let snapshot = buffer.read(cx).snapshot();
894 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
895 return Task::ready(Err(anyhow!("No file path for excerpt")));
896 };
897 let worktree_snapshots = project
898 .read(cx)
899 .worktrees(cx)
900 .map(|worktree| worktree.read(cx).snapshot())
901 .collect::<Vec<_>>();
902
903 let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
904 let mut path = f.worktree.read(cx).absolutize(&f.path);
905 if path.pop() { Some(path) } else { None }
906 });
907
908 cx.background_spawn(async move {
909 let index_state = if let Some(index_state) = index_state {
910 Some(index_state.lock_owned().await)
911 } else {
912 None
913 };
914
915 let cursor_point = position.to_point(&snapshot);
916
917 let debug_info = true;
918 EditPredictionContext::gather_context(
919 cursor_point,
920 &snapshot,
921 parent_abs_path.as_deref(),
922 &options.context,
923 index_state.as_deref(),
924 )
925 .context("Failed to select excerpt")
926 .map(|context| {
927 make_cloud_request(
928 excerpt_path.into(),
929 context,
930 // TODO pass everything
931 Vec::new(),
932 false,
933 Vec::new(),
934 false,
935 None,
936 debug_info,
937 &worktree_snapshots,
938 index_state.as_deref(),
939 Some(options.max_prompt_bytes),
940 options.prompt_format,
941 )
942 })
943 })
944 }
945
946 pub fn wait_for_initial_indexing(
947 &mut self,
948 project: &Entity<Project>,
949 cx: &mut App,
950 ) -> Task<Result<()>> {
951 let zeta_project = self.get_or_init_zeta_project(project, cx);
952 zeta_project
953 .syntax_index
954 .read(cx)
955 .wait_for_initial_file_indexing(cx)
956 }
957}
958
959#[derive(Error, Debug)]
960#[error(
961 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
962)]
963pub struct ZedUpdateRequiredError {
964 minimum_version: SemanticVersion,
965}
966
967fn make_cloud_request(
968 excerpt_path: Arc<Path>,
969 context: EditPredictionContext,
970 events: Vec<predict_edits_v3::Event>,
971 can_collect_data: bool,
972 diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
973 diagnostic_groups_truncated: bool,
974 git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
975 debug_info: bool,
976 worktrees: &Vec<worktree::Snapshot>,
977 index_state: Option<&SyntaxIndexState>,
978 prompt_max_bytes: Option<usize>,
979 prompt_format: PromptFormat,
980) -> predict_edits_v3::PredictEditsRequest {
981 let mut signatures = Vec::new();
982 let mut declaration_to_signature_index = HashMap::default();
983 let mut referenced_declarations = Vec::new();
984
985 for snippet in context.declarations {
986 let project_entry_id = snippet.declaration.project_entry_id();
987 let Some(path) = worktrees.iter().find_map(|worktree| {
988 worktree.entry_for_id(project_entry_id).map(|entry| {
989 let mut full_path = RelPathBuf::new();
990 full_path.push(worktree.root_name());
991 full_path.push(&entry.path);
992 full_path
993 })
994 }) else {
995 continue;
996 };
997
998 let parent_index = index_state.and_then(|index_state| {
999 snippet.declaration.parent().and_then(|parent| {
1000 add_signature(
1001 parent,
1002 &mut declaration_to_signature_index,
1003 &mut signatures,
1004 index_state,
1005 )
1006 })
1007 });
1008
1009 let (text, text_is_truncated) = snippet.declaration.item_text();
1010 referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
1011 path: path.as_std_path().into(),
1012 text: text.into(),
1013 range: snippet.declaration.item_line_range(),
1014 text_is_truncated,
1015 signature_range: snippet.declaration.signature_range_in_item_text(),
1016 parent_index,
1017 signature_score: snippet.score(DeclarationStyle::Signature),
1018 declaration_score: snippet.score(DeclarationStyle::Declaration),
1019 score_components: snippet.components,
1020 });
1021 }
1022
1023 let excerpt_parent = index_state.and_then(|index_state| {
1024 context
1025 .excerpt
1026 .parent_declarations
1027 .last()
1028 .and_then(|(parent, _)| {
1029 add_signature(
1030 *parent,
1031 &mut declaration_to_signature_index,
1032 &mut signatures,
1033 index_state,
1034 )
1035 })
1036 });
1037
1038 predict_edits_v3::PredictEditsRequest {
1039 excerpt_path,
1040 excerpt: context.excerpt_text.body,
1041 excerpt_line_range: context.excerpt.line_range,
1042 excerpt_range: context.excerpt.range,
1043 cursor_point: predict_edits_v3::Point {
1044 line: predict_edits_v3::Line(context.cursor_point.row),
1045 column: context.cursor_point.column,
1046 },
1047 referenced_declarations,
1048 signatures,
1049 excerpt_parent,
1050 events,
1051 can_collect_data,
1052 diagnostic_groups,
1053 diagnostic_groups_truncated,
1054 git_info,
1055 debug_info,
1056 prompt_max_bytes,
1057 prompt_format,
1058 }
1059}
1060
1061fn add_signature(
1062 declaration_id: DeclarationId,
1063 declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
1064 signatures: &mut Vec<Signature>,
1065 index: &SyntaxIndexState,
1066) -> Option<usize> {
1067 if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
1068 return Some(*signature_index);
1069 }
1070 let Some(parent_declaration) = index.declaration(declaration_id) else {
1071 log::error!("bug: missing parent declaration");
1072 return None;
1073 };
1074 let parent_index = parent_declaration.parent().and_then(|parent| {
1075 add_signature(parent, declaration_to_signature_index, signatures, index)
1076 });
1077 let (text, text_is_truncated) = parent_declaration.signature_text();
1078 let signature_index = signatures.len();
1079 signatures.push(Signature {
1080 text: text.into(),
1081 text_is_truncated,
1082 parent_index,
1083 range: parent_declaration.signature_line_range(),
1084 });
1085 declaration_to_signature_index.insert(declaration_id, signature_index);
1086 Some(signature_index)
1087}
1088
1089#[cfg(test)]
1090mod tests {
1091 use std::{
1092 path::{Path, PathBuf},
1093 sync::Arc,
1094 };
1095
1096 use client::UserStore;
1097 use clock::FakeSystemClock;
1098 use cloud_llm_client::predict_edits_v3::{self, Point};
1099 use edit_prediction_context::Line;
1100 use futures::{
1101 AsyncReadExt, StreamExt,
1102 channel::{mpsc, oneshot},
1103 };
1104 use gpui::{
1105 Entity, TestAppContext,
1106 http_client::{FakeHttpClient, Response},
1107 prelude::*,
1108 };
1109 use indoc::indoc;
1110 use language::{LanguageServerId, OffsetRangeExt as _};
1111 use pretty_assertions::{assert_eq, assert_matches};
1112 use project::{FakeFs, Project};
1113 use serde_json::json;
1114 use settings::SettingsStore;
1115 use util::path;
1116 use uuid::Uuid;
1117
1118 use crate::{BufferEditPrediction, Zeta};
1119
1120 #[gpui::test]
1121 async fn test_current_state(cx: &mut TestAppContext) {
1122 let (zeta, mut req_rx) = init_test(cx);
1123 let fs = FakeFs::new(cx.executor());
1124 fs.insert_tree(
1125 "/root",
1126 json!({
1127 "1.txt": "Hello!\nHow\nBye",
1128 "2.txt": "Hola!\nComo\nAdios"
1129 }),
1130 )
1131 .await;
1132 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1133
1134 zeta.update(cx, |zeta, cx| {
1135 zeta.register_project(&project, cx);
1136 });
1137
1138 let buffer1 = project
1139 .update(cx, |project, cx| {
1140 let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
1141 project.open_buffer(path, cx)
1142 })
1143 .await
1144 .unwrap();
1145 let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
1146 let position = snapshot1.anchor_before(language::Point::new(1, 3));
1147
1148 // Prediction for current file
1149
1150 let prediction_task = zeta.update(cx, |zeta, cx| {
1151 zeta.refresh_prediction(&project, &buffer1, position, cx)
1152 });
1153 let (_request, respond_tx) = req_rx.next().await.unwrap();
1154 respond_tx
1155 .send(predict_edits_v3::PredictEditsResponse {
1156 request_id: Uuid::new_v4(),
1157 edits: vec![predict_edits_v3::Edit {
1158 path: Path::new(path!("root/1.txt")).into(),
1159 range: Line(0)..Line(snapshot1.max_point().row + 1),
1160 content: "Hello!\nHow are you?\nBye".into(),
1161 }],
1162 debug_info: None,
1163 })
1164 .unwrap();
1165 prediction_task.await.unwrap();
1166
1167 zeta.read_with(cx, |zeta, cx| {
1168 let prediction = zeta
1169 .current_prediction_for_buffer(&buffer1, &project, cx)
1170 .unwrap();
1171 assert_matches!(prediction, BufferEditPrediction::Local { .. });
1172 });
1173
1174 // Prediction for another file
1175 let prediction_task = zeta.update(cx, |zeta, cx| {
1176 zeta.refresh_prediction(&project, &buffer1, position, cx)
1177 });
1178 let (_request, respond_tx) = req_rx.next().await.unwrap();
1179 respond_tx
1180 .send(predict_edits_v3::PredictEditsResponse {
1181 request_id: Uuid::new_v4(),
1182 edits: vec![predict_edits_v3::Edit {
1183 path: Path::new(path!("root/2.txt")).into(),
1184 range: Line(0)..Line(snapshot1.max_point().row + 1),
1185 content: "Hola!\nComo estas?\nAdios".into(),
1186 }],
1187 debug_info: None,
1188 })
1189 .unwrap();
1190 prediction_task.await.unwrap();
1191 zeta.read_with(cx, |zeta, cx| {
1192 let prediction = zeta
1193 .current_prediction_for_buffer(&buffer1, &project, cx)
1194 .unwrap();
1195 assert_matches!(
1196 prediction,
1197 BufferEditPrediction::Jump { prediction } if prediction.path.as_ref() == Path::new(path!("root/2.txt"))
1198 );
1199 });
1200
1201 let buffer2 = project
1202 .update(cx, |project, cx| {
1203 let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
1204 project.open_buffer(path, cx)
1205 })
1206 .await
1207 .unwrap();
1208
1209 zeta.read_with(cx, |zeta, cx| {
1210 let prediction = zeta
1211 .current_prediction_for_buffer(&buffer2, &project, cx)
1212 .unwrap();
1213 assert_matches!(prediction, BufferEditPrediction::Local { .. });
1214 });
1215 }
1216
1217 #[gpui::test]
1218 async fn test_simple_request(cx: &mut TestAppContext) {
1219 let (zeta, mut req_rx) = init_test(cx);
1220 let fs = FakeFs::new(cx.executor());
1221 fs.insert_tree(
1222 "/root",
1223 json!({
1224 "foo.md": "Hello!\nHow\nBye"
1225 }),
1226 )
1227 .await;
1228 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1229
1230 let buffer = project
1231 .update(cx, |project, cx| {
1232 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1233 project.open_buffer(path, cx)
1234 })
1235 .await
1236 .unwrap();
1237 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1238 let position = snapshot.anchor_before(language::Point::new(1, 3));
1239
1240 let prediction_task = zeta.update(cx, |zeta, cx| {
1241 zeta.request_prediction(&project, &buffer, position, cx)
1242 });
1243
1244 let (request, respond_tx) = req_rx.next().await.unwrap();
1245 assert_eq!(
1246 request.excerpt_path.as_ref(),
1247 Path::new(path!("root/foo.md"))
1248 );
1249 assert_eq!(
1250 request.cursor_point,
1251 Point {
1252 line: Line(1),
1253 column: 3
1254 }
1255 );
1256
1257 respond_tx
1258 .send(predict_edits_v3::PredictEditsResponse {
1259 request_id: Uuid::new_v4(),
1260 edits: vec![predict_edits_v3::Edit {
1261 path: Path::new(path!("root/foo.md")).into(),
1262 range: Line(0)..Line(snapshot.max_point().row + 1),
1263 content: "Hello!\nHow are you?\nBye".into(),
1264 }],
1265 debug_info: None,
1266 })
1267 .unwrap();
1268
1269 let prediction = prediction_task.await.unwrap().unwrap();
1270
1271 assert_eq!(prediction.edits.len(), 1);
1272 assert_eq!(
1273 prediction.edits[0].0.to_point(&snapshot).start,
1274 language::Point::new(1, 3)
1275 );
1276 assert_eq!(prediction.edits[0].1, " are you?");
1277 }
1278
1279 #[gpui::test]
1280 async fn test_request_events(cx: &mut TestAppContext) {
1281 let (zeta, mut req_rx) = init_test(cx);
1282 let fs = FakeFs::new(cx.executor());
1283 fs.insert_tree(
1284 "/root",
1285 json!({
1286 "foo.md": "Hello!\n\nBye"
1287 }),
1288 )
1289 .await;
1290 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1291
1292 let buffer = project
1293 .update(cx, |project, cx| {
1294 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1295 project.open_buffer(path, cx)
1296 })
1297 .await
1298 .unwrap();
1299
1300 zeta.update(cx, |zeta, cx| {
1301 zeta.register_buffer(&buffer, &project, cx);
1302 });
1303
1304 buffer.update(cx, |buffer, cx| {
1305 buffer.edit(vec![(7..7, "How")], None, cx);
1306 });
1307
1308 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1309 let position = snapshot.anchor_before(language::Point::new(1, 3));
1310
1311 let prediction_task = zeta.update(cx, |zeta, cx| {
1312 zeta.request_prediction(&project, &buffer, position, cx)
1313 });
1314
1315 let (request, respond_tx) = req_rx.next().await.unwrap();
1316
1317 assert_eq!(request.events.len(), 1);
1318 assert_eq!(
1319 request.events[0],
1320 predict_edits_v3::Event::BufferChange {
1321 path: Some(PathBuf::from(path!("root/foo.md"))),
1322 old_path: None,
1323 diff: indoc! {"
1324 @@ -1,3 +1,3 @@
1325 Hello!
1326 -
1327 +How
1328 Bye
1329 "}
1330 .to_string(),
1331 predicted: false
1332 }
1333 );
1334
1335 respond_tx
1336 .send(predict_edits_v3::PredictEditsResponse {
1337 request_id: Uuid::new_v4(),
1338 edits: vec![predict_edits_v3::Edit {
1339 path: Path::new(path!("root/foo.md")).into(),
1340 range: Line(0)..Line(snapshot.max_point().row + 1),
1341 content: "Hello!\nHow are you?\nBye".into(),
1342 }],
1343 debug_info: None,
1344 })
1345 .unwrap();
1346
1347 let prediction = prediction_task.await.unwrap().unwrap();
1348
1349 assert_eq!(prediction.edits.len(), 1);
1350 assert_eq!(
1351 prediction.edits[0].0.to_point(&snapshot).start,
1352 language::Point::new(1, 3)
1353 );
1354 assert_eq!(prediction.edits[0].1, " are you?");
1355 }
1356
1357 #[gpui::test]
1358 async fn test_request_diagnostics(cx: &mut TestAppContext) {
1359 let (zeta, mut req_rx) = init_test(cx);
1360 let fs = FakeFs::new(cx.executor());
1361 fs.insert_tree(
1362 "/root",
1363 json!({
1364 "foo.md": "Hello!\nBye"
1365 }),
1366 )
1367 .await;
1368 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1369
1370 let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1371 let diagnostic = lsp::Diagnostic {
1372 range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1373 severity: Some(lsp::DiagnosticSeverity::ERROR),
1374 message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1375 ..Default::default()
1376 };
1377
1378 project.update(cx, |project, cx| {
1379 project.lsp_store().update(cx, |lsp_store, cx| {
1380 // Create some diagnostics
1381 lsp_store
1382 .update_diagnostics(
1383 LanguageServerId(0),
1384 lsp::PublishDiagnosticsParams {
1385 uri: path_to_buffer_uri.clone(),
1386 diagnostics: vec![diagnostic],
1387 version: None,
1388 },
1389 None,
1390 language::DiagnosticSourceKind::Pushed,
1391 &[],
1392 cx,
1393 )
1394 .unwrap();
1395 });
1396 });
1397
1398 let buffer = project
1399 .update(cx, |project, cx| {
1400 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1401 project.open_buffer(path, cx)
1402 })
1403 .await
1404 .unwrap();
1405
1406 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1407 let position = snapshot.anchor_before(language::Point::new(0, 0));
1408
1409 let _prediction_task = zeta.update(cx, |zeta, cx| {
1410 zeta.request_prediction(&project, &buffer, position, cx)
1411 });
1412
1413 let (request, _respond_tx) = req_rx.next().await.unwrap();
1414
1415 assert_eq!(request.diagnostic_groups.len(), 1);
1416 let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1417 .unwrap();
1418 // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1419 assert_eq!(
1420 value,
1421 json!({
1422 "entries": [{
1423 "range": {
1424 "start": 8,
1425 "end": 10
1426 },
1427 "diagnostic": {
1428 "source": null,
1429 "code": null,
1430 "code_description": null,
1431 "severity": 1,
1432 "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1433 "markdown": null,
1434 "group_id": 0,
1435 "is_primary": true,
1436 "is_disk_based": false,
1437 "is_unnecessary": false,
1438 "source_kind": "Pushed",
1439 "data": null,
1440 "underline": true
1441 }
1442 }],
1443 "primary_ix": 0
1444 })
1445 );
1446 }
1447
1448 fn init_test(
1449 cx: &mut TestAppContext,
1450 ) -> (
1451 Entity<Zeta>,
1452 mpsc::UnboundedReceiver<(
1453 predict_edits_v3::PredictEditsRequest,
1454 oneshot::Sender<predict_edits_v3::PredictEditsResponse>,
1455 )>,
1456 ) {
1457 cx.update(move |cx| {
1458 let settings_store = SettingsStore::test(cx);
1459 cx.set_global(settings_store);
1460 language::init(cx);
1461 Project::init_settings(cx);
1462
1463 let (req_tx, req_rx) = mpsc::unbounded();
1464
1465 let http_client = FakeHttpClient::create({
1466 move |req| {
1467 let uri = req.uri().path().to_string();
1468 let mut body = req.into_body();
1469 let req_tx = req_tx.clone();
1470 async move {
1471 let resp = match uri.as_str() {
1472 "/client/llm_tokens" => serde_json::to_string(&json!({
1473 "token": "test"
1474 }))
1475 .unwrap(),
1476 "/predict_edits/v3" => {
1477 let mut buf = Vec::new();
1478 body.read_to_end(&mut buf).await.ok();
1479 let req = serde_json::from_slice(&buf).unwrap();
1480
1481 let (res_tx, res_rx) = oneshot::channel();
1482 req_tx.unbounded_send((req, res_tx)).unwrap();
1483 serde_json::to_string(&res_rx.await?).unwrap()
1484 }
1485 _ => {
1486 panic!("Unexpected path: {}", uri)
1487 }
1488 };
1489
1490 Ok(Response::builder().body(resp.into()).unwrap())
1491 }
1492 }
1493 });
1494
1495 let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1496 client.cloud_client().set_credentials(1, "test".into());
1497
1498 language_model::init(client.clone(), cx);
1499
1500 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1501 let zeta = Zeta::global(&client, &user_store, cx);
1502
1503 (zeta, req_rx)
1504 })
1505 }
1506}