1use anyhow::{Context as _, Result, anyhow, bail};
2use chrono::TimeDelta;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
5use cloud_llm_client::{
6 AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
7 ZED_VERSION_HEADER_NAME,
8};
9use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
10use cloud_zeta2_prompt::{CURSOR_MARKER, DEFAULT_MAX_PROMPT_BYTES};
11use collections::HashMap;
12use edit_prediction_context::{
13 DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
14 EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionScoreOptions, Line,
15 SyntaxIndex, SyntaxIndexState,
16};
17use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
18use futures::AsyncReadExt as _;
19use futures::channel::{mpsc, oneshot};
20use gpui::http_client::{AsyncBody, Method};
21use gpui::{
22 App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
23 http_client, prelude::*,
24};
25use language::{Anchor, Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
26use language::{BufferSnapshot, OffsetRangeExt};
27use language_model::{LlmApiToken, RefreshLlmTokenListener};
28use open_ai::FunctionDefinition;
29use project::Project;
30use release_channel::AppVersion;
31use serde::de::DeserializeOwned;
32use std::collections::{VecDeque, hash_map};
33
34use std::env;
35use std::ops::Range;
36use std::path::Path;
37use std::str::FromStr as _;
38use std::sync::{Arc, LazyLock};
39use std::time::{Duration, Instant};
40use thiserror::Error;
41use util::rel_path::RelPathBuf;
42use util::{LogErrorFuture, TryFutureExt};
43use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
44
45pub mod merge_excerpts;
46mod prediction;
47mod provider;
48pub mod retrieval_search;
49pub mod udiff;
50mod xml_edits;
51
52use crate::merge_excerpts::merge_excerpts;
53use crate::prediction::EditPrediction;
54pub use crate::prediction::EditPredictionId;
55pub use provider::ZetaEditPredictionProvider;
56
57/// Maximum number of events to track.
58const MAX_EVENT_COUNT: usize = 16;
59
60pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
61 max_bytes: 512,
62 min_bytes: 128,
63 target_before_cursor_over_total_bytes: 0.5,
64};
65
66pub const DEFAULT_CONTEXT_OPTIONS: ContextMode =
67 ContextMode::Agentic(DEFAULT_AGENTIC_CONTEXT_OPTIONS);
68
69pub const DEFAULT_AGENTIC_CONTEXT_OPTIONS: AgenticContextOptions = AgenticContextOptions {
70 excerpt: DEFAULT_EXCERPT_OPTIONS,
71};
72
73pub const DEFAULT_SYNTAX_CONTEXT_OPTIONS: EditPredictionContextOptions =
74 EditPredictionContextOptions {
75 use_imports: true,
76 max_retrieved_declarations: 0,
77 excerpt: DEFAULT_EXCERPT_OPTIONS,
78 score: EditPredictionScoreOptions {
79 omit_excerpt_overlaps: true,
80 },
81 };
82
83pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
84 context: DEFAULT_CONTEXT_OPTIONS,
85 max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
86 max_diagnostic_bytes: 2048,
87 prompt_format: PromptFormat::DEFAULT,
88 file_indexing_parallelism: 1,
89 buffer_change_grouping_interval: Duration::from_secs(1),
90};
91
92static USE_OLLAMA: LazyLock<bool> =
93 LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
94static MODEL_ID: LazyLock<String> = LazyLock::new(|| {
95 env::var("ZED_ZETA2_MODEL").unwrap_or(if *USE_OLLAMA {
96 "qwen3-coder:30b".to_string()
97 } else {
98 "yqvev8r3".to_string()
99 })
100});
101static PREDICT_EDITS_URL: LazyLock<Option<String>> = LazyLock::new(|| {
102 env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| {
103 if *USE_OLLAMA {
104 Some("http://localhost:11434/v1/chat/completions".into())
105 } else {
106 None
107 }
108 })
109});
110
111pub struct Zeta2FeatureFlag;
112
113impl FeatureFlag for Zeta2FeatureFlag {
114 const NAME: &'static str = "zeta2";
115
116 fn enabled_for_staff() -> bool {
117 false
118 }
119}
120
121#[derive(Clone)]
122struct ZetaGlobal(Entity<Zeta>);
123
124impl Global for ZetaGlobal {}
125
126pub struct Zeta {
127 client: Arc<Client>,
128 user_store: Entity<UserStore>,
129 llm_token: LlmApiToken,
130 _llm_token_subscription: Subscription,
131 projects: HashMap<EntityId, ZetaProject>,
132 options: ZetaOptions,
133 update_required: bool,
134 debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
135 #[cfg(feature = "llm-response-cache")]
136 llm_response_cache: Option<Arc<dyn LlmResponseCache>>,
137}
138
139#[cfg(feature = "llm-response-cache")]
140pub trait LlmResponseCache: Send + Sync {
141 fn get_key(&self, url: &gpui::http_client::Url, body: &str) -> u64;
142 fn read_response(&self, key: u64) -> Option<String>;
143 fn write_response(&self, key: u64, value: &str);
144}
145
146#[derive(Debug, Clone, PartialEq)]
147pub struct ZetaOptions {
148 pub context: ContextMode,
149 pub max_prompt_bytes: usize,
150 pub max_diagnostic_bytes: usize,
151 pub prompt_format: predict_edits_v3::PromptFormat,
152 pub file_indexing_parallelism: usize,
153 pub buffer_change_grouping_interval: Duration,
154}
155
156#[derive(Debug, Clone, PartialEq)]
157pub enum ContextMode {
158 Agentic(AgenticContextOptions),
159 Syntax(EditPredictionContextOptions),
160}
161
162#[derive(Debug, Clone, PartialEq)]
163pub struct AgenticContextOptions {
164 pub excerpt: EditPredictionExcerptOptions,
165}
166
167impl ContextMode {
168 pub fn excerpt(&self) -> &EditPredictionExcerptOptions {
169 match self {
170 ContextMode::Agentic(options) => &options.excerpt,
171 ContextMode::Syntax(options) => &options.excerpt,
172 }
173 }
174}
175
176#[derive(Debug)]
177pub enum ZetaDebugInfo {
178 ContextRetrievalStarted(ZetaContextRetrievalStartedDebugInfo),
179 SearchQueriesGenerated(ZetaSearchQueryDebugInfo),
180 SearchQueriesExecuted(ZetaContextRetrievalDebugInfo),
181 ContextRetrievalFinished(ZetaContextRetrievalDebugInfo),
182 EditPredictionRequested(ZetaEditPredictionDebugInfo),
183}
184
185#[derive(Debug)]
186pub struct ZetaContextRetrievalStartedDebugInfo {
187 pub project: Entity<Project>,
188 pub timestamp: Instant,
189 pub search_prompt: String,
190}
191
192#[derive(Debug)]
193pub struct ZetaContextRetrievalDebugInfo {
194 pub project: Entity<Project>,
195 pub timestamp: Instant,
196}
197
198#[derive(Debug)]
199pub struct ZetaEditPredictionDebugInfo {
200 pub request: predict_edits_v3::PredictEditsRequest,
201 pub retrieval_time: TimeDelta,
202 pub buffer: WeakEntity<Buffer>,
203 pub position: language::Anchor,
204 pub local_prompt: Result<String, String>,
205 pub response_rx: oneshot::Receiver<(Result<open_ai::Response, String>, TimeDelta)>,
206}
207
208#[derive(Debug)]
209pub struct ZetaSearchQueryDebugInfo {
210 pub project: Entity<Project>,
211 pub timestamp: Instant,
212 pub search_queries: Vec<SearchToolQuery>,
213}
214
215pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
216
217struct ZetaProject {
218 syntax_index: Entity<SyntaxIndex>,
219 events: VecDeque<Event>,
220 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
221 current_prediction: Option<CurrentEditPrediction>,
222 context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
223 refresh_context_task: Option<LogErrorFuture<Task<Result<()>>>>,
224 refresh_context_debounce_task: Option<Task<Option<()>>>,
225 refresh_context_timestamp: Option<Instant>,
226}
227
228#[derive(Debug, Clone)]
229struct CurrentEditPrediction {
230 pub requested_by_buffer_id: EntityId,
231 pub prediction: EditPrediction,
232}
233
234impl CurrentEditPrediction {
235 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
236 let Some(new_edits) = self
237 .prediction
238 .interpolate(&self.prediction.buffer.read(cx))
239 else {
240 return false;
241 };
242
243 if self.prediction.buffer != old_prediction.prediction.buffer {
244 return true;
245 }
246
247 let Some(old_edits) = old_prediction
248 .prediction
249 .interpolate(&old_prediction.prediction.buffer.read(cx))
250 else {
251 return true;
252 };
253
254 // This reduces the occurrence of UI thrash from replacing edits
255 //
256 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
257 if self.requested_by_buffer_id == self.prediction.buffer.entity_id()
258 && self.requested_by_buffer_id == old_prediction.prediction.buffer.entity_id()
259 && old_edits.len() == 1
260 && new_edits.len() == 1
261 {
262 let (old_range, old_text) = &old_edits[0];
263 let (new_range, new_text) = &new_edits[0];
264 new_range == old_range && new_text.starts_with(old_text.as_ref())
265 } else {
266 true
267 }
268 }
269}
270
271/// A prediction from the perspective of a buffer.
272#[derive(Debug)]
273enum BufferEditPrediction<'a> {
274 Local { prediction: &'a EditPrediction },
275 Jump { prediction: &'a EditPrediction },
276}
277
278struct RegisteredBuffer {
279 snapshot: BufferSnapshot,
280 _subscriptions: [gpui::Subscription; 2],
281}
282
283#[derive(Clone)]
284pub enum Event {
285 BufferChange {
286 old_snapshot: BufferSnapshot,
287 new_snapshot: BufferSnapshot,
288 timestamp: Instant,
289 },
290}
291
292impl Event {
293 pub fn to_request_event(&self, cx: &App) -> Option<predict_edits_v3::Event> {
294 match self {
295 Event::BufferChange {
296 old_snapshot,
297 new_snapshot,
298 ..
299 } => {
300 let path = new_snapshot.file().map(|f| f.full_path(cx));
301
302 let old_path = old_snapshot.file().and_then(|f| {
303 let old_path = f.full_path(cx);
304 if Some(&old_path) != path.as_ref() {
305 Some(old_path)
306 } else {
307 None
308 }
309 });
310
311 // TODO [zeta2] move to bg?
312 let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
313
314 if path == old_path && diff.is_empty() {
315 None
316 } else {
317 Some(predict_edits_v3::Event::BufferChange {
318 old_path,
319 path,
320 diff,
321 //todo: Actually detect if this edit was predicted or not
322 predicted: false,
323 })
324 }
325 }
326 }
327 }
328}
329
330impl Zeta {
331 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
332 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
333 }
334
335 pub fn global(
336 client: &Arc<Client>,
337 user_store: &Entity<UserStore>,
338 cx: &mut App,
339 ) -> Entity<Self> {
340 cx.try_global::<ZetaGlobal>()
341 .map(|global| global.0.clone())
342 .unwrap_or_else(|| {
343 let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
344 cx.set_global(ZetaGlobal(zeta.clone()));
345 zeta
346 })
347 }
348
349 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
350 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
351
352 Self {
353 projects: HashMap::default(),
354 client,
355 user_store,
356 options: DEFAULT_OPTIONS,
357 llm_token: LlmApiToken::default(),
358 _llm_token_subscription: cx.subscribe(
359 &refresh_llm_token_listener,
360 |this, _listener, _event, cx| {
361 let client = this.client.clone();
362 let llm_token = this.llm_token.clone();
363 cx.spawn(async move |_this, _cx| {
364 llm_token.refresh(&client).await?;
365 anyhow::Ok(())
366 })
367 .detach_and_log_err(cx);
368 },
369 ),
370 update_required: false,
371 debug_tx: None,
372 #[cfg(feature = "llm-response-cache")]
373 llm_response_cache: None,
374 }
375 }
376
377 #[cfg(feature = "llm-response-cache")]
378 pub fn with_llm_response_cache(&mut self, cache: Arc<dyn LlmResponseCache>) {
379 self.llm_response_cache = Some(cache);
380 }
381
382 pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<ZetaDebugInfo> {
383 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
384 self.debug_tx = Some(debug_watch_tx);
385 debug_watch_rx
386 }
387
388 pub fn options(&self) -> &ZetaOptions {
389 &self.options
390 }
391
392 pub fn set_options(&mut self, options: ZetaOptions) {
393 self.options = options;
394 }
395
396 pub fn clear_history(&mut self) {
397 for zeta_project in self.projects.values_mut() {
398 zeta_project.events.clear();
399 }
400 }
401
402 pub fn history_for_project(&self, project: &Entity<Project>) -> impl Iterator<Item = &Event> {
403 self.projects
404 .get(&project.entity_id())
405 .map(|project| project.events.iter())
406 .into_iter()
407 .flatten()
408 }
409
410 pub fn context_for_project(
411 &self,
412 project: &Entity<Project>,
413 ) -> impl Iterator<Item = (Entity<Buffer>, &[Range<Anchor>])> {
414 self.projects
415 .get(&project.entity_id())
416 .and_then(|project| {
417 Some(
418 project
419 .context
420 .as_ref()?
421 .iter()
422 .map(|(buffer, ranges)| (buffer.clone(), ranges.as_slice())),
423 )
424 })
425 .into_iter()
426 .flatten()
427 }
428
429 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
430 self.user_store.read(cx).edit_prediction_usage()
431 }
432
433 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
434 self.get_or_init_zeta_project(project, cx);
435 }
436
437 pub fn register_buffer(
438 &mut self,
439 buffer: &Entity<Buffer>,
440 project: &Entity<Project>,
441 cx: &mut Context<Self>,
442 ) {
443 let zeta_project = self.get_or_init_zeta_project(project, cx);
444 Self::register_buffer_impl(zeta_project, buffer, project, cx);
445 }
446
447 fn get_or_init_zeta_project(
448 &mut self,
449 project: &Entity<Project>,
450 cx: &mut App,
451 ) -> &mut ZetaProject {
452 self.projects
453 .entry(project.entity_id())
454 .or_insert_with(|| ZetaProject {
455 syntax_index: cx.new(|cx| {
456 SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
457 }),
458 events: VecDeque::new(),
459 registered_buffers: HashMap::default(),
460 current_prediction: None,
461 context: None,
462 refresh_context_task: None,
463 refresh_context_debounce_task: None,
464 refresh_context_timestamp: None,
465 })
466 }
467
468 fn register_buffer_impl<'a>(
469 zeta_project: &'a mut ZetaProject,
470 buffer: &Entity<Buffer>,
471 project: &Entity<Project>,
472 cx: &mut Context<Self>,
473 ) -> &'a mut RegisteredBuffer {
474 let buffer_id = buffer.entity_id();
475 match zeta_project.registered_buffers.entry(buffer_id) {
476 hash_map::Entry::Occupied(entry) => entry.into_mut(),
477 hash_map::Entry::Vacant(entry) => {
478 let snapshot = buffer.read(cx).snapshot();
479 let project_entity_id = project.entity_id();
480 entry.insert(RegisteredBuffer {
481 snapshot,
482 _subscriptions: [
483 cx.subscribe(buffer, {
484 let project = project.downgrade();
485 move |this, buffer, event, cx| {
486 if let language::BufferEvent::Edited = event
487 && let Some(project) = project.upgrade()
488 {
489 this.report_changes_for_buffer(&buffer, &project, cx);
490 }
491 }
492 }),
493 cx.observe_release(buffer, move |this, _buffer, _cx| {
494 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
495 else {
496 return;
497 };
498 zeta_project.registered_buffers.remove(&buffer_id);
499 }),
500 ],
501 })
502 }
503 }
504 }
505
506 fn report_changes_for_buffer(
507 &mut self,
508 buffer: &Entity<Buffer>,
509 project: &Entity<Project>,
510 cx: &mut Context<Self>,
511 ) -> BufferSnapshot {
512 let buffer_change_grouping_interval = self.options.buffer_change_grouping_interval;
513 let zeta_project = self.get_or_init_zeta_project(project, cx);
514 let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
515
516 let new_snapshot = buffer.read(cx).snapshot();
517 if new_snapshot.version != registered_buffer.snapshot.version {
518 let old_snapshot =
519 std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
520 Self::push_event(
521 zeta_project,
522 buffer_change_grouping_interval,
523 Event::BufferChange {
524 old_snapshot,
525 new_snapshot: new_snapshot.clone(),
526 timestamp: Instant::now(),
527 },
528 );
529 }
530
531 new_snapshot
532 }
533
534 fn push_event(
535 zeta_project: &mut ZetaProject,
536 buffer_change_grouping_interval: Duration,
537 event: Event,
538 ) {
539 let events = &mut zeta_project.events;
540
541 if buffer_change_grouping_interval > Duration::ZERO
542 && let Some(Event::BufferChange {
543 new_snapshot: last_new_snapshot,
544 timestamp: last_timestamp,
545 ..
546 }) = events.back_mut()
547 {
548 // Coalesce edits for the same buffer when they happen one after the other.
549 let Event::BufferChange {
550 old_snapshot,
551 new_snapshot,
552 timestamp,
553 } = &event;
554
555 if timestamp.duration_since(*last_timestamp) <= buffer_change_grouping_interval
556 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
557 && old_snapshot.version == last_new_snapshot.version
558 {
559 *last_new_snapshot = new_snapshot.clone();
560 *last_timestamp = *timestamp;
561 return;
562 }
563 }
564
565 if events.len() >= MAX_EVENT_COUNT {
566 // These are halved instead of popping to improve prompt caching.
567 events.drain(..MAX_EVENT_COUNT / 2);
568 }
569
570 events.push_back(event);
571 }
572
573 fn current_prediction_for_buffer(
574 &self,
575 buffer: &Entity<Buffer>,
576 project: &Entity<Project>,
577 cx: &App,
578 ) -> Option<BufferEditPrediction<'_>> {
579 let project_state = self.projects.get(&project.entity_id())?;
580
581 let CurrentEditPrediction {
582 requested_by_buffer_id,
583 prediction,
584 } = project_state.current_prediction.as_ref()?;
585
586 if prediction.targets_buffer(buffer.read(cx)) {
587 Some(BufferEditPrediction::Local { prediction })
588 } else if *requested_by_buffer_id == buffer.entity_id() {
589 Some(BufferEditPrediction::Jump { prediction })
590 } else {
591 None
592 }
593 }
594
595 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
596 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
597 return;
598 };
599
600 let Some(prediction) = project_state.current_prediction.take() else {
601 return;
602 };
603 let request_id = prediction.prediction.id.to_string();
604
605 let client = self.client.clone();
606 let llm_token = self.llm_token.clone();
607 let app_version = AppVersion::global(cx);
608 cx.spawn(async move |this, cx| {
609 let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") {
610 http_client::Url::parse(&predict_edits_url)?
611 } else {
612 client
613 .http_client()
614 .build_zed_llm_url("/predict_edits/accept", &[])?
615 };
616
617 let response = cx
618 .background_spawn(Self::send_api_request::<()>(
619 move |builder| {
620 let req = builder.uri(url.as_ref()).body(
621 serde_json::to_string(&AcceptEditPredictionBody {
622 request_id: request_id.clone(),
623 })?
624 .into(),
625 );
626 Ok(req?)
627 },
628 client,
629 llm_token,
630 app_version,
631 ))
632 .await;
633
634 Self::handle_api_response(&this, response, cx)?;
635 anyhow::Ok(())
636 })
637 .detach_and_log_err(cx);
638 }
639
640 fn discard_current_prediction(&mut self, project: &Entity<Project>) {
641 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
642 project_state.current_prediction.take();
643 };
644 }
645
646 pub fn refresh_prediction(
647 &mut self,
648 project: &Entity<Project>,
649 buffer: &Entity<Buffer>,
650 position: language::Anchor,
651 cx: &mut Context<Self>,
652 ) -> Task<Result<()>> {
653 let request_task = self.request_prediction(project, buffer, position, cx);
654 let buffer = buffer.clone();
655 let project = project.clone();
656
657 cx.spawn(async move |this, cx| {
658 if let Some(prediction) = request_task.await? {
659 this.update(cx, |this, cx| {
660 let project_state = this
661 .projects
662 .get_mut(&project.entity_id())
663 .context("Project not found")?;
664
665 let new_prediction = CurrentEditPrediction {
666 requested_by_buffer_id: buffer.entity_id(),
667 prediction: prediction,
668 };
669
670 if project_state
671 .current_prediction
672 .as_ref()
673 .is_none_or(|old_prediction| {
674 new_prediction.should_replace_prediction(&old_prediction, cx)
675 })
676 {
677 project_state.current_prediction = Some(new_prediction);
678 }
679 anyhow::Ok(())
680 })??;
681 }
682 Ok(())
683 })
684 }
685
686 pub fn request_prediction(
687 &mut self,
688 project: &Entity<Project>,
689 active_buffer: &Entity<Buffer>,
690 position: language::Anchor,
691 cx: &mut Context<Self>,
692 ) -> Task<Result<Option<EditPrediction>>> {
693 let project_state = self.projects.get(&project.entity_id());
694
695 let index_state = project_state.map(|state| {
696 state
697 .syntax_index
698 .read_with(cx, |index, _cx| index.state().clone())
699 });
700 let options = self.options.clone();
701 let active_snapshot = active_buffer.read(cx).snapshot();
702 let Some(excerpt_path) = active_snapshot
703 .file()
704 .map(|path| -> Arc<Path> { path.full_path(cx).into() })
705 else {
706 return Task::ready(Err(anyhow!("No file path for excerpt")));
707 };
708 let client = self.client.clone();
709 let llm_token = self.llm_token.clone();
710 let app_version = AppVersion::global(cx);
711 let worktree_snapshots = project
712 .read(cx)
713 .worktrees(cx)
714 .map(|worktree| worktree.read(cx).snapshot())
715 .collect::<Vec<_>>();
716 let debug_tx = self.debug_tx.clone();
717
718 let events = project_state
719 .map(|state| {
720 state
721 .events
722 .iter()
723 .filter_map(|event| event.to_request_event(cx))
724 .collect::<Vec<_>>()
725 })
726 .unwrap_or_default();
727
728 let diagnostics = active_snapshot.diagnostic_sets().clone();
729
730 let parent_abs_path =
731 project::File::from_dyn(active_buffer.read(cx).file()).and_then(|f| {
732 let mut path = f.worktree.read(cx).absolutize(&f.path);
733 if path.pop() { Some(path) } else { None }
734 });
735
736 // TODO data collection
737 let can_collect_data = cx.is_staff();
738
739 let mut included_files = project_state
740 .and_then(|project_state| project_state.context.as_ref())
741 .unwrap_or(&HashMap::default())
742 .iter()
743 .filter_map(|(buffer_entity, ranges)| {
744 let buffer = buffer_entity.read(cx);
745 Some((
746 buffer_entity.clone(),
747 buffer.snapshot(),
748 buffer.file()?.full_path(cx).into(),
749 ranges.clone(),
750 ))
751 })
752 .collect::<Vec<_>>();
753
754 #[cfg(feature = "llm-response-cache")]
755 let llm_response_cache = self.llm_response_cache.clone();
756
757 let request_task = cx.background_spawn({
758 let active_buffer = active_buffer.clone();
759 async move {
760 let index_state = if let Some(index_state) = index_state {
761 Some(index_state.lock_owned().await)
762 } else {
763 None
764 };
765
766 let cursor_offset = position.to_offset(&active_snapshot);
767 let cursor_point = cursor_offset.to_point(&active_snapshot);
768
769 let before_retrieval = chrono::Utc::now();
770
771 let (diagnostic_groups, diagnostic_groups_truncated) =
772 Self::gather_nearby_diagnostics(
773 cursor_offset,
774 &diagnostics,
775 &active_snapshot,
776 options.max_diagnostic_bytes,
777 );
778
779 let cloud_request = match options.context {
780 ContextMode::Agentic(context_options) => {
781 let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
782 cursor_point,
783 &active_snapshot,
784 &context_options.excerpt,
785 index_state.as_deref(),
786 ) else {
787 return Ok((None, None));
788 };
789
790 let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start)
791 ..active_snapshot.anchor_before(excerpt.range.end);
792
793 if let Some(buffer_ix) =
794 included_files.iter().position(|(_, snapshot, _, _)| {
795 snapshot.remote_id() == active_snapshot.remote_id()
796 })
797 {
798 let (_, buffer, _, ranges) = &mut included_files[buffer_ix];
799 let range_ix = ranges
800 .binary_search_by(|probe| {
801 probe
802 .start
803 .cmp(&excerpt_anchor_range.start, buffer)
804 .then(excerpt_anchor_range.end.cmp(&probe.end, buffer))
805 })
806 .unwrap_or_else(|ix| ix);
807
808 ranges.insert(range_ix, excerpt_anchor_range);
809 let last_ix = included_files.len() - 1;
810 included_files.swap(buffer_ix, last_ix);
811 } else {
812 included_files.push((
813 active_buffer.clone(),
814 active_snapshot,
815 excerpt_path.clone(),
816 vec![excerpt_anchor_range],
817 ));
818 }
819
820 let included_files = included_files
821 .iter()
822 .map(|(_, buffer, path, ranges)| {
823 let excerpts = merge_excerpts(
824 &buffer,
825 ranges.iter().map(|range| {
826 let point_range = range.to_point(&buffer);
827 Line(point_range.start.row)..Line(point_range.end.row)
828 }),
829 );
830 predict_edits_v3::IncludedFile {
831 path: path.clone(),
832 max_row: Line(buffer.max_point().row),
833 excerpts,
834 }
835 })
836 .collect::<Vec<_>>();
837
838 predict_edits_v3::PredictEditsRequest {
839 excerpt_path,
840 excerpt: String::new(),
841 excerpt_line_range: Line(0)..Line(0),
842 excerpt_range: 0..0,
843 cursor_point: predict_edits_v3::Point {
844 line: predict_edits_v3::Line(cursor_point.row),
845 column: cursor_point.column,
846 },
847 included_files,
848 referenced_declarations: vec![],
849 events,
850 can_collect_data,
851 diagnostic_groups,
852 diagnostic_groups_truncated,
853 debug_info: debug_tx.is_some(),
854 prompt_max_bytes: Some(options.max_prompt_bytes),
855 prompt_format: options.prompt_format,
856 // TODO [zeta2]
857 signatures: vec![],
858 excerpt_parent: None,
859 git_info: None,
860 }
861 }
862 ContextMode::Syntax(context_options) => {
863 let Some(context) = EditPredictionContext::gather_context(
864 cursor_point,
865 &active_snapshot,
866 parent_abs_path.as_deref(),
867 &context_options,
868 index_state.as_deref(),
869 ) else {
870 return Ok((None, None));
871 };
872
873 make_syntax_context_cloud_request(
874 excerpt_path,
875 context,
876 events,
877 can_collect_data,
878 diagnostic_groups,
879 diagnostic_groups_truncated,
880 None,
881 debug_tx.is_some(),
882 &worktree_snapshots,
883 index_state.as_deref(),
884 Some(options.max_prompt_bytes),
885 options.prompt_format,
886 )
887 }
888 };
889
890 let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request);
891
892 let retrieval_time = chrono::Utc::now() - before_retrieval;
893
894 let debug_response_tx = if let Some(debug_tx) = &debug_tx {
895 let (response_tx, response_rx) = oneshot::channel();
896
897 debug_tx
898 .unbounded_send(ZetaDebugInfo::EditPredictionRequested(
899 ZetaEditPredictionDebugInfo {
900 request: cloud_request.clone(),
901 retrieval_time,
902 buffer: active_buffer.downgrade(),
903 local_prompt: match prompt_result.as_ref() {
904 Ok((prompt, _)) => Ok(prompt.clone()),
905 Err(err) => Err(err.to_string()),
906 },
907 position,
908 response_rx,
909 },
910 ))
911 .ok();
912 Some(response_tx)
913 } else {
914 None
915 };
916
917 if cfg!(debug_assertions) && env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
918 if let Some(debug_response_tx) = debug_response_tx {
919 debug_response_tx
920 .send((Err("Request skipped".to_string()), TimeDelta::zero()))
921 .ok();
922 }
923 anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
924 }
925
926 let (prompt, _) = prompt_result?;
927 let request = open_ai::Request {
928 model: MODEL_ID.clone(),
929 messages: vec![open_ai::RequestMessage::User {
930 content: open_ai::MessageContent::Plain(prompt),
931 }],
932 stream: false,
933 max_completion_tokens: None,
934 stop: Default::default(),
935 temperature: 0.7,
936 tool_choice: None,
937 parallel_tool_calls: None,
938 tools: vec![],
939 prompt_cache_key: None,
940 reasoning_effort: None,
941 };
942
943 log::trace!("Sending edit prediction request");
944
945 let before_request = chrono::Utc::now();
946 let response = Self::send_raw_llm_request(
947 request,
948 client,
949 llm_token,
950 app_version,
951 #[cfg(feature = "llm-response-cache")]
952 llm_response_cache,
953 )
954 .await;
955 let request_time = chrono::Utc::now() - before_request;
956
957 log::trace!("Got edit prediction response");
958
959 if let Some(debug_response_tx) = debug_response_tx {
960 debug_response_tx
961 .send((
962 response
963 .as_ref()
964 .map_err(|err| err.to_string())
965 .map(|response| response.0.clone()),
966 request_time,
967 ))
968 .ok();
969 }
970
971 let (res, usage) = response?;
972 let request_id = EditPredictionId(res.id.clone().into());
973 let Some(mut output_text) = text_from_response(res) else {
974 return Ok((None, usage));
975 };
976
977 if output_text.contains(CURSOR_MARKER) {
978 log::trace!("Stripping out {CURSOR_MARKER} from response");
979 output_text = output_text.replace(CURSOR_MARKER, "");
980 }
981
982 let get_buffer_from_context = |path: &Path| {
983 included_files
984 .iter()
985 .find_map(|(_, buffer, probe_path, ranges)| {
986 if probe_path.as_ref() == path {
987 Some((buffer, ranges.as_slice()))
988 } else {
989 None
990 }
991 })
992 };
993
994 let (edited_buffer_snapshot, edits) = match options.prompt_format {
995 PromptFormat::NumLinesUniDiff => {
996 crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
997 }
998 PromptFormat::OldTextNewText => {
999 crate::xml_edits::parse_xml_edits(&output_text, get_buffer_from_context)
1000 .await?
1001 }
1002 _ => {
1003 bail!("unsupported prompt format {}", options.prompt_format)
1004 }
1005 };
1006
1007 let edited_buffer = included_files
1008 .iter()
1009 .find_map(|(buffer, snapshot, _, _)| {
1010 if snapshot.remote_id() == edited_buffer_snapshot.remote_id() {
1011 Some(buffer.clone())
1012 } else {
1013 None
1014 }
1015 })
1016 .context("Failed to find buffer in included_buffers")?;
1017
1018 anyhow::Ok((
1019 Some((
1020 request_id,
1021 edited_buffer,
1022 edited_buffer_snapshot.clone(),
1023 edits,
1024 )),
1025 usage,
1026 ))
1027 }
1028 });
1029
1030 cx.spawn({
1031 async move |this, cx| {
1032 let Some((id, edited_buffer, edited_buffer_snapshot, edits)) =
1033 Self::handle_api_response(&this, request_task.await, cx)?
1034 else {
1035 return Ok(None);
1036 };
1037
1038 // TODO telemetry: duration, etc
1039 Ok(
1040 EditPrediction::new(id, &edited_buffer, &edited_buffer_snapshot, edits, cx)
1041 .await,
1042 )
1043 }
1044 })
1045 }
1046
1047 async fn send_raw_llm_request(
1048 request: open_ai::Request,
1049 client: Arc<Client>,
1050 llm_token: LlmApiToken,
1051 app_version: SemanticVersion,
1052 #[cfg(feature = "llm-response-cache")] llm_response_cache: Option<
1053 Arc<dyn LlmResponseCache>,
1054 >,
1055 ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
1056 let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
1057 http_client::Url::parse(&predict_edits_url)?
1058 } else {
1059 client
1060 .http_client()
1061 .build_zed_llm_url("/predict_edits/raw", &[])?
1062 };
1063
1064 #[cfg(feature = "llm-response-cache")]
1065 let cache_key = if let Some(cache) = llm_response_cache {
1066 let request_json = serde_json::to_string(&request)?;
1067 let key = cache.get_key(&url, &request_json);
1068
1069 if let Some(response_str) = cache.read_response(key) {
1070 return Ok((serde_json::from_str(&response_str)?, None));
1071 }
1072
1073 Some((cache, key))
1074 } else {
1075 None
1076 };
1077
1078 let (response, usage) = Self::send_api_request(
1079 |builder| {
1080 let req = builder
1081 .uri(url.as_ref())
1082 .body(serde_json::to_string(&request)?.into());
1083 Ok(req?)
1084 },
1085 client,
1086 llm_token,
1087 app_version,
1088 )
1089 .await?;
1090
1091 #[cfg(feature = "llm-response-cache")]
1092 if let Some((cache, key)) = cache_key {
1093 cache.write_response(key, &serde_json::to_string(&response)?);
1094 }
1095
1096 Ok((response, usage))
1097 }
1098
1099 fn handle_api_response<T>(
1100 this: &WeakEntity<Self>,
1101 response: Result<(T, Option<EditPredictionUsage>)>,
1102 cx: &mut gpui::AsyncApp,
1103 ) -> Result<T> {
1104 match response {
1105 Ok((data, usage)) => {
1106 if let Some(usage) = usage {
1107 this.update(cx, |this, cx| {
1108 this.user_store.update(cx, |user_store, cx| {
1109 user_store.update_edit_prediction_usage(usage, cx);
1110 });
1111 })
1112 .ok();
1113 }
1114 Ok(data)
1115 }
1116 Err(err) => {
1117 if err.is::<ZedUpdateRequiredError>() {
1118 cx.update(|cx| {
1119 this.update(cx, |this, _cx| {
1120 this.update_required = true;
1121 })
1122 .ok();
1123
1124 let error_message: SharedString = err.to_string().into();
1125 show_app_notification(
1126 NotificationId::unique::<ZedUpdateRequiredError>(),
1127 cx,
1128 move |cx| {
1129 cx.new(|cx| {
1130 ErrorMessagePrompt::new(error_message.clone(), cx)
1131 .with_link_button("Update Zed", "https://zed.dev/releases")
1132 })
1133 },
1134 );
1135 })
1136 .ok();
1137 }
1138 Err(err)
1139 }
1140 }
1141 }
1142
1143 async fn send_api_request<Res>(
1144 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1145 client: Arc<Client>,
1146 llm_token: LlmApiToken,
1147 app_version: SemanticVersion,
1148 ) -> Result<(Res, Option<EditPredictionUsage>)>
1149 where
1150 Res: DeserializeOwned,
1151 {
1152 let http_client = client.http_client();
1153 let mut token = llm_token.acquire(&client).await?;
1154 let mut did_retry = false;
1155
1156 loop {
1157 let request_builder = http_client::Request::builder().method(Method::POST);
1158
1159 let request = build(
1160 request_builder
1161 .header("Content-Type", "application/json")
1162 .header("Authorization", format!("Bearer {}", token))
1163 .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
1164 )?;
1165
1166 let mut response = http_client.send(request).await?;
1167
1168 if let Some(minimum_required_version) = response
1169 .headers()
1170 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
1171 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
1172 {
1173 anyhow::ensure!(
1174 app_version >= minimum_required_version,
1175 ZedUpdateRequiredError {
1176 minimum_version: minimum_required_version
1177 }
1178 );
1179 }
1180
1181 if response.status().is_success() {
1182 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
1183
1184 let mut body = Vec::new();
1185 response.body_mut().read_to_end(&mut body).await?;
1186 return Ok((serde_json::from_slice(&body)?, usage));
1187 } else if !did_retry
1188 && response
1189 .headers()
1190 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1191 .is_some()
1192 {
1193 did_retry = true;
1194 token = llm_token.refresh(&client).await?;
1195 } else {
1196 let mut body = String::new();
1197 response.body_mut().read_to_string(&mut body).await?;
1198 anyhow::bail!(
1199 "Request failed with status: {:?}\nBody: {}",
1200 response.status(),
1201 body
1202 );
1203 }
1204 }
1205 }
1206
1207 pub const CONTEXT_RETRIEVAL_IDLE_DURATION: Duration = Duration::from_secs(10);
1208 pub const CONTEXT_RETRIEVAL_DEBOUNCE_DURATION: Duration = Duration::from_secs(3);
1209
1210 // Refresh the related excerpts when the user just beguns editing after
1211 // an idle period, and after they pause editing.
1212 fn refresh_context_if_needed(
1213 &mut self,
1214 project: &Entity<Project>,
1215 buffer: &Entity<language::Buffer>,
1216 cursor_position: language::Anchor,
1217 cx: &mut Context<Self>,
1218 ) {
1219 if !matches!(&self.options().context, ContextMode::Agentic { .. }) {
1220 return;
1221 }
1222
1223 let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1224 return;
1225 };
1226
1227 let now = Instant::now();
1228 let was_idle = zeta_project
1229 .refresh_context_timestamp
1230 .map_or(true, |timestamp| {
1231 now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION
1232 });
1233 zeta_project.refresh_context_timestamp = Some(now);
1234 zeta_project.refresh_context_debounce_task = Some(cx.spawn({
1235 let buffer = buffer.clone();
1236 let project = project.clone();
1237 async move |this, cx| {
1238 if was_idle {
1239 log::debug!("refetching edit prediction context after idle");
1240 } else {
1241 cx.background_executor()
1242 .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION)
1243 .await;
1244 log::debug!("refetching edit prediction context after pause");
1245 }
1246 this.update(cx, |this, cx| {
1247 let task = this.refresh_context(project.clone(), buffer, cursor_position, cx);
1248
1249 if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
1250 zeta_project.refresh_context_task = Some(task.log_err());
1251 };
1252 })
1253 .ok()
1254 }
1255 }));
1256 }
1257
1258 // Refresh the related excerpts asynchronously. Ensure the task runs to completion,
1259 // and avoid spawning more than one concurrent task.
1260 pub fn refresh_context(
1261 &mut self,
1262 project: Entity<Project>,
1263 buffer: Entity<language::Buffer>,
1264 cursor_position: language::Anchor,
1265 cx: &mut Context<Self>,
1266 ) -> Task<Result<()>> {
1267 let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
1268 return Task::ready(anyhow::Ok(()));
1269 };
1270
1271 let ContextMode::Agentic(options) = &self.options().context else {
1272 return Task::ready(anyhow::Ok(()));
1273 };
1274
1275 let snapshot = buffer.read(cx).snapshot();
1276 let cursor_point = cursor_position.to_point(&snapshot);
1277 let Some(cursor_excerpt) = EditPredictionExcerpt::select_from_buffer(
1278 cursor_point,
1279 &snapshot,
1280 &options.excerpt,
1281 None,
1282 ) else {
1283 return Task::ready(Ok(()));
1284 };
1285
1286 let app_version = AppVersion::global(cx);
1287 let client = self.client.clone();
1288 let llm_token = self.llm_token.clone();
1289 let debug_tx = self.debug_tx.clone();
1290 let current_file_path: Arc<Path> = snapshot
1291 .file()
1292 .map(|f| f.full_path(cx).into())
1293 .unwrap_or_else(|| Path::new("untitled").into());
1294
1295 let prompt = match cloud_zeta2_prompt::retrieval_prompt::build_prompt(
1296 predict_edits_v3::PlanContextRetrievalRequest {
1297 excerpt: cursor_excerpt.text(&snapshot).body,
1298 excerpt_path: current_file_path,
1299 excerpt_line_range: cursor_excerpt.line_range,
1300 cursor_file_max_row: Line(snapshot.max_point().row),
1301 events: zeta_project
1302 .events
1303 .iter()
1304 .filter_map(|ev| ev.to_request_event(cx))
1305 .collect(),
1306 },
1307 ) {
1308 Ok(prompt) => prompt,
1309 Err(err) => {
1310 return Task::ready(Err(err));
1311 }
1312 };
1313
1314 if let Some(debug_tx) = &debug_tx {
1315 debug_tx
1316 .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted(
1317 ZetaContextRetrievalStartedDebugInfo {
1318 project: project.clone(),
1319 timestamp: Instant::now(),
1320 search_prompt: prompt.clone(),
1321 },
1322 ))
1323 .ok();
1324 }
1325
1326 pub static TOOL_SCHEMA: LazyLock<(serde_json::Value, String)> = LazyLock::new(|| {
1327 let schema = language_model::tool_schema::root_schema_for::<SearchToolInput>(
1328 language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset,
1329 );
1330
1331 let description = schema
1332 .get("description")
1333 .and_then(|description| description.as_str())
1334 .unwrap()
1335 .to_string();
1336
1337 (schema.into(), description)
1338 });
1339
1340 let (tool_schema, tool_description) = TOOL_SCHEMA.clone();
1341
1342 let request = open_ai::Request {
1343 model: MODEL_ID.clone(),
1344 messages: vec![open_ai::RequestMessage::User {
1345 content: open_ai::MessageContent::Plain(prompt),
1346 }],
1347 stream: false,
1348 max_completion_tokens: None,
1349 stop: Default::default(),
1350 temperature: 0.7,
1351 tool_choice: None,
1352 parallel_tool_calls: None,
1353 tools: vec![open_ai::ToolDefinition::Function {
1354 function: FunctionDefinition {
1355 name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME.to_string(),
1356 description: Some(tool_description),
1357 parameters: Some(tool_schema),
1358 },
1359 }],
1360 prompt_cache_key: None,
1361 reasoning_effort: None,
1362 };
1363
1364 #[cfg(feature = "llm-response-cache")]
1365 let llm_response_cache = self.llm_response_cache.clone();
1366
1367 cx.spawn(async move |this, cx| {
1368 log::trace!("Sending search planning request");
1369 let response = Self::send_raw_llm_request(
1370 request,
1371 client,
1372 llm_token,
1373 app_version,
1374 #[cfg(feature = "llm-response-cache")]
1375 llm_response_cache,
1376 )
1377 .await;
1378 let mut response = Self::handle_api_response(&this, response, cx)?;
1379 log::trace!("Got search planning response");
1380
1381 let choice = response
1382 .choices
1383 .pop()
1384 .context("No choices in retrieval response")?;
1385 let open_ai::RequestMessage::Assistant {
1386 content: _,
1387 tool_calls,
1388 } = choice.message
1389 else {
1390 anyhow::bail!("Retrieval response didn't include an assistant message");
1391 };
1392
1393 let mut queries: Vec<SearchToolQuery> = Vec::new();
1394 for tool_call in tool_calls {
1395 let open_ai::ToolCallContent::Function { function } = tool_call.content;
1396 if function.name != cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME {
1397 log::warn!(
1398 "Context retrieval response tried to call an unknown tool: {}",
1399 function.name
1400 );
1401
1402 continue;
1403 }
1404
1405 let input: SearchToolInput = serde_json::from_str(&function.arguments)
1406 .with_context(|| format!("invalid search json {}", &function.arguments))?;
1407 queries.extend(input.queries);
1408 }
1409
1410 if let Some(debug_tx) = &debug_tx {
1411 debug_tx
1412 .unbounded_send(ZetaDebugInfo::SearchQueriesGenerated(
1413 ZetaSearchQueryDebugInfo {
1414 project: project.clone(),
1415 timestamp: Instant::now(),
1416 search_queries: queries.clone(),
1417 },
1418 ))
1419 .ok();
1420 }
1421
1422 log::trace!("Running retrieval search: {queries:#?}");
1423
1424 let related_excerpts_result =
1425 retrieval_search::run_retrieval_searches(project.clone(), queries, cx).await;
1426
1427 log::trace!("Search queries executed");
1428
1429 if let Some(debug_tx) = &debug_tx {
1430 debug_tx
1431 .unbounded_send(ZetaDebugInfo::SearchQueriesExecuted(
1432 ZetaContextRetrievalDebugInfo {
1433 project: project.clone(),
1434 timestamp: Instant::now(),
1435 },
1436 ))
1437 .ok();
1438 }
1439
1440 this.update(cx, |this, _cx| {
1441 let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
1442 return Ok(());
1443 };
1444 zeta_project.refresh_context_task.take();
1445 if let Some(debug_tx) = &this.debug_tx {
1446 debug_tx
1447 .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished(
1448 ZetaContextRetrievalDebugInfo {
1449 project,
1450 timestamp: Instant::now(),
1451 },
1452 ))
1453 .ok();
1454 }
1455 match related_excerpts_result {
1456 Ok(excerpts) => {
1457 zeta_project.context = Some(excerpts);
1458 Ok(())
1459 }
1460 Err(error) => Err(error),
1461 }
1462 })?
1463 })
1464 }
1465
1466 pub fn set_context(
1467 &mut self,
1468 project: Entity<Project>,
1469 context: HashMap<Entity<Buffer>, Vec<Range<Anchor>>>,
1470 ) {
1471 if let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) {
1472 zeta_project.context = Some(context);
1473 }
1474 }
1475
1476 fn gather_nearby_diagnostics(
1477 cursor_offset: usize,
1478 diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
1479 snapshot: &BufferSnapshot,
1480 max_diagnostics_bytes: usize,
1481 ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
1482 // TODO: Could make this more efficient
1483 let mut diagnostic_groups = Vec::new();
1484 for (language_server_id, diagnostics) in diagnostic_sets {
1485 let mut groups = Vec::new();
1486 diagnostics.groups(*language_server_id, &mut groups, &snapshot);
1487 diagnostic_groups.extend(
1488 groups
1489 .into_iter()
1490 .map(|(_, group)| group.resolve::<usize>(&snapshot)),
1491 );
1492 }
1493
1494 // sort by proximity to cursor
1495 diagnostic_groups.sort_by_key(|group| {
1496 let range = &group.entries[group.primary_ix].range;
1497 if range.start >= cursor_offset {
1498 range.start - cursor_offset
1499 } else if cursor_offset >= range.end {
1500 cursor_offset - range.end
1501 } else {
1502 (cursor_offset - range.start).min(range.end - cursor_offset)
1503 }
1504 });
1505
1506 let mut results = Vec::new();
1507 let mut diagnostic_groups_truncated = false;
1508 let mut diagnostics_byte_count = 0;
1509 for group in diagnostic_groups {
1510 let raw_value = serde_json::value::to_raw_value(&group).unwrap();
1511 diagnostics_byte_count += raw_value.get().len();
1512 if diagnostics_byte_count > max_diagnostics_bytes {
1513 diagnostic_groups_truncated = true;
1514 break;
1515 }
1516 results.push(predict_edits_v3::DiagnosticGroup(raw_value));
1517 }
1518
1519 (results, diagnostic_groups_truncated)
1520 }
1521
1522 // TODO: Dedupe with similar code in request_prediction?
1523 pub fn cloud_request_for_zeta_cli(
1524 &mut self,
1525 project: &Entity<Project>,
1526 buffer: &Entity<Buffer>,
1527 position: language::Anchor,
1528 cx: &mut Context<Self>,
1529 ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
1530 let project_state = self.projects.get(&project.entity_id());
1531
1532 let index_state = project_state.map(|state| {
1533 state
1534 .syntax_index
1535 .read_with(cx, |index, _cx| index.state().clone())
1536 });
1537 let options = self.options.clone();
1538 let snapshot = buffer.read(cx).snapshot();
1539 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
1540 return Task::ready(Err(anyhow!("No file path for excerpt")));
1541 };
1542 let worktree_snapshots = project
1543 .read(cx)
1544 .worktrees(cx)
1545 .map(|worktree| worktree.read(cx).snapshot())
1546 .collect::<Vec<_>>();
1547
1548 let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
1549 let mut path = f.worktree.read(cx).absolutize(&f.path);
1550 if path.pop() { Some(path) } else { None }
1551 });
1552
1553 cx.background_spawn(async move {
1554 let index_state = if let Some(index_state) = index_state {
1555 Some(index_state.lock_owned().await)
1556 } else {
1557 None
1558 };
1559
1560 let cursor_point = position.to_point(&snapshot);
1561
1562 let debug_info = true;
1563 EditPredictionContext::gather_context(
1564 cursor_point,
1565 &snapshot,
1566 parent_abs_path.as_deref(),
1567 match &options.context {
1568 ContextMode::Agentic(_) => {
1569 // TODO
1570 panic!("Llm mode not supported in zeta cli yet");
1571 }
1572 ContextMode::Syntax(edit_prediction_context_options) => {
1573 edit_prediction_context_options
1574 }
1575 },
1576 index_state.as_deref(),
1577 )
1578 .context("Failed to select excerpt")
1579 .map(|context| {
1580 make_syntax_context_cloud_request(
1581 excerpt_path.into(),
1582 context,
1583 // TODO pass everything
1584 Vec::new(),
1585 false,
1586 Vec::new(),
1587 false,
1588 None,
1589 debug_info,
1590 &worktree_snapshots,
1591 index_state.as_deref(),
1592 Some(options.max_prompt_bytes),
1593 options.prompt_format,
1594 )
1595 })
1596 })
1597 }
1598
1599 pub fn wait_for_initial_indexing(
1600 &mut self,
1601 project: &Entity<Project>,
1602 cx: &mut App,
1603 ) -> Task<Result<()>> {
1604 let zeta_project = self.get_or_init_zeta_project(project, cx);
1605 zeta_project
1606 .syntax_index
1607 .read(cx)
1608 .wait_for_initial_file_indexing(cx)
1609 }
1610}
1611
1612pub fn text_from_response(mut res: open_ai::Response) -> Option<String> {
1613 let choice = res.choices.pop()?;
1614 let output_text = match choice.message {
1615 open_ai::RequestMessage::Assistant {
1616 content: Some(open_ai::MessageContent::Plain(content)),
1617 ..
1618 } => content,
1619 open_ai::RequestMessage::Assistant {
1620 content: Some(open_ai::MessageContent::Multipart(mut content)),
1621 ..
1622 } => {
1623 if content.is_empty() {
1624 log::error!("No output from Baseten completion response");
1625 return None;
1626 }
1627
1628 match content.remove(0) {
1629 open_ai::MessagePart::Text { text } => text,
1630 open_ai::MessagePart::Image { .. } => {
1631 log::error!("Expected text, got an image");
1632 return None;
1633 }
1634 }
1635 }
1636 _ => {
1637 log::error!("Invalid response message: {:?}", choice.message);
1638 return None;
1639 }
1640 };
1641 Some(output_text)
1642}
1643
1644#[derive(Error, Debug)]
1645#[error(
1646 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1647)]
1648pub struct ZedUpdateRequiredError {
1649 minimum_version: SemanticVersion,
1650}
1651
1652fn make_syntax_context_cloud_request(
1653 excerpt_path: Arc<Path>,
1654 context: EditPredictionContext,
1655 events: Vec<predict_edits_v3::Event>,
1656 can_collect_data: bool,
1657 diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
1658 diagnostic_groups_truncated: bool,
1659 git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
1660 debug_info: bool,
1661 worktrees: &Vec<worktree::Snapshot>,
1662 index_state: Option<&SyntaxIndexState>,
1663 prompt_max_bytes: Option<usize>,
1664 prompt_format: PromptFormat,
1665) -> predict_edits_v3::PredictEditsRequest {
1666 let mut signatures = Vec::new();
1667 let mut declaration_to_signature_index = HashMap::default();
1668 let mut referenced_declarations = Vec::new();
1669
1670 for snippet in context.declarations {
1671 let project_entry_id = snippet.declaration.project_entry_id();
1672 let Some(path) = worktrees.iter().find_map(|worktree| {
1673 worktree.entry_for_id(project_entry_id).map(|entry| {
1674 let mut full_path = RelPathBuf::new();
1675 full_path.push(worktree.root_name());
1676 full_path.push(&entry.path);
1677 full_path
1678 })
1679 }) else {
1680 continue;
1681 };
1682
1683 let parent_index = index_state.and_then(|index_state| {
1684 snippet.declaration.parent().and_then(|parent| {
1685 add_signature(
1686 parent,
1687 &mut declaration_to_signature_index,
1688 &mut signatures,
1689 index_state,
1690 )
1691 })
1692 });
1693
1694 let (text, text_is_truncated) = snippet.declaration.item_text();
1695 referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
1696 path: path.as_std_path().into(),
1697 text: text.into(),
1698 range: snippet.declaration.item_line_range(),
1699 text_is_truncated,
1700 signature_range: snippet.declaration.signature_range_in_item_text(),
1701 parent_index,
1702 signature_score: snippet.score(DeclarationStyle::Signature),
1703 declaration_score: snippet.score(DeclarationStyle::Declaration),
1704 score_components: snippet.components,
1705 });
1706 }
1707
1708 let excerpt_parent = index_state.and_then(|index_state| {
1709 context
1710 .excerpt
1711 .parent_declarations
1712 .last()
1713 .and_then(|(parent, _)| {
1714 add_signature(
1715 *parent,
1716 &mut declaration_to_signature_index,
1717 &mut signatures,
1718 index_state,
1719 )
1720 })
1721 });
1722
1723 predict_edits_v3::PredictEditsRequest {
1724 excerpt_path,
1725 excerpt: context.excerpt_text.body,
1726 excerpt_line_range: context.excerpt.line_range,
1727 excerpt_range: context.excerpt.range,
1728 cursor_point: predict_edits_v3::Point {
1729 line: predict_edits_v3::Line(context.cursor_point.row),
1730 column: context.cursor_point.column,
1731 },
1732 referenced_declarations,
1733 included_files: vec![],
1734 signatures,
1735 excerpt_parent,
1736 events,
1737 can_collect_data,
1738 diagnostic_groups,
1739 diagnostic_groups_truncated,
1740 git_info,
1741 debug_info,
1742 prompt_max_bytes,
1743 prompt_format,
1744 }
1745}
1746
1747fn add_signature(
1748 declaration_id: DeclarationId,
1749 declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
1750 signatures: &mut Vec<Signature>,
1751 index: &SyntaxIndexState,
1752) -> Option<usize> {
1753 if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
1754 return Some(*signature_index);
1755 }
1756 let Some(parent_declaration) = index.declaration(declaration_id) else {
1757 log::error!("bug: missing parent declaration");
1758 return None;
1759 };
1760 let parent_index = parent_declaration.parent().and_then(|parent| {
1761 add_signature(parent, declaration_to_signature_index, signatures, index)
1762 });
1763 let (text, text_is_truncated) = parent_declaration.signature_text();
1764 let signature_index = signatures.len();
1765 signatures.push(Signature {
1766 text: text.into(),
1767 text_is_truncated,
1768 parent_index,
1769 range: parent_declaration.signature_line_range(),
1770 });
1771 declaration_to_signature_index.insert(declaration_id, signature_index);
1772 Some(signature_index)
1773}
1774
1775#[cfg(test)]
1776mod tests {
1777 use std::{path::Path, sync::Arc};
1778
1779 use client::UserStore;
1780 use clock::FakeSystemClock;
1781 use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
1782 use futures::{
1783 AsyncReadExt, StreamExt,
1784 channel::{mpsc, oneshot},
1785 };
1786 use gpui::{
1787 Entity, TestAppContext,
1788 http_client::{FakeHttpClient, Response},
1789 prelude::*,
1790 };
1791 use indoc::indoc;
1792 use language::OffsetRangeExt as _;
1793 use open_ai::Usage;
1794 use pretty_assertions::{assert_eq, assert_matches};
1795 use project::{FakeFs, Project};
1796 use serde_json::json;
1797 use settings::SettingsStore;
1798 use util::path;
1799 use uuid::Uuid;
1800
1801 use crate::{BufferEditPrediction, Zeta};
1802
1803 #[gpui::test]
1804 async fn test_current_state(cx: &mut TestAppContext) {
1805 let (zeta, mut req_rx) = init_test(cx);
1806 let fs = FakeFs::new(cx.executor());
1807 fs.insert_tree(
1808 "/root",
1809 json!({
1810 "1.txt": "Hello!\nHow\nBye\n",
1811 "2.txt": "Hola!\nComo\nAdios\n"
1812 }),
1813 )
1814 .await;
1815 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1816
1817 zeta.update(cx, |zeta, cx| {
1818 zeta.register_project(&project, cx);
1819 });
1820
1821 let buffer1 = project
1822 .update(cx, |project, cx| {
1823 let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
1824 project.open_buffer(path, cx)
1825 })
1826 .await
1827 .unwrap();
1828 let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
1829 let position = snapshot1.anchor_before(language::Point::new(1, 3));
1830
1831 // Prediction for current file
1832
1833 let prediction_task = zeta.update(cx, |zeta, cx| {
1834 zeta.refresh_prediction(&project, &buffer1, position, cx)
1835 });
1836 let (_request, respond_tx) = req_rx.next().await.unwrap();
1837
1838 respond_tx
1839 .send(model_response(indoc! {r"
1840 --- a/root/1.txt
1841 +++ b/root/1.txt
1842 @@ ... @@
1843 Hello!
1844 -How
1845 +How are you?
1846 Bye
1847 "}))
1848 .unwrap();
1849 prediction_task.await.unwrap();
1850
1851 zeta.read_with(cx, |zeta, cx| {
1852 let prediction = zeta
1853 .current_prediction_for_buffer(&buffer1, &project, cx)
1854 .unwrap();
1855 assert_matches!(prediction, BufferEditPrediction::Local { .. });
1856 });
1857
1858 // Context refresh
1859 let refresh_task = zeta.update(cx, |zeta, cx| {
1860 zeta.refresh_context(project.clone(), buffer1.clone(), position, cx)
1861 });
1862 let (_request, respond_tx) = req_rx.next().await.unwrap();
1863 respond_tx
1864 .send(open_ai::Response {
1865 id: Uuid::new_v4().to_string(),
1866 object: "response".into(),
1867 created: 0,
1868 model: "model".into(),
1869 choices: vec![open_ai::Choice {
1870 index: 0,
1871 message: open_ai::RequestMessage::Assistant {
1872 content: None,
1873 tool_calls: vec![open_ai::ToolCall {
1874 id: "search".into(),
1875 content: open_ai::ToolCallContent::Function {
1876 function: open_ai::FunctionContent {
1877 name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME
1878 .to_string(),
1879 arguments: serde_json::to_string(&SearchToolInput {
1880 queries: Box::new([SearchToolQuery {
1881 glob: "root/2.txt".to_string(),
1882 syntax_node: vec![],
1883 content: Some(".".into()),
1884 }]),
1885 })
1886 .unwrap(),
1887 },
1888 },
1889 }],
1890 },
1891 finish_reason: None,
1892 }],
1893 usage: Usage {
1894 prompt_tokens: 0,
1895 completion_tokens: 0,
1896 total_tokens: 0,
1897 },
1898 })
1899 .unwrap();
1900 refresh_task.await.unwrap();
1901
1902 zeta.update(cx, |zeta, _cx| {
1903 zeta.discard_current_prediction(&project);
1904 });
1905
1906 // Prediction for another file
1907 let prediction_task = zeta.update(cx, |zeta, cx| {
1908 zeta.refresh_prediction(&project, &buffer1, position, cx)
1909 });
1910 let (_request, respond_tx) = req_rx.next().await.unwrap();
1911 respond_tx
1912 .send(model_response(indoc! {r#"
1913 --- a/root/2.txt
1914 +++ b/root/2.txt
1915 Hola!
1916 -Como
1917 +Como estas?
1918 Adios
1919 "#}))
1920 .unwrap();
1921 prediction_task.await.unwrap();
1922 zeta.read_with(cx, |zeta, cx| {
1923 let prediction = zeta
1924 .current_prediction_for_buffer(&buffer1, &project, cx)
1925 .unwrap();
1926 assert_matches!(
1927 prediction,
1928 BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
1929 );
1930 });
1931
1932 let buffer2 = project
1933 .update(cx, |project, cx| {
1934 let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
1935 project.open_buffer(path, cx)
1936 })
1937 .await
1938 .unwrap();
1939
1940 zeta.read_with(cx, |zeta, cx| {
1941 let prediction = zeta
1942 .current_prediction_for_buffer(&buffer2, &project, cx)
1943 .unwrap();
1944 assert_matches!(prediction, BufferEditPrediction::Local { .. });
1945 });
1946 }
1947
1948 #[gpui::test]
1949 async fn test_simple_request(cx: &mut TestAppContext) {
1950 let (zeta, mut req_rx) = init_test(cx);
1951 let fs = FakeFs::new(cx.executor());
1952 fs.insert_tree(
1953 "/root",
1954 json!({
1955 "foo.md": "Hello!\nHow\nBye\n"
1956 }),
1957 )
1958 .await;
1959 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1960
1961 let buffer = project
1962 .update(cx, |project, cx| {
1963 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1964 project.open_buffer(path, cx)
1965 })
1966 .await
1967 .unwrap();
1968 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1969 let position = snapshot.anchor_before(language::Point::new(1, 3));
1970
1971 let prediction_task = zeta.update(cx, |zeta, cx| {
1972 zeta.request_prediction(&project, &buffer, position, cx)
1973 });
1974
1975 let (_, respond_tx) = req_rx.next().await.unwrap();
1976
1977 // TODO Put back when we have a structured request again
1978 // assert_eq!(
1979 // request.excerpt_path.as_ref(),
1980 // Path::new(path!("root/foo.md"))
1981 // );
1982 // assert_eq!(
1983 // request.cursor_point,
1984 // Point {
1985 // line: Line(1),
1986 // column: 3
1987 // }
1988 // );
1989
1990 respond_tx
1991 .send(model_response(indoc! { r"
1992 --- a/root/foo.md
1993 +++ b/root/foo.md
1994 @@ ... @@
1995 Hello!
1996 -How
1997 +How are you?
1998 Bye
1999 "}))
2000 .unwrap();
2001
2002 let prediction = prediction_task.await.unwrap().unwrap();
2003
2004 assert_eq!(prediction.edits.len(), 1);
2005 assert_eq!(
2006 prediction.edits[0].0.to_point(&snapshot).start,
2007 language::Point::new(1, 3)
2008 );
2009 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
2010 }
2011
2012 #[gpui::test]
2013 async fn test_request_events(cx: &mut TestAppContext) {
2014 let (zeta, mut req_rx) = init_test(cx);
2015 let fs = FakeFs::new(cx.executor());
2016 fs.insert_tree(
2017 "/root",
2018 json!({
2019 "foo.md": "Hello!\n\nBye\n"
2020 }),
2021 )
2022 .await;
2023 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2024
2025 let buffer = project
2026 .update(cx, |project, cx| {
2027 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2028 project.open_buffer(path, cx)
2029 })
2030 .await
2031 .unwrap();
2032
2033 zeta.update(cx, |zeta, cx| {
2034 zeta.register_buffer(&buffer, &project, cx);
2035 });
2036
2037 buffer.update(cx, |buffer, cx| {
2038 buffer.edit(vec![(7..7, "How")], None, cx);
2039 });
2040
2041 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2042 let position = snapshot.anchor_before(language::Point::new(1, 3));
2043
2044 let prediction_task = zeta.update(cx, |zeta, cx| {
2045 zeta.request_prediction(&project, &buffer, position, cx)
2046 });
2047
2048 let (request, respond_tx) = req_rx.next().await.unwrap();
2049
2050 let prompt = prompt_from_request(&request);
2051 assert!(
2052 prompt.contains(indoc! {"
2053 --- a/root/foo.md
2054 +++ b/root/foo.md
2055 @@ -1,3 +1,3 @@
2056 Hello!
2057 -
2058 +How
2059 Bye
2060 "}),
2061 "{prompt}"
2062 );
2063
2064 respond_tx
2065 .send(model_response(indoc! {r#"
2066 --- a/root/foo.md
2067 +++ b/root/foo.md
2068 @@ ... @@
2069 Hello!
2070 -How
2071 +How are you?
2072 Bye
2073 "#}))
2074 .unwrap();
2075
2076 let prediction = prediction_task.await.unwrap().unwrap();
2077
2078 assert_eq!(prediction.edits.len(), 1);
2079 assert_eq!(
2080 prediction.edits[0].0.to_point(&snapshot).start,
2081 language::Point::new(1, 3)
2082 );
2083 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
2084 }
2085
2086 // Skipped until we start including diagnostics in prompt
2087 // #[gpui::test]
2088 // async fn test_request_diagnostics(cx: &mut TestAppContext) {
2089 // let (zeta, mut req_rx) = init_test(cx);
2090 // let fs = FakeFs::new(cx.executor());
2091 // fs.insert_tree(
2092 // "/root",
2093 // json!({
2094 // "foo.md": "Hello!\nBye"
2095 // }),
2096 // )
2097 // .await;
2098 // let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2099
2100 // let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
2101 // let diagnostic = lsp::Diagnostic {
2102 // range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
2103 // severity: Some(lsp::DiagnosticSeverity::ERROR),
2104 // message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
2105 // ..Default::default()
2106 // };
2107
2108 // project.update(cx, |project, cx| {
2109 // project.lsp_store().update(cx, |lsp_store, cx| {
2110 // // Create some diagnostics
2111 // lsp_store
2112 // .update_diagnostics(
2113 // LanguageServerId(0),
2114 // lsp::PublishDiagnosticsParams {
2115 // uri: path_to_buffer_uri.clone(),
2116 // diagnostics: vec![diagnostic],
2117 // version: None,
2118 // },
2119 // None,
2120 // language::DiagnosticSourceKind::Pushed,
2121 // &[],
2122 // cx,
2123 // )
2124 // .unwrap();
2125 // });
2126 // });
2127
2128 // let buffer = project
2129 // .update(cx, |project, cx| {
2130 // let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2131 // project.open_buffer(path, cx)
2132 // })
2133 // .await
2134 // .unwrap();
2135
2136 // let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2137 // let position = snapshot.anchor_before(language::Point::new(0, 0));
2138
2139 // let _prediction_task = zeta.update(cx, |zeta, cx| {
2140 // zeta.request_prediction(&project, &buffer, position, cx)
2141 // });
2142
2143 // let (request, _respond_tx) = req_rx.next().await.unwrap();
2144
2145 // assert_eq!(request.diagnostic_groups.len(), 1);
2146 // let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
2147 // .unwrap();
2148 // // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
2149 // assert_eq!(
2150 // value,
2151 // json!({
2152 // "entries": [{
2153 // "range": {
2154 // "start": 8,
2155 // "end": 10
2156 // },
2157 // "diagnostic": {
2158 // "source": null,
2159 // "code": null,
2160 // "code_description": null,
2161 // "severity": 1,
2162 // "message": "\"Hello\" deprecated. Use \"Hi\" instead",
2163 // "markdown": null,
2164 // "group_id": 0,
2165 // "is_primary": true,
2166 // "is_disk_based": false,
2167 // "is_unnecessary": false,
2168 // "source_kind": "Pushed",
2169 // "data": null,
2170 // "underline": true
2171 // }
2172 // }],
2173 // "primary_ix": 0
2174 // })
2175 // );
2176 // }
2177
2178 fn model_response(text: &str) -> open_ai::Response {
2179 open_ai::Response {
2180 id: Uuid::new_v4().to_string(),
2181 object: "response".into(),
2182 created: 0,
2183 model: "model".into(),
2184 choices: vec![open_ai::Choice {
2185 index: 0,
2186 message: open_ai::RequestMessage::Assistant {
2187 content: Some(open_ai::MessageContent::Plain(text.to_string())),
2188 tool_calls: vec![],
2189 },
2190 finish_reason: None,
2191 }],
2192 usage: Usage {
2193 prompt_tokens: 0,
2194 completion_tokens: 0,
2195 total_tokens: 0,
2196 },
2197 }
2198 }
2199
2200 fn prompt_from_request(request: &open_ai::Request) -> &str {
2201 assert_eq!(request.messages.len(), 1);
2202 let open_ai::RequestMessage::User {
2203 content: open_ai::MessageContent::Plain(content),
2204 ..
2205 } = &request.messages[0]
2206 else {
2207 panic!(
2208 "Request does not have single user message of type Plain. {:#?}",
2209 request
2210 );
2211 };
2212 content
2213 }
2214
2215 fn init_test(
2216 cx: &mut TestAppContext,
2217 ) -> (
2218 Entity<Zeta>,
2219 mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender<open_ai::Response>)>,
2220 ) {
2221 cx.update(move |cx| {
2222 let settings_store = SettingsStore::test(cx);
2223 cx.set_global(settings_store);
2224 zlog::init_test();
2225
2226 let (req_tx, req_rx) = mpsc::unbounded();
2227
2228 let http_client = FakeHttpClient::create({
2229 move |req| {
2230 let uri = req.uri().path().to_string();
2231 let mut body = req.into_body();
2232 let req_tx = req_tx.clone();
2233 async move {
2234 let resp = match uri.as_str() {
2235 "/client/llm_tokens" => serde_json::to_string(&json!({
2236 "token": "test"
2237 }))
2238 .unwrap(),
2239 "/predict_edits/raw" => {
2240 let mut buf = Vec::new();
2241 body.read_to_end(&mut buf).await.ok();
2242 let req = serde_json::from_slice(&buf).unwrap();
2243
2244 let (res_tx, res_rx) = oneshot::channel();
2245 req_tx.unbounded_send((req, res_tx)).unwrap();
2246 serde_json::to_string(&res_rx.await?).unwrap()
2247 }
2248 _ => {
2249 panic!("Unexpected path: {}", uri)
2250 }
2251 };
2252
2253 Ok(Response::builder().body(resp.into()).unwrap())
2254 }
2255 }
2256 });
2257
2258 let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
2259 client.cloud_client().set_credentials(1, "test".into());
2260
2261 language_model::init(client.clone(), cx);
2262
2263 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2264 let zeta = Zeta::global(&client, &user_store, cx);
2265
2266 (zeta, req_rx)
2267 })
2268 }
2269}