1use anyhow::{Context as _, Result, anyhow};
2use chrono::TimeDelta;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
5use cloud_llm_client::{
6 AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
7 ZED_VERSION_HEADER_NAME,
8};
9use cloud_zeta2_prompt::{DEFAULT_MAX_PROMPT_BYTES, PlannedPrompt};
10use edit_prediction_context::{
11 DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
12 EditPredictionExcerptOptions, EditPredictionScoreOptions, SimilarSnippetOptions, SyntaxIndex,
13 SyntaxIndexState,
14};
15use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
16use futures::AsyncReadExt as _;
17use futures::channel::{mpsc, oneshot};
18use gpui::http_client::{AsyncBody, Method};
19use gpui::{
20 App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
21 http_client, prelude::*,
22};
23use language::BufferSnapshot;
24use language::{Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
25use language_model::{LlmApiToken, RefreshLlmTokenListener};
26use project::Project;
27use release_channel::AppVersion;
28use serde::de::DeserializeOwned;
29use std::collections::{HashMap, VecDeque, hash_map};
30use std::path::Path;
31use std::str::FromStr as _;
32use std::sync::Arc;
33use std::time::{Duration, Instant};
34use thiserror::Error;
35use util::rel_path::RelPathBuf;
36use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
37
38mod prediction;
39mod provider;
40
41use crate::prediction::EditPrediction;
42pub use provider::ZetaEditPredictionProvider;
43
44const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
45
46/// Maximum number of events to track.
47const MAX_EVENT_COUNT: usize = 16;
48
49pub const DEFAULT_CONTEXT_OPTIONS: EditPredictionContextOptions = EditPredictionContextOptions {
50 use_imports: true,
51 max_retrieved_declarations: 0,
52 excerpt: EditPredictionExcerptOptions {
53 max_bytes: 512,
54 min_bytes: 128,
55 target_before_cursor_over_total_bytes: 0.5,
56 },
57 score: EditPredictionScoreOptions {
58 omit_excerpt_overlaps: true,
59 },
60 similar_snippets: SimilarSnippetOptions::DEFAULT,
61};
62
63pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
64 context: DEFAULT_CONTEXT_OPTIONS,
65 max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
66 max_diagnostic_bytes: 2048,
67 prompt_format: PromptFormat::DEFAULT,
68 file_indexing_parallelism: 1,
69};
70
71pub struct Zeta2FeatureFlag;
72
73impl FeatureFlag for Zeta2FeatureFlag {
74 const NAME: &'static str = "zeta2";
75
76 fn enabled_for_staff() -> bool {
77 false
78 }
79}
80
81#[derive(Clone)]
82struct ZetaGlobal(Entity<Zeta>);
83
84impl Global for ZetaGlobal {}
85
86pub struct Zeta {
87 client: Arc<Client>,
88 user_store: Entity<UserStore>,
89 llm_token: LlmApiToken,
90 _llm_token_subscription: Subscription,
91 projects: HashMap<EntityId, ZetaProject>,
92 options: ZetaOptions,
93 update_required: bool,
94 debug_tx: Option<mpsc::UnboundedSender<PredictionDebugInfo>>,
95}
96
97#[derive(Debug, Clone, PartialEq)]
98pub struct ZetaOptions {
99 pub context: EditPredictionContextOptions,
100 pub max_prompt_bytes: usize,
101 pub max_diagnostic_bytes: usize,
102 pub prompt_format: predict_edits_v3::PromptFormat,
103 pub file_indexing_parallelism: usize,
104}
105
106pub struct PredictionDebugInfo {
107 pub request: predict_edits_v3::PredictEditsRequest,
108 pub retrieval_time: TimeDelta,
109 pub buffer: WeakEntity<Buffer>,
110 pub position: language::Anchor,
111 pub local_prompt: Result<String, String>,
112 pub response_rx: oneshot::Receiver<Result<predict_edits_v3::PredictEditsResponse, String>>,
113}
114
115pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
116
117struct ZetaProject {
118 syntax_index: Entity<SyntaxIndex>,
119 events: VecDeque<Event>,
120 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
121 current_prediction: Option<CurrentEditPrediction>,
122}
123
124#[derive(Debug, Clone)]
125struct CurrentEditPrediction {
126 pub requested_by_buffer_id: EntityId,
127 pub prediction: EditPrediction,
128}
129
130impl CurrentEditPrediction {
131 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
132 let Some(new_edits) = self
133 .prediction
134 .interpolate(&self.prediction.buffer.read(cx))
135 else {
136 return false;
137 };
138
139 if self.prediction.buffer != old_prediction.prediction.buffer {
140 return true;
141 }
142
143 let Some(old_edits) = old_prediction
144 .prediction
145 .interpolate(&old_prediction.prediction.buffer.read(cx))
146 else {
147 return true;
148 };
149
150 // This reduces the occurrence of UI thrash from replacing edits
151 //
152 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
153 if self.requested_by_buffer_id == self.prediction.buffer.entity_id()
154 && self.requested_by_buffer_id == old_prediction.prediction.buffer.entity_id()
155 && old_edits.len() == 1
156 && new_edits.len() == 1
157 {
158 let (old_range, old_text) = &old_edits[0];
159 let (new_range, new_text) = &new_edits[0];
160 new_range == old_range && new_text.starts_with(old_text)
161 } else {
162 true
163 }
164 }
165}
166
167/// A prediction from the perspective of a buffer.
168#[derive(Debug)]
169enum BufferEditPrediction<'a> {
170 Local { prediction: &'a EditPrediction },
171 Jump { prediction: &'a EditPrediction },
172}
173
174struct RegisteredBuffer {
175 snapshot: BufferSnapshot,
176 _subscriptions: [gpui::Subscription; 2],
177}
178
179#[derive(Clone)]
180pub enum Event {
181 BufferChange {
182 old_snapshot: BufferSnapshot,
183 new_snapshot: BufferSnapshot,
184 timestamp: Instant,
185 },
186}
187
188impl Zeta {
189 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
190 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
191 }
192
193 pub fn global(
194 client: &Arc<Client>,
195 user_store: &Entity<UserStore>,
196 cx: &mut App,
197 ) -> Entity<Self> {
198 cx.try_global::<ZetaGlobal>()
199 .map(|global| global.0.clone())
200 .unwrap_or_else(|| {
201 let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
202 cx.set_global(ZetaGlobal(zeta.clone()));
203 zeta
204 })
205 }
206
207 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
208 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
209
210 Self {
211 projects: HashMap::new(),
212 client,
213 user_store,
214 options: DEFAULT_OPTIONS,
215 llm_token: LlmApiToken::default(),
216 _llm_token_subscription: cx.subscribe(
217 &refresh_llm_token_listener,
218 |this, _listener, _event, cx| {
219 let client = this.client.clone();
220 let llm_token = this.llm_token.clone();
221 cx.spawn(async move |_this, _cx| {
222 llm_token.refresh(&client).await?;
223 anyhow::Ok(())
224 })
225 .detach_and_log_err(cx);
226 },
227 ),
228 update_required: false,
229 debug_tx: None,
230 }
231 }
232
233 pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<PredictionDebugInfo> {
234 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
235 self.debug_tx = Some(debug_watch_tx);
236 debug_watch_rx
237 }
238
239 pub fn options(&self) -> &ZetaOptions {
240 &self.options
241 }
242
243 pub fn set_options(&mut self, options: ZetaOptions) {
244 self.options = options;
245 }
246
247 pub fn clear_history(&mut self) {
248 for zeta_project in self.projects.values_mut() {
249 zeta_project.events.clear();
250 }
251 }
252
253 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
254 self.user_store.read(cx).edit_prediction_usage()
255 }
256
257 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
258 self.get_or_init_zeta_project(project, cx);
259 }
260
261 pub fn register_buffer(
262 &mut self,
263 buffer: &Entity<Buffer>,
264 project: &Entity<Project>,
265 cx: &mut Context<Self>,
266 ) {
267 let zeta_project = self.get_or_init_zeta_project(project, cx);
268 Self::register_buffer_impl(zeta_project, buffer, project, cx);
269 }
270
271 fn get_or_init_zeta_project(
272 &mut self,
273 project: &Entity<Project>,
274 cx: &mut App,
275 ) -> &mut ZetaProject {
276 self.projects
277 .entry(project.entity_id())
278 .or_insert_with(|| ZetaProject {
279 syntax_index: cx.new(|cx| {
280 SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
281 }),
282 events: VecDeque::new(),
283 registered_buffers: HashMap::new(),
284 current_prediction: None,
285 })
286 }
287
288 fn register_buffer_impl<'a>(
289 zeta_project: &'a mut ZetaProject,
290 buffer: &Entity<Buffer>,
291 project: &Entity<Project>,
292 cx: &mut Context<Self>,
293 ) -> &'a mut RegisteredBuffer {
294 let buffer_id = buffer.entity_id();
295 match zeta_project.registered_buffers.entry(buffer_id) {
296 hash_map::Entry::Occupied(entry) => entry.into_mut(),
297 hash_map::Entry::Vacant(entry) => {
298 let snapshot = buffer.read(cx).snapshot();
299 let project_entity_id = project.entity_id();
300 entry.insert(RegisteredBuffer {
301 snapshot,
302 _subscriptions: [
303 cx.subscribe(buffer, {
304 let project = project.downgrade();
305 move |this, buffer, event, cx| {
306 if let language::BufferEvent::Edited = event
307 && let Some(project) = project.upgrade()
308 {
309 this.report_changes_for_buffer(&buffer, &project, cx);
310 }
311 }
312 }),
313 cx.observe_release(buffer, move |this, _buffer, _cx| {
314 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
315 else {
316 return;
317 };
318 zeta_project.registered_buffers.remove(&buffer_id);
319 }),
320 ],
321 })
322 }
323 }
324 }
325
326 fn report_changes_for_buffer(
327 &mut self,
328 buffer: &Entity<Buffer>,
329 project: &Entity<Project>,
330 cx: &mut Context<Self>,
331 ) -> BufferSnapshot {
332 let zeta_project = self.get_or_init_zeta_project(project, cx);
333 let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
334
335 let new_snapshot = buffer.read(cx).snapshot();
336 if new_snapshot.version != registered_buffer.snapshot.version {
337 let old_snapshot =
338 std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
339 Self::push_event(
340 zeta_project,
341 Event::BufferChange {
342 old_snapshot,
343 new_snapshot: new_snapshot.clone(),
344 timestamp: Instant::now(),
345 },
346 );
347 }
348
349 new_snapshot
350 }
351
352 fn push_event(zeta_project: &mut ZetaProject, event: Event) {
353 let events = &mut zeta_project.events;
354
355 if let Some(Event::BufferChange {
356 new_snapshot: last_new_snapshot,
357 timestamp: last_timestamp,
358 ..
359 }) = events.back_mut()
360 {
361 // Coalesce edits for the same buffer when they happen one after the other.
362 let Event::BufferChange {
363 old_snapshot,
364 new_snapshot,
365 timestamp,
366 } = &event;
367
368 if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
369 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
370 && old_snapshot.version == last_new_snapshot.version
371 {
372 *last_new_snapshot = new_snapshot.clone();
373 *last_timestamp = *timestamp;
374 return;
375 }
376 }
377
378 if events.len() >= MAX_EVENT_COUNT {
379 // These are halved instead of popping to improve prompt caching.
380 events.drain(..MAX_EVENT_COUNT / 2);
381 }
382
383 events.push_back(event);
384 }
385
386 fn current_prediction_for_buffer(
387 &self,
388 buffer: &Entity<Buffer>,
389 project: &Entity<Project>,
390 cx: &App,
391 ) -> Option<BufferEditPrediction<'_>> {
392 let project_state = self.projects.get(&project.entity_id())?;
393
394 let CurrentEditPrediction {
395 requested_by_buffer_id,
396 prediction,
397 } = project_state.current_prediction.as_ref()?;
398
399 if prediction.targets_buffer(buffer.read(cx), cx) {
400 Some(BufferEditPrediction::Local { prediction })
401 } else if *requested_by_buffer_id == buffer.entity_id() {
402 Some(BufferEditPrediction::Jump { prediction })
403 } else {
404 None
405 }
406 }
407
408 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
409 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
410 return;
411 };
412
413 let Some(prediction) = project_state.current_prediction.take() else {
414 return;
415 };
416 let request_id = prediction.prediction.id.into();
417
418 let client = self.client.clone();
419 let llm_token = self.llm_token.clone();
420 let app_version = AppVersion::global(cx);
421 cx.spawn(async move |this, cx| {
422 let url = if let Ok(predict_edits_url) = std::env::var("ZED_ACCEPT_PREDICTION_URL") {
423 http_client::Url::parse(&predict_edits_url)?
424 } else {
425 client
426 .http_client()
427 .build_zed_llm_url("/predict_edits/accept", &[])?
428 };
429
430 let response = cx
431 .background_spawn(Self::send_api_request::<()>(
432 move |builder| {
433 let req = builder.uri(url.as_ref()).body(
434 serde_json::to_string(&AcceptEditPredictionBody { request_id })?.into(),
435 );
436 Ok(req?)
437 },
438 client,
439 llm_token,
440 app_version,
441 ))
442 .await;
443
444 Self::handle_api_response(&this, response, cx)?;
445 anyhow::Ok(())
446 })
447 .detach_and_log_err(cx);
448 }
449
450 fn discard_current_prediction(&mut self, project: &Entity<Project>) {
451 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
452 project_state.current_prediction.take();
453 };
454 }
455
456 pub fn refresh_prediction(
457 &mut self,
458 project: &Entity<Project>,
459 buffer: &Entity<Buffer>,
460 position: language::Anchor,
461 cx: &mut Context<Self>,
462 ) -> Task<Result<()>> {
463 let request_task = self.request_prediction(project, buffer, position, cx);
464 let buffer = buffer.clone();
465 let project = project.clone();
466
467 cx.spawn(async move |this, cx| {
468 if let Some(prediction) = request_task.await? {
469 this.update(cx, |this, cx| {
470 let project_state = this
471 .projects
472 .get_mut(&project.entity_id())
473 .context("Project not found")?;
474
475 let new_prediction = CurrentEditPrediction {
476 requested_by_buffer_id: buffer.entity_id(),
477 prediction: prediction,
478 };
479
480 if project_state
481 .current_prediction
482 .as_ref()
483 .is_none_or(|old_prediction| {
484 new_prediction.should_replace_prediction(&old_prediction, cx)
485 })
486 {
487 project_state.current_prediction = Some(new_prediction);
488 }
489 anyhow::Ok(())
490 })??;
491 }
492 Ok(())
493 })
494 }
495
496 fn request_prediction(
497 &mut self,
498 project: &Entity<Project>,
499 buffer: &Entity<Buffer>,
500 position: language::Anchor,
501 cx: &mut Context<Self>,
502 ) -> Task<Result<Option<EditPrediction>>> {
503 let project_state = self.projects.get(&project.entity_id());
504
505 let index_state = project_state.map(|state| {
506 state
507 .syntax_index
508 .read_with(cx, |index, _cx| index.state().clone())
509 });
510 let options = self.options.clone();
511 let snapshot = buffer.read(cx).snapshot();
512 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx).into()) else {
513 return Task::ready(Err(anyhow!("No file path for excerpt")));
514 };
515 let client = self.client.clone();
516 let llm_token = self.llm_token.clone();
517 let app_version = AppVersion::global(cx);
518 let worktree_snapshots = project
519 .read(cx)
520 .worktrees(cx)
521 .map(|worktree| worktree.read(cx).snapshot())
522 .collect::<Vec<_>>();
523 let debug_tx = self.debug_tx.clone();
524
525 let events = project_state
526 .map(|state| {
527 state
528 .events
529 .iter()
530 .filter_map(|event| match event {
531 Event::BufferChange {
532 old_snapshot,
533 new_snapshot,
534 ..
535 } => {
536 let path = new_snapshot.file().map(|f| f.full_path(cx));
537
538 let old_path = old_snapshot.file().and_then(|f| {
539 let old_path = f.full_path(cx);
540 if Some(&old_path) != path.as_ref() {
541 Some(old_path)
542 } else {
543 None
544 }
545 });
546
547 // TODO [zeta2] move to bg?
548 let diff =
549 language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
550
551 if path == old_path && diff.is_empty() {
552 None
553 } else {
554 Some(predict_edits_v3::Event::BufferChange {
555 old_path,
556 path,
557 diff,
558 //todo: Actually detect if this edit was predicted or not
559 predicted: false,
560 })
561 }
562 }
563 })
564 .collect::<Vec<_>>()
565 })
566 .unwrap_or_default();
567
568 let diagnostics = snapshot.diagnostic_sets().clone();
569
570 let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
571 let mut path = f.worktree.read(cx).absolutize(&f.path);
572 if path.pop() { Some(path) } else { None }
573 });
574
575 // TODO data collection
576 let can_collect_data = cx.is_staff();
577
578 let request_task = cx.background_spawn({
579 let snapshot = snapshot.clone();
580 let buffer = buffer.clone();
581 async move {
582 let index_state = if let Some(index_state) = index_state {
583 Some(index_state.lock_owned().await)
584 } else {
585 None
586 };
587
588 let cursor_offset = position.to_offset(&snapshot);
589 let cursor_point = cursor_offset.to_point(&snapshot);
590
591 let before_retrieval = chrono::Utc::now();
592
593 let Some(context) = EditPredictionContext::gather_context(
594 cursor_point,
595 &snapshot,
596 parent_abs_path.as_deref(),
597 &options.context,
598 index_state.as_deref(),
599 ) else {
600 return Ok((None, None));
601 };
602
603 let retrieval_time = chrono::Utc::now() - before_retrieval;
604
605 let (diagnostic_groups, diagnostic_groups_truncated) =
606 Self::gather_nearby_diagnostics(
607 cursor_offset,
608 &diagnostics,
609 &snapshot,
610 options.max_diagnostic_bytes,
611 );
612
613 let request = make_cloud_request(
614 excerpt_path,
615 context,
616 events,
617 can_collect_data,
618 diagnostic_groups,
619 diagnostic_groups_truncated,
620 None,
621 debug_tx.is_some(),
622 &worktree_snapshots,
623 index_state.as_deref(),
624 Some(options.max_prompt_bytes),
625 options.prompt_format,
626 );
627
628 let debug_response_tx = if let Some(debug_tx) = &debug_tx {
629 let (response_tx, response_rx) = oneshot::channel();
630
631 let local_prompt = PlannedPrompt::populate(&request)
632 .and_then(|p| p.to_prompt_string().map(|p| p.0))
633 .map_err(|err| err.to_string());
634
635 debug_tx
636 .unbounded_send(PredictionDebugInfo {
637 request: request.clone(),
638 retrieval_time,
639 buffer: buffer.downgrade(),
640 local_prompt,
641 position,
642 response_rx,
643 })
644 .ok();
645 Some(response_tx)
646 } else {
647 None
648 };
649
650 if cfg!(debug_assertions) && std::env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
651 if let Some(debug_response_tx) = debug_response_tx {
652 debug_response_tx
653 .send(Err("Request skipped".to_string()))
654 .ok();
655 }
656 anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
657 }
658
659 let response =
660 Self::send_prediction_request(client, llm_token, app_version, request).await;
661
662 if let Some(debug_response_tx) = debug_response_tx {
663 debug_response_tx
664 .send(
665 response
666 .as_ref()
667 .map_err(|err| err.to_string())
668 .map(|response| response.0.clone()),
669 )
670 .ok();
671 }
672
673 response.map(|(res, usage)| (Some(res), usage))
674 }
675 });
676
677 let buffer = buffer.clone();
678
679 cx.spawn({
680 let project = project.clone();
681 async move |this, cx| {
682 let Some(response) = Self::handle_api_response(&this, request_task.await, cx)?
683 else {
684 return Ok(None);
685 };
686
687 // TODO telemetry: duration, etc
688 Ok(EditPrediction::from_response(response, &snapshot, &buffer, &project, cx).await)
689 }
690 })
691 }
692
693 async fn send_prediction_request(
694 client: Arc<Client>,
695 llm_token: LlmApiToken,
696 app_version: SemanticVersion,
697 request: predict_edits_v3::PredictEditsRequest,
698 ) -> Result<(
699 predict_edits_v3::PredictEditsResponse,
700 Option<EditPredictionUsage>,
701 )> {
702 let url = if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
703 http_client::Url::parse(&predict_edits_url)?
704 } else {
705 client
706 .http_client()
707 .build_zed_llm_url("/predict_edits/v3", &[])?
708 };
709
710 Self::send_api_request(
711 |builder| {
712 let req = builder
713 .uri(url.as_ref())
714 .body(serde_json::to_string(&request)?.into());
715 Ok(req?)
716 },
717 client,
718 llm_token,
719 app_version,
720 )
721 .await
722 }
723
724 fn handle_api_response<T>(
725 this: &WeakEntity<Self>,
726 response: Result<(T, Option<EditPredictionUsage>)>,
727 cx: &mut gpui::AsyncApp,
728 ) -> Result<T> {
729 match response {
730 Ok((data, usage)) => {
731 if let Some(usage) = usage {
732 this.update(cx, |this, cx| {
733 this.user_store.update(cx, |user_store, cx| {
734 user_store.update_edit_prediction_usage(usage, cx);
735 });
736 })
737 .ok();
738 }
739 Ok(data)
740 }
741 Err(err) => {
742 if err.is::<ZedUpdateRequiredError>() {
743 cx.update(|cx| {
744 this.update(cx, |this, _cx| {
745 this.update_required = true;
746 })
747 .ok();
748
749 let error_message: SharedString = err.to_string().into();
750 show_app_notification(
751 NotificationId::unique::<ZedUpdateRequiredError>(),
752 cx,
753 move |cx| {
754 cx.new(|cx| {
755 ErrorMessagePrompt::new(error_message.clone(), cx)
756 .with_link_button("Update Zed", "https://zed.dev/releases")
757 })
758 },
759 );
760 })
761 .ok();
762 }
763 Err(err)
764 }
765 }
766 }
767
768 async fn send_api_request<Res>(
769 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
770 client: Arc<Client>,
771 llm_token: LlmApiToken,
772 app_version: SemanticVersion,
773 ) -> Result<(Res, Option<EditPredictionUsage>)>
774 where
775 Res: DeserializeOwned,
776 {
777 let http_client = client.http_client();
778 let mut token = llm_token.acquire(&client).await?;
779 let mut did_retry = false;
780
781 loop {
782 let request_builder = http_client::Request::builder().method(Method::POST);
783
784 let request = build(
785 request_builder
786 .header("Content-Type", "application/json")
787 .header("Authorization", format!("Bearer {}", token))
788 .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
789 )?;
790
791 let mut response = http_client.send(request).await?;
792
793 if let Some(minimum_required_version) = response
794 .headers()
795 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
796 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
797 {
798 anyhow::ensure!(
799 app_version >= minimum_required_version,
800 ZedUpdateRequiredError {
801 minimum_version: minimum_required_version
802 }
803 );
804 }
805
806 if response.status().is_success() {
807 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
808
809 let mut body = Vec::new();
810 response.body_mut().read_to_end(&mut body).await?;
811 return Ok((serde_json::from_slice(&body)?, usage));
812 } else if !did_retry
813 && response
814 .headers()
815 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
816 .is_some()
817 {
818 did_retry = true;
819 token = llm_token.refresh(&client).await?;
820 } else {
821 let mut body = String::new();
822 response.body_mut().read_to_string(&mut body).await?;
823 anyhow::bail!(
824 "Request failed with status: {:?}\nBody: {}",
825 response.status(),
826 body
827 );
828 }
829 }
830 }
831
832 fn gather_nearby_diagnostics(
833 cursor_offset: usize,
834 diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
835 snapshot: &BufferSnapshot,
836 max_diagnostics_bytes: usize,
837 ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
838 // TODO: Could make this more efficient
839 let mut diagnostic_groups = Vec::new();
840 for (language_server_id, diagnostics) in diagnostic_sets {
841 let mut groups = Vec::new();
842 diagnostics.groups(*language_server_id, &mut groups, &snapshot);
843 diagnostic_groups.extend(
844 groups
845 .into_iter()
846 .map(|(_, group)| group.resolve::<usize>(&snapshot)),
847 );
848 }
849
850 // sort by proximity to cursor
851 diagnostic_groups.sort_by_key(|group| {
852 let range = &group.entries[group.primary_ix].range;
853 if range.start >= cursor_offset {
854 range.start - cursor_offset
855 } else if cursor_offset >= range.end {
856 cursor_offset - range.end
857 } else {
858 (cursor_offset - range.start).min(range.end - cursor_offset)
859 }
860 });
861
862 let mut results = Vec::new();
863 let mut diagnostic_groups_truncated = false;
864 let mut diagnostics_byte_count = 0;
865 for group in diagnostic_groups {
866 let raw_value = serde_json::value::to_raw_value(&group).unwrap();
867 diagnostics_byte_count += raw_value.get().len();
868 if diagnostics_byte_count > max_diagnostics_bytes {
869 diagnostic_groups_truncated = true;
870 break;
871 }
872 results.push(predict_edits_v3::DiagnosticGroup(raw_value));
873 }
874
875 (results, diagnostic_groups_truncated)
876 }
877
878 // TODO: Dedupe with similar code in request_prediction?
879 pub fn cloud_request_for_zeta_cli(
880 &mut self,
881 project: &Entity<Project>,
882 buffer: &Entity<Buffer>,
883 position: language::Anchor,
884 cx: &mut Context<Self>,
885 ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
886 let project_state = self.projects.get(&project.entity_id());
887
888 let index_state = project_state.map(|state| {
889 state
890 .syntax_index
891 .read_with(cx, |index, _cx| index.state().clone())
892 });
893 let options = self.options.clone();
894 let snapshot = buffer.read(cx).snapshot();
895 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
896 return Task::ready(Err(anyhow!("No file path for excerpt")));
897 };
898 let worktree_snapshots = project
899 .read(cx)
900 .worktrees(cx)
901 .map(|worktree| worktree.read(cx).snapshot())
902 .collect::<Vec<_>>();
903
904 let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
905 let mut path = f.worktree.read(cx).absolutize(&f.path);
906 if path.pop() { Some(path) } else { None }
907 });
908
909 cx.background_spawn(async move {
910 let index_state = if let Some(index_state) = index_state {
911 Some(index_state.lock_owned().await)
912 } else {
913 None
914 };
915
916 let cursor_point = position.to_point(&snapshot);
917
918 let debug_info = true;
919 EditPredictionContext::gather_context(
920 cursor_point,
921 &snapshot,
922 parent_abs_path.as_deref(),
923 &options.context,
924 index_state.as_deref(),
925 )
926 .context("Failed to select excerpt")
927 .map(|context| {
928 make_cloud_request(
929 excerpt_path.into(),
930 context,
931 // TODO pass everything
932 Vec::new(),
933 false,
934 Vec::new(),
935 false,
936 None,
937 debug_info,
938 &worktree_snapshots,
939 index_state.as_deref(),
940 Some(options.max_prompt_bytes),
941 options.prompt_format,
942 )
943 })
944 })
945 }
946
947 pub fn wait_for_initial_indexing(
948 &mut self,
949 project: &Entity<Project>,
950 cx: &mut App,
951 ) -> Task<Result<()>> {
952 let zeta_project = self.get_or_init_zeta_project(project, cx);
953 zeta_project
954 .syntax_index
955 .read(cx)
956 .wait_for_initial_file_indexing(cx)
957 }
958}
959
960#[derive(Error, Debug)]
961#[error(
962 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
963)]
964pub struct ZedUpdateRequiredError {
965 minimum_version: SemanticVersion,
966}
967
968fn make_cloud_request(
969 excerpt_path: Arc<Path>,
970 context: EditPredictionContext,
971 events: Vec<predict_edits_v3::Event>,
972 can_collect_data: bool,
973 diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
974 diagnostic_groups_truncated: bool,
975 git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
976 debug_info: bool,
977 worktrees: &Vec<worktree::Snapshot>,
978 index_state: Option<&SyntaxIndexState>,
979 prompt_max_bytes: Option<usize>,
980 prompt_format: PromptFormat,
981) -> predict_edits_v3::PredictEditsRequest {
982 let mut signatures = Vec::new();
983 let mut declaration_to_signature_index = HashMap::default();
984 let mut referenced_declarations = Vec::new();
985
986 for snippet in context.declarations {
987 let project_entry_id = snippet.declaration.project_entry_id();
988 let Some(path) = worktrees.iter().find_map(|worktree| {
989 worktree.entry_for_id(project_entry_id).map(|entry| {
990 let mut full_path = RelPathBuf::new();
991 full_path.push(worktree.root_name());
992 full_path.push(&entry.path);
993 full_path
994 })
995 }) else {
996 continue;
997 };
998
999 let parent_index = index_state.and_then(|index_state| {
1000 snippet.declaration.parent().and_then(|parent| {
1001 add_signature(
1002 parent,
1003 &mut declaration_to_signature_index,
1004 &mut signatures,
1005 index_state,
1006 )
1007 })
1008 });
1009
1010 let (text, text_is_truncated) = snippet.declaration.item_text();
1011 referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
1012 path: path.as_std_path().into(),
1013 text: text.into(),
1014 range: snippet.declaration.item_line_range(),
1015 text_is_truncated,
1016 signature_range: snippet.declaration.signature_range_in_item_text(),
1017 parent_index,
1018 signature_score: snippet.score(DeclarationStyle::Signature),
1019 declaration_score: snippet.score(DeclarationStyle::Declaration),
1020 score_components: snippet.components,
1021 });
1022 }
1023
1024 let excerpt_parent = index_state.and_then(|index_state| {
1025 context
1026 .excerpt
1027 .parent_declarations
1028 .last()
1029 .and_then(|(parent, _)| {
1030 add_signature(
1031 *parent,
1032 &mut declaration_to_signature_index,
1033 &mut signatures,
1034 index_state,
1035 )
1036 })
1037 });
1038
1039 predict_edits_v3::PredictEditsRequest {
1040 excerpt_path,
1041 excerpt: context.excerpt_text.body,
1042 excerpt_line_range: context.excerpt.line_range,
1043 excerpt_range: context.excerpt.range,
1044 cursor_point: predict_edits_v3::Point {
1045 line: predict_edits_v3::Line(context.cursor_point.row),
1046 column: context.cursor_point.column,
1047 },
1048 referenced_declarations,
1049 signatures,
1050 excerpt_parent,
1051 events,
1052 can_collect_data,
1053 diagnostic_groups,
1054 diagnostic_groups_truncated,
1055 git_info,
1056 debug_info,
1057 prompt_max_bytes,
1058 prompt_format,
1059 }
1060}
1061
1062fn add_signature(
1063 declaration_id: DeclarationId,
1064 declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
1065 signatures: &mut Vec<Signature>,
1066 index: &SyntaxIndexState,
1067) -> Option<usize> {
1068 if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
1069 return Some(*signature_index);
1070 }
1071 let Some(parent_declaration) = index.declaration(declaration_id) else {
1072 log::error!("bug: missing parent declaration");
1073 return None;
1074 };
1075 let parent_index = parent_declaration.parent().and_then(|parent| {
1076 add_signature(parent, declaration_to_signature_index, signatures, index)
1077 });
1078 let (text, text_is_truncated) = parent_declaration.signature_text();
1079 let signature_index = signatures.len();
1080 signatures.push(Signature {
1081 text: text.into(),
1082 text_is_truncated,
1083 parent_index,
1084 range: parent_declaration.signature_line_range(),
1085 });
1086 declaration_to_signature_index.insert(declaration_id, signature_index);
1087 Some(signature_index)
1088}
1089
1090#[cfg(test)]
1091mod tests {
1092 use std::{
1093 path::{Path, PathBuf},
1094 sync::Arc,
1095 };
1096
1097 use client::UserStore;
1098 use clock::FakeSystemClock;
1099 use cloud_llm_client::predict_edits_v3::{self, Point};
1100 use edit_prediction_context::Line;
1101 use futures::{
1102 AsyncReadExt, StreamExt,
1103 channel::{mpsc, oneshot},
1104 };
1105 use gpui::{
1106 Entity, TestAppContext,
1107 http_client::{FakeHttpClient, Response},
1108 prelude::*,
1109 };
1110 use indoc::indoc;
1111 use language::{LanguageServerId, OffsetRangeExt as _};
1112 use pretty_assertions::{assert_eq, assert_matches};
1113 use project::{FakeFs, Project};
1114 use serde_json::json;
1115 use settings::SettingsStore;
1116 use util::path;
1117 use uuid::Uuid;
1118
1119 use crate::{BufferEditPrediction, Zeta};
1120
1121 #[gpui::test]
1122 async fn test_current_state(cx: &mut TestAppContext) {
1123 let (zeta, mut req_rx) = init_test(cx);
1124 let fs = FakeFs::new(cx.executor());
1125 fs.insert_tree(
1126 "/root",
1127 json!({
1128 "1.txt": "Hello!\nHow\nBye",
1129 "2.txt": "Hola!\nComo\nAdios"
1130 }),
1131 )
1132 .await;
1133 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1134
1135 zeta.update(cx, |zeta, cx| {
1136 zeta.register_project(&project, cx);
1137 });
1138
1139 let buffer1 = project
1140 .update(cx, |project, cx| {
1141 let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
1142 project.open_buffer(path, cx)
1143 })
1144 .await
1145 .unwrap();
1146 let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
1147 let position = snapshot1.anchor_before(language::Point::new(1, 3));
1148
1149 // Prediction for current file
1150
1151 let prediction_task = zeta.update(cx, |zeta, cx| {
1152 zeta.refresh_prediction(&project, &buffer1, position, cx)
1153 });
1154 let (_request, respond_tx) = req_rx.next().await.unwrap();
1155 respond_tx
1156 .send(predict_edits_v3::PredictEditsResponse {
1157 request_id: Uuid::new_v4(),
1158 edits: vec![predict_edits_v3::Edit {
1159 path: Path::new(path!("root/1.txt")).into(),
1160 range: Line(0)..Line(snapshot1.max_point().row + 1),
1161 content: "Hello!\nHow are you?\nBye".into(),
1162 }],
1163 debug_info: None,
1164 })
1165 .unwrap();
1166 prediction_task.await.unwrap();
1167
1168 zeta.read_with(cx, |zeta, cx| {
1169 let prediction = zeta
1170 .current_prediction_for_buffer(&buffer1, &project, cx)
1171 .unwrap();
1172 assert_matches!(prediction, BufferEditPrediction::Local { .. });
1173 });
1174
1175 // Prediction for another file
1176 let prediction_task = zeta.update(cx, |zeta, cx| {
1177 zeta.refresh_prediction(&project, &buffer1, position, cx)
1178 });
1179 let (_request, respond_tx) = req_rx.next().await.unwrap();
1180 respond_tx
1181 .send(predict_edits_v3::PredictEditsResponse {
1182 request_id: Uuid::new_v4(),
1183 edits: vec![predict_edits_v3::Edit {
1184 path: Path::new(path!("root/2.txt")).into(),
1185 range: Line(0)..Line(snapshot1.max_point().row + 1),
1186 content: "Hola!\nComo estas?\nAdios".into(),
1187 }],
1188 debug_info: None,
1189 })
1190 .unwrap();
1191 prediction_task.await.unwrap();
1192 zeta.read_with(cx, |zeta, cx| {
1193 let prediction = zeta
1194 .current_prediction_for_buffer(&buffer1, &project, cx)
1195 .unwrap();
1196 assert_matches!(
1197 prediction,
1198 BufferEditPrediction::Jump { prediction } if prediction.path.as_ref() == Path::new(path!("root/2.txt"))
1199 );
1200 });
1201
1202 let buffer2 = project
1203 .update(cx, |project, cx| {
1204 let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
1205 project.open_buffer(path, cx)
1206 })
1207 .await
1208 .unwrap();
1209
1210 zeta.read_with(cx, |zeta, cx| {
1211 let prediction = zeta
1212 .current_prediction_for_buffer(&buffer2, &project, cx)
1213 .unwrap();
1214 assert_matches!(prediction, BufferEditPrediction::Local { .. });
1215 });
1216 }
1217
1218 #[gpui::test]
1219 async fn test_simple_request(cx: &mut TestAppContext) {
1220 let (zeta, mut req_rx) = init_test(cx);
1221 let fs = FakeFs::new(cx.executor());
1222 fs.insert_tree(
1223 "/root",
1224 json!({
1225 "foo.md": "Hello!\nHow\nBye"
1226 }),
1227 )
1228 .await;
1229 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1230
1231 let buffer = project
1232 .update(cx, |project, cx| {
1233 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1234 project.open_buffer(path, cx)
1235 })
1236 .await
1237 .unwrap();
1238 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1239 let position = snapshot.anchor_before(language::Point::new(1, 3));
1240
1241 let prediction_task = zeta.update(cx, |zeta, cx| {
1242 zeta.request_prediction(&project, &buffer, position, cx)
1243 });
1244
1245 let (request, respond_tx) = req_rx.next().await.unwrap();
1246 assert_eq!(
1247 request.excerpt_path.as_ref(),
1248 Path::new(path!("root/foo.md"))
1249 );
1250 assert_eq!(
1251 request.cursor_point,
1252 Point {
1253 line: Line(1),
1254 column: 3
1255 }
1256 );
1257
1258 respond_tx
1259 .send(predict_edits_v3::PredictEditsResponse {
1260 request_id: Uuid::new_v4(),
1261 edits: vec![predict_edits_v3::Edit {
1262 path: Path::new(path!("root/foo.md")).into(),
1263 range: Line(0)..Line(snapshot.max_point().row + 1),
1264 content: "Hello!\nHow are you?\nBye".into(),
1265 }],
1266 debug_info: None,
1267 })
1268 .unwrap();
1269
1270 let prediction = prediction_task.await.unwrap().unwrap();
1271
1272 assert_eq!(prediction.edits.len(), 1);
1273 assert_eq!(
1274 prediction.edits[0].0.to_point(&snapshot).start,
1275 language::Point::new(1, 3)
1276 );
1277 assert_eq!(prediction.edits[0].1, " are you?");
1278 }
1279
1280 #[gpui::test]
1281 async fn test_request_events(cx: &mut TestAppContext) {
1282 let (zeta, mut req_rx) = init_test(cx);
1283 let fs = FakeFs::new(cx.executor());
1284 fs.insert_tree(
1285 "/root",
1286 json!({
1287 "foo.md": "Hello!\n\nBye"
1288 }),
1289 )
1290 .await;
1291 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1292
1293 let buffer = project
1294 .update(cx, |project, cx| {
1295 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1296 project.open_buffer(path, cx)
1297 })
1298 .await
1299 .unwrap();
1300
1301 zeta.update(cx, |zeta, cx| {
1302 zeta.register_buffer(&buffer, &project, cx);
1303 });
1304
1305 buffer.update(cx, |buffer, cx| {
1306 buffer.edit(vec![(7..7, "How")], None, cx);
1307 });
1308
1309 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1310 let position = snapshot.anchor_before(language::Point::new(1, 3));
1311
1312 let prediction_task = zeta.update(cx, |zeta, cx| {
1313 zeta.request_prediction(&project, &buffer, position, cx)
1314 });
1315
1316 let (request, respond_tx) = req_rx.next().await.unwrap();
1317
1318 assert_eq!(request.events.len(), 1);
1319 assert_eq!(
1320 request.events[0],
1321 predict_edits_v3::Event::BufferChange {
1322 path: Some(PathBuf::from(path!("root/foo.md"))),
1323 old_path: None,
1324 diff: indoc! {"
1325 @@ -1,3 +1,3 @@
1326 Hello!
1327 -
1328 +How
1329 Bye
1330 "}
1331 .to_string(),
1332 predicted: false
1333 }
1334 );
1335
1336 respond_tx
1337 .send(predict_edits_v3::PredictEditsResponse {
1338 request_id: Uuid::new_v4(),
1339 edits: vec![predict_edits_v3::Edit {
1340 path: Path::new(path!("root/foo.md")).into(),
1341 range: Line(0)..Line(snapshot.max_point().row + 1),
1342 content: "Hello!\nHow are you?\nBye".into(),
1343 }],
1344 debug_info: None,
1345 })
1346 .unwrap();
1347
1348 let prediction = prediction_task.await.unwrap().unwrap();
1349
1350 assert_eq!(prediction.edits.len(), 1);
1351 assert_eq!(
1352 prediction.edits[0].0.to_point(&snapshot).start,
1353 language::Point::new(1, 3)
1354 );
1355 assert_eq!(prediction.edits[0].1, " are you?");
1356 }
1357
1358 #[gpui::test]
1359 async fn test_request_diagnostics(cx: &mut TestAppContext) {
1360 let (zeta, mut req_rx) = init_test(cx);
1361 let fs = FakeFs::new(cx.executor());
1362 fs.insert_tree(
1363 "/root",
1364 json!({
1365 "foo.md": "Hello!\nBye"
1366 }),
1367 )
1368 .await;
1369 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1370
1371 let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1372 let diagnostic = lsp::Diagnostic {
1373 range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1374 severity: Some(lsp::DiagnosticSeverity::ERROR),
1375 message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1376 ..Default::default()
1377 };
1378
1379 project.update(cx, |project, cx| {
1380 project.lsp_store().update(cx, |lsp_store, cx| {
1381 // Create some diagnostics
1382 lsp_store
1383 .update_diagnostics(
1384 LanguageServerId(0),
1385 lsp::PublishDiagnosticsParams {
1386 uri: path_to_buffer_uri.clone(),
1387 diagnostics: vec![diagnostic],
1388 version: None,
1389 },
1390 None,
1391 language::DiagnosticSourceKind::Pushed,
1392 &[],
1393 cx,
1394 )
1395 .unwrap();
1396 });
1397 });
1398
1399 let buffer = project
1400 .update(cx, |project, cx| {
1401 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1402 project.open_buffer(path, cx)
1403 })
1404 .await
1405 .unwrap();
1406
1407 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1408 let position = snapshot.anchor_before(language::Point::new(0, 0));
1409
1410 let _prediction_task = zeta.update(cx, |zeta, cx| {
1411 zeta.request_prediction(&project, &buffer, position, cx)
1412 });
1413
1414 let (request, _respond_tx) = req_rx.next().await.unwrap();
1415
1416 assert_eq!(request.diagnostic_groups.len(), 1);
1417 let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1418 .unwrap();
1419 // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1420 assert_eq!(
1421 value,
1422 json!({
1423 "entries": [{
1424 "range": {
1425 "start": 8,
1426 "end": 10
1427 },
1428 "diagnostic": {
1429 "source": null,
1430 "code": null,
1431 "code_description": null,
1432 "severity": 1,
1433 "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1434 "markdown": null,
1435 "group_id": 0,
1436 "is_primary": true,
1437 "is_disk_based": false,
1438 "is_unnecessary": false,
1439 "source_kind": "Pushed",
1440 "data": null,
1441 "underline": true
1442 }
1443 }],
1444 "primary_ix": 0
1445 })
1446 );
1447 }
1448
1449 fn init_test(
1450 cx: &mut TestAppContext,
1451 ) -> (
1452 Entity<Zeta>,
1453 mpsc::UnboundedReceiver<(
1454 predict_edits_v3::PredictEditsRequest,
1455 oneshot::Sender<predict_edits_v3::PredictEditsResponse>,
1456 )>,
1457 ) {
1458 cx.update(move |cx| {
1459 let settings_store = SettingsStore::test(cx);
1460 cx.set_global(settings_store);
1461 language::init(cx);
1462 Project::init_settings(cx);
1463
1464 let (req_tx, req_rx) = mpsc::unbounded();
1465
1466 let http_client = FakeHttpClient::create({
1467 move |req| {
1468 let uri = req.uri().path().to_string();
1469 let mut body = req.into_body();
1470 let req_tx = req_tx.clone();
1471 async move {
1472 let resp = match uri.as_str() {
1473 "/client/llm_tokens" => serde_json::to_string(&json!({
1474 "token": "test"
1475 }))
1476 .unwrap(),
1477 "/predict_edits/v3" => {
1478 let mut buf = Vec::new();
1479 body.read_to_end(&mut buf).await.ok();
1480 let req = serde_json::from_slice(&buf).unwrap();
1481
1482 let (res_tx, res_rx) = oneshot::channel();
1483 req_tx.unbounded_send((req, res_tx)).unwrap();
1484 serde_json::to_string(&res_rx.await?).unwrap()
1485 }
1486 _ => {
1487 panic!("Unexpected path: {}", uri)
1488 }
1489 };
1490
1491 Ok(Response::builder().body(resp.into()).unwrap())
1492 }
1493 }
1494 });
1495
1496 let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1497 client.cloud_client().set_credentials(1, "test".into());
1498
1499 language_model::init(client.clone(), cx);
1500
1501 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1502 let zeta = Zeta::global(&client, &user_store, cx);
1503
1504 (zeta, req_rx)
1505 })
1506 }
1507}