1use anyhow::{Context as _, Result, anyhow, bail};
2use chrono::TimeDelta;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
5use cloud_llm_client::{
6 AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
7 ZED_VERSION_HEADER_NAME,
8};
9use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
10use cloud_zeta2_prompt::{CURSOR_MARKER, DEFAULT_MAX_PROMPT_BYTES};
11use collections::HashMap;
12use edit_prediction_context::{
13 DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
14 EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionScoreOptions, Line,
15 SyntaxIndex, SyntaxIndexState,
16};
17use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
18use futures::AsyncReadExt as _;
19use futures::channel::{mpsc, oneshot};
20use gpui::http_client::{AsyncBody, Method};
21use gpui::{
22 App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
23 http_client, prelude::*,
24};
25use language::{Anchor, Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
26use language::{BufferSnapshot, OffsetRangeExt};
27use language_model::{LlmApiToken, RefreshLlmTokenListener};
28use open_ai::FunctionDefinition;
29use project::Project;
30use release_channel::AppVersion;
31use serde::de::DeserializeOwned;
32use std::collections::{VecDeque, hash_map};
33
34use std::env;
35use std::ops::Range;
36use std::path::Path;
37use std::str::FromStr as _;
38use std::sync::{Arc, LazyLock};
39use std::time::{Duration, Instant};
40use thiserror::Error;
41use util::rel_path::RelPathBuf;
42use util::{LogErrorFuture, TryFutureExt};
43use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
44
45pub mod merge_excerpts;
46mod prediction;
47mod provider;
48pub mod retrieval_search;
49pub mod udiff;
50mod xml_edits;
51
52use crate::merge_excerpts::merge_excerpts;
53use crate::prediction::EditPrediction;
54pub use crate::prediction::EditPredictionId;
55pub use provider::ZetaEditPredictionProvider;
56
57/// Maximum number of events to track.
58const MAX_EVENT_COUNT: usize = 16;
59
60pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
61 max_bytes: 512,
62 min_bytes: 128,
63 target_before_cursor_over_total_bytes: 0.5,
64};
65
66pub const DEFAULT_CONTEXT_OPTIONS: ContextMode =
67 ContextMode::Agentic(DEFAULT_AGENTIC_CONTEXT_OPTIONS);
68
69pub const DEFAULT_AGENTIC_CONTEXT_OPTIONS: AgenticContextOptions = AgenticContextOptions {
70 excerpt: DEFAULT_EXCERPT_OPTIONS,
71};
72
73pub const DEFAULT_SYNTAX_CONTEXT_OPTIONS: EditPredictionContextOptions =
74 EditPredictionContextOptions {
75 use_imports: true,
76 max_retrieved_declarations: 0,
77 excerpt: DEFAULT_EXCERPT_OPTIONS,
78 score: EditPredictionScoreOptions {
79 omit_excerpt_overlaps: true,
80 },
81 };
82
83pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
84 context: DEFAULT_CONTEXT_OPTIONS,
85 max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
86 max_diagnostic_bytes: 2048,
87 prompt_format: PromptFormat::DEFAULT,
88 file_indexing_parallelism: 1,
89 buffer_change_grouping_interval: Duration::from_secs(1),
90};
91
92static USE_OLLAMA: LazyLock<bool> =
93 LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
94static MODEL_ID: LazyLock<String> = LazyLock::new(|| {
95 env::var("ZED_ZETA2_MODEL").unwrap_or(if *USE_OLLAMA {
96 "qwen3-coder:30b".to_string()
97 } else {
98 "yqvev8r3".to_string()
99 })
100});
101static PREDICT_EDITS_URL: LazyLock<Option<String>> = LazyLock::new(|| {
102 env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| {
103 if *USE_OLLAMA {
104 Some("http://localhost:11434/v1/chat/completions".into())
105 } else {
106 None
107 }
108 })
109});
110
111pub struct Zeta2FeatureFlag;
112
113impl FeatureFlag for Zeta2FeatureFlag {
114 const NAME: &'static str = "zeta2";
115
116 fn enabled_for_staff() -> bool {
117 false
118 }
119}
120
121#[derive(Clone)]
122struct ZetaGlobal(Entity<Zeta>);
123
124impl Global for ZetaGlobal {}
125
126pub struct Zeta {
127 client: Arc<Client>,
128 user_store: Entity<UserStore>,
129 llm_token: LlmApiToken,
130 _llm_token_subscription: Subscription,
131 projects: HashMap<EntityId, ZetaProject>,
132 options: ZetaOptions,
133 update_required: bool,
134 debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
135 #[cfg(feature = "eval-support")]
136 eval_cache: Option<Arc<dyn EvalCache>>,
137}
138
139#[derive(Debug, Clone, PartialEq)]
140pub struct ZetaOptions {
141 pub context: ContextMode,
142 pub max_prompt_bytes: usize,
143 pub max_diagnostic_bytes: usize,
144 pub prompt_format: predict_edits_v3::PromptFormat,
145 pub file_indexing_parallelism: usize,
146 pub buffer_change_grouping_interval: Duration,
147}
148
149#[derive(Debug, Clone, PartialEq)]
150pub enum ContextMode {
151 Agentic(AgenticContextOptions),
152 Syntax(EditPredictionContextOptions),
153}
154
155#[derive(Debug, Clone, PartialEq)]
156pub struct AgenticContextOptions {
157 pub excerpt: EditPredictionExcerptOptions,
158}
159
160impl ContextMode {
161 pub fn excerpt(&self) -> &EditPredictionExcerptOptions {
162 match self {
163 ContextMode::Agentic(options) => &options.excerpt,
164 ContextMode::Syntax(options) => &options.excerpt,
165 }
166 }
167}
168
169#[derive(Debug)]
170pub enum ZetaDebugInfo {
171 ContextRetrievalStarted(ZetaContextRetrievalStartedDebugInfo),
172 SearchQueriesGenerated(ZetaSearchQueryDebugInfo),
173 SearchQueriesExecuted(ZetaContextRetrievalDebugInfo),
174 ContextRetrievalFinished(ZetaContextRetrievalDebugInfo),
175 EditPredictionRequested(ZetaEditPredictionDebugInfo),
176}
177
178#[derive(Debug)]
179pub struct ZetaContextRetrievalStartedDebugInfo {
180 pub project: Entity<Project>,
181 pub timestamp: Instant,
182 pub search_prompt: String,
183}
184
185#[derive(Debug)]
186pub struct ZetaContextRetrievalDebugInfo {
187 pub project: Entity<Project>,
188 pub timestamp: Instant,
189}
190
191#[derive(Debug)]
192pub struct ZetaEditPredictionDebugInfo {
193 pub request: predict_edits_v3::PredictEditsRequest,
194 pub retrieval_time: TimeDelta,
195 pub buffer: WeakEntity<Buffer>,
196 pub position: language::Anchor,
197 pub local_prompt: Result<String, String>,
198 pub response_rx: oneshot::Receiver<(Result<open_ai::Response, String>, TimeDelta)>,
199}
200
201#[derive(Debug)]
202pub struct ZetaSearchQueryDebugInfo {
203 pub project: Entity<Project>,
204 pub timestamp: Instant,
205 pub search_queries: Vec<SearchToolQuery>,
206}
207
208pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
209
210struct ZetaProject {
211 syntax_index: Entity<SyntaxIndex>,
212 events: VecDeque<Event>,
213 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
214 current_prediction: Option<CurrentEditPrediction>,
215 context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
216 refresh_context_task: Option<LogErrorFuture<Task<Result<()>>>>,
217 refresh_context_debounce_task: Option<Task<Option<()>>>,
218 refresh_context_timestamp: Option<Instant>,
219}
220
221#[derive(Debug, Clone)]
222struct CurrentEditPrediction {
223 pub requested_by_buffer_id: EntityId,
224 pub prediction: EditPrediction,
225}
226
227impl CurrentEditPrediction {
228 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
229 let Some(new_edits) = self
230 .prediction
231 .interpolate(&self.prediction.buffer.read(cx))
232 else {
233 return false;
234 };
235
236 if self.prediction.buffer != old_prediction.prediction.buffer {
237 return true;
238 }
239
240 let Some(old_edits) = old_prediction
241 .prediction
242 .interpolate(&old_prediction.prediction.buffer.read(cx))
243 else {
244 return true;
245 };
246
247 // This reduces the occurrence of UI thrash from replacing edits
248 //
249 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
250 if self.requested_by_buffer_id == self.prediction.buffer.entity_id()
251 && self.requested_by_buffer_id == old_prediction.prediction.buffer.entity_id()
252 && old_edits.len() == 1
253 && new_edits.len() == 1
254 {
255 let (old_range, old_text) = &old_edits[0];
256 let (new_range, new_text) = &new_edits[0];
257 new_range == old_range && new_text.starts_with(old_text.as_ref())
258 } else {
259 true
260 }
261 }
262}
263
264/// A prediction from the perspective of a buffer.
265#[derive(Debug)]
266enum BufferEditPrediction<'a> {
267 Local { prediction: &'a EditPrediction },
268 Jump { prediction: &'a EditPrediction },
269}
270
271struct RegisteredBuffer {
272 snapshot: BufferSnapshot,
273 _subscriptions: [gpui::Subscription; 2],
274}
275
276#[derive(Clone)]
277pub enum Event {
278 BufferChange {
279 old_snapshot: BufferSnapshot,
280 new_snapshot: BufferSnapshot,
281 timestamp: Instant,
282 },
283}
284
285impl Event {
286 pub fn to_request_event(&self, cx: &App) -> Option<predict_edits_v3::Event> {
287 match self {
288 Event::BufferChange {
289 old_snapshot,
290 new_snapshot,
291 ..
292 } => {
293 let path = new_snapshot.file().map(|f| f.full_path(cx));
294
295 let old_path = old_snapshot.file().and_then(|f| {
296 let old_path = f.full_path(cx);
297 if Some(&old_path) != path.as_ref() {
298 Some(old_path)
299 } else {
300 None
301 }
302 });
303
304 // TODO [zeta2] move to bg?
305 let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
306
307 if path == old_path && diff.is_empty() {
308 None
309 } else {
310 Some(predict_edits_v3::Event::BufferChange {
311 old_path,
312 path,
313 diff,
314 //todo: Actually detect if this edit was predicted or not
315 predicted: false,
316 })
317 }
318 }
319 }
320 }
321}
322
323impl Zeta {
324 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
325 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
326 }
327
328 pub fn global(
329 client: &Arc<Client>,
330 user_store: &Entity<UserStore>,
331 cx: &mut App,
332 ) -> Entity<Self> {
333 cx.try_global::<ZetaGlobal>()
334 .map(|global| global.0.clone())
335 .unwrap_or_else(|| {
336 let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
337 cx.set_global(ZetaGlobal(zeta.clone()));
338 zeta
339 })
340 }
341
342 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
343 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
344
345 Self {
346 projects: HashMap::default(),
347 client,
348 user_store,
349 options: DEFAULT_OPTIONS,
350 llm_token: LlmApiToken::default(),
351 _llm_token_subscription: cx.subscribe(
352 &refresh_llm_token_listener,
353 |this, _listener, _event, cx| {
354 let client = this.client.clone();
355 let llm_token = this.llm_token.clone();
356 cx.spawn(async move |_this, _cx| {
357 llm_token.refresh(&client).await?;
358 anyhow::Ok(())
359 })
360 .detach_and_log_err(cx);
361 },
362 ),
363 update_required: false,
364 debug_tx: None,
365 #[cfg(feature = "eval-support")]
366 eval_cache: None,
367 }
368 }
369
370 #[cfg(feature = "eval-support")]
371 pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
372 self.eval_cache = Some(cache);
373 }
374
375 pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<ZetaDebugInfo> {
376 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
377 self.debug_tx = Some(debug_watch_tx);
378 debug_watch_rx
379 }
380
381 pub fn options(&self) -> &ZetaOptions {
382 &self.options
383 }
384
385 pub fn set_options(&mut self, options: ZetaOptions) {
386 self.options = options;
387 }
388
389 pub fn clear_history(&mut self) {
390 for zeta_project in self.projects.values_mut() {
391 zeta_project.events.clear();
392 }
393 }
394
395 pub fn history_for_project(&self, project: &Entity<Project>) -> impl Iterator<Item = &Event> {
396 self.projects
397 .get(&project.entity_id())
398 .map(|project| project.events.iter())
399 .into_iter()
400 .flatten()
401 }
402
403 pub fn context_for_project(
404 &self,
405 project: &Entity<Project>,
406 ) -> impl Iterator<Item = (Entity<Buffer>, &[Range<Anchor>])> {
407 self.projects
408 .get(&project.entity_id())
409 .and_then(|project| {
410 Some(
411 project
412 .context
413 .as_ref()?
414 .iter()
415 .map(|(buffer, ranges)| (buffer.clone(), ranges.as_slice())),
416 )
417 })
418 .into_iter()
419 .flatten()
420 }
421
422 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
423 self.user_store.read(cx).edit_prediction_usage()
424 }
425
426 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
427 self.get_or_init_zeta_project(project, cx);
428 }
429
430 pub fn register_buffer(
431 &mut self,
432 buffer: &Entity<Buffer>,
433 project: &Entity<Project>,
434 cx: &mut Context<Self>,
435 ) {
436 let zeta_project = self.get_or_init_zeta_project(project, cx);
437 Self::register_buffer_impl(zeta_project, buffer, project, cx);
438 }
439
440 fn get_or_init_zeta_project(
441 &mut self,
442 project: &Entity<Project>,
443 cx: &mut App,
444 ) -> &mut ZetaProject {
445 self.projects
446 .entry(project.entity_id())
447 .or_insert_with(|| ZetaProject {
448 syntax_index: cx.new(|cx| {
449 SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
450 }),
451 events: VecDeque::new(),
452 registered_buffers: HashMap::default(),
453 current_prediction: None,
454 context: None,
455 refresh_context_task: None,
456 refresh_context_debounce_task: None,
457 refresh_context_timestamp: None,
458 })
459 }
460
461 fn register_buffer_impl<'a>(
462 zeta_project: &'a mut ZetaProject,
463 buffer: &Entity<Buffer>,
464 project: &Entity<Project>,
465 cx: &mut Context<Self>,
466 ) -> &'a mut RegisteredBuffer {
467 let buffer_id = buffer.entity_id();
468 match zeta_project.registered_buffers.entry(buffer_id) {
469 hash_map::Entry::Occupied(entry) => entry.into_mut(),
470 hash_map::Entry::Vacant(entry) => {
471 let snapshot = buffer.read(cx).snapshot();
472 let project_entity_id = project.entity_id();
473 entry.insert(RegisteredBuffer {
474 snapshot,
475 _subscriptions: [
476 cx.subscribe(buffer, {
477 let project = project.downgrade();
478 move |this, buffer, event, cx| {
479 if let language::BufferEvent::Edited = event
480 && let Some(project) = project.upgrade()
481 {
482 this.report_changes_for_buffer(&buffer, &project, cx);
483 }
484 }
485 }),
486 cx.observe_release(buffer, move |this, _buffer, _cx| {
487 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
488 else {
489 return;
490 };
491 zeta_project.registered_buffers.remove(&buffer_id);
492 }),
493 ],
494 })
495 }
496 }
497 }
498
499 fn report_changes_for_buffer(
500 &mut self,
501 buffer: &Entity<Buffer>,
502 project: &Entity<Project>,
503 cx: &mut Context<Self>,
504 ) -> BufferSnapshot {
505 let buffer_change_grouping_interval = self.options.buffer_change_grouping_interval;
506 let zeta_project = self.get_or_init_zeta_project(project, cx);
507 let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
508
509 let new_snapshot = buffer.read(cx).snapshot();
510 if new_snapshot.version != registered_buffer.snapshot.version {
511 let old_snapshot =
512 std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
513 Self::push_event(
514 zeta_project,
515 buffer_change_grouping_interval,
516 Event::BufferChange {
517 old_snapshot,
518 new_snapshot: new_snapshot.clone(),
519 timestamp: Instant::now(),
520 },
521 );
522 }
523
524 new_snapshot
525 }
526
527 fn push_event(
528 zeta_project: &mut ZetaProject,
529 buffer_change_grouping_interval: Duration,
530 event: Event,
531 ) {
532 let events = &mut zeta_project.events;
533
534 if buffer_change_grouping_interval > Duration::ZERO
535 && let Some(Event::BufferChange {
536 new_snapshot: last_new_snapshot,
537 timestamp: last_timestamp,
538 ..
539 }) = events.back_mut()
540 {
541 // Coalesce edits for the same buffer when they happen one after the other.
542 let Event::BufferChange {
543 old_snapshot,
544 new_snapshot,
545 timestamp,
546 } = &event;
547
548 if timestamp.duration_since(*last_timestamp) <= buffer_change_grouping_interval
549 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
550 && old_snapshot.version == last_new_snapshot.version
551 {
552 *last_new_snapshot = new_snapshot.clone();
553 *last_timestamp = *timestamp;
554 return;
555 }
556 }
557
558 if events.len() >= MAX_EVENT_COUNT {
559 // These are halved instead of popping to improve prompt caching.
560 events.drain(..MAX_EVENT_COUNT / 2);
561 }
562
563 events.push_back(event);
564 }
565
566 fn current_prediction_for_buffer(
567 &self,
568 buffer: &Entity<Buffer>,
569 project: &Entity<Project>,
570 cx: &App,
571 ) -> Option<BufferEditPrediction<'_>> {
572 let project_state = self.projects.get(&project.entity_id())?;
573
574 let CurrentEditPrediction {
575 requested_by_buffer_id,
576 prediction,
577 } = project_state.current_prediction.as_ref()?;
578
579 if prediction.targets_buffer(buffer.read(cx)) {
580 Some(BufferEditPrediction::Local { prediction })
581 } else if *requested_by_buffer_id == buffer.entity_id() {
582 Some(BufferEditPrediction::Jump { prediction })
583 } else {
584 None
585 }
586 }
587
588 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
589 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
590 return;
591 };
592
593 let Some(prediction) = project_state.current_prediction.take() else {
594 return;
595 };
596 let request_id = prediction.prediction.id.to_string();
597
598 let client = self.client.clone();
599 let llm_token = self.llm_token.clone();
600 let app_version = AppVersion::global(cx);
601 cx.spawn(async move |this, cx| {
602 let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") {
603 http_client::Url::parse(&predict_edits_url)?
604 } else {
605 client
606 .http_client()
607 .build_zed_llm_url("/predict_edits/accept", &[])?
608 };
609
610 let response = cx
611 .background_spawn(Self::send_api_request::<()>(
612 move |builder| {
613 let req = builder.uri(url.as_ref()).body(
614 serde_json::to_string(&AcceptEditPredictionBody {
615 request_id: request_id.clone(),
616 })?
617 .into(),
618 );
619 Ok(req?)
620 },
621 client,
622 llm_token,
623 app_version,
624 ))
625 .await;
626
627 Self::handle_api_response(&this, response, cx)?;
628 anyhow::Ok(())
629 })
630 .detach_and_log_err(cx);
631 }
632
633 fn discard_current_prediction(&mut self, project: &Entity<Project>) {
634 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
635 project_state.current_prediction.take();
636 };
637 }
638
639 pub fn refresh_prediction(
640 &mut self,
641 project: &Entity<Project>,
642 buffer: &Entity<Buffer>,
643 position: language::Anchor,
644 cx: &mut Context<Self>,
645 ) -> Task<Result<()>> {
646 let request_task = self.request_prediction(project, buffer, position, cx);
647 let buffer = buffer.clone();
648 let project = project.clone();
649
650 cx.spawn(async move |this, cx| {
651 if let Some(prediction) = request_task.await? {
652 this.update(cx, |this, cx| {
653 let project_state = this
654 .projects
655 .get_mut(&project.entity_id())
656 .context("Project not found")?;
657
658 let new_prediction = CurrentEditPrediction {
659 requested_by_buffer_id: buffer.entity_id(),
660 prediction: prediction,
661 };
662
663 if project_state
664 .current_prediction
665 .as_ref()
666 .is_none_or(|old_prediction| {
667 new_prediction.should_replace_prediction(&old_prediction, cx)
668 })
669 {
670 project_state.current_prediction = Some(new_prediction);
671 }
672 anyhow::Ok(())
673 })??;
674 }
675 Ok(())
676 })
677 }
678
679 pub fn request_prediction(
680 &mut self,
681 project: &Entity<Project>,
682 active_buffer: &Entity<Buffer>,
683 position: language::Anchor,
684 cx: &mut Context<Self>,
685 ) -> Task<Result<Option<EditPrediction>>> {
686 let project_state = self.projects.get(&project.entity_id());
687
688 let index_state = project_state.map(|state| {
689 state
690 .syntax_index
691 .read_with(cx, |index, _cx| index.state().clone())
692 });
693 let options = self.options.clone();
694 let active_snapshot = active_buffer.read(cx).snapshot();
695 let Some(excerpt_path) = active_snapshot
696 .file()
697 .map(|path| -> Arc<Path> { path.full_path(cx).into() })
698 else {
699 return Task::ready(Err(anyhow!("No file path for excerpt")));
700 };
701 let client = self.client.clone();
702 let llm_token = self.llm_token.clone();
703 let app_version = AppVersion::global(cx);
704 let worktree_snapshots = project
705 .read(cx)
706 .worktrees(cx)
707 .map(|worktree| worktree.read(cx).snapshot())
708 .collect::<Vec<_>>();
709 let debug_tx = self.debug_tx.clone();
710
711 let events = project_state
712 .map(|state| {
713 state
714 .events
715 .iter()
716 .filter_map(|event| event.to_request_event(cx))
717 .collect::<Vec<_>>()
718 })
719 .unwrap_or_default();
720
721 let diagnostics = active_snapshot.diagnostic_sets().clone();
722
723 let parent_abs_path =
724 project::File::from_dyn(active_buffer.read(cx).file()).and_then(|f| {
725 let mut path = f.worktree.read(cx).absolutize(&f.path);
726 if path.pop() { Some(path) } else { None }
727 });
728
729 // TODO data collection
730 let can_collect_data = cx.is_staff();
731
732 let empty_context_files = HashMap::default();
733 let context_files = project_state
734 .and_then(|project_state| project_state.context.as_ref())
735 .unwrap_or(&empty_context_files);
736
737 #[cfg(feature = "eval-support")]
738 let parsed_fut = futures::future::join_all(
739 context_files
740 .keys()
741 .map(|buffer| buffer.read(cx).parsing_idle()),
742 );
743
744 let mut included_files = context_files
745 .iter()
746 .filter_map(|(buffer_entity, ranges)| {
747 let buffer = buffer_entity.read(cx);
748 Some((
749 buffer_entity.clone(),
750 buffer.snapshot(),
751 buffer.file()?.full_path(cx).into(),
752 ranges.clone(),
753 ))
754 })
755 .collect::<Vec<_>>();
756
757 included_files.sort_by(|(_, _, path_a, ranges_a), (_, _, path_b, ranges_b)| {
758 (path_a, ranges_a.len()).cmp(&(path_b, ranges_b.len()))
759 });
760
761 #[cfg(feature = "eval-support")]
762 let eval_cache = self.eval_cache.clone();
763
764 let request_task = cx.background_spawn({
765 let active_buffer = active_buffer.clone();
766 async move {
767 #[cfg(feature = "eval-support")]
768 parsed_fut.await;
769
770 let index_state = if let Some(index_state) = index_state {
771 Some(index_state.lock_owned().await)
772 } else {
773 None
774 };
775
776 let cursor_offset = position.to_offset(&active_snapshot);
777 let cursor_point = cursor_offset.to_point(&active_snapshot);
778
779 let before_retrieval = chrono::Utc::now();
780
781 let (diagnostic_groups, diagnostic_groups_truncated) =
782 Self::gather_nearby_diagnostics(
783 cursor_offset,
784 &diagnostics,
785 &active_snapshot,
786 options.max_diagnostic_bytes,
787 );
788
789 let cloud_request = match options.context {
790 ContextMode::Agentic(context_options) => {
791 let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
792 cursor_point,
793 &active_snapshot,
794 &context_options.excerpt,
795 index_state.as_deref(),
796 ) else {
797 return Ok((None, None));
798 };
799
800 let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start)
801 ..active_snapshot.anchor_before(excerpt.range.end);
802
803 if let Some(buffer_ix) =
804 included_files.iter().position(|(_, snapshot, _, _)| {
805 snapshot.remote_id() == active_snapshot.remote_id()
806 })
807 {
808 let (_, buffer, _, ranges) = &mut included_files[buffer_ix];
809 let range_ix = ranges
810 .binary_search_by(|probe| {
811 probe
812 .start
813 .cmp(&excerpt_anchor_range.start, buffer)
814 .then(excerpt_anchor_range.end.cmp(&probe.end, buffer))
815 })
816 .unwrap_or_else(|ix| ix);
817
818 ranges.insert(range_ix, excerpt_anchor_range);
819 let last_ix = included_files.len() - 1;
820 included_files.swap(buffer_ix, last_ix);
821 } else {
822 included_files.push((
823 active_buffer.clone(),
824 active_snapshot,
825 excerpt_path.clone(),
826 vec![excerpt_anchor_range],
827 ));
828 }
829
830 let included_files = included_files
831 .iter()
832 .map(|(_, snapshot, path, ranges)| {
833 let excerpts = merge_excerpts(
834 &snapshot,
835 ranges.iter().map(|range| {
836 let point_range = range.to_point(&snapshot);
837 Line(point_range.start.row)..Line(point_range.end.row)
838 }),
839 );
840 predict_edits_v3::IncludedFile {
841 path: path.clone(),
842 max_row: Line(snapshot.max_point().row),
843 excerpts,
844 }
845 })
846 .collect::<Vec<_>>();
847
848 predict_edits_v3::PredictEditsRequest {
849 excerpt_path,
850 excerpt: String::new(),
851 excerpt_line_range: Line(0)..Line(0),
852 excerpt_range: 0..0,
853 cursor_point: predict_edits_v3::Point {
854 line: predict_edits_v3::Line(cursor_point.row),
855 column: cursor_point.column,
856 },
857 included_files,
858 referenced_declarations: vec![],
859 events,
860 can_collect_data,
861 diagnostic_groups,
862 diagnostic_groups_truncated,
863 debug_info: debug_tx.is_some(),
864 prompt_max_bytes: Some(options.max_prompt_bytes),
865 prompt_format: options.prompt_format,
866 // TODO [zeta2]
867 signatures: vec![],
868 excerpt_parent: None,
869 git_info: None,
870 }
871 }
872 ContextMode::Syntax(context_options) => {
873 let Some(context) = EditPredictionContext::gather_context(
874 cursor_point,
875 &active_snapshot,
876 parent_abs_path.as_deref(),
877 &context_options,
878 index_state.as_deref(),
879 ) else {
880 return Ok((None, None));
881 };
882
883 make_syntax_context_cloud_request(
884 excerpt_path,
885 context,
886 events,
887 can_collect_data,
888 diagnostic_groups,
889 diagnostic_groups_truncated,
890 None,
891 debug_tx.is_some(),
892 &worktree_snapshots,
893 index_state.as_deref(),
894 Some(options.max_prompt_bytes),
895 options.prompt_format,
896 )
897 }
898 };
899
900 let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request);
901
902 let retrieval_time = chrono::Utc::now() - before_retrieval;
903
904 let debug_response_tx = if let Some(debug_tx) = &debug_tx {
905 let (response_tx, response_rx) = oneshot::channel();
906
907 debug_tx
908 .unbounded_send(ZetaDebugInfo::EditPredictionRequested(
909 ZetaEditPredictionDebugInfo {
910 request: cloud_request.clone(),
911 retrieval_time,
912 buffer: active_buffer.downgrade(),
913 local_prompt: match prompt_result.as_ref() {
914 Ok((prompt, _)) => Ok(prompt.clone()),
915 Err(err) => Err(err.to_string()),
916 },
917 position,
918 response_rx,
919 },
920 ))
921 .ok();
922 Some(response_tx)
923 } else {
924 None
925 };
926
927 if cfg!(debug_assertions) && env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
928 if let Some(debug_response_tx) = debug_response_tx {
929 debug_response_tx
930 .send((Err("Request skipped".to_string()), TimeDelta::zero()))
931 .ok();
932 }
933 anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
934 }
935
936 let (prompt, _) = prompt_result?;
937 let request = open_ai::Request {
938 model: MODEL_ID.clone(),
939 messages: vec![open_ai::RequestMessage::User {
940 content: open_ai::MessageContent::Plain(prompt),
941 }],
942 stream: false,
943 max_completion_tokens: None,
944 stop: Default::default(),
945 temperature: 0.7,
946 tool_choice: None,
947 parallel_tool_calls: None,
948 tools: vec![],
949 prompt_cache_key: None,
950 reasoning_effort: None,
951 };
952
953 log::trace!("Sending edit prediction request");
954
955 let before_request = chrono::Utc::now();
956 let response = Self::send_raw_llm_request(
957 request,
958 client,
959 llm_token,
960 app_version,
961 #[cfg(feature = "eval-support")]
962 eval_cache,
963 #[cfg(feature = "eval-support")]
964 EvalCacheEntryKind::Prediction,
965 )
966 .await;
967 let request_time = chrono::Utc::now() - before_request;
968
969 log::trace!("Got edit prediction response");
970
971 if let Some(debug_response_tx) = debug_response_tx {
972 debug_response_tx
973 .send((
974 response
975 .as_ref()
976 .map_err(|err| err.to_string())
977 .map(|response| response.0.clone()),
978 request_time,
979 ))
980 .ok();
981 }
982
983 let (res, usage) = response?;
984 let request_id = EditPredictionId(res.id.clone().into());
985 let Some(mut output_text) = text_from_response(res) else {
986 return Ok((None, usage));
987 };
988
989 if output_text.contains(CURSOR_MARKER) {
990 log::trace!("Stripping out {CURSOR_MARKER} from response");
991 output_text = output_text.replace(CURSOR_MARKER, "");
992 }
993
994 let get_buffer_from_context = |path: &Path| {
995 included_files
996 .iter()
997 .find_map(|(_, buffer, probe_path, ranges)| {
998 if probe_path.as_ref() == path {
999 Some((buffer, ranges.as_slice()))
1000 } else {
1001 None
1002 }
1003 })
1004 };
1005
1006 let (edited_buffer_snapshot, edits) = match options.prompt_format {
1007 PromptFormat::NumLinesUniDiff => {
1008 crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
1009 }
1010 PromptFormat::OldTextNewText => {
1011 crate::xml_edits::parse_xml_edits(&output_text, get_buffer_from_context)
1012 .await?
1013 }
1014 _ => {
1015 bail!("unsupported prompt format {}", options.prompt_format)
1016 }
1017 };
1018
1019 let edited_buffer = included_files
1020 .iter()
1021 .find_map(|(buffer, snapshot, _, _)| {
1022 if snapshot.remote_id() == edited_buffer_snapshot.remote_id() {
1023 Some(buffer.clone())
1024 } else {
1025 None
1026 }
1027 })
1028 .context("Failed to find buffer in included_buffers")?;
1029
1030 anyhow::Ok((
1031 Some((
1032 request_id,
1033 edited_buffer,
1034 edited_buffer_snapshot.clone(),
1035 edits,
1036 )),
1037 usage,
1038 ))
1039 }
1040 });
1041
1042 cx.spawn({
1043 async move |this, cx| {
1044 let Some((id, edited_buffer, edited_buffer_snapshot, edits)) =
1045 Self::handle_api_response(&this, request_task.await, cx)?
1046 else {
1047 return Ok(None);
1048 };
1049
1050 // TODO telemetry: duration, etc
1051 Ok(
1052 EditPrediction::new(id, &edited_buffer, &edited_buffer_snapshot, edits, cx)
1053 .await,
1054 )
1055 }
1056 })
1057 }
1058
1059 async fn send_raw_llm_request(
1060 request: open_ai::Request,
1061 client: Arc<Client>,
1062 llm_token: LlmApiToken,
1063 app_version: SemanticVersion,
1064 #[cfg(feature = "eval-support")] eval_cache: Option<Arc<dyn EvalCache>>,
1065 #[cfg(feature = "eval-support")] eval_cache_kind: EvalCacheEntryKind,
1066 ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
1067 let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
1068 http_client::Url::parse(&predict_edits_url)?
1069 } else {
1070 client
1071 .http_client()
1072 .build_zed_llm_url("/predict_edits/raw", &[])?
1073 };
1074
1075 #[cfg(feature = "eval-support")]
1076 let cache_key = if let Some(cache) = eval_cache {
1077 use collections::FxHasher;
1078 use std::hash::{Hash, Hasher};
1079
1080 let mut hasher = FxHasher::default();
1081 url.hash(&mut hasher);
1082 let request_str = serde_json::to_string_pretty(&request)?;
1083 request_str.hash(&mut hasher);
1084 let hash = hasher.finish();
1085
1086 let key = (eval_cache_kind, hash);
1087 if let Some(response_str) = cache.read(key) {
1088 return Ok((serde_json::from_str(&response_str)?, None));
1089 }
1090
1091 Some((cache, request_str, key))
1092 } else {
1093 None
1094 };
1095
1096 let (response, usage) = Self::send_api_request(
1097 |builder| {
1098 let req = builder
1099 .uri(url.as_ref())
1100 .body(serde_json::to_string(&request)?.into());
1101 Ok(req?)
1102 },
1103 client,
1104 llm_token,
1105 app_version,
1106 )
1107 .await?;
1108
1109 #[cfg(feature = "eval-support")]
1110 if let Some((cache, request, key)) = cache_key {
1111 cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
1112 }
1113
1114 Ok((response, usage))
1115 }
1116
1117 fn handle_api_response<T>(
1118 this: &WeakEntity<Self>,
1119 response: Result<(T, Option<EditPredictionUsage>)>,
1120 cx: &mut gpui::AsyncApp,
1121 ) -> Result<T> {
1122 match response {
1123 Ok((data, usage)) => {
1124 if let Some(usage) = usage {
1125 this.update(cx, |this, cx| {
1126 this.user_store.update(cx, |user_store, cx| {
1127 user_store.update_edit_prediction_usage(usage, cx);
1128 });
1129 })
1130 .ok();
1131 }
1132 Ok(data)
1133 }
1134 Err(err) => {
1135 if err.is::<ZedUpdateRequiredError>() {
1136 cx.update(|cx| {
1137 this.update(cx, |this, _cx| {
1138 this.update_required = true;
1139 })
1140 .ok();
1141
1142 let error_message: SharedString = err.to_string().into();
1143 show_app_notification(
1144 NotificationId::unique::<ZedUpdateRequiredError>(),
1145 cx,
1146 move |cx| {
1147 cx.new(|cx| {
1148 ErrorMessagePrompt::new(error_message.clone(), cx)
1149 .with_link_button("Update Zed", "https://zed.dev/releases")
1150 })
1151 },
1152 );
1153 })
1154 .ok();
1155 }
1156 Err(err)
1157 }
1158 }
1159 }
1160
1161 async fn send_api_request<Res>(
1162 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1163 client: Arc<Client>,
1164 llm_token: LlmApiToken,
1165 app_version: SemanticVersion,
1166 ) -> Result<(Res, Option<EditPredictionUsage>)>
1167 where
1168 Res: DeserializeOwned,
1169 {
1170 let http_client = client.http_client();
1171 let mut token = llm_token.acquire(&client).await?;
1172 let mut did_retry = false;
1173
1174 loop {
1175 let request_builder = http_client::Request::builder().method(Method::POST);
1176
1177 let request = build(
1178 request_builder
1179 .header("Content-Type", "application/json")
1180 .header("Authorization", format!("Bearer {}", token))
1181 .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
1182 )?;
1183
1184 let mut response = http_client.send(request).await?;
1185
1186 if let Some(minimum_required_version) = response
1187 .headers()
1188 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
1189 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
1190 {
1191 anyhow::ensure!(
1192 app_version >= minimum_required_version,
1193 ZedUpdateRequiredError {
1194 minimum_version: minimum_required_version
1195 }
1196 );
1197 }
1198
1199 if response.status().is_success() {
1200 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
1201
1202 let mut body = Vec::new();
1203 response.body_mut().read_to_end(&mut body).await?;
1204 return Ok((serde_json::from_slice(&body)?, usage));
1205 } else if !did_retry
1206 && response
1207 .headers()
1208 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1209 .is_some()
1210 {
1211 did_retry = true;
1212 token = llm_token.refresh(&client).await?;
1213 } else {
1214 let mut body = String::new();
1215 response.body_mut().read_to_string(&mut body).await?;
1216 anyhow::bail!(
1217 "Request failed with status: {:?}\nBody: {}",
1218 response.status(),
1219 body
1220 );
1221 }
1222 }
1223 }
1224
1225 pub const CONTEXT_RETRIEVAL_IDLE_DURATION: Duration = Duration::from_secs(10);
1226 pub const CONTEXT_RETRIEVAL_DEBOUNCE_DURATION: Duration = Duration::from_secs(3);
1227
1228 // Refresh the related excerpts when the user just beguns editing after
1229 // an idle period, and after they pause editing.
1230 fn refresh_context_if_needed(
1231 &mut self,
1232 project: &Entity<Project>,
1233 buffer: &Entity<language::Buffer>,
1234 cursor_position: language::Anchor,
1235 cx: &mut Context<Self>,
1236 ) {
1237 if !matches!(&self.options().context, ContextMode::Agentic { .. }) {
1238 return;
1239 }
1240
1241 let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1242 return;
1243 };
1244
1245 let now = Instant::now();
1246 let was_idle = zeta_project
1247 .refresh_context_timestamp
1248 .map_or(true, |timestamp| {
1249 now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION
1250 });
1251 zeta_project.refresh_context_timestamp = Some(now);
1252 zeta_project.refresh_context_debounce_task = Some(cx.spawn({
1253 let buffer = buffer.clone();
1254 let project = project.clone();
1255 async move |this, cx| {
1256 if was_idle {
1257 log::debug!("refetching edit prediction context after idle");
1258 } else {
1259 cx.background_executor()
1260 .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION)
1261 .await;
1262 log::debug!("refetching edit prediction context after pause");
1263 }
1264 this.update(cx, |this, cx| {
1265 let task = this.refresh_context(project.clone(), buffer, cursor_position, cx);
1266
1267 if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
1268 zeta_project.refresh_context_task = Some(task.log_err());
1269 };
1270 })
1271 .ok()
1272 }
1273 }));
1274 }
1275
1276 // Refresh the related excerpts asynchronously. Ensure the task runs to completion,
1277 // and avoid spawning more than one concurrent task.
1278 pub fn refresh_context(
1279 &mut self,
1280 project: Entity<Project>,
1281 buffer: Entity<language::Buffer>,
1282 cursor_position: language::Anchor,
1283 cx: &mut Context<Self>,
1284 ) -> Task<Result<()>> {
1285 let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
1286 return Task::ready(anyhow::Ok(()));
1287 };
1288
1289 let ContextMode::Agentic(options) = &self.options().context else {
1290 return Task::ready(anyhow::Ok(()));
1291 };
1292
1293 let snapshot = buffer.read(cx).snapshot();
1294 let cursor_point = cursor_position.to_point(&snapshot);
1295 let Some(cursor_excerpt) = EditPredictionExcerpt::select_from_buffer(
1296 cursor_point,
1297 &snapshot,
1298 &options.excerpt,
1299 None,
1300 ) else {
1301 return Task::ready(Ok(()));
1302 };
1303
1304 let app_version = AppVersion::global(cx);
1305 let client = self.client.clone();
1306 let llm_token = self.llm_token.clone();
1307 let debug_tx = self.debug_tx.clone();
1308 let current_file_path: Arc<Path> = snapshot
1309 .file()
1310 .map(|f| f.full_path(cx).into())
1311 .unwrap_or_else(|| Path::new("untitled").into());
1312
1313 let prompt = match cloud_zeta2_prompt::retrieval_prompt::build_prompt(
1314 predict_edits_v3::PlanContextRetrievalRequest {
1315 excerpt: cursor_excerpt.text(&snapshot).body,
1316 excerpt_path: current_file_path,
1317 excerpt_line_range: cursor_excerpt.line_range,
1318 cursor_file_max_row: Line(snapshot.max_point().row),
1319 events: zeta_project
1320 .events
1321 .iter()
1322 .filter_map(|ev| ev.to_request_event(cx))
1323 .collect(),
1324 },
1325 ) {
1326 Ok(prompt) => prompt,
1327 Err(err) => {
1328 return Task::ready(Err(err));
1329 }
1330 };
1331
1332 if let Some(debug_tx) = &debug_tx {
1333 debug_tx
1334 .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted(
1335 ZetaContextRetrievalStartedDebugInfo {
1336 project: project.clone(),
1337 timestamp: Instant::now(),
1338 search_prompt: prompt.clone(),
1339 },
1340 ))
1341 .ok();
1342 }
1343
1344 pub static TOOL_SCHEMA: LazyLock<(serde_json::Value, String)> = LazyLock::new(|| {
1345 let schema = language_model::tool_schema::root_schema_for::<SearchToolInput>(
1346 language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset,
1347 );
1348
1349 let description = schema
1350 .get("description")
1351 .and_then(|description| description.as_str())
1352 .unwrap()
1353 .to_string();
1354
1355 (schema.into(), description)
1356 });
1357
1358 let (tool_schema, tool_description) = TOOL_SCHEMA.clone();
1359
1360 let request = open_ai::Request {
1361 model: MODEL_ID.clone(),
1362 messages: vec![open_ai::RequestMessage::User {
1363 content: open_ai::MessageContent::Plain(prompt),
1364 }],
1365 stream: false,
1366 max_completion_tokens: None,
1367 stop: Default::default(),
1368 temperature: 0.7,
1369 tool_choice: None,
1370 parallel_tool_calls: None,
1371 tools: vec![open_ai::ToolDefinition::Function {
1372 function: FunctionDefinition {
1373 name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME.to_string(),
1374 description: Some(tool_description),
1375 parameters: Some(tool_schema),
1376 },
1377 }],
1378 prompt_cache_key: None,
1379 reasoning_effort: None,
1380 };
1381
1382 #[cfg(feature = "eval-support")]
1383 let eval_cache = self.eval_cache.clone();
1384
1385 cx.spawn(async move |this, cx| {
1386 log::trace!("Sending search planning request");
1387 let response = Self::send_raw_llm_request(
1388 request,
1389 client,
1390 llm_token,
1391 app_version,
1392 #[cfg(feature = "eval-support")]
1393 eval_cache.clone(),
1394 #[cfg(feature = "eval-support")]
1395 EvalCacheEntryKind::Context,
1396 )
1397 .await;
1398 let mut response = Self::handle_api_response(&this, response, cx)?;
1399 log::trace!("Got search planning response");
1400
1401 let choice = response
1402 .choices
1403 .pop()
1404 .context("No choices in retrieval response")?;
1405 let open_ai::RequestMessage::Assistant {
1406 content: _,
1407 tool_calls,
1408 } = choice.message
1409 else {
1410 anyhow::bail!("Retrieval response didn't include an assistant message");
1411 };
1412
1413 let mut queries: Vec<SearchToolQuery> = Vec::new();
1414 for tool_call in tool_calls {
1415 let open_ai::ToolCallContent::Function { function } = tool_call.content;
1416 if function.name != cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME {
1417 log::warn!(
1418 "Context retrieval response tried to call an unknown tool: {}",
1419 function.name
1420 );
1421
1422 continue;
1423 }
1424
1425 let input: SearchToolInput = serde_json::from_str(&function.arguments)
1426 .with_context(|| format!("invalid search json {}", &function.arguments))?;
1427 queries.extend(input.queries);
1428 }
1429
1430 if let Some(debug_tx) = &debug_tx {
1431 debug_tx
1432 .unbounded_send(ZetaDebugInfo::SearchQueriesGenerated(
1433 ZetaSearchQueryDebugInfo {
1434 project: project.clone(),
1435 timestamp: Instant::now(),
1436 search_queries: queries.clone(),
1437 },
1438 ))
1439 .ok();
1440 }
1441
1442 log::trace!("Running retrieval search: {queries:#?}");
1443
1444 let related_excerpts_result = retrieval_search::run_retrieval_searches(
1445 queries,
1446 project.clone(),
1447 #[cfg(feature = "eval-support")]
1448 eval_cache,
1449 cx,
1450 )
1451 .await;
1452
1453 log::trace!("Search queries executed");
1454
1455 if let Some(debug_tx) = &debug_tx {
1456 debug_tx
1457 .unbounded_send(ZetaDebugInfo::SearchQueriesExecuted(
1458 ZetaContextRetrievalDebugInfo {
1459 project: project.clone(),
1460 timestamp: Instant::now(),
1461 },
1462 ))
1463 .ok();
1464 }
1465
1466 this.update(cx, |this, _cx| {
1467 let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
1468 return Ok(());
1469 };
1470 zeta_project.refresh_context_task.take();
1471 if let Some(debug_tx) = &this.debug_tx {
1472 debug_tx
1473 .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished(
1474 ZetaContextRetrievalDebugInfo {
1475 project,
1476 timestamp: Instant::now(),
1477 },
1478 ))
1479 .ok();
1480 }
1481 match related_excerpts_result {
1482 Ok(excerpts) => {
1483 zeta_project.context = Some(excerpts);
1484 Ok(())
1485 }
1486 Err(error) => Err(error),
1487 }
1488 })?
1489 })
1490 }
1491
1492 pub fn set_context(
1493 &mut self,
1494 project: Entity<Project>,
1495 context: HashMap<Entity<Buffer>, Vec<Range<Anchor>>>,
1496 ) {
1497 if let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) {
1498 zeta_project.context = Some(context);
1499 }
1500 }
1501
1502 fn gather_nearby_diagnostics(
1503 cursor_offset: usize,
1504 diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
1505 snapshot: &BufferSnapshot,
1506 max_diagnostics_bytes: usize,
1507 ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
1508 // TODO: Could make this more efficient
1509 let mut diagnostic_groups = Vec::new();
1510 for (language_server_id, diagnostics) in diagnostic_sets {
1511 let mut groups = Vec::new();
1512 diagnostics.groups(*language_server_id, &mut groups, &snapshot);
1513 diagnostic_groups.extend(
1514 groups
1515 .into_iter()
1516 .map(|(_, group)| group.resolve::<usize>(&snapshot)),
1517 );
1518 }
1519
1520 // sort by proximity to cursor
1521 diagnostic_groups.sort_by_key(|group| {
1522 let range = &group.entries[group.primary_ix].range;
1523 if range.start >= cursor_offset {
1524 range.start - cursor_offset
1525 } else if cursor_offset >= range.end {
1526 cursor_offset - range.end
1527 } else {
1528 (cursor_offset - range.start).min(range.end - cursor_offset)
1529 }
1530 });
1531
1532 let mut results = Vec::new();
1533 let mut diagnostic_groups_truncated = false;
1534 let mut diagnostics_byte_count = 0;
1535 for group in diagnostic_groups {
1536 let raw_value = serde_json::value::to_raw_value(&group).unwrap();
1537 diagnostics_byte_count += raw_value.get().len();
1538 if diagnostics_byte_count > max_diagnostics_bytes {
1539 diagnostic_groups_truncated = true;
1540 break;
1541 }
1542 results.push(predict_edits_v3::DiagnosticGroup(raw_value));
1543 }
1544
1545 (results, diagnostic_groups_truncated)
1546 }
1547
1548 // TODO: Dedupe with similar code in request_prediction?
1549 pub fn cloud_request_for_zeta_cli(
1550 &mut self,
1551 project: &Entity<Project>,
1552 buffer: &Entity<Buffer>,
1553 position: language::Anchor,
1554 cx: &mut Context<Self>,
1555 ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
1556 let project_state = self.projects.get(&project.entity_id());
1557
1558 let index_state = project_state.map(|state| {
1559 state
1560 .syntax_index
1561 .read_with(cx, |index, _cx| index.state().clone())
1562 });
1563 let options = self.options.clone();
1564 let snapshot = buffer.read(cx).snapshot();
1565 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
1566 return Task::ready(Err(anyhow!("No file path for excerpt")));
1567 };
1568 let worktree_snapshots = project
1569 .read(cx)
1570 .worktrees(cx)
1571 .map(|worktree| worktree.read(cx).snapshot())
1572 .collect::<Vec<_>>();
1573
1574 let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
1575 let mut path = f.worktree.read(cx).absolutize(&f.path);
1576 if path.pop() { Some(path) } else { None }
1577 });
1578
1579 cx.background_spawn(async move {
1580 let index_state = if let Some(index_state) = index_state {
1581 Some(index_state.lock_owned().await)
1582 } else {
1583 None
1584 };
1585
1586 let cursor_point = position.to_point(&snapshot);
1587
1588 let debug_info = true;
1589 EditPredictionContext::gather_context(
1590 cursor_point,
1591 &snapshot,
1592 parent_abs_path.as_deref(),
1593 match &options.context {
1594 ContextMode::Agentic(_) => {
1595 // TODO
1596 panic!("Llm mode not supported in zeta cli yet");
1597 }
1598 ContextMode::Syntax(edit_prediction_context_options) => {
1599 edit_prediction_context_options
1600 }
1601 },
1602 index_state.as_deref(),
1603 )
1604 .context("Failed to select excerpt")
1605 .map(|context| {
1606 make_syntax_context_cloud_request(
1607 excerpt_path.into(),
1608 context,
1609 // TODO pass everything
1610 Vec::new(),
1611 false,
1612 Vec::new(),
1613 false,
1614 None,
1615 debug_info,
1616 &worktree_snapshots,
1617 index_state.as_deref(),
1618 Some(options.max_prompt_bytes),
1619 options.prompt_format,
1620 )
1621 })
1622 })
1623 }
1624
1625 pub fn wait_for_initial_indexing(
1626 &mut self,
1627 project: &Entity<Project>,
1628 cx: &mut App,
1629 ) -> Task<Result<()>> {
1630 let zeta_project = self.get_or_init_zeta_project(project, cx);
1631 zeta_project
1632 .syntax_index
1633 .read(cx)
1634 .wait_for_initial_file_indexing(cx)
1635 }
1636}
1637
1638pub fn text_from_response(mut res: open_ai::Response) -> Option<String> {
1639 let choice = res.choices.pop()?;
1640 let output_text = match choice.message {
1641 open_ai::RequestMessage::Assistant {
1642 content: Some(open_ai::MessageContent::Plain(content)),
1643 ..
1644 } => content,
1645 open_ai::RequestMessage::Assistant {
1646 content: Some(open_ai::MessageContent::Multipart(mut content)),
1647 ..
1648 } => {
1649 if content.is_empty() {
1650 log::error!("No output from Baseten completion response");
1651 return None;
1652 }
1653
1654 match content.remove(0) {
1655 open_ai::MessagePart::Text { text } => text,
1656 open_ai::MessagePart::Image { .. } => {
1657 log::error!("Expected text, got an image");
1658 return None;
1659 }
1660 }
1661 }
1662 _ => {
1663 log::error!("Invalid response message: {:?}", choice.message);
1664 return None;
1665 }
1666 };
1667 Some(output_text)
1668}
1669
1670#[derive(Error, Debug)]
1671#[error(
1672 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1673)]
1674pub struct ZedUpdateRequiredError {
1675 minimum_version: SemanticVersion,
1676}
1677
1678fn make_syntax_context_cloud_request(
1679 excerpt_path: Arc<Path>,
1680 context: EditPredictionContext,
1681 events: Vec<predict_edits_v3::Event>,
1682 can_collect_data: bool,
1683 diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
1684 diagnostic_groups_truncated: bool,
1685 git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
1686 debug_info: bool,
1687 worktrees: &Vec<worktree::Snapshot>,
1688 index_state: Option<&SyntaxIndexState>,
1689 prompt_max_bytes: Option<usize>,
1690 prompt_format: PromptFormat,
1691) -> predict_edits_v3::PredictEditsRequest {
1692 let mut signatures = Vec::new();
1693 let mut declaration_to_signature_index = HashMap::default();
1694 let mut referenced_declarations = Vec::new();
1695
1696 for snippet in context.declarations {
1697 let project_entry_id = snippet.declaration.project_entry_id();
1698 let Some(path) = worktrees.iter().find_map(|worktree| {
1699 worktree.entry_for_id(project_entry_id).map(|entry| {
1700 let mut full_path = RelPathBuf::new();
1701 full_path.push(worktree.root_name());
1702 full_path.push(&entry.path);
1703 full_path
1704 })
1705 }) else {
1706 continue;
1707 };
1708
1709 let parent_index = index_state.and_then(|index_state| {
1710 snippet.declaration.parent().and_then(|parent| {
1711 add_signature(
1712 parent,
1713 &mut declaration_to_signature_index,
1714 &mut signatures,
1715 index_state,
1716 )
1717 })
1718 });
1719
1720 let (text, text_is_truncated) = snippet.declaration.item_text();
1721 referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
1722 path: path.as_std_path().into(),
1723 text: text.into(),
1724 range: snippet.declaration.item_line_range(),
1725 text_is_truncated,
1726 signature_range: snippet.declaration.signature_range_in_item_text(),
1727 parent_index,
1728 signature_score: snippet.score(DeclarationStyle::Signature),
1729 declaration_score: snippet.score(DeclarationStyle::Declaration),
1730 score_components: snippet.components,
1731 });
1732 }
1733
1734 let excerpt_parent = index_state.and_then(|index_state| {
1735 context
1736 .excerpt
1737 .parent_declarations
1738 .last()
1739 .and_then(|(parent, _)| {
1740 add_signature(
1741 *parent,
1742 &mut declaration_to_signature_index,
1743 &mut signatures,
1744 index_state,
1745 )
1746 })
1747 });
1748
1749 predict_edits_v3::PredictEditsRequest {
1750 excerpt_path,
1751 excerpt: context.excerpt_text.body,
1752 excerpt_line_range: context.excerpt.line_range,
1753 excerpt_range: context.excerpt.range,
1754 cursor_point: predict_edits_v3::Point {
1755 line: predict_edits_v3::Line(context.cursor_point.row),
1756 column: context.cursor_point.column,
1757 },
1758 referenced_declarations,
1759 included_files: vec![],
1760 signatures,
1761 excerpt_parent,
1762 events,
1763 can_collect_data,
1764 diagnostic_groups,
1765 diagnostic_groups_truncated,
1766 git_info,
1767 debug_info,
1768 prompt_max_bytes,
1769 prompt_format,
1770 }
1771}
1772
1773fn add_signature(
1774 declaration_id: DeclarationId,
1775 declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
1776 signatures: &mut Vec<Signature>,
1777 index: &SyntaxIndexState,
1778) -> Option<usize> {
1779 if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
1780 return Some(*signature_index);
1781 }
1782 let Some(parent_declaration) = index.declaration(declaration_id) else {
1783 log::error!("bug: missing parent declaration");
1784 return None;
1785 };
1786 let parent_index = parent_declaration.parent().and_then(|parent| {
1787 add_signature(parent, declaration_to_signature_index, signatures, index)
1788 });
1789 let (text, text_is_truncated) = parent_declaration.signature_text();
1790 let signature_index = signatures.len();
1791 signatures.push(Signature {
1792 text: text.into(),
1793 text_is_truncated,
1794 parent_index,
1795 range: parent_declaration.signature_line_range(),
1796 });
1797 declaration_to_signature_index.insert(declaration_id, signature_index);
1798 Some(signature_index)
1799}
1800
1801#[cfg(feature = "eval-support")]
1802pub type EvalCacheKey = (EvalCacheEntryKind, u64);
1803
1804#[cfg(feature = "eval-support")]
1805#[derive(Debug, Clone, Copy, PartialEq)]
1806pub enum EvalCacheEntryKind {
1807 Context,
1808 Search,
1809 Prediction,
1810}
1811
1812#[cfg(feature = "eval-support")]
1813impl std::fmt::Display for EvalCacheEntryKind {
1814 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1815 match self {
1816 EvalCacheEntryKind::Search => write!(f, "search"),
1817 EvalCacheEntryKind::Context => write!(f, "context"),
1818 EvalCacheEntryKind::Prediction => write!(f, "prediction"),
1819 }
1820 }
1821}
1822
1823#[cfg(feature = "eval-support")]
1824pub trait EvalCache: Send + Sync {
1825 fn read(&self, key: EvalCacheKey) -> Option<String>;
1826 fn write(&self, key: EvalCacheKey, input: &str, value: &str);
1827}
1828
1829#[cfg(test)]
1830mod tests {
1831 use std::{path::Path, sync::Arc};
1832
1833 use client::UserStore;
1834 use clock::FakeSystemClock;
1835 use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
1836 use futures::{
1837 AsyncReadExt, StreamExt,
1838 channel::{mpsc, oneshot},
1839 };
1840 use gpui::{
1841 Entity, TestAppContext,
1842 http_client::{FakeHttpClient, Response},
1843 prelude::*,
1844 };
1845 use indoc::indoc;
1846 use language::OffsetRangeExt as _;
1847 use open_ai::Usage;
1848 use pretty_assertions::{assert_eq, assert_matches};
1849 use project::{FakeFs, Project};
1850 use serde_json::json;
1851 use settings::SettingsStore;
1852 use util::path;
1853 use uuid::Uuid;
1854
1855 use crate::{BufferEditPrediction, Zeta};
1856
1857 #[gpui::test]
1858 async fn test_current_state(cx: &mut TestAppContext) {
1859 let (zeta, mut req_rx) = init_test(cx);
1860 let fs = FakeFs::new(cx.executor());
1861 fs.insert_tree(
1862 "/root",
1863 json!({
1864 "1.txt": "Hello!\nHow\nBye\n",
1865 "2.txt": "Hola!\nComo\nAdios\n"
1866 }),
1867 )
1868 .await;
1869 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1870
1871 zeta.update(cx, |zeta, cx| {
1872 zeta.register_project(&project, cx);
1873 });
1874
1875 let buffer1 = project
1876 .update(cx, |project, cx| {
1877 let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
1878 project.open_buffer(path, cx)
1879 })
1880 .await
1881 .unwrap();
1882 let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
1883 let position = snapshot1.anchor_before(language::Point::new(1, 3));
1884
1885 // Prediction for current file
1886
1887 let prediction_task = zeta.update(cx, |zeta, cx| {
1888 zeta.refresh_prediction(&project, &buffer1, position, cx)
1889 });
1890 let (_request, respond_tx) = req_rx.next().await.unwrap();
1891
1892 respond_tx
1893 .send(model_response(indoc! {r"
1894 --- a/root/1.txt
1895 +++ b/root/1.txt
1896 @@ ... @@
1897 Hello!
1898 -How
1899 +How are you?
1900 Bye
1901 "}))
1902 .unwrap();
1903 prediction_task.await.unwrap();
1904
1905 zeta.read_with(cx, |zeta, cx| {
1906 let prediction = zeta
1907 .current_prediction_for_buffer(&buffer1, &project, cx)
1908 .unwrap();
1909 assert_matches!(prediction, BufferEditPrediction::Local { .. });
1910 });
1911
1912 // Context refresh
1913 let refresh_task = zeta.update(cx, |zeta, cx| {
1914 zeta.refresh_context(project.clone(), buffer1.clone(), position, cx)
1915 });
1916 let (_request, respond_tx) = req_rx.next().await.unwrap();
1917 respond_tx
1918 .send(open_ai::Response {
1919 id: Uuid::new_v4().to_string(),
1920 object: "response".into(),
1921 created: 0,
1922 model: "model".into(),
1923 choices: vec![open_ai::Choice {
1924 index: 0,
1925 message: open_ai::RequestMessage::Assistant {
1926 content: None,
1927 tool_calls: vec![open_ai::ToolCall {
1928 id: "search".into(),
1929 content: open_ai::ToolCallContent::Function {
1930 function: open_ai::FunctionContent {
1931 name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME
1932 .to_string(),
1933 arguments: serde_json::to_string(&SearchToolInput {
1934 queries: Box::new([SearchToolQuery {
1935 glob: "root/2.txt".to_string(),
1936 syntax_node: vec![],
1937 content: Some(".".into()),
1938 }]),
1939 })
1940 .unwrap(),
1941 },
1942 },
1943 }],
1944 },
1945 finish_reason: None,
1946 }],
1947 usage: Usage {
1948 prompt_tokens: 0,
1949 completion_tokens: 0,
1950 total_tokens: 0,
1951 },
1952 })
1953 .unwrap();
1954 refresh_task.await.unwrap();
1955
1956 zeta.update(cx, |zeta, _cx| {
1957 zeta.discard_current_prediction(&project);
1958 });
1959
1960 // Prediction for another file
1961 let prediction_task = zeta.update(cx, |zeta, cx| {
1962 zeta.refresh_prediction(&project, &buffer1, position, cx)
1963 });
1964 let (_request, respond_tx) = req_rx.next().await.unwrap();
1965 respond_tx
1966 .send(model_response(indoc! {r#"
1967 --- a/root/2.txt
1968 +++ b/root/2.txt
1969 Hola!
1970 -Como
1971 +Como estas?
1972 Adios
1973 "#}))
1974 .unwrap();
1975 prediction_task.await.unwrap();
1976 zeta.read_with(cx, |zeta, cx| {
1977 let prediction = zeta
1978 .current_prediction_for_buffer(&buffer1, &project, cx)
1979 .unwrap();
1980 assert_matches!(
1981 prediction,
1982 BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
1983 );
1984 });
1985
1986 let buffer2 = project
1987 .update(cx, |project, cx| {
1988 let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
1989 project.open_buffer(path, cx)
1990 })
1991 .await
1992 .unwrap();
1993
1994 zeta.read_with(cx, |zeta, cx| {
1995 let prediction = zeta
1996 .current_prediction_for_buffer(&buffer2, &project, cx)
1997 .unwrap();
1998 assert_matches!(prediction, BufferEditPrediction::Local { .. });
1999 });
2000 }
2001
2002 #[gpui::test]
2003 async fn test_simple_request(cx: &mut TestAppContext) {
2004 let (zeta, mut req_rx) = init_test(cx);
2005 let fs = FakeFs::new(cx.executor());
2006 fs.insert_tree(
2007 "/root",
2008 json!({
2009 "foo.md": "Hello!\nHow\nBye\n"
2010 }),
2011 )
2012 .await;
2013 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2014
2015 let buffer = project
2016 .update(cx, |project, cx| {
2017 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2018 project.open_buffer(path, cx)
2019 })
2020 .await
2021 .unwrap();
2022 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2023 let position = snapshot.anchor_before(language::Point::new(1, 3));
2024
2025 let prediction_task = zeta.update(cx, |zeta, cx| {
2026 zeta.request_prediction(&project, &buffer, position, cx)
2027 });
2028
2029 let (_, respond_tx) = req_rx.next().await.unwrap();
2030
2031 // TODO Put back when we have a structured request again
2032 // assert_eq!(
2033 // request.excerpt_path.as_ref(),
2034 // Path::new(path!("root/foo.md"))
2035 // );
2036 // assert_eq!(
2037 // request.cursor_point,
2038 // Point {
2039 // line: Line(1),
2040 // column: 3
2041 // }
2042 // );
2043
2044 respond_tx
2045 .send(model_response(indoc! { r"
2046 --- a/root/foo.md
2047 +++ b/root/foo.md
2048 @@ ... @@
2049 Hello!
2050 -How
2051 +How are you?
2052 Bye
2053 "}))
2054 .unwrap();
2055
2056 let prediction = prediction_task.await.unwrap().unwrap();
2057
2058 assert_eq!(prediction.edits.len(), 1);
2059 assert_eq!(
2060 prediction.edits[0].0.to_point(&snapshot).start,
2061 language::Point::new(1, 3)
2062 );
2063 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
2064 }
2065
2066 #[gpui::test]
2067 async fn test_request_events(cx: &mut TestAppContext) {
2068 let (zeta, mut req_rx) = init_test(cx);
2069 let fs = FakeFs::new(cx.executor());
2070 fs.insert_tree(
2071 "/root",
2072 json!({
2073 "foo.md": "Hello!\n\nBye\n"
2074 }),
2075 )
2076 .await;
2077 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2078
2079 let buffer = project
2080 .update(cx, |project, cx| {
2081 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2082 project.open_buffer(path, cx)
2083 })
2084 .await
2085 .unwrap();
2086
2087 zeta.update(cx, |zeta, cx| {
2088 zeta.register_buffer(&buffer, &project, cx);
2089 });
2090
2091 buffer.update(cx, |buffer, cx| {
2092 buffer.edit(vec![(7..7, "How")], None, cx);
2093 });
2094
2095 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2096 let position = snapshot.anchor_before(language::Point::new(1, 3));
2097
2098 let prediction_task = zeta.update(cx, |zeta, cx| {
2099 zeta.request_prediction(&project, &buffer, position, cx)
2100 });
2101
2102 let (request, respond_tx) = req_rx.next().await.unwrap();
2103
2104 let prompt = prompt_from_request(&request);
2105 assert!(
2106 prompt.contains(indoc! {"
2107 --- a/root/foo.md
2108 +++ b/root/foo.md
2109 @@ -1,3 +1,3 @@
2110 Hello!
2111 -
2112 +How
2113 Bye
2114 "}),
2115 "{prompt}"
2116 );
2117
2118 respond_tx
2119 .send(model_response(indoc! {r#"
2120 --- a/root/foo.md
2121 +++ b/root/foo.md
2122 @@ ... @@
2123 Hello!
2124 -How
2125 +How are you?
2126 Bye
2127 "#}))
2128 .unwrap();
2129
2130 let prediction = prediction_task.await.unwrap().unwrap();
2131
2132 assert_eq!(prediction.edits.len(), 1);
2133 assert_eq!(
2134 prediction.edits[0].0.to_point(&snapshot).start,
2135 language::Point::new(1, 3)
2136 );
2137 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
2138 }
2139
2140 // Skipped until we start including diagnostics in prompt
2141 // #[gpui::test]
2142 // async fn test_request_diagnostics(cx: &mut TestAppContext) {
2143 // let (zeta, mut req_rx) = init_test(cx);
2144 // let fs = FakeFs::new(cx.executor());
2145 // fs.insert_tree(
2146 // "/root",
2147 // json!({
2148 // "foo.md": "Hello!\nBye"
2149 // }),
2150 // )
2151 // .await;
2152 // let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2153
2154 // let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
2155 // let diagnostic = lsp::Diagnostic {
2156 // range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
2157 // severity: Some(lsp::DiagnosticSeverity::ERROR),
2158 // message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
2159 // ..Default::default()
2160 // };
2161
2162 // project.update(cx, |project, cx| {
2163 // project.lsp_store().update(cx, |lsp_store, cx| {
2164 // // Create some diagnostics
2165 // lsp_store
2166 // .update_diagnostics(
2167 // LanguageServerId(0),
2168 // lsp::PublishDiagnosticsParams {
2169 // uri: path_to_buffer_uri.clone(),
2170 // diagnostics: vec![diagnostic],
2171 // version: None,
2172 // },
2173 // None,
2174 // language::DiagnosticSourceKind::Pushed,
2175 // &[],
2176 // cx,
2177 // )
2178 // .unwrap();
2179 // });
2180 // });
2181
2182 // let buffer = project
2183 // .update(cx, |project, cx| {
2184 // let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2185 // project.open_buffer(path, cx)
2186 // })
2187 // .await
2188 // .unwrap();
2189
2190 // let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2191 // let position = snapshot.anchor_before(language::Point::new(0, 0));
2192
2193 // let _prediction_task = zeta.update(cx, |zeta, cx| {
2194 // zeta.request_prediction(&project, &buffer, position, cx)
2195 // });
2196
2197 // let (request, _respond_tx) = req_rx.next().await.unwrap();
2198
2199 // assert_eq!(request.diagnostic_groups.len(), 1);
2200 // let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
2201 // .unwrap();
2202 // // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
2203 // assert_eq!(
2204 // value,
2205 // json!({
2206 // "entries": [{
2207 // "range": {
2208 // "start": 8,
2209 // "end": 10
2210 // },
2211 // "diagnostic": {
2212 // "source": null,
2213 // "code": null,
2214 // "code_description": null,
2215 // "severity": 1,
2216 // "message": "\"Hello\" deprecated. Use \"Hi\" instead",
2217 // "markdown": null,
2218 // "group_id": 0,
2219 // "is_primary": true,
2220 // "is_disk_based": false,
2221 // "is_unnecessary": false,
2222 // "source_kind": "Pushed",
2223 // "data": null,
2224 // "underline": true
2225 // }
2226 // }],
2227 // "primary_ix": 0
2228 // })
2229 // );
2230 // }
2231
2232 fn model_response(text: &str) -> open_ai::Response {
2233 open_ai::Response {
2234 id: Uuid::new_v4().to_string(),
2235 object: "response".into(),
2236 created: 0,
2237 model: "model".into(),
2238 choices: vec![open_ai::Choice {
2239 index: 0,
2240 message: open_ai::RequestMessage::Assistant {
2241 content: Some(open_ai::MessageContent::Plain(text.to_string())),
2242 tool_calls: vec![],
2243 },
2244 finish_reason: None,
2245 }],
2246 usage: Usage {
2247 prompt_tokens: 0,
2248 completion_tokens: 0,
2249 total_tokens: 0,
2250 },
2251 }
2252 }
2253
2254 fn prompt_from_request(request: &open_ai::Request) -> &str {
2255 assert_eq!(request.messages.len(), 1);
2256 let open_ai::RequestMessage::User {
2257 content: open_ai::MessageContent::Plain(content),
2258 ..
2259 } = &request.messages[0]
2260 else {
2261 panic!(
2262 "Request does not have single user message of type Plain. {:#?}",
2263 request
2264 );
2265 };
2266 content
2267 }
2268
2269 fn init_test(
2270 cx: &mut TestAppContext,
2271 ) -> (
2272 Entity<Zeta>,
2273 mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender<open_ai::Response>)>,
2274 ) {
2275 cx.update(move |cx| {
2276 let settings_store = SettingsStore::test(cx);
2277 cx.set_global(settings_store);
2278 zlog::init_test();
2279
2280 let (req_tx, req_rx) = mpsc::unbounded();
2281
2282 let http_client = FakeHttpClient::create({
2283 move |req| {
2284 let uri = req.uri().path().to_string();
2285 let mut body = req.into_body();
2286 let req_tx = req_tx.clone();
2287 async move {
2288 let resp = match uri.as_str() {
2289 "/client/llm_tokens" => serde_json::to_string(&json!({
2290 "token": "test"
2291 }))
2292 .unwrap(),
2293 "/predict_edits/raw" => {
2294 let mut buf = Vec::new();
2295 body.read_to_end(&mut buf).await.ok();
2296 let req = serde_json::from_slice(&buf).unwrap();
2297
2298 let (res_tx, res_rx) = oneshot::channel();
2299 req_tx.unbounded_send((req, res_tx)).unwrap();
2300 serde_json::to_string(&res_rx.await?).unwrap()
2301 }
2302 _ => {
2303 panic!("Unexpected path: {}", uri)
2304 }
2305 };
2306
2307 Ok(Response::builder().body(resp.into()).unwrap())
2308 }
2309 }
2310 });
2311
2312 let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
2313 client.cloud_client().set_credentials(1, "test".into());
2314
2315 language_model::init(client.clone(), cx);
2316
2317 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2318 let zeta = Zeta::global(&client, &user_store, cx);
2319
2320 (zeta, req_rx)
2321 })
2322 }
2323}