1use anyhow::{Context as _, Result, anyhow};
2use chrono::TimeDelta;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
5use cloud_llm_client::{
6 AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
7 ZED_VERSION_HEADER_NAME,
8};
9use cloud_zeta2_prompt::{DEFAULT_MAX_PROMPT_BYTES, PlannedPrompt};
10use edit_prediction_context::{
11 DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
12 EditPredictionExcerptOptions, EditPredictionScoreOptions, SyntaxIndex, SyntaxIndexState,
13};
14use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
15use futures::AsyncReadExt as _;
16use futures::channel::{mpsc, oneshot};
17use gpui::http_client::{AsyncBody, Method};
18use gpui::{
19 App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
20 http_client, prelude::*,
21};
22use language::BufferSnapshot;
23use language::{Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
24use language_model::{LlmApiToken, RefreshLlmTokenListener};
25use project::Project;
26use release_channel::AppVersion;
27use serde::de::DeserializeOwned;
28use std::collections::{HashMap, VecDeque, hash_map};
29use std::path::Path;
30use std::str::FromStr as _;
31use std::sync::Arc;
32use std::time::{Duration, Instant};
33use thiserror::Error;
34use util::rel_path::RelPathBuf;
35use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
36
37mod prediction;
38mod provider;
39
40use crate::prediction::EditPrediction;
41pub use provider::ZetaEditPredictionProvider;
42
43const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
44
45/// Maximum number of events to track.
46const MAX_EVENT_COUNT: usize = 16;
47
48pub const DEFAULT_CONTEXT_OPTIONS: EditPredictionContextOptions = EditPredictionContextOptions {
49 use_imports: true,
50 max_retrieved_declarations: 0,
51 excerpt: EditPredictionExcerptOptions {
52 max_bytes: 512,
53 min_bytes: 128,
54 target_before_cursor_over_total_bytes: 0.5,
55 },
56 score: EditPredictionScoreOptions {
57 omit_excerpt_overlaps: true,
58 },
59};
60
61pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
62 context: DEFAULT_CONTEXT_OPTIONS,
63 max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
64 max_diagnostic_bytes: 2048,
65 prompt_format: PromptFormat::DEFAULT,
66 file_indexing_parallelism: 1,
67};
68
69pub struct Zeta2FeatureFlag;
70
71impl FeatureFlag for Zeta2FeatureFlag {
72 const NAME: &'static str = "zeta2";
73
74 fn enabled_for_staff() -> bool {
75 false
76 }
77}
78
79#[derive(Clone)]
80struct ZetaGlobal(Entity<Zeta>);
81
82impl Global for ZetaGlobal {}
83
84pub struct Zeta {
85 client: Arc<Client>,
86 user_store: Entity<UserStore>,
87 llm_token: LlmApiToken,
88 _llm_token_subscription: Subscription,
89 projects: HashMap<EntityId, ZetaProject>,
90 options: ZetaOptions,
91 update_required: bool,
92 debug_tx: Option<mpsc::UnboundedSender<PredictionDebugInfo>>,
93}
94
95#[derive(Debug, Clone, PartialEq)]
96pub struct ZetaOptions {
97 pub context: EditPredictionContextOptions,
98 pub max_prompt_bytes: usize,
99 pub max_diagnostic_bytes: usize,
100 pub prompt_format: predict_edits_v3::PromptFormat,
101 pub file_indexing_parallelism: usize,
102}
103
104pub struct PredictionDebugInfo {
105 pub request: predict_edits_v3::PredictEditsRequest,
106 pub retrieval_time: TimeDelta,
107 pub buffer: WeakEntity<Buffer>,
108 pub position: language::Anchor,
109 pub local_prompt: Result<String, String>,
110 pub response_rx: oneshot::Receiver<Result<predict_edits_v3::PredictEditsResponse, String>>,
111}
112
113pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
114
115struct ZetaProject {
116 syntax_index: Entity<SyntaxIndex>,
117 events: VecDeque<Event>,
118 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
119 current_prediction: Option<CurrentEditPrediction>,
120}
121
122#[derive(Debug, Clone)]
123struct CurrentEditPrediction {
124 pub requested_by_buffer_id: EntityId,
125 pub prediction: EditPrediction,
126}
127
128impl CurrentEditPrediction {
129 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
130 let Some(new_edits) = self
131 .prediction
132 .interpolate(&self.prediction.buffer.read(cx))
133 else {
134 return false;
135 };
136
137 if self.prediction.buffer != old_prediction.prediction.buffer {
138 return true;
139 }
140
141 let Some(old_edits) = old_prediction
142 .prediction
143 .interpolate(&old_prediction.prediction.buffer.read(cx))
144 else {
145 return true;
146 };
147
148 // This reduces the occurrence of UI thrash from replacing edits
149 //
150 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
151 if self.requested_by_buffer_id == self.prediction.buffer.entity_id()
152 && self.requested_by_buffer_id == old_prediction.prediction.buffer.entity_id()
153 && old_edits.len() == 1
154 && new_edits.len() == 1
155 {
156 let (old_range, old_text) = &old_edits[0];
157 let (new_range, new_text) = &new_edits[0];
158 new_range == old_range && new_text.starts_with(old_text)
159 } else {
160 true
161 }
162 }
163}
164
165/// A prediction from the perspective of a buffer.
166#[derive(Debug)]
167enum BufferEditPrediction<'a> {
168 Local { prediction: &'a EditPrediction },
169 Jump { prediction: &'a EditPrediction },
170}
171
172struct RegisteredBuffer {
173 snapshot: BufferSnapshot,
174 _subscriptions: [gpui::Subscription; 2],
175}
176
177#[derive(Clone)]
178pub enum Event {
179 BufferChange {
180 old_snapshot: BufferSnapshot,
181 new_snapshot: BufferSnapshot,
182 timestamp: Instant,
183 },
184}
185
186impl Zeta {
187 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
188 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
189 }
190
191 pub fn global(
192 client: &Arc<Client>,
193 user_store: &Entity<UserStore>,
194 cx: &mut App,
195 ) -> Entity<Self> {
196 cx.try_global::<ZetaGlobal>()
197 .map(|global| global.0.clone())
198 .unwrap_or_else(|| {
199 let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
200 cx.set_global(ZetaGlobal(zeta.clone()));
201 zeta
202 })
203 }
204
205 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
206 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
207
208 Self {
209 projects: HashMap::new(),
210 client,
211 user_store,
212 options: DEFAULT_OPTIONS,
213 llm_token: LlmApiToken::default(),
214 _llm_token_subscription: cx.subscribe(
215 &refresh_llm_token_listener,
216 |this, _listener, _event, cx| {
217 let client = this.client.clone();
218 let llm_token = this.llm_token.clone();
219 cx.spawn(async move |_this, _cx| {
220 llm_token.refresh(&client).await?;
221 anyhow::Ok(())
222 })
223 .detach_and_log_err(cx);
224 },
225 ),
226 update_required: false,
227 debug_tx: None,
228 }
229 }
230
231 pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<PredictionDebugInfo> {
232 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
233 self.debug_tx = Some(debug_watch_tx);
234 debug_watch_rx
235 }
236
237 pub fn options(&self) -> &ZetaOptions {
238 &self.options
239 }
240
241 pub fn set_options(&mut self, options: ZetaOptions) {
242 self.options = options;
243 }
244
245 pub fn clear_history(&mut self) {
246 for zeta_project in self.projects.values_mut() {
247 zeta_project.events.clear();
248 }
249 }
250
251 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
252 self.user_store.read(cx).edit_prediction_usage()
253 }
254
255 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
256 self.get_or_init_zeta_project(project, cx);
257 }
258
259 pub fn register_buffer(
260 &mut self,
261 buffer: &Entity<Buffer>,
262 project: &Entity<Project>,
263 cx: &mut Context<Self>,
264 ) {
265 let zeta_project = self.get_or_init_zeta_project(project, cx);
266 Self::register_buffer_impl(zeta_project, buffer, project, cx);
267 }
268
269 fn get_or_init_zeta_project(
270 &mut self,
271 project: &Entity<Project>,
272 cx: &mut App,
273 ) -> &mut ZetaProject {
274 self.projects
275 .entry(project.entity_id())
276 .or_insert_with(|| ZetaProject {
277 syntax_index: cx.new(|cx| {
278 SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
279 }),
280 events: VecDeque::new(),
281 registered_buffers: HashMap::new(),
282 current_prediction: None,
283 })
284 }
285
286 fn register_buffer_impl<'a>(
287 zeta_project: &'a mut ZetaProject,
288 buffer: &Entity<Buffer>,
289 project: &Entity<Project>,
290 cx: &mut Context<Self>,
291 ) -> &'a mut RegisteredBuffer {
292 let buffer_id = buffer.entity_id();
293 match zeta_project.registered_buffers.entry(buffer_id) {
294 hash_map::Entry::Occupied(entry) => entry.into_mut(),
295 hash_map::Entry::Vacant(entry) => {
296 let snapshot = buffer.read(cx).snapshot();
297 let project_entity_id = project.entity_id();
298 entry.insert(RegisteredBuffer {
299 snapshot,
300 _subscriptions: [
301 cx.subscribe(buffer, {
302 let project = project.downgrade();
303 move |this, buffer, event, cx| {
304 if let language::BufferEvent::Edited = event
305 && let Some(project) = project.upgrade()
306 {
307 this.report_changes_for_buffer(&buffer, &project, cx);
308 }
309 }
310 }),
311 cx.observe_release(buffer, move |this, _buffer, _cx| {
312 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
313 else {
314 return;
315 };
316 zeta_project.registered_buffers.remove(&buffer_id);
317 }),
318 ],
319 })
320 }
321 }
322 }
323
324 fn report_changes_for_buffer(
325 &mut self,
326 buffer: &Entity<Buffer>,
327 project: &Entity<Project>,
328 cx: &mut Context<Self>,
329 ) -> BufferSnapshot {
330 let zeta_project = self.get_or_init_zeta_project(project, cx);
331 let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
332
333 let new_snapshot = buffer.read(cx).snapshot();
334 if new_snapshot.version != registered_buffer.snapshot.version {
335 let old_snapshot =
336 std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
337 Self::push_event(
338 zeta_project,
339 Event::BufferChange {
340 old_snapshot,
341 new_snapshot: new_snapshot.clone(),
342 timestamp: Instant::now(),
343 },
344 );
345 }
346
347 new_snapshot
348 }
349
350 fn push_event(zeta_project: &mut ZetaProject, event: Event) {
351 let events = &mut zeta_project.events;
352
353 if let Some(Event::BufferChange {
354 new_snapshot: last_new_snapshot,
355 timestamp: last_timestamp,
356 ..
357 }) = events.back_mut()
358 {
359 // Coalesce edits for the same buffer when they happen one after the other.
360 let Event::BufferChange {
361 old_snapshot,
362 new_snapshot,
363 timestamp,
364 } = &event;
365
366 if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
367 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
368 && old_snapshot.version == last_new_snapshot.version
369 {
370 *last_new_snapshot = new_snapshot.clone();
371 *last_timestamp = *timestamp;
372 return;
373 }
374 }
375
376 if events.len() >= MAX_EVENT_COUNT {
377 // These are halved instead of popping to improve prompt caching.
378 events.drain(..MAX_EVENT_COUNT / 2);
379 }
380
381 events.push_back(event);
382 }
383
384 fn current_prediction_for_buffer(
385 &self,
386 buffer: &Entity<Buffer>,
387 project: &Entity<Project>,
388 cx: &App,
389 ) -> Option<BufferEditPrediction<'_>> {
390 let project_state = self.projects.get(&project.entity_id())?;
391
392 let CurrentEditPrediction {
393 requested_by_buffer_id,
394 prediction,
395 } = project_state.current_prediction.as_ref()?;
396
397 if prediction.targets_buffer(buffer.read(cx), cx) {
398 Some(BufferEditPrediction::Local { prediction })
399 } else if *requested_by_buffer_id == buffer.entity_id() {
400 Some(BufferEditPrediction::Jump { prediction })
401 } else {
402 None
403 }
404 }
405
406 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
407 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
408 return;
409 };
410
411 let Some(prediction) = project_state.current_prediction.take() else {
412 return;
413 };
414 let request_id = prediction.prediction.id.into();
415
416 let client = self.client.clone();
417 let llm_token = self.llm_token.clone();
418 let app_version = AppVersion::global(cx);
419 cx.spawn(async move |this, cx| {
420 let url = if let Ok(predict_edits_url) = std::env::var("ZED_ACCEPT_PREDICTION_URL") {
421 http_client::Url::parse(&predict_edits_url)?
422 } else {
423 client
424 .http_client()
425 .build_zed_llm_url("/predict_edits/accept", &[])?
426 };
427
428 let response = cx
429 .background_spawn(Self::send_api_request::<()>(
430 move |builder| {
431 let req = builder.uri(url.as_ref()).body(
432 serde_json::to_string(&AcceptEditPredictionBody { request_id })?.into(),
433 );
434 Ok(req?)
435 },
436 client,
437 llm_token,
438 app_version,
439 ))
440 .await;
441
442 Self::handle_api_response(&this, response, cx)?;
443 anyhow::Ok(())
444 })
445 .detach_and_log_err(cx);
446 }
447
448 fn discard_current_prediction(&mut self, project: &Entity<Project>) {
449 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
450 project_state.current_prediction.take();
451 };
452 }
453
454 pub fn refresh_prediction(
455 &mut self,
456 project: &Entity<Project>,
457 buffer: &Entity<Buffer>,
458 position: language::Anchor,
459 cx: &mut Context<Self>,
460 ) -> Task<Result<()>> {
461 let request_task = self.request_prediction(project, buffer, position, cx);
462 let buffer = buffer.clone();
463 let project = project.clone();
464
465 cx.spawn(async move |this, cx| {
466 if let Some(prediction) = request_task.await? {
467 this.update(cx, |this, cx| {
468 let project_state = this
469 .projects
470 .get_mut(&project.entity_id())
471 .context("Project not found")?;
472
473 let new_prediction = CurrentEditPrediction {
474 requested_by_buffer_id: buffer.entity_id(),
475 prediction: prediction,
476 };
477
478 if project_state
479 .current_prediction
480 .as_ref()
481 .is_none_or(|old_prediction| {
482 new_prediction.should_replace_prediction(&old_prediction, cx)
483 })
484 {
485 project_state.current_prediction = Some(new_prediction);
486 }
487 anyhow::Ok(())
488 })??;
489 }
490 Ok(())
491 })
492 }
493
494 fn request_prediction(
495 &mut self,
496 project: &Entity<Project>,
497 buffer: &Entity<Buffer>,
498 position: language::Anchor,
499 cx: &mut Context<Self>,
500 ) -> Task<Result<Option<EditPrediction>>> {
501 let project_state = self.projects.get(&project.entity_id());
502
503 let index_state = project_state.map(|state| {
504 state
505 .syntax_index
506 .read_with(cx, |index, _cx| index.state().clone())
507 });
508 let options = self.options.clone();
509 let snapshot = buffer.read(cx).snapshot();
510 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx).into()) else {
511 return Task::ready(Err(anyhow!("No file path for excerpt")));
512 };
513 let client = self.client.clone();
514 let llm_token = self.llm_token.clone();
515 let app_version = AppVersion::global(cx);
516 let worktree_snapshots = project
517 .read(cx)
518 .worktrees(cx)
519 .map(|worktree| worktree.read(cx).snapshot())
520 .collect::<Vec<_>>();
521 let debug_tx = self.debug_tx.clone();
522
523 let events = project_state
524 .map(|state| {
525 state
526 .events
527 .iter()
528 .filter_map(|event| match event {
529 Event::BufferChange {
530 old_snapshot,
531 new_snapshot,
532 ..
533 } => {
534 let path = new_snapshot.file().map(|f| f.full_path(cx));
535
536 let old_path = old_snapshot.file().and_then(|f| {
537 let old_path = f.full_path(cx);
538 if Some(&old_path) != path.as_ref() {
539 Some(old_path)
540 } else {
541 None
542 }
543 });
544
545 // TODO [zeta2] move to bg?
546 let diff =
547 language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
548
549 if path == old_path && diff.is_empty() {
550 None
551 } else {
552 Some(predict_edits_v3::Event::BufferChange {
553 old_path,
554 path,
555 diff,
556 //todo: Actually detect if this edit was predicted or not
557 predicted: false,
558 })
559 }
560 }
561 })
562 .collect::<Vec<_>>()
563 })
564 .unwrap_or_default();
565
566 let diagnostics = snapshot.diagnostic_sets().clone();
567
568 let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
569 let mut path = f.worktree.read(cx).absolutize(&f.path);
570 if path.pop() { Some(path) } else { None }
571 });
572
573 // TODO data collection
574 let can_collect_data = cx.is_staff();
575
576 let request_task = cx.background_spawn({
577 let snapshot = snapshot.clone();
578 let buffer = buffer.clone();
579 async move {
580 let index_state = if let Some(index_state) = index_state {
581 Some(index_state.lock_owned().await)
582 } else {
583 None
584 };
585
586 let cursor_offset = position.to_offset(&snapshot);
587 let cursor_point = cursor_offset.to_point(&snapshot);
588
589 let before_retrieval = chrono::Utc::now();
590
591 let Some(context) = EditPredictionContext::gather_context(
592 cursor_point,
593 &snapshot,
594 parent_abs_path.as_deref(),
595 &options.context,
596 index_state.as_deref(),
597 ) else {
598 return Ok((None, None));
599 };
600
601 let retrieval_time = chrono::Utc::now() - before_retrieval;
602
603 let (diagnostic_groups, diagnostic_groups_truncated) =
604 Self::gather_nearby_diagnostics(
605 cursor_offset,
606 &diagnostics,
607 &snapshot,
608 options.max_diagnostic_bytes,
609 );
610
611 let request = make_cloud_request(
612 excerpt_path,
613 context,
614 events,
615 can_collect_data,
616 diagnostic_groups,
617 diagnostic_groups_truncated,
618 None,
619 debug_tx.is_some(),
620 &worktree_snapshots,
621 index_state.as_deref(),
622 Some(options.max_prompt_bytes),
623 options.prompt_format,
624 );
625
626 let debug_response_tx = if let Some(debug_tx) = &debug_tx {
627 let (response_tx, response_rx) = oneshot::channel();
628
629 let local_prompt = PlannedPrompt::populate(&request)
630 .and_then(|p| p.to_prompt_string().map(|p| p.0))
631 .map_err(|err| err.to_string());
632
633 debug_tx
634 .unbounded_send(PredictionDebugInfo {
635 request: request.clone(),
636 retrieval_time,
637 buffer: buffer.downgrade(),
638 local_prompt,
639 position,
640 response_rx,
641 })
642 .ok();
643 Some(response_tx)
644 } else {
645 None
646 };
647
648 if cfg!(debug_assertions) && std::env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
649 if let Some(debug_response_tx) = debug_response_tx {
650 debug_response_tx
651 .send(Err("Request skipped".to_string()))
652 .ok();
653 }
654 anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
655 }
656
657 let response =
658 Self::send_prediction_request(client, llm_token, app_version, request).await;
659
660 if let Some(debug_response_tx) = debug_response_tx {
661 debug_response_tx
662 .send(
663 response
664 .as_ref()
665 .map_err(|err| err.to_string())
666 .map(|response| response.0.clone()),
667 )
668 .ok();
669 }
670
671 response.map(|(res, usage)| (Some(res), usage))
672 }
673 });
674
675 let buffer = buffer.clone();
676
677 cx.spawn({
678 let project = project.clone();
679 async move |this, cx| {
680 let Some(response) = Self::handle_api_response(&this, request_task.await, cx)?
681 else {
682 return Ok(None);
683 };
684
685 // TODO telemetry: duration, etc
686 Ok(EditPrediction::from_response(response, &snapshot, &buffer, &project, cx).await)
687 }
688 })
689 }
690
691 async fn send_prediction_request(
692 client: Arc<Client>,
693 llm_token: LlmApiToken,
694 app_version: SemanticVersion,
695 request: predict_edits_v3::PredictEditsRequest,
696 ) -> Result<(
697 predict_edits_v3::PredictEditsResponse,
698 Option<EditPredictionUsage>,
699 )> {
700 let url = if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
701 http_client::Url::parse(&predict_edits_url)?
702 } else {
703 client
704 .http_client()
705 .build_zed_llm_url("/predict_edits/v3", &[])?
706 };
707
708 Self::send_api_request(
709 |builder| {
710 let req = builder
711 .uri(url.as_ref())
712 .body(serde_json::to_string(&request)?.into());
713 Ok(req?)
714 },
715 client,
716 llm_token,
717 app_version,
718 )
719 .await
720 }
721
722 fn handle_api_response<T>(
723 this: &WeakEntity<Self>,
724 response: Result<(T, Option<EditPredictionUsage>)>,
725 cx: &mut gpui::AsyncApp,
726 ) -> Result<T> {
727 match response {
728 Ok((data, usage)) => {
729 if let Some(usage) = usage {
730 this.update(cx, |this, cx| {
731 this.user_store.update(cx, |user_store, cx| {
732 user_store.update_edit_prediction_usage(usage, cx);
733 });
734 })
735 .ok();
736 }
737 Ok(data)
738 }
739 Err(err) => {
740 if err.is::<ZedUpdateRequiredError>() {
741 cx.update(|cx| {
742 this.update(cx, |this, _cx| {
743 this.update_required = true;
744 })
745 .ok();
746
747 let error_message: SharedString = err.to_string().into();
748 show_app_notification(
749 NotificationId::unique::<ZedUpdateRequiredError>(),
750 cx,
751 move |cx| {
752 cx.new(|cx| {
753 ErrorMessagePrompt::new(error_message.clone(), cx)
754 .with_link_button("Update Zed", "https://zed.dev/releases")
755 })
756 },
757 );
758 })
759 .ok();
760 }
761 Err(err)
762 }
763 }
764 }
765
766 async fn send_api_request<Res>(
767 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
768 client: Arc<Client>,
769 llm_token: LlmApiToken,
770 app_version: SemanticVersion,
771 ) -> Result<(Res, Option<EditPredictionUsage>)>
772 where
773 Res: DeserializeOwned,
774 {
775 let http_client = client.http_client();
776 let mut token = llm_token.acquire(&client).await?;
777 let mut did_retry = false;
778
779 loop {
780 let request_builder = http_client::Request::builder().method(Method::POST);
781
782 let request = build(
783 request_builder
784 .header("Content-Type", "application/json")
785 .header("Authorization", format!("Bearer {}", token))
786 .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
787 )?;
788
789 let mut response = http_client.send(request).await?;
790
791 if let Some(minimum_required_version) = response
792 .headers()
793 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
794 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
795 {
796 anyhow::ensure!(
797 app_version >= minimum_required_version,
798 ZedUpdateRequiredError {
799 minimum_version: minimum_required_version
800 }
801 );
802 }
803
804 if response.status().is_success() {
805 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
806
807 let mut body = Vec::new();
808 response.body_mut().read_to_end(&mut body).await?;
809 return Ok((serde_json::from_slice(&body)?, usage));
810 } else if !did_retry
811 && response
812 .headers()
813 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
814 .is_some()
815 {
816 did_retry = true;
817 token = llm_token.refresh(&client).await?;
818 } else {
819 let mut body = String::new();
820 response.body_mut().read_to_string(&mut body).await?;
821 anyhow::bail!(
822 "Request failed with status: {:?}\nBody: {}",
823 response.status(),
824 body
825 );
826 }
827 }
828 }
829
830 fn gather_nearby_diagnostics(
831 cursor_offset: usize,
832 diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
833 snapshot: &BufferSnapshot,
834 max_diagnostics_bytes: usize,
835 ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
836 // TODO: Could make this more efficient
837 let mut diagnostic_groups = Vec::new();
838 for (language_server_id, diagnostics) in diagnostic_sets {
839 let mut groups = Vec::new();
840 diagnostics.groups(*language_server_id, &mut groups, &snapshot);
841 diagnostic_groups.extend(
842 groups
843 .into_iter()
844 .map(|(_, group)| group.resolve::<usize>(&snapshot)),
845 );
846 }
847
848 // sort by proximity to cursor
849 diagnostic_groups.sort_by_key(|group| {
850 let range = &group.entries[group.primary_ix].range;
851 if range.start >= cursor_offset {
852 range.start - cursor_offset
853 } else if cursor_offset >= range.end {
854 cursor_offset - range.end
855 } else {
856 (cursor_offset - range.start).min(range.end - cursor_offset)
857 }
858 });
859
860 let mut results = Vec::new();
861 let mut diagnostic_groups_truncated = false;
862 let mut diagnostics_byte_count = 0;
863 for group in diagnostic_groups {
864 let raw_value = serde_json::value::to_raw_value(&group).unwrap();
865 diagnostics_byte_count += raw_value.get().len();
866 if diagnostics_byte_count > max_diagnostics_bytes {
867 diagnostic_groups_truncated = true;
868 break;
869 }
870 results.push(predict_edits_v3::DiagnosticGroup(raw_value));
871 }
872
873 (results, diagnostic_groups_truncated)
874 }
875
876 // TODO: Dedupe with similar code in request_prediction?
877 pub fn cloud_request_for_zeta_cli(
878 &mut self,
879 project: &Entity<Project>,
880 buffer: &Entity<Buffer>,
881 position: language::Anchor,
882 cx: &mut Context<Self>,
883 ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
884 let project_state = self.projects.get(&project.entity_id());
885
886 let index_state = project_state.map(|state| {
887 state
888 .syntax_index
889 .read_with(cx, |index, _cx| index.state().clone())
890 });
891 let options = self.options.clone();
892 let snapshot = buffer.read(cx).snapshot();
893 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
894 return Task::ready(Err(anyhow!("No file path for excerpt")));
895 };
896 let worktree_snapshots = project
897 .read(cx)
898 .worktrees(cx)
899 .map(|worktree| worktree.read(cx).snapshot())
900 .collect::<Vec<_>>();
901
902 let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
903 let mut path = f.worktree.read(cx).absolutize(&f.path);
904 if path.pop() { Some(path) } else { None }
905 });
906
907 cx.background_spawn(async move {
908 let index_state = if let Some(index_state) = index_state {
909 Some(index_state.lock_owned().await)
910 } else {
911 None
912 };
913
914 let cursor_point = position.to_point(&snapshot);
915
916 let debug_info = true;
917 EditPredictionContext::gather_context(
918 cursor_point,
919 &snapshot,
920 parent_abs_path.as_deref(),
921 &options.context,
922 index_state.as_deref(),
923 )
924 .context("Failed to select excerpt")
925 .map(|context| {
926 make_cloud_request(
927 excerpt_path.into(),
928 context,
929 // TODO pass everything
930 Vec::new(),
931 false,
932 Vec::new(),
933 false,
934 None,
935 debug_info,
936 &worktree_snapshots,
937 index_state.as_deref(),
938 Some(options.max_prompt_bytes),
939 options.prompt_format,
940 )
941 })
942 })
943 }
944
945 pub fn wait_for_initial_indexing(
946 &mut self,
947 project: &Entity<Project>,
948 cx: &mut App,
949 ) -> Task<Result<()>> {
950 let zeta_project = self.get_or_init_zeta_project(project, cx);
951 zeta_project
952 .syntax_index
953 .read(cx)
954 .wait_for_initial_file_indexing(cx)
955 }
956}
957
958#[derive(Error, Debug)]
959#[error(
960 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
961)]
962pub struct ZedUpdateRequiredError {
963 minimum_version: SemanticVersion,
964}
965
966fn make_cloud_request(
967 excerpt_path: Arc<Path>,
968 context: EditPredictionContext,
969 events: Vec<predict_edits_v3::Event>,
970 can_collect_data: bool,
971 diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
972 diagnostic_groups_truncated: bool,
973 git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
974 debug_info: bool,
975 worktrees: &Vec<worktree::Snapshot>,
976 index_state: Option<&SyntaxIndexState>,
977 prompt_max_bytes: Option<usize>,
978 prompt_format: PromptFormat,
979) -> predict_edits_v3::PredictEditsRequest {
980 let mut signatures = Vec::new();
981 let mut declaration_to_signature_index = HashMap::default();
982 let mut referenced_declarations = Vec::new();
983
984 for snippet in context.declarations {
985 let project_entry_id = snippet.declaration.project_entry_id();
986 let Some(path) = worktrees.iter().find_map(|worktree| {
987 worktree.entry_for_id(project_entry_id).map(|entry| {
988 let mut full_path = RelPathBuf::new();
989 full_path.push(worktree.root_name());
990 full_path.push(&entry.path);
991 full_path
992 })
993 }) else {
994 continue;
995 };
996
997 let parent_index = index_state.and_then(|index_state| {
998 snippet.declaration.parent().and_then(|parent| {
999 add_signature(
1000 parent,
1001 &mut declaration_to_signature_index,
1002 &mut signatures,
1003 index_state,
1004 )
1005 })
1006 });
1007
1008 let (text, text_is_truncated) = snippet.declaration.item_text();
1009 referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
1010 path: path.as_std_path().into(),
1011 text: text.into(),
1012 range: snippet.declaration.item_line_range(),
1013 text_is_truncated,
1014 signature_range: snippet.declaration.signature_range_in_item_text(),
1015 parent_index,
1016 signature_score: snippet.score(DeclarationStyle::Signature),
1017 declaration_score: snippet.score(DeclarationStyle::Declaration),
1018 score_components: snippet.components,
1019 });
1020 }
1021
1022 let excerpt_parent = index_state.and_then(|index_state| {
1023 context
1024 .excerpt
1025 .parent_declarations
1026 .last()
1027 .and_then(|(parent, _)| {
1028 add_signature(
1029 *parent,
1030 &mut declaration_to_signature_index,
1031 &mut signatures,
1032 index_state,
1033 )
1034 })
1035 });
1036
1037 predict_edits_v3::PredictEditsRequest {
1038 excerpt_path,
1039 excerpt: context.excerpt_text.body,
1040 excerpt_line_range: context.excerpt.line_range,
1041 excerpt_range: context.excerpt.range,
1042 cursor_point: predict_edits_v3::Point {
1043 line: predict_edits_v3::Line(context.cursor_point.row),
1044 column: context.cursor_point.column,
1045 },
1046 referenced_declarations,
1047 signatures,
1048 excerpt_parent,
1049 events,
1050 can_collect_data,
1051 diagnostic_groups,
1052 diagnostic_groups_truncated,
1053 git_info,
1054 debug_info,
1055 prompt_max_bytes,
1056 prompt_format,
1057 }
1058}
1059
1060fn add_signature(
1061 declaration_id: DeclarationId,
1062 declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
1063 signatures: &mut Vec<Signature>,
1064 index: &SyntaxIndexState,
1065) -> Option<usize> {
1066 if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
1067 return Some(*signature_index);
1068 }
1069 let Some(parent_declaration) = index.declaration(declaration_id) else {
1070 log::error!("bug: missing parent declaration");
1071 return None;
1072 };
1073 let parent_index = parent_declaration.parent().and_then(|parent| {
1074 add_signature(parent, declaration_to_signature_index, signatures, index)
1075 });
1076 let (text, text_is_truncated) = parent_declaration.signature_text();
1077 let signature_index = signatures.len();
1078 signatures.push(Signature {
1079 text: text.into(),
1080 text_is_truncated,
1081 parent_index,
1082 range: parent_declaration.signature_line_range(),
1083 });
1084 declaration_to_signature_index.insert(declaration_id, signature_index);
1085 Some(signature_index)
1086}
1087
1088#[cfg(test)]
1089mod tests {
1090 use std::{
1091 path::{Path, PathBuf},
1092 sync::Arc,
1093 };
1094
1095 use client::UserStore;
1096 use clock::FakeSystemClock;
1097 use cloud_llm_client::predict_edits_v3::{self, Point};
1098 use edit_prediction_context::Line;
1099 use futures::{
1100 AsyncReadExt, StreamExt,
1101 channel::{mpsc, oneshot},
1102 };
1103 use gpui::{
1104 Entity, TestAppContext,
1105 http_client::{FakeHttpClient, Response},
1106 prelude::*,
1107 };
1108 use indoc::indoc;
1109 use language::{LanguageServerId, OffsetRangeExt as _};
1110 use pretty_assertions::{assert_eq, assert_matches};
1111 use project::{FakeFs, Project};
1112 use serde_json::json;
1113 use settings::SettingsStore;
1114 use util::path;
1115 use uuid::Uuid;
1116
1117 use crate::{BufferEditPrediction, Zeta};
1118
1119 #[gpui::test]
1120 async fn test_current_state(cx: &mut TestAppContext) {
1121 let (zeta, mut req_rx) = init_test(cx);
1122 let fs = FakeFs::new(cx.executor());
1123 fs.insert_tree(
1124 "/root",
1125 json!({
1126 "1.txt": "Hello!\nHow\nBye",
1127 "2.txt": "Hola!\nComo\nAdios"
1128 }),
1129 )
1130 .await;
1131 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1132
1133 zeta.update(cx, |zeta, cx| {
1134 zeta.register_project(&project, cx);
1135 });
1136
1137 let buffer1 = project
1138 .update(cx, |project, cx| {
1139 let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
1140 project.open_buffer(path, cx)
1141 })
1142 .await
1143 .unwrap();
1144 let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
1145 let position = snapshot1.anchor_before(language::Point::new(1, 3));
1146
1147 // Prediction for current file
1148
1149 let prediction_task = zeta.update(cx, |zeta, cx| {
1150 zeta.refresh_prediction(&project, &buffer1, position, cx)
1151 });
1152 let (_request, respond_tx) = req_rx.next().await.unwrap();
1153 respond_tx
1154 .send(predict_edits_v3::PredictEditsResponse {
1155 request_id: Uuid::new_v4(),
1156 edits: vec![predict_edits_v3::Edit {
1157 path: Path::new(path!("root/1.txt")).into(),
1158 range: Line(0)..Line(snapshot1.max_point().row + 1),
1159 content: "Hello!\nHow are you?\nBye".into(),
1160 }],
1161 debug_info: None,
1162 })
1163 .unwrap();
1164 prediction_task.await.unwrap();
1165
1166 zeta.read_with(cx, |zeta, cx| {
1167 let prediction = zeta
1168 .current_prediction_for_buffer(&buffer1, &project, cx)
1169 .unwrap();
1170 assert_matches!(prediction, BufferEditPrediction::Local { .. });
1171 });
1172
1173 // Prediction for another file
1174 let prediction_task = zeta.update(cx, |zeta, cx| {
1175 zeta.refresh_prediction(&project, &buffer1, position, cx)
1176 });
1177 let (_request, respond_tx) = req_rx.next().await.unwrap();
1178 respond_tx
1179 .send(predict_edits_v3::PredictEditsResponse {
1180 request_id: Uuid::new_v4(),
1181 edits: vec![predict_edits_v3::Edit {
1182 path: Path::new(path!("root/2.txt")).into(),
1183 range: Line(0)..Line(snapshot1.max_point().row + 1),
1184 content: "Hola!\nComo estas?\nAdios".into(),
1185 }],
1186 debug_info: None,
1187 })
1188 .unwrap();
1189 prediction_task.await.unwrap();
1190 zeta.read_with(cx, |zeta, cx| {
1191 let prediction = zeta
1192 .current_prediction_for_buffer(&buffer1, &project, cx)
1193 .unwrap();
1194 assert_matches!(
1195 prediction,
1196 BufferEditPrediction::Jump { prediction } if prediction.path.as_ref() == Path::new(path!("root/2.txt"))
1197 );
1198 });
1199
1200 let buffer2 = project
1201 .update(cx, |project, cx| {
1202 let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
1203 project.open_buffer(path, cx)
1204 })
1205 .await
1206 .unwrap();
1207
1208 zeta.read_with(cx, |zeta, cx| {
1209 let prediction = zeta
1210 .current_prediction_for_buffer(&buffer2, &project, cx)
1211 .unwrap();
1212 assert_matches!(prediction, BufferEditPrediction::Local { .. });
1213 });
1214 }
1215
1216 #[gpui::test]
1217 async fn test_simple_request(cx: &mut TestAppContext) {
1218 let (zeta, mut req_rx) = init_test(cx);
1219 let fs = FakeFs::new(cx.executor());
1220 fs.insert_tree(
1221 "/root",
1222 json!({
1223 "foo.md": "Hello!\nHow\nBye"
1224 }),
1225 )
1226 .await;
1227 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1228
1229 let buffer = project
1230 .update(cx, |project, cx| {
1231 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1232 project.open_buffer(path, cx)
1233 })
1234 .await
1235 .unwrap();
1236 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1237 let position = snapshot.anchor_before(language::Point::new(1, 3));
1238
1239 let prediction_task = zeta.update(cx, |zeta, cx| {
1240 zeta.request_prediction(&project, &buffer, position, cx)
1241 });
1242
1243 let (request, respond_tx) = req_rx.next().await.unwrap();
1244 assert_eq!(
1245 request.excerpt_path.as_ref(),
1246 Path::new(path!("root/foo.md"))
1247 );
1248 assert_eq!(
1249 request.cursor_point,
1250 Point {
1251 line: Line(1),
1252 column: 3
1253 }
1254 );
1255
1256 respond_tx
1257 .send(predict_edits_v3::PredictEditsResponse {
1258 request_id: Uuid::new_v4(),
1259 edits: vec![predict_edits_v3::Edit {
1260 path: Path::new(path!("root/foo.md")).into(),
1261 range: Line(0)..Line(snapshot.max_point().row + 1),
1262 content: "Hello!\nHow are you?\nBye".into(),
1263 }],
1264 debug_info: None,
1265 })
1266 .unwrap();
1267
1268 let prediction = prediction_task.await.unwrap().unwrap();
1269
1270 assert_eq!(prediction.edits.len(), 1);
1271 assert_eq!(
1272 prediction.edits[0].0.to_point(&snapshot).start,
1273 language::Point::new(1, 3)
1274 );
1275 assert_eq!(prediction.edits[0].1, " are you?");
1276 }
1277
1278 #[gpui::test]
1279 async fn test_request_events(cx: &mut TestAppContext) {
1280 let (zeta, mut req_rx) = init_test(cx);
1281 let fs = FakeFs::new(cx.executor());
1282 fs.insert_tree(
1283 "/root",
1284 json!({
1285 "foo.md": "Hello!\n\nBye"
1286 }),
1287 )
1288 .await;
1289 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1290
1291 let buffer = project
1292 .update(cx, |project, cx| {
1293 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1294 project.open_buffer(path, cx)
1295 })
1296 .await
1297 .unwrap();
1298
1299 zeta.update(cx, |zeta, cx| {
1300 zeta.register_buffer(&buffer, &project, cx);
1301 });
1302
1303 buffer.update(cx, |buffer, cx| {
1304 buffer.edit(vec![(7..7, "How")], None, cx);
1305 });
1306
1307 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1308 let position = snapshot.anchor_before(language::Point::new(1, 3));
1309
1310 let prediction_task = zeta.update(cx, |zeta, cx| {
1311 zeta.request_prediction(&project, &buffer, position, cx)
1312 });
1313
1314 let (request, respond_tx) = req_rx.next().await.unwrap();
1315
1316 assert_eq!(request.events.len(), 1);
1317 assert_eq!(
1318 request.events[0],
1319 predict_edits_v3::Event::BufferChange {
1320 path: Some(PathBuf::from(path!("root/foo.md"))),
1321 old_path: None,
1322 diff: indoc! {"
1323 @@ -1,3 +1,3 @@
1324 Hello!
1325 -
1326 +How
1327 Bye
1328 "}
1329 .to_string(),
1330 predicted: false
1331 }
1332 );
1333
1334 respond_tx
1335 .send(predict_edits_v3::PredictEditsResponse {
1336 request_id: Uuid::new_v4(),
1337 edits: vec![predict_edits_v3::Edit {
1338 path: Path::new(path!("root/foo.md")).into(),
1339 range: Line(0)..Line(snapshot.max_point().row + 1),
1340 content: "Hello!\nHow are you?\nBye".into(),
1341 }],
1342 debug_info: None,
1343 })
1344 .unwrap();
1345
1346 let prediction = prediction_task.await.unwrap().unwrap();
1347
1348 assert_eq!(prediction.edits.len(), 1);
1349 assert_eq!(
1350 prediction.edits[0].0.to_point(&snapshot).start,
1351 language::Point::new(1, 3)
1352 );
1353 assert_eq!(prediction.edits[0].1, " are you?");
1354 }
1355
1356 #[gpui::test]
1357 async fn test_request_diagnostics(cx: &mut TestAppContext) {
1358 let (zeta, mut req_rx) = init_test(cx);
1359 let fs = FakeFs::new(cx.executor());
1360 fs.insert_tree(
1361 "/root",
1362 json!({
1363 "foo.md": "Hello!\nBye"
1364 }),
1365 )
1366 .await;
1367 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1368
1369 let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1370 let diagnostic = lsp::Diagnostic {
1371 range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1372 severity: Some(lsp::DiagnosticSeverity::ERROR),
1373 message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1374 ..Default::default()
1375 };
1376
1377 project.update(cx, |project, cx| {
1378 project.lsp_store().update(cx, |lsp_store, cx| {
1379 // Create some diagnostics
1380 lsp_store
1381 .update_diagnostics(
1382 LanguageServerId(0),
1383 lsp::PublishDiagnosticsParams {
1384 uri: path_to_buffer_uri.clone(),
1385 diagnostics: vec![diagnostic],
1386 version: None,
1387 },
1388 None,
1389 language::DiagnosticSourceKind::Pushed,
1390 &[],
1391 cx,
1392 )
1393 .unwrap();
1394 });
1395 });
1396
1397 let buffer = project
1398 .update(cx, |project, cx| {
1399 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1400 project.open_buffer(path, cx)
1401 })
1402 .await
1403 .unwrap();
1404
1405 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1406 let position = snapshot.anchor_before(language::Point::new(0, 0));
1407
1408 let _prediction_task = zeta.update(cx, |zeta, cx| {
1409 zeta.request_prediction(&project, &buffer, position, cx)
1410 });
1411
1412 let (request, _respond_tx) = req_rx.next().await.unwrap();
1413
1414 assert_eq!(request.diagnostic_groups.len(), 1);
1415 let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1416 .unwrap();
1417 // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1418 assert_eq!(
1419 value,
1420 json!({
1421 "entries": [{
1422 "range": {
1423 "start": 8,
1424 "end": 10
1425 },
1426 "diagnostic": {
1427 "source": null,
1428 "code": null,
1429 "code_description": null,
1430 "severity": 1,
1431 "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1432 "markdown": null,
1433 "group_id": 0,
1434 "is_primary": true,
1435 "is_disk_based": false,
1436 "is_unnecessary": false,
1437 "source_kind": "Pushed",
1438 "data": null,
1439 "underline": true
1440 }
1441 }],
1442 "primary_ix": 0
1443 })
1444 );
1445 }
1446
1447 fn init_test(
1448 cx: &mut TestAppContext,
1449 ) -> (
1450 Entity<Zeta>,
1451 mpsc::UnboundedReceiver<(
1452 predict_edits_v3::PredictEditsRequest,
1453 oneshot::Sender<predict_edits_v3::PredictEditsResponse>,
1454 )>,
1455 ) {
1456 cx.update(move |cx| {
1457 let settings_store = SettingsStore::test(cx);
1458 cx.set_global(settings_store);
1459 language::init(cx);
1460 Project::init_settings(cx);
1461
1462 let (req_tx, req_rx) = mpsc::unbounded();
1463
1464 let http_client = FakeHttpClient::create({
1465 move |req| {
1466 let uri = req.uri().path().to_string();
1467 let mut body = req.into_body();
1468 let req_tx = req_tx.clone();
1469 async move {
1470 let resp = match uri.as_str() {
1471 "/client/llm_tokens" => serde_json::to_string(&json!({
1472 "token": "test"
1473 }))
1474 .unwrap(),
1475 "/predict_edits/v3" => {
1476 let mut buf = Vec::new();
1477 body.read_to_end(&mut buf).await.ok();
1478 let req = serde_json::from_slice(&buf).unwrap();
1479
1480 let (res_tx, res_rx) = oneshot::channel();
1481 req_tx.unbounded_send((req, res_tx)).unwrap();
1482 serde_json::to_string(&res_rx.await?).unwrap()
1483 }
1484 _ => {
1485 panic!("Unexpected path: {}", uri)
1486 }
1487 };
1488
1489 Ok(Response::builder().body(resp.into()).unwrap())
1490 }
1491 }
1492 });
1493
1494 let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1495 client.cloud_client().set_credentials(1, "test".into());
1496
1497 language_model::init(client.clone(), cx);
1498
1499 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1500 let zeta = Zeta::global(&client, &user_store, cx);
1501
1502 (zeta, req_rx)
1503 })
1504 }
1505}