1use anyhow::{Context as _, Result, anyhow, bail};
2use chrono::TimeDelta;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
5use cloud_llm_client::{
6 AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
7 ZED_VERSION_HEADER_NAME,
8};
9use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
10use cloud_zeta2_prompt::{CURSOR_MARKER, DEFAULT_MAX_PROMPT_BYTES};
11use collections::HashMap;
12use edit_prediction_context::{
13 DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
14 EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionScoreOptions, Line,
15 SyntaxIndex, SyntaxIndexState,
16};
17use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
18use futures::AsyncReadExt as _;
19use futures::channel::{mpsc, oneshot};
20use gpui::http_client::{AsyncBody, Method};
21use gpui::{
22 App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
23 http_client, prelude::*,
24};
25use language::{Anchor, Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
26use language::{BufferSnapshot, OffsetRangeExt};
27use language_model::{LlmApiToken, RefreshLlmTokenListener};
28use open_ai::FunctionDefinition;
29use project::Project;
30use release_channel::AppVersion;
31use serde::de::DeserializeOwned;
32use std::collections::{VecDeque, hash_map};
33
34use std::env;
35use std::ops::Range;
36use std::path::Path;
37use std::str::FromStr as _;
38use std::sync::{Arc, LazyLock};
39use std::time::{Duration, Instant};
40use thiserror::Error;
41use util::rel_path::RelPathBuf;
42use util::{LogErrorFuture, TryFutureExt};
43use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
44
45pub mod merge_excerpts;
46mod prediction;
47mod provider;
48pub mod retrieval_search;
49pub mod udiff;
50mod xml_edits;
51
52use crate::merge_excerpts::merge_excerpts;
53use crate::prediction::EditPrediction;
54pub use crate::prediction::EditPredictionId;
55pub use provider::ZetaEditPredictionProvider;
56
57/// Maximum number of events to track.
58const MAX_EVENT_COUNT: usize = 16;
59
60pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
61 max_bytes: 512,
62 min_bytes: 128,
63 target_before_cursor_over_total_bytes: 0.5,
64};
65
66pub const DEFAULT_CONTEXT_OPTIONS: ContextMode =
67 ContextMode::Agentic(DEFAULT_AGENTIC_CONTEXT_OPTIONS);
68
69pub const DEFAULT_AGENTIC_CONTEXT_OPTIONS: AgenticContextOptions = AgenticContextOptions {
70 excerpt: DEFAULT_EXCERPT_OPTIONS,
71};
72
73pub const DEFAULT_SYNTAX_CONTEXT_OPTIONS: EditPredictionContextOptions =
74 EditPredictionContextOptions {
75 use_imports: true,
76 max_retrieved_declarations: 0,
77 excerpt: DEFAULT_EXCERPT_OPTIONS,
78 score: EditPredictionScoreOptions {
79 omit_excerpt_overlaps: true,
80 },
81 };
82
83pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
84 context: DEFAULT_CONTEXT_OPTIONS,
85 max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
86 max_diagnostic_bytes: 2048,
87 prompt_format: PromptFormat::DEFAULT,
88 file_indexing_parallelism: 1,
89 buffer_change_grouping_interval: Duration::from_secs(1),
90};
91
92static USE_OLLAMA: LazyLock<bool> =
93 LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
94static CONTEXT_RETRIEVAL_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
95 env::var("ZED_ZETA2_CONTEXT_MODEL").unwrap_or(if *USE_OLLAMA {
96 "qwen3-coder:30b".to_string()
97 } else {
98 "yqvev8r3".to_string()
99 })
100});
101static EDIT_PREDICTIONS_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
102 match env::var("ZED_ZETA2_MODEL").as_deref() {
103 Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten
104 Ok(model) => model,
105 Err(_) if *USE_OLLAMA => "qwen3-coder:30b",
106 Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten
107 }
108 .to_string()
109});
110static PREDICT_EDITS_URL: LazyLock<Option<String>> = LazyLock::new(|| {
111 env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| {
112 if *USE_OLLAMA {
113 Some("http://localhost:11434/v1/chat/completions".into())
114 } else {
115 None
116 }
117 })
118});
119
120pub struct Zeta2FeatureFlag;
121
122impl FeatureFlag for Zeta2FeatureFlag {
123 const NAME: &'static str = "zeta2";
124
125 fn enabled_for_staff() -> bool {
126 false
127 }
128}
129
130#[derive(Clone)]
131struct ZetaGlobal(Entity<Zeta>);
132
133impl Global for ZetaGlobal {}
134
135pub struct Zeta {
136 client: Arc<Client>,
137 user_store: Entity<UserStore>,
138 llm_token: LlmApiToken,
139 _llm_token_subscription: Subscription,
140 projects: HashMap<EntityId, ZetaProject>,
141 options: ZetaOptions,
142 update_required: bool,
143 debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
144 #[cfg(feature = "eval-support")]
145 eval_cache: Option<Arc<dyn EvalCache>>,
146}
147
148#[derive(Debug, Clone, PartialEq)]
149pub struct ZetaOptions {
150 pub context: ContextMode,
151 pub max_prompt_bytes: usize,
152 pub max_diagnostic_bytes: usize,
153 pub prompt_format: predict_edits_v3::PromptFormat,
154 pub file_indexing_parallelism: usize,
155 pub buffer_change_grouping_interval: Duration,
156}
157
158#[derive(Debug, Clone, PartialEq)]
159pub enum ContextMode {
160 Agentic(AgenticContextOptions),
161 Syntax(EditPredictionContextOptions),
162}
163
164#[derive(Debug, Clone, PartialEq)]
165pub struct AgenticContextOptions {
166 pub excerpt: EditPredictionExcerptOptions,
167}
168
169impl ContextMode {
170 pub fn excerpt(&self) -> &EditPredictionExcerptOptions {
171 match self {
172 ContextMode::Agentic(options) => &options.excerpt,
173 ContextMode::Syntax(options) => &options.excerpt,
174 }
175 }
176}
177
178#[derive(Debug)]
179pub enum ZetaDebugInfo {
180 ContextRetrievalStarted(ZetaContextRetrievalStartedDebugInfo),
181 SearchQueriesGenerated(ZetaSearchQueryDebugInfo),
182 SearchQueriesExecuted(ZetaContextRetrievalDebugInfo),
183 ContextRetrievalFinished(ZetaContextRetrievalDebugInfo),
184 EditPredictionRequested(ZetaEditPredictionDebugInfo),
185}
186
187#[derive(Debug)]
188pub struct ZetaContextRetrievalStartedDebugInfo {
189 pub project: Entity<Project>,
190 pub timestamp: Instant,
191 pub search_prompt: String,
192}
193
194#[derive(Debug)]
195pub struct ZetaContextRetrievalDebugInfo {
196 pub project: Entity<Project>,
197 pub timestamp: Instant,
198}
199
200#[derive(Debug)]
201pub struct ZetaEditPredictionDebugInfo {
202 pub request: predict_edits_v3::PredictEditsRequest,
203 pub retrieval_time: TimeDelta,
204 pub buffer: WeakEntity<Buffer>,
205 pub position: language::Anchor,
206 pub local_prompt: Result<String, String>,
207 pub response_rx: oneshot::Receiver<(Result<open_ai::Response, String>, TimeDelta)>,
208}
209
210#[derive(Debug)]
211pub struct ZetaSearchQueryDebugInfo {
212 pub project: Entity<Project>,
213 pub timestamp: Instant,
214 pub search_queries: Vec<SearchToolQuery>,
215}
216
217pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
218
219struct ZetaProject {
220 syntax_index: Option<Entity<SyntaxIndex>>,
221 events: VecDeque<Event>,
222 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
223 current_prediction: Option<CurrentEditPrediction>,
224 context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
225 refresh_context_task: Option<LogErrorFuture<Task<Result<()>>>>,
226 refresh_context_debounce_task: Option<Task<Option<()>>>,
227 refresh_context_timestamp: Option<Instant>,
228}
229
230#[derive(Debug, Clone)]
231struct CurrentEditPrediction {
232 pub requested_by_buffer_id: EntityId,
233 pub prediction: EditPrediction,
234}
235
236impl CurrentEditPrediction {
237 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
238 let Some(new_edits) = self
239 .prediction
240 .interpolate(&self.prediction.buffer.read(cx))
241 else {
242 return false;
243 };
244
245 if self.prediction.buffer != old_prediction.prediction.buffer {
246 return true;
247 }
248
249 let Some(old_edits) = old_prediction
250 .prediction
251 .interpolate(&old_prediction.prediction.buffer.read(cx))
252 else {
253 return true;
254 };
255
256 // This reduces the occurrence of UI thrash from replacing edits
257 //
258 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
259 if self.requested_by_buffer_id == self.prediction.buffer.entity_id()
260 && self.requested_by_buffer_id == old_prediction.prediction.buffer.entity_id()
261 && old_edits.len() == 1
262 && new_edits.len() == 1
263 {
264 let (old_range, old_text) = &old_edits[0];
265 let (new_range, new_text) = &new_edits[0];
266 new_range == old_range && new_text.starts_with(old_text.as_ref())
267 } else {
268 true
269 }
270 }
271}
272
273/// A prediction from the perspective of a buffer.
274#[derive(Debug)]
275enum BufferEditPrediction<'a> {
276 Local { prediction: &'a EditPrediction },
277 Jump { prediction: &'a EditPrediction },
278}
279
280struct RegisteredBuffer {
281 snapshot: BufferSnapshot,
282 _subscriptions: [gpui::Subscription; 2],
283}
284
285#[derive(Clone)]
286pub enum Event {
287 BufferChange {
288 old_snapshot: BufferSnapshot,
289 new_snapshot: BufferSnapshot,
290 timestamp: Instant,
291 },
292}
293
294impl Event {
295 pub fn to_request_event(&self, cx: &App) -> Option<predict_edits_v3::Event> {
296 match self {
297 Event::BufferChange {
298 old_snapshot,
299 new_snapshot,
300 ..
301 } => {
302 let path = new_snapshot.file().map(|f| f.full_path(cx));
303
304 let old_path = old_snapshot.file().and_then(|f| {
305 let old_path = f.full_path(cx);
306 if Some(&old_path) != path.as_ref() {
307 Some(old_path)
308 } else {
309 None
310 }
311 });
312
313 // TODO [zeta2] move to bg?
314 let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
315
316 if path == old_path && diff.is_empty() {
317 None
318 } else {
319 Some(predict_edits_v3::Event::BufferChange {
320 old_path,
321 path,
322 diff,
323 //todo: Actually detect if this edit was predicted or not
324 predicted: false,
325 })
326 }
327 }
328 }
329 }
330}
331
332impl Zeta {
333 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
334 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
335 }
336
337 pub fn global(
338 client: &Arc<Client>,
339 user_store: &Entity<UserStore>,
340 cx: &mut App,
341 ) -> Entity<Self> {
342 cx.try_global::<ZetaGlobal>()
343 .map(|global| global.0.clone())
344 .unwrap_or_else(|| {
345 let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
346 cx.set_global(ZetaGlobal(zeta.clone()));
347 zeta
348 })
349 }
350
351 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
352 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
353
354 Self {
355 projects: HashMap::default(),
356 client,
357 user_store,
358 options: DEFAULT_OPTIONS,
359 llm_token: LlmApiToken::default(),
360 _llm_token_subscription: cx.subscribe(
361 &refresh_llm_token_listener,
362 |this, _listener, _event, cx| {
363 let client = this.client.clone();
364 let llm_token = this.llm_token.clone();
365 cx.spawn(async move |_this, _cx| {
366 llm_token.refresh(&client).await?;
367 anyhow::Ok(())
368 })
369 .detach_and_log_err(cx);
370 },
371 ),
372 update_required: false,
373 debug_tx: None,
374 #[cfg(feature = "eval-support")]
375 eval_cache: None,
376 }
377 }
378
379 #[cfg(feature = "eval-support")]
380 pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
381 self.eval_cache = Some(cache);
382 }
383
384 pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<ZetaDebugInfo> {
385 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
386 self.debug_tx = Some(debug_watch_tx);
387 debug_watch_rx
388 }
389
390 pub fn options(&self) -> &ZetaOptions {
391 &self.options
392 }
393
394 pub fn set_options(&mut self, options: ZetaOptions) {
395 self.options = options;
396 }
397
398 pub fn clear_history(&mut self) {
399 for zeta_project in self.projects.values_mut() {
400 zeta_project.events.clear();
401 }
402 }
403
404 pub fn history_for_project(&self, project: &Entity<Project>) -> impl Iterator<Item = &Event> {
405 self.projects
406 .get(&project.entity_id())
407 .map(|project| project.events.iter())
408 .into_iter()
409 .flatten()
410 }
411
412 pub fn context_for_project(
413 &self,
414 project: &Entity<Project>,
415 ) -> impl Iterator<Item = (Entity<Buffer>, &[Range<Anchor>])> {
416 self.projects
417 .get(&project.entity_id())
418 .and_then(|project| {
419 Some(
420 project
421 .context
422 .as_ref()?
423 .iter()
424 .map(|(buffer, ranges)| (buffer.clone(), ranges.as_slice())),
425 )
426 })
427 .into_iter()
428 .flatten()
429 }
430
431 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
432 self.user_store.read(cx).edit_prediction_usage()
433 }
434
435 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
436 self.get_or_init_zeta_project(project, cx);
437 }
438
439 pub fn register_buffer(
440 &mut self,
441 buffer: &Entity<Buffer>,
442 project: &Entity<Project>,
443 cx: &mut Context<Self>,
444 ) {
445 let zeta_project = self.get_or_init_zeta_project(project, cx);
446 Self::register_buffer_impl(zeta_project, buffer, project, cx);
447 }
448
449 fn get_or_init_zeta_project(
450 &mut self,
451 project: &Entity<Project>,
452 cx: &mut App,
453 ) -> &mut ZetaProject {
454 self.projects
455 .entry(project.entity_id())
456 .or_insert_with(|| ZetaProject {
457 syntax_index: if let ContextMode::Syntax(_) = &self.options.context {
458 Some(cx.new(|cx| {
459 SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
460 }))
461 } else {
462 None
463 },
464 events: VecDeque::new(),
465 registered_buffers: HashMap::default(),
466 current_prediction: None,
467 context: None,
468 refresh_context_task: None,
469 refresh_context_debounce_task: None,
470 refresh_context_timestamp: None,
471 })
472 }
473
474 fn register_buffer_impl<'a>(
475 zeta_project: &'a mut ZetaProject,
476 buffer: &Entity<Buffer>,
477 project: &Entity<Project>,
478 cx: &mut Context<Self>,
479 ) -> &'a mut RegisteredBuffer {
480 let buffer_id = buffer.entity_id();
481 match zeta_project.registered_buffers.entry(buffer_id) {
482 hash_map::Entry::Occupied(entry) => entry.into_mut(),
483 hash_map::Entry::Vacant(entry) => {
484 let snapshot = buffer.read(cx).snapshot();
485 let project_entity_id = project.entity_id();
486 entry.insert(RegisteredBuffer {
487 snapshot,
488 _subscriptions: [
489 cx.subscribe(buffer, {
490 let project = project.downgrade();
491 move |this, buffer, event, cx| {
492 if let language::BufferEvent::Edited = event
493 && let Some(project) = project.upgrade()
494 {
495 this.report_changes_for_buffer(&buffer, &project, cx);
496 }
497 }
498 }),
499 cx.observe_release(buffer, move |this, _buffer, _cx| {
500 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
501 else {
502 return;
503 };
504 zeta_project.registered_buffers.remove(&buffer_id);
505 }),
506 ],
507 })
508 }
509 }
510 }
511
512 fn report_changes_for_buffer(
513 &mut self,
514 buffer: &Entity<Buffer>,
515 project: &Entity<Project>,
516 cx: &mut Context<Self>,
517 ) -> BufferSnapshot {
518 let buffer_change_grouping_interval = self.options.buffer_change_grouping_interval;
519 let zeta_project = self.get_or_init_zeta_project(project, cx);
520 let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
521
522 let new_snapshot = buffer.read(cx).snapshot();
523 if new_snapshot.version != registered_buffer.snapshot.version {
524 let old_snapshot =
525 std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
526 Self::push_event(
527 zeta_project,
528 buffer_change_grouping_interval,
529 Event::BufferChange {
530 old_snapshot,
531 new_snapshot: new_snapshot.clone(),
532 timestamp: Instant::now(),
533 },
534 );
535 }
536
537 new_snapshot
538 }
539
540 fn push_event(
541 zeta_project: &mut ZetaProject,
542 buffer_change_grouping_interval: Duration,
543 event: Event,
544 ) {
545 let events = &mut zeta_project.events;
546
547 if buffer_change_grouping_interval > Duration::ZERO
548 && let Some(Event::BufferChange {
549 new_snapshot: last_new_snapshot,
550 timestamp: last_timestamp,
551 ..
552 }) = events.back_mut()
553 {
554 // Coalesce edits for the same buffer when they happen one after the other.
555 let Event::BufferChange {
556 old_snapshot,
557 new_snapshot,
558 timestamp,
559 } = &event;
560
561 if timestamp.duration_since(*last_timestamp) <= buffer_change_grouping_interval
562 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
563 && old_snapshot.version == last_new_snapshot.version
564 {
565 *last_new_snapshot = new_snapshot.clone();
566 *last_timestamp = *timestamp;
567 return;
568 }
569 }
570
571 if events.len() >= MAX_EVENT_COUNT {
572 // These are halved instead of popping to improve prompt caching.
573 events.drain(..MAX_EVENT_COUNT / 2);
574 }
575
576 events.push_back(event);
577 }
578
579 fn current_prediction_for_buffer(
580 &self,
581 buffer: &Entity<Buffer>,
582 project: &Entity<Project>,
583 cx: &App,
584 ) -> Option<BufferEditPrediction<'_>> {
585 let project_state = self.projects.get(&project.entity_id())?;
586
587 let CurrentEditPrediction {
588 requested_by_buffer_id,
589 prediction,
590 } = project_state.current_prediction.as_ref()?;
591
592 if prediction.targets_buffer(buffer.read(cx)) {
593 Some(BufferEditPrediction::Local { prediction })
594 } else if *requested_by_buffer_id == buffer.entity_id() {
595 Some(BufferEditPrediction::Jump { prediction })
596 } else {
597 None
598 }
599 }
600
601 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
602 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
603 return;
604 };
605
606 let Some(prediction) = project_state.current_prediction.take() else {
607 return;
608 };
609 let request_id = prediction.prediction.id.to_string();
610
611 let client = self.client.clone();
612 let llm_token = self.llm_token.clone();
613 let app_version = AppVersion::global(cx);
614 cx.spawn(async move |this, cx| {
615 let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") {
616 http_client::Url::parse(&predict_edits_url)?
617 } else {
618 client
619 .http_client()
620 .build_zed_llm_url("/predict_edits/accept", &[])?
621 };
622
623 let response = cx
624 .background_spawn(Self::send_api_request::<()>(
625 move |builder| {
626 let req = builder.uri(url.as_ref()).body(
627 serde_json::to_string(&AcceptEditPredictionBody {
628 request_id: request_id.clone(),
629 })?
630 .into(),
631 );
632 Ok(req?)
633 },
634 client,
635 llm_token,
636 app_version,
637 ))
638 .await;
639
640 Self::handle_api_response(&this, response, cx)?;
641 anyhow::Ok(())
642 })
643 .detach_and_log_err(cx);
644 }
645
646 fn discard_current_prediction(&mut self, project: &Entity<Project>) {
647 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
648 project_state.current_prediction.take();
649 };
650 }
651
652 pub fn refresh_prediction(
653 &mut self,
654 project: &Entity<Project>,
655 buffer: &Entity<Buffer>,
656 position: language::Anchor,
657 cx: &mut Context<Self>,
658 ) -> Task<Result<()>> {
659 let request_task = self.request_prediction(project, buffer, position, cx);
660 let buffer = buffer.clone();
661 let project = project.clone();
662
663 cx.spawn(async move |this, cx| {
664 if let Some(prediction) = request_task.await? {
665 this.update(cx, |this, cx| {
666 let project_state = this
667 .projects
668 .get_mut(&project.entity_id())
669 .context("Project not found")?;
670
671 let new_prediction = CurrentEditPrediction {
672 requested_by_buffer_id: buffer.entity_id(),
673 prediction: prediction,
674 };
675
676 if project_state
677 .current_prediction
678 .as_ref()
679 .is_none_or(|old_prediction| {
680 new_prediction.should_replace_prediction(&old_prediction, cx)
681 })
682 {
683 project_state.current_prediction = Some(new_prediction);
684 }
685 anyhow::Ok(())
686 })??;
687 }
688 Ok(())
689 })
690 }
691
692 pub fn request_prediction(
693 &mut self,
694 project: &Entity<Project>,
695 active_buffer: &Entity<Buffer>,
696 position: language::Anchor,
697 cx: &mut Context<Self>,
698 ) -> Task<Result<Option<EditPrediction>>> {
699 let project_state = self.projects.get(&project.entity_id());
700
701 let index_state = project_state.and_then(|state| {
702 state
703 .syntax_index
704 .as_ref()
705 .map(|syntax_index| syntax_index.read_with(cx, |index, _cx| index.state().clone()))
706 });
707 let options = self.options.clone();
708 let active_snapshot = active_buffer.read(cx).snapshot();
709 let Some(excerpt_path) = active_snapshot
710 .file()
711 .map(|path| -> Arc<Path> { path.full_path(cx).into() })
712 else {
713 return Task::ready(Err(anyhow!("No file path for excerpt")));
714 };
715 let client = self.client.clone();
716 let llm_token = self.llm_token.clone();
717 let app_version = AppVersion::global(cx);
718 let worktree_snapshots = project
719 .read(cx)
720 .worktrees(cx)
721 .map(|worktree| worktree.read(cx).snapshot())
722 .collect::<Vec<_>>();
723 let debug_tx = self.debug_tx.clone();
724
725 let events = project_state
726 .map(|state| {
727 state
728 .events
729 .iter()
730 .filter_map(|event| event.to_request_event(cx))
731 .collect::<Vec<_>>()
732 })
733 .unwrap_or_default();
734
735 let diagnostics = active_snapshot.diagnostic_sets().clone();
736
737 let parent_abs_path =
738 project::File::from_dyn(active_buffer.read(cx).file()).and_then(|f| {
739 let mut path = f.worktree.read(cx).absolutize(&f.path);
740 if path.pop() { Some(path) } else { None }
741 });
742
743 // TODO data collection
744 let can_collect_data = cx.is_staff();
745
746 let empty_context_files = HashMap::default();
747 let context_files = project_state
748 .and_then(|project_state| project_state.context.as_ref())
749 .unwrap_or(&empty_context_files);
750
751 #[cfg(feature = "eval-support")]
752 let parsed_fut = futures::future::join_all(
753 context_files
754 .keys()
755 .map(|buffer| buffer.read(cx).parsing_idle()),
756 );
757
758 let mut included_files = context_files
759 .iter()
760 .filter_map(|(buffer_entity, ranges)| {
761 let buffer = buffer_entity.read(cx);
762 Some((
763 buffer_entity.clone(),
764 buffer.snapshot(),
765 buffer.file()?.full_path(cx).into(),
766 ranges.clone(),
767 ))
768 })
769 .collect::<Vec<_>>();
770
771 included_files.sort_by(|(_, _, path_a, ranges_a), (_, _, path_b, ranges_b)| {
772 (path_a, ranges_a.len()).cmp(&(path_b, ranges_b.len()))
773 });
774
775 #[cfg(feature = "eval-support")]
776 let eval_cache = self.eval_cache.clone();
777
778 let request_task = cx.background_spawn({
779 let active_buffer = active_buffer.clone();
780 async move {
781 #[cfg(feature = "eval-support")]
782 parsed_fut.await;
783
784 let index_state = if let Some(index_state) = index_state {
785 Some(index_state.lock_owned().await)
786 } else {
787 None
788 };
789
790 let cursor_offset = position.to_offset(&active_snapshot);
791 let cursor_point = cursor_offset.to_point(&active_snapshot);
792
793 let before_retrieval = chrono::Utc::now();
794
795 let (diagnostic_groups, diagnostic_groups_truncated) =
796 Self::gather_nearby_diagnostics(
797 cursor_offset,
798 &diagnostics,
799 &active_snapshot,
800 options.max_diagnostic_bytes,
801 );
802
803 let cloud_request = match options.context {
804 ContextMode::Agentic(context_options) => {
805 let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
806 cursor_point,
807 &active_snapshot,
808 &context_options.excerpt,
809 index_state.as_deref(),
810 ) else {
811 return Ok((None, None));
812 };
813
814 let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start)
815 ..active_snapshot.anchor_before(excerpt.range.end);
816
817 if let Some(buffer_ix) =
818 included_files.iter().position(|(_, snapshot, _, _)| {
819 snapshot.remote_id() == active_snapshot.remote_id()
820 })
821 {
822 let (_, buffer, _, ranges) = &mut included_files[buffer_ix];
823 let range_ix = ranges
824 .binary_search_by(|probe| {
825 probe
826 .start
827 .cmp(&excerpt_anchor_range.start, buffer)
828 .then(excerpt_anchor_range.end.cmp(&probe.end, buffer))
829 })
830 .unwrap_or_else(|ix| ix);
831
832 ranges.insert(range_ix, excerpt_anchor_range);
833 let last_ix = included_files.len() - 1;
834 included_files.swap(buffer_ix, last_ix);
835 } else {
836 included_files.push((
837 active_buffer.clone(),
838 active_snapshot.clone(),
839 excerpt_path.clone(),
840 vec![excerpt_anchor_range],
841 ));
842 }
843
844 let included_files = included_files
845 .iter()
846 .map(|(_, snapshot, path, ranges)| {
847 let excerpts = merge_excerpts(
848 &snapshot,
849 ranges.iter().map(|range| {
850 let point_range = range.to_point(&snapshot);
851 Line(point_range.start.row)..Line(point_range.end.row)
852 }),
853 );
854 predict_edits_v3::IncludedFile {
855 path: path.clone(),
856 max_row: Line(snapshot.max_point().row),
857 excerpts,
858 }
859 })
860 .collect::<Vec<_>>();
861
862 predict_edits_v3::PredictEditsRequest {
863 excerpt_path,
864 excerpt: String::new(),
865 excerpt_line_range: Line(0)..Line(0),
866 excerpt_range: 0..0,
867 cursor_point: predict_edits_v3::Point {
868 line: predict_edits_v3::Line(cursor_point.row),
869 column: cursor_point.column,
870 },
871 included_files,
872 referenced_declarations: vec![],
873 events,
874 can_collect_data,
875 diagnostic_groups,
876 diagnostic_groups_truncated,
877 debug_info: debug_tx.is_some(),
878 prompt_max_bytes: Some(options.max_prompt_bytes),
879 prompt_format: options.prompt_format,
880 // TODO [zeta2]
881 signatures: vec![],
882 excerpt_parent: None,
883 git_info: None,
884 }
885 }
886 ContextMode::Syntax(context_options) => {
887 let Some(context) = EditPredictionContext::gather_context(
888 cursor_point,
889 &active_snapshot,
890 parent_abs_path.as_deref(),
891 &context_options,
892 index_state.as_deref(),
893 ) else {
894 return Ok((None, None));
895 };
896
897 make_syntax_context_cloud_request(
898 excerpt_path,
899 context,
900 events,
901 can_collect_data,
902 diagnostic_groups,
903 diagnostic_groups_truncated,
904 None,
905 debug_tx.is_some(),
906 &worktree_snapshots,
907 index_state.as_deref(),
908 Some(options.max_prompt_bytes),
909 options.prompt_format,
910 )
911 }
912 };
913
914 let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request);
915
916 let retrieval_time = chrono::Utc::now() - before_retrieval;
917
918 let debug_response_tx = if let Some(debug_tx) = &debug_tx {
919 let (response_tx, response_rx) = oneshot::channel();
920
921 debug_tx
922 .unbounded_send(ZetaDebugInfo::EditPredictionRequested(
923 ZetaEditPredictionDebugInfo {
924 request: cloud_request.clone(),
925 retrieval_time,
926 buffer: active_buffer.downgrade(),
927 local_prompt: match prompt_result.as_ref() {
928 Ok((prompt, _)) => Ok(prompt.clone()),
929 Err(err) => Err(err.to_string()),
930 },
931 position,
932 response_rx,
933 },
934 ))
935 .ok();
936 Some(response_tx)
937 } else {
938 None
939 };
940
941 if cfg!(debug_assertions) && env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
942 if let Some(debug_response_tx) = debug_response_tx {
943 debug_response_tx
944 .send((Err("Request skipped".to_string()), TimeDelta::zero()))
945 .ok();
946 }
947 anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
948 }
949
950 let (prompt, _) = prompt_result?;
951 let request = open_ai::Request {
952 model: EDIT_PREDICTIONS_MODEL_ID.clone(),
953 messages: vec![open_ai::RequestMessage::User {
954 content: open_ai::MessageContent::Plain(prompt),
955 }],
956 stream: false,
957 max_completion_tokens: None,
958 stop: Default::default(),
959 temperature: 0.7,
960 tool_choice: None,
961 parallel_tool_calls: None,
962 tools: vec![],
963 prompt_cache_key: None,
964 reasoning_effort: None,
965 };
966
967 log::trace!("Sending edit prediction request");
968
969 let before_request = chrono::Utc::now();
970 let response = Self::send_raw_llm_request(
971 request,
972 client,
973 llm_token,
974 app_version,
975 #[cfg(feature = "eval-support")]
976 eval_cache,
977 #[cfg(feature = "eval-support")]
978 EvalCacheEntryKind::Prediction,
979 )
980 .await;
981 let request_time = chrono::Utc::now() - before_request;
982
983 log::trace!("Got edit prediction response");
984
985 if let Some(debug_response_tx) = debug_response_tx {
986 debug_response_tx
987 .send((
988 response
989 .as_ref()
990 .map_err(|err| err.to_string())
991 .map(|response| response.0.clone()),
992 request_time,
993 ))
994 .ok();
995 }
996
997 let (res, usage) = response?;
998 let request_id = EditPredictionId(res.id.clone().into());
999 let Some(mut output_text) = text_from_response(res) else {
1000 return Ok((None, usage));
1001 };
1002
1003 if output_text.contains(CURSOR_MARKER) {
1004 log::trace!("Stripping out {CURSOR_MARKER} from response");
1005 output_text = output_text.replace(CURSOR_MARKER, "");
1006 }
1007
1008 let get_buffer_from_context = |path: &Path| {
1009 included_files
1010 .iter()
1011 .find_map(|(_, buffer, probe_path, ranges)| {
1012 if probe_path.as_ref() == path {
1013 Some((buffer, ranges.as_slice()))
1014 } else {
1015 None
1016 }
1017 })
1018 };
1019
1020 let (edited_buffer_snapshot, edits) = match options.prompt_format {
1021 PromptFormat::NumLinesUniDiff => {
1022 // TODO: Implement parsing of multi-file diffs
1023 crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
1024 }
1025 PromptFormat::Minimal => {
1026 if output_text.contains("--- a/\n+++ b/\nNo edits") {
1027 let edits = vec![];
1028 (&active_snapshot, edits)
1029 } else {
1030 crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
1031 }
1032 }
1033 PromptFormat::OldTextNewText => {
1034 crate::xml_edits::parse_xml_edits(&output_text, get_buffer_from_context)
1035 .await?
1036 }
1037 _ => {
1038 bail!("unsupported prompt format {}", options.prompt_format)
1039 }
1040 };
1041
1042 let edited_buffer = included_files
1043 .iter()
1044 .find_map(|(buffer, snapshot, _, _)| {
1045 if snapshot.remote_id() == edited_buffer_snapshot.remote_id() {
1046 Some(buffer.clone())
1047 } else {
1048 None
1049 }
1050 })
1051 .context("Failed to find buffer in included_buffers")?;
1052
1053 anyhow::Ok((
1054 Some((
1055 request_id,
1056 edited_buffer,
1057 edited_buffer_snapshot.clone(),
1058 edits,
1059 )),
1060 usage,
1061 ))
1062 }
1063 });
1064
1065 cx.spawn({
1066 async move |this, cx| {
1067 let Some((id, edited_buffer, edited_buffer_snapshot, edits)) =
1068 Self::handle_api_response(&this, request_task.await, cx)?
1069 else {
1070 return Ok(None);
1071 };
1072
1073 // TODO telemetry: duration, etc
1074 Ok(
1075 EditPrediction::new(id, &edited_buffer, &edited_buffer_snapshot, edits, cx)
1076 .await,
1077 )
1078 }
1079 })
1080 }
1081
1082 async fn send_raw_llm_request(
1083 request: open_ai::Request,
1084 client: Arc<Client>,
1085 llm_token: LlmApiToken,
1086 app_version: SemanticVersion,
1087 #[cfg(feature = "eval-support")] eval_cache: Option<Arc<dyn EvalCache>>,
1088 #[cfg(feature = "eval-support")] eval_cache_kind: EvalCacheEntryKind,
1089 ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
1090 let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
1091 http_client::Url::parse(&predict_edits_url)?
1092 } else {
1093 client
1094 .http_client()
1095 .build_zed_llm_url("/predict_edits/raw", &[])?
1096 };
1097
1098 #[cfg(feature = "eval-support")]
1099 let cache_key = if let Some(cache) = eval_cache {
1100 use collections::FxHasher;
1101 use std::hash::{Hash, Hasher};
1102
1103 let mut hasher = FxHasher::default();
1104 url.hash(&mut hasher);
1105 let request_str = serde_json::to_string_pretty(&request)?;
1106 request_str.hash(&mut hasher);
1107 let hash = hasher.finish();
1108
1109 let key = (eval_cache_kind, hash);
1110 if let Some(response_str) = cache.read(key) {
1111 return Ok((serde_json::from_str(&response_str)?, None));
1112 }
1113
1114 Some((cache, request_str, key))
1115 } else {
1116 None
1117 };
1118
1119 let (response, usage) = Self::send_api_request(
1120 |builder| {
1121 let req = builder
1122 .uri(url.as_ref())
1123 .body(serde_json::to_string(&request)?.into());
1124 Ok(req?)
1125 },
1126 client,
1127 llm_token,
1128 app_version,
1129 )
1130 .await?;
1131
1132 #[cfg(feature = "eval-support")]
1133 if let Some((cache, request, key)) = cache_key {
1134 cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
1135 }
1136
1137 Ok((response, usage))
1138 }
1139
1140 fn handle_api_response<T>(
1141 this: &WeakEntity<Self>,
1142 response: Result<(T, Option<EditPredictionUsage>)>,
1143 cx: &mut gpui::AsyncApp,
1144 ) -> Result<T> {
1145 match response {
1146 Ok((data, usage)) => {
1147 if let Some(usage) = usage {
1148 this.update(cx, |this, cx| {
1149 this.user_store.update(cx, |user_store, cx| {
1150 user_store.update_edit_prediction_usage(usage, cx);
1151 });
1152 })
1153 .ok();
1154 }
1155 Ok(data)
1156 }
1157 Err(err) => {
1158 if err.is::<ZedUpdateRequiredError>() {
1159 cx.update(|cx| {
1160 this.update(cx, |this, _cx| {
1161 this.update_required = true;
1162 })
1163 .ok();
1164
1165 let error_message: SharedString = err.to_string().into();
1166 show_app_notification(
1167 NotificationId::unique::<ZedUpdateRequiredError>(),
1168 cx,
1169 move |cx| {
1170 cx.new(|cx| {
1171 ErrorMessagePrompt::new(error_message.clone(), cx)
1172 .with_link_button("Update Zed", "https://zed.dev/releases")
1173 })
1174 },
1175 );
1176 })
1177 .ok();
1178 }
1179 Err(err)
1180 }
1181 }
1182 }
1183
1184 async fn send_api_request<Res>(
1185 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1186 client: Arc<Client>,
1187 llm_token: LlmApiToken,
1188 app_version: SemanticVersion,
1189 ) -> Result<(Res, Option<EditPredictionUsage>)>
1190 where
1191 Res: DeserializeOwned,
1192 {
1193 let http_client = client.http_client();
1194 let mut token = llm_token.acquire(&client).await?;
1195 let mut did_retry = false;
1196
1197 loop {
1198 let request_builder = http_client::Request::builder().method(Method::POST);
1199
1200 let request = build(
1201 request_builder
1202 .header("Content-Type", "application/json")
1203 .header("Authorization", format!("Bearer {}", token))
1204 .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
1205 )?;
1206
1207 let mut response = http_client.send(request).await?;
1208
1209 if let Some(minimum_required_version) = response
1210 .headers()
1211 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
1212 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
1213 {
1214 anyhow::ensure!(
1215 app_version >= minimum_required_version,
1216 ZedUpdateRequiredError {
1217 minimum_version: minimum_required_version
1218 }
1219 );
1220 }
1221
1222 if response.status().is_success() {
1223 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
1224
1225 let mut body = Vec::new();
1226 response.body_mut().read_to_end(&mut body).await?;
1227 return Ok((serde_json::from_slice(&body)?, usage));
1228 } else if !did_retry
1229 && response
1230 .headers()
1231 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1232 .is_some()
1233 {
1234 did_retry = true;
1235 token = llm_token.refresh(&client).await?;
1236 } else {
1237 let mut body = String::new();
1238 response.body_mut().read_to_string(&mut body).await?;
1239 anyhow::bail!(
1240 "Request failed with status: {:?}\nBody: {}",
1241 response.status(),
1242 body
1243 );
1244 }
1245 }
1246 }
1247
1248 pub const CONTEXT_RETRIEVAL_IDLE_DURATION: Duration = Duration::from_secs(10);
1249 pub const CONTEXT_RETRIEVAL_DEBOUNCE_DURATION: Duration = Duration::from_secs(3);
1250
1251 // Refresh the related excerpts when the user just beguns editing after
1252 // an idle period, and after they pause editing.
1253 fn refresh_context_if_needed(
1254 &mut self,
1255 project: &Entity<Project>,
1256 buffer: &Entity<language::Buffer>,
1257 cursor_position: language::Anchor,
1258 cx: &mut Context<Self>,
1259 ) {
1260 if !matches!(&self.options().context, ContextMode::Agentic { .. }) {
1261 return;
1262 }
1263
1264 let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1265 return;
1266 };
1267
1268 let now = Instant::now();
1269 let was_idle = zeta_project
1270 .refresh_context_timestamp
1271 .map_or(true, |timestamp| {
1272 now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION
1273 });
1274 zeta_project.refresh_context_timestamp = Some(now);
1275 zeta_project.refresh_context_debounce_task = Some(cx.spawn({
1276 let buffer = buffer.clone();
1277 let project = project.clone();
1278 async move |this, cx| {
1279 if was_idle {
1280 log::debug!("refetching edit prediction context after idle");
1281 } else {
1282 cx.background_executor()
1283 .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION)
1284 .await;
1285 log::debug!("refetching edit prediction context after pause");
1286 }
1287 this.update(cx, |this, cx| {
1288 let task = this.refresh_context(project.clone(), buffer, cursor_position, cx);
1289
1290 if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
1291 zeta_project.refresh_context_task = Some(task.log_err());
1292 };
1293 })
1294 .ok()
1295 }
1296 }));
1297 }
1298
1299 // Refresh the related excerpts asynchronously. Ensure the task runs to completion,
1300 // and avoid spawning more than one concurrent task.
1301 pub fn refresh_context(
1302 &mut self,
1303 project: Entity<Project>,
1304 buffer: Entity<language::Buffer>,
1305 cursor_position: language::Anchor,
1306 cx: &mut Context<Self>,
1307 ) -> Task<Result<()>> {
1308 let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
1309 return Task::ready(anyhow::Ok(()));
1310 };
1311
1312 let ContextMode::Agentic(options) = &self.options().context else {
1313 return Task::ready(anyhow::Ok(()));
1314 };
1315
1316 let snapshot = buffer.read(cx).snapshot();
1317 let cursor_point = cursor_position.to_point(&snapshot);
1318 let Some(cursor_excerpt) = EditPredictionExcerpt::select_from_buffer(
1319 cursor_point,
1320 &snapshot,
1321 &options.excerpt,
1322 None,
1323 ) else {
1324 return Task::ready(Ok(()));
1325 };
1326
1327 let app_version = AppVersion::global(cx);
1328 let client = self.client.clone();
1329 let llm_token = self.llm_token.clone();
1330 let debug_tx = self.debug_tx.clone();
1331 let current_file_path: Arc<Path> = snapshot
1332 .file()
1333 .map(|f| f.full_path(cx).into())
1334 .unwrap_or_else(|| Path::new("untitled").into());
1335
1336 let prompt = match cloud_zeta2_prompt::retrieval_prompt::build_prompt(
1337 predict_edits_v3::PlanContextRetrievalRequest {
1338 excerpt: cursor_excerpt.text(&snapshot).body,
1339 excerpt_path: current_file_path,
1340 excerpt_line_range: cursor_excerpt.line_range,
1341 cursor_file_max_row: Line(snapshot.max_point().row),
1342 events: zeta_project
1343 .events
1344 .iter()
1345 .filter_map(|ev| ev.to_request_event(cx))
1346 .collect(),
1347 },
1348 ) {
1349 Ok(prompt) => prompt,
1350 Err(err) => {
1351 return Task::ready(Err(err));
1352 }
1353 };
1354
1355 if let Some(debug_tx) = &debug_tx {
1356 debug_tx
1357 .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted(
1358 ZetaContextRetrievalStartedDebugInfo {
1359 project: project.clone(),
1360 timestamp: Instant::now(),
1361 search_prompt: prompt.clone(),
1362 },
1363 ))
1364 .ok();
1365 }
1366
1367 pub static TOOL_SCHEMA: LazyLock<(serde_json::Value, String)> = LazyLock::new(|| {
1368 let schema = language_model::tool_schema::root_schema_for::<SearchToolInput>(
1369 language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset,
1370 );
1371
1372 let description = schema
1373 .get("description")
1374 .and_then(|description| description.as_str())
1375 .unwrap()
1376 .to_string();
1377
1378 (schema.into(), description)
1379 });
1380
1381 let (tool_schema, tool_description) = TOOL_SCHEMA.clone();
1382
1383 let request = open_ai::Request {
1384 model: CONTEXT_RETRIEVAL_MODEL_ID.clone(),
1385 messages: vec![open_ai::RequestMessage::User {
1386 content: open_ai::MessageContent::Plain(prompt),
1387 }],
1388 stream: false,
1389 max_completion_tokens: None,
1390 stop: Default::default(),
1391 temperature: 0.7,
1392 tool_choice: None,
1393 parallel_tool_calls: None,
1394 tools: vec![open_ai::ToolDefinition::Function {
1395 function: FunctionDefinition {
1396 name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME.to_string(),
1397 description: Some(tool_description),
1398 parameters: Some(tool_schema),
1399 },
1400 }],
1401 prompt_cache_key: None,
1402 reasoning_effort: None,
1403 };
1404
1405 #[cfg(feature = "eval-support")]
1406 let eval_cache = self.eval_cache.clone();
1407
1408 cx.spawn(async move |this, cx| {
1409 log::trace!("Sending search planning request");
1410 let response = Self::send_raw_llm_request(
1411 request,
1412 client,
1413 llm_token,
1414 app_version,
1415 #[cfg(feature = "eval-support")]
1416 eval_cache.clone(),
1417 #[cfg(feature = "eval-support")]
1418 EvalCacheEntryKind::Context,
1419 )
1420 .await;
1421 let mut response = Self::handle_api_response(&this, response, cx)?;
1422 log::trace!("Got search planning response");
1423
1424 let choice = response
1425 .choices
1426 .pop()
1427 .context("No choices in retrieval response")?;
1428 let open_ai::RequestMessage::Assistant {
1429 content: _,
1430 tool_calls,
1431 } = choice.message
1432 else {
1433 anyhow::bail!("Retrieval response didn't include an assistant message");
1434 };
1435
1436 let mut queries: Vec<SearchToolQuery> = Vec::new();
1437 for tool_call in tool_calls {
1438 let open_ai::ToolCallContent::Function { function } = tool_call.content;
1439 if function.name != cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME {
1440 log::warn!(
1441 "Context retrieval response tried to call an unknown tool: {}",
1442 function.name
1443 );
1444
1445 continue;
1446 }
1447
1448 let input: SearchToolInput = serde_json::from_str(&function.arguments)
1449 .with_context(|| format!("invalid search json {}", &function.arguments))?;
1450 queries.extend(input.queries);
1451 }
1452
1453 if let Some(debug_tx) = &debug_tx {
1454 debug_tx
1455 .unbounded_send(ZetaDebugInfo::SearchQueriesGenerated(
1456 ZetaSearchQueryDebugInfo {
1457 project: project.clone(),
1458 timestamp: Instant::now(),
1459 search_queries: queries.clone(),
1460 },
1461 ))
1462 .ok();
1463 }
1464
1465 log::trace!("Running retrieval search: {queries:#?}");
1466
1467 let related_excerpts_result = retrieval_search::run_retrieval_searches(
1468 queries,
1469 project.clone(),
1470 #[cfg(feature = "eval-support")]
1471 eval_cache,
1472 cx,
1473 )
1474 .await;
1475
1476 log::trace!("Search queries executed");
1477
1478 if let Some(debug_tx) = &debug_tx {
1479 debug_tx
1480 .unbounded_send(ZetaDebugInfo::SearchQueriesExecuted(
1481 ZetaContextRetrievalDebugInfo {
1482 project: project.clone(),
1483 timestamp: Instant::now(),
1484 },
1485 ))
1486 .ok();
1487 }
1488
1489 this.update(cx, |this, _cx| {
1490 let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
1491 return Ok(());
1492 };
1493 zeta_project.refresh_context_task.take();
1494 if let Some(debug_tx) = &this.debug_tx {
1495 debug_tx
1496 .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished(
1497 ZetaContextRetrievalDebugInfo {
1498 project,
1499 timestamp: Instant::now(),
1500 },
1501 ))
1502 .ok();
1503 }
1504 match related_excerpts_result {
1505 Ok(excerpts) => {
1506 zeta_project.context = Some(excerpts);
1507 Ok(())
1508 }
1509 Err(error) => Err(error),
1510 }
1511 })?
1512 })
1513 }
1514
1515 pub fn set_context(
1516 &mut self,
1517 project: Entity<Project>,
1518 context: HashMap<Entity<Buffer>, Vec<Range<Anchor>>>,
1519 ) {
1520 if let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) {
1521 zeta_project.context = Some(context);
1522 }
1523 }
1524
1525 fn gather_nearby_diagnostics(
1526 cursor_offset: usize,
1527 diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
1528 snapshot: &BufferSnapshot,
1529 max_diagnostics_bytes: usize,
1530 ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
1531 // TODO: Could make this more efficient
1532 let mut diagnostic_groups = Vec::new();
1533 for (language_server_id, diagnostics) in diagnostic_sets {
1534 let mut groups = Vec::new();
1535 diagnostics.groups(*language_server_id, &mut groups, &snapshot);
1536 diagnostic_groups.extend(
1537 groups
1538 .into_iter()
1539 .map(|(_, group)| group.resolve::<usize>(&snapshot)),
1540 );
1541 }
1542
1543 // sort by proximity to cursor
1544 diagnostic_groups.sort_by_key(|group| {
1545 let range = &group.entries[group.primary_ix].range;
1546 if range.start >= cursor_offset {
1547 range.start - cursor_offset
1548 } else if cursor_offset >= range.end {
1549 cursor_offset - range.end
1550 } else {
1551 (cursor_offset - range.start).min(range.end - cursor_offset)
1552 }
1553 });
1554
1555 let mut results = Vec::new();
1556 let mut diagnostic_groups_truncated = false;
1557 let mut diagnostics_byte_count = 0;
1558 for group in diagnostic_groups {
1559 let raw_value = serde_json::value::to_raw_value(&group).unwrap();
1560 diagnostics_byte_count += raw_value.get().len();
1561 if diagnostics_byte_count > max_diagnostics_bytes {
1562 diagnostic_groups_truncated = true;
1563 break;
1564 }
1565 results.push(predict_edits_v3::DiagnosticGroup(raw_value));
1566 }
1567
1568 (results, diagnostic_groups_truncated)
1569 }
1570
1571 // TODO: Dedupe with similar code in request_prediction?
1572 pub fn cloud_request_for_zeta_cli(
1573 &mut self,
1574 project: &Entity<Project>,
1575 buffer: &Entity<Buffer>,
1576 position: language::Anchor,
1577 cx: &mut Context<Self>,
1578 ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
1579 let project_state = self.projects.get(&project.entity_id());
1580
1581 let index_state = project_state.and_then(|state| {
1582 state
1583 .syntax_index
1584 .as_ref()
1585 .map(|index| index.read_with(cx, |index, _cx| index.state().clone()))
1586 });
1587 let options = self.options.clone();
1588 let snapshot = buffer.read(cx).snapshot();
1589 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
1590 return Task::ready(Err(anyhow!("No file path for excerpt")));
1591 };
1592 let worktree_snapshots = project
1593 .read(cx)
1594 .worktrees(cx)
1595 .map(|worktree| worktree.read(cx).snapshot())
1596 .collect::<Vec<_>>();
1597
1598 let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
1599 let mut path = f.worktree.read(cx).absolutize(&f.path);
1600 if path.pop() { Some(path) } else { None }
1601 });
1602
1603 cx.background_spawn(async move {
1604 let index_state = if let Some(index_state) = index_state {
1605 Some(index_state.lock_owned().await)
1606 } else {
1607 None
1608 };
1609
1610 let cursor_point = position.to_point(&snapshot);
1611
1612 let debug_info = true;
1613 EditPredictionContext::gather_context(
1614 cursor_point,
1615 &snapshot,
1616 parent_abs_path.as_deref(),
1617 match &options.context {
1618 ContextMode::Agentic(_) => {
1619 // TODO
1620 panic!("Llm mode not supported in zeta cli yet");
1621 }
1622 ContextMode::Syntax(edit_prediction_context_options) => {
1623 edit_prediction_context_options
1624 }
1625 },
1626 index_state.as_deref(),
1627 )
1628 .context("Failed to select excerpt")
1629 .map(|context| {
1630 make_syntax_context_cloud_request(
1631 excerpt_path.into(),
1632 context,
1633 // TODO pass everything
1634 Vec::new(),
1635 false,
1636 Vec::new(),
1637 false,
1638 None,
1639 debug_info,
1640 &worktree_snapshots,
1641 index_state.as_deref(),
1642 Some(options.max_prompt_bytes),
1643 options.prompt_format,
1644 )
1645 })
1646 })
1647 }
1648
1649 pub fn wait_for_initial_indexing(
1650 &mut self,
1651 project: &Entity<Project>,
1652 cx: &mut App,
1653 ) -> Task<Result<()>> {
1654 let zeta_project = self.get_or_init_zeta_project(project, cx);
1655 if let Some(syntax_index) = &zeta_project.syntax_index {
1656 syntax_index.read(cx).wait_for_initial_file_indexing(cx)
1657 } else {
1658 Task::ready(Ok(()))
1659 }
1660 }
1661}
1662
1663pub fn text_from_response(mut res: open_ai::Response) -> Option<String> {
1664 let choice = res.choices.pop()?;
1665 let output_text = match choice.message {
1666 open_ai::RequestMessage::Assistant {
1667 content: Some(open_ai::MessageContent::Plain(content)),
1668 ..
1669 } => content,
1670 open_ai::RequestMessage::Assistant {
1671 content: Some(open_ai::MessageContent::Multipart(mut content)),
1672 ..
1673 } => {
1674 if content.is_empty() {
1675 log::error!("No output from Baseten completion response");
1676 return None;
1677 }
1678
1679 match content.remove(0) {
1680 open_ai::MessagePart::Text { text } => text,
1681 open_ai::MessagePart::Image { .. } => {
1682 log::error!("Expected text, got an image");
1683 return None;
1684 }
1685 }
1686 }
1687 _ => {
1688 log::error!("Invalid response message: {:?}", choice.message);
1689 return None;
1690 }
1691 };
1692 Some(output_text)
1693}
1694
1695#[derive(Error, Debug)]
1696#[error(
1697 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1698)]
1699pub struct ZedUpdateRequiredError {
1700 minimum_version: SemanticVersion,
1701}
1702
1703fn make_syntax_context_cloud_request(
1704 excerpt_path: Arc<Path>,
1705 context: EditPredictionContext,
1706 events: Vec<predict_edits_v3::Event>,
1707 can_collect_data: bool,
1708 diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
1709 diagnostic_groups_truncated: bool,
1710 git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
1711 debug_info: bool,
1712 worktrees: &Vec<worktree::Snapshot>,
1713 index_state: Option<&SyntaxIndexState>,
1714 prompt_max_bytes: Option<usize>,
1715 prompt_format: PromptFormat,
1716) -> predict_edits_v3::PredictEditsRequest {
1717 let mut signatures = Vec::new();
1718 let mut declaration_to_signature_index = HashMap::default();
1719 let mut referenced_declarations = Vec::new();
1720
1721 for snippet in context.declarations {
1722 let project_entry_id = snippet.declaration.project_entry_id();
1723 let Some(path) = worktrees.iter().find_map(|worktree| {
1724 worktree.entry_for_id(project_entry_id).map(|entry| {
1725 let mut full_path = RelPathBuf::new();
1726 full_path.push(worktree.root_name());
1727 full_path.push(&entry.path);
1728 full_path
1729 })
1730 }) else {
1731 continue;
1732 };
1733
1734 let parent_index = index_state.and_then(|index_state| {
1735 snippet.declaration.parent().and_then(|parent| {
1736 add_signature(
1737 parent,
1738 &mut declaration_to_signature_index,
1739 &mut signatures,
1740 index_state,
1741 )
1742 })
1743 });
1744
1745 let (text, text_is_truncated) = snippet.declaration.item_text();
1746 referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
1747 path: path.as_std_path().into(),
1748 text: text.into(),
1749 range: snippet.declaration.item_line_range(),
1750 text_is_truncated,
1751 signature_range: snippet.declaration.signature_range_in_item_text(),
1752 parent_index,
1753 signature_score: snippet.score(DeclarationStyle::Signature),
1754 declaration_score: snippet.score(DeclarationStyle::Declaration),
1755 score_components: snippet.components,
1756 });
1757 }
1758
1759 let excerpt_parent = index_state.and_then(|index_state| {
1760 context
1761 .excerpt
1762 .parent_declarations
1763 .last()
1764 .and_then(|(parent, _)| {
1765 add_signature(
1766 *parent,
1767 &mut declaration_to_signature_index,
1768 &mut signatures,
1769 index_state,
1770 )
1771 })
1772 });
1773
1774 predict_edits_v3::PredictEditsRequest {
1775 excerpt_path,
1776 excerpt: context.excerpt_text.body,
1777 excerpt_line_range: context.excerpt.line_range,
1778 excerpt_range: context.excerpt.range,
1779 cursor_point: predict_edits_v3::Point {
1780 line: predict_edits_v3::Line(context.cursor_point.row),
1781 column: context.cursor_point.column,
1782 },
1783 referenced_declarations,
1784 included_files: vec![],
1785 signatures,
1786 excerpt_parent,
1787 events,
1788 can_collect_data,
1789 diagnostic_groups,
1790 diagnostic_groups_truncated,
1791 git_info,
1792 debug_info,
1793 prompt_max_bytes,
1794 prompt_format,
1795 }
1796}
1797
1798fn add_signature(
1799 declaration_id: DeclarationId,
1800 declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
1801 signatures: &mut Vec<Signature>,
1802 index: &SyntaxIndexState,
1803) -> Option<usize> {
1804 if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
1805 return Some(*signature_index);
1806 }
1807 let Some(parent_declaration) = index.declaration(declaration_id) else {
1808 log::error!("bug: missing parent declaration");
1809 return None;
1810 };
1811 let parent_index = parent_declaration.parent().and_then(|parent| {
1812 add_signature(parent, declaration_to_signature_index, signatures, index)
1813 });
1814 let (text, text_is_truncated) = parent_declaration.signature_text();
1815 let signature_index = signatures.len();
1816 signatures.push(Signature {
1817 text: text.into(),
1818 text_is_truncated,
1819 parent_index,
1820 range: parent_declaration.signature_line_range(),
1821 });
1822 declaration_to_signature_index.insert(declaration_id, signature_index);
1823 Some(signature_index)
1824}
1825
1826#[cfg(feature = "eval-support")]
1827pub type EvalCacheKey = (EvalCacheEntryKind, u64);
1828
1829#[cfg(feature = "eval-support")]
1830#[derive(Debug, Clone, Copy, PartialEq)]
1831pub enum EvalCacheEntryKind {
1832 Context,
1833 Search,
1834 Prediction,
1835}
1836
1837#[cfg(feature = "eval-support")]
1838impl std::fmt::Display for EvalCacheEntryKind {
1839 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1840 match self {
1841 EvalCacheEntryKind::Search => write!(f, "search"),
1842 EvalCacheEntryKind::Context => write!(f, "context"),
1843 EvalCacheEntryKind::Prediction => write!(f, "prediction"),
1844 }
1845 }
1846}
1847
1848#[cfg(feature = "eval-support")]
1849pub trait EvalCache: Send + Sync {
1850 fn read(&self, key: EvalCacheKey) -> Option<String>;
1851 fn write(&self, key: EvalCacheKey, input: &str, value: &str);
1852}
1853
1854#[cfg(test)]
1855mod tests {
1856 use std::{path::Path, sync::Arc};
1857
1858 use client::UserStore;
1859 use clock::FakeSystemClock;
1860 use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
1861 use futures::{
1862 AsyncReadExt, StreamExt,
1863 channel::{mpsc, oneshot},
1864 };
1865 use gpui::{
1866 Entity, TestAppContext,
1867 http_client::{FakeHttpClient, Response},
1868 prelude::*,
1869 };
1870 use indoc::indoc;
1871 use language::OffsetRangeExt as _;
1872 use open_ai::Usage;
1873 use pretty_assertions::{assert_eq, assert_matches};
1874 use project::{FakeFs, Project};
1875 use serde_json::json;
1876 use settings::SettingsStore;
1877 use util::path;
1878 use uuid::Uuid;
1879
1880 use crate::{BufferEditPrediction, Zeta};
1881
1882 #[gpui::test]
1883 async fn test_current_state(cx: &mut TestAppContext) {
1884 let (zeta, mut req_rx) = init_test(cx);
1885 let fs = FakeFs::new(cx.executor());
1886 fs.insert_tree(
1887 "/root",
1888 json!({
1889 "1.txt": "Hello!\nHow\nBye\n",
1890 "2.txt": "Hola!\nComo\nAdios\n"
1891 }),
1892 )
1893 .await;
1894 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1895
1896 zeta.update(cx, |zeta, cx| {
1897 zeta.register_project(&project, cx);
1898 });
1899
1900 let buffer1 = project
1901 .update(cx, |project, cx| {
1902 let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
1903 project.open_buffer(path, cx)
1904 })
1905 .await
1906 .unwrap();
1907 let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
1908 let position = snapshot1.anchor_before(language::Point::new(1, 3));
1909
1910 // Prediction for current file
1911
1912 let prediction_task = zeta.update(cx, |zeta, cx| {
1913 zeta.refresh_prediction(&project, &buffer1, position, cx)
1914 });
1915 let (_request, respond_tx) = req_rx.next().await.unwrap();
1916
1917 respond_tx
1918 .send(model_response(indoc! {r"
1919 --- a/root/1.txt
1920 +++ b/root/1.txt
1921 @@ ... @@
1922 Hello!
1923 -How
1924 +How are you?
1925 Bye
1926 "}))
1927 .unwrap();
1928 prediction_task.await.unwrap();
1929
1930 zeta.read_with(cx, |zeta, cx| {
1931 let prediction = zeta
1932 .current_prediction_for_buffer(&buffer1, &project, cx)
1933 .unwrap();
1934 assert_matches!(prediction, BufferEditPrediction::Local { .. });
1935 });
1936
1937 // Context refresh
1938 let refresh_task = zeta.update(cx, |zeta, cx| {
1939 zeta.refresh_context(project.clone(), buffer1.clone(), position, cx)
1940 });
1941 let (_request, respond_tx) = req_rx.next().await.unwrap();
1942 respond_tx
1943 .send(open_ai::Response {
1944 id: Uuid::new_v4().to_string(),
1945 object: "response".into(),
1946 created: 0,
1947 model: "model".into(),
1948 choices: vec![open_ai::Choice {
1949 index: 0,
1950 message: open_ai::RequestMessage::Assistant {
1951 content: None,
1952 tool_calls: vec![open_ai::ToolCall {
1953 id: "search".into(),
1954 content: open_ai::ToolCallContent::Function {
1955 function: open_ai::FunctionContent {
1956 name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME
1957 .to_string(),
1958 arguments: serde_json::to_string(&SearchToolInput {
1959 queries: Box::new([SearchToolQuery {
1960 glob: "root/2.txt".to_string(),
1961 syntax_node: vec![],
1962 content: Some(".".into()),
1963 }]),
1964 })
1965 .unwrap(),
1966 },
1967 },
1968 }],
1969 },
1970 finish_reason: None,
1971 }],
1972 usage: Usage {
1973 prompt_tokens: 0,
1974 completion_tokens: 0,
1975 total_tokens: 0,
1976 },
1977 })
1978 .unwrap();
1979 refresh_task.await.unwrap();
1980
1981 zeta.update(cx, |zeta, _cx| {
1982 zeta.discard_current_prediction(&project);
1983 });
1984
1985 // Prediction for another file
1986 let prediction_task = zeta.update(cx, |zeta, cx| {
1987 zeta.refresh_prediction(&project, &buffer1, position, cx)
1988 });
1989 let (_request, respond_tx) = req_rx.next().await.unwrap();
1990 respond_tx
1991 .send(model_response(indoc! {r#"
1992 --- a/root/2.txt
1993 +++ b/root/2.txt
1994 Hola!
1995 -Como
1996 +Como estas?
1997 Adios
1998 "#}))
1999 .unwrap();
2000 prediction_task.await.unwrap();
2001 zeta.read_with(cx, |zeta, cx| {
2002 let prediction = zeta
2003 .current_prediction_for_buffer(&buffer1, &project, cx)
2004 .unwrap();
2005 assert_matches!(
2006 prediction,
2007 BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
2008 );
2009 });
2010
2011 let buffer2 = project
2012 .update(cx, |project, cx| {
2013 let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
2014 project.open_buffer(path, cx)
2015 })
2016 .await
2017 .unwrap();
2018
2019 zeta.read_with(cx, |zeta, cx| {
2020 let prediction = zeta
2021 .current_prediction_for_buffer(&buffer2, &project, cx)
2022 .unwrap();
2023 assert_matches!(prediction, BufferEditPrediction::Local { .. });
2024 });
2025 }
2026
2027 #[gpui::test]
2028 async fn test_simple_request(cx: &mut TestAppContext) {
2029 let (zeta, mut req_rx) = init_test(cx);
2030 let fs = FakeFs::new(cx.executor());
2031 fs.insert_tree(
2032 "/root",
2033 json!({
2034 "foo.md": "Hello!\nHow\nBye\n"
2035 }),
2036 )
2037 .await;
2038 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2039
2040 let buffer = project
2041 .update(cx, |project, cx| {
2042 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2043 project.open_buffer(path, cx)
2044 })
2045 .await
2046 .unwrap();
2047 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2048 let position = snapshot.anchor_before(language::Point::new(1, 3));
2049
2050 let prediction_task = zeta.update(cx, |zeta, cx| {
2051 zeta.request_prediction(&project, &buffer, position, cx)
2052 });
2053
2054 let (_, respond_tx) = req_rx.next().await.unwrap();
2055
2056 // TODO Put back when we have a structured request again
2057 // assert_eq!(
2058 // request.excerpt_path.as_ref(),
2059 // Path::new(path!("root/foo.md"))
2060 // );
2061 // assert_eq!(
2062 // request.cursor_point,
2063 // Point {
2064 // line: Line(1),
2065 // column: 3
2066 // }
2067 // );
2068
2069 respond_tx
2070 .send(model_response(indoc! { r"
2071 --- a/root/foo.md
2072 +++ b/root/foo.md
2073 @@ ... @@
2074 Hello!
2075 -How
2076 +How are you?
2077 Bye
2078 "}))
2079 .unwrap();
2080
2081 let prediction = prediction_task.await.unwrap().unwrap();
2082
2083 assert_eq!(prediction.edits.len(), 1);
2084 assert_eq!(
2085 prediction.edits[0].0.to_point(&snapshot).start,
2086 language::Point::new(1, 3)
2087 );
2088 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
2089 }
2090
2091 #[gpui::test]
2092 async fn test_request_events(cx: &mut TestAppContext) {
2093 let (zeta, mut req_rx) = init_test(cx);
2094 let fs = FakeFs::new(cx.executor());
2095 fs.insert_tree(
2096 "/root",
2097 json!({
2098 "foo.md": "Hello!\n\nBye\n"
2099 }),
2100 )
2101 .await;
2102 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2103
2104 let buffer = project
2105 .update(cx, |project, cx| {
2106 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2107 project.open_buffer(path, cx)
2108 })
2109 .await
2110 .unwrap();
2111
2112 zeta.update(cx, |zeta, cx| {
2113 zeta.register_buffer(&buffer, &project, cx);
2114 });
2115
2116 buffer.update(cx, |buffer, cx| {
2117 buffer.edit(vec![(7..7, "How")], None, cx);
2118 });
2119
2120 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2121 let position = snapshot.anchor_before(language::Point::new(1, 3));
2122
2123 let prediction_task = zeta.update(cx, |zeta, cx| {
2124 zeta.request_prediction(&project, &buffer, position, cx)
2125 });
2126
2127 let (request, respond_tx) = req_rx.next().await.unwrap();
2128
2129 let prompt = prompt_from_request(&request);
2130 assert!(
2131 prompt.contains(indoc! {"
2132 --- a/root/foo.md
2133 +++ b/root/foo.md
2134 @@ -1,3 +1,3 @@
2135 Hello!
2136 -
2137 +How
2138 Bye
2139 "}),
2140 "{prompt}"
2141 );
2142
2143 respond_tx
2144 .send(model_response(indoc! {r#"
2145 --- a/root/foo.md
2146 +++ b/root/foo.md
2147 @@ ... @@
2148 Hello!
2149 -How
2150 +How are you?
2151 Bye
2152 "#}))
2153 .unwrap();
2154
2155 let prediction = prediction_task.await.unwrap().unwrap();
2156
2157 assert_eq!(prediction.edits.len(), 1);
2158 assert_eq!(
2159 prediction.edits[0].0.to_point(&snapshot).start,
2160 language::Point::new(1, 3)
2161 );
2162 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
2163 }
2164
2165 // Skipped until we start including diagnostics in prompt
2166 // #[gpui::test]
2167 // async fn test_request_diagnostics(cx: &mut TestAppContext) {
2168 // let (zeta, mut req_rx) = init_test(cx);
2169 // let fs = FakeFs::new(cx.executor());
2170 // fs.insert_tree(
2171 // "/root",
2172 // json!({
2173 // "foo.md": "Hello!\nBye"
2174 // }),
2175 // )
2176 // .await;
2177 // let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2178
2179 // let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
2180 // let diagnostic = lsp::Diagnostic {
2181 // range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
2182 // severity: Some(lsp::DiagnosticSeverity::ERROR),
2183 // message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
2184 // ..Default::default()
2185 // };
2186
2187 // project.update(cx, |project, cx| {
2188 // project.lsp_store().update(cx, |lsp_store, cx| {
2189 // // Create some diagnostics
2190 // lsp_store
2191 // .update_diagnostics(
2192 // LanguageServerId(0),
2193 // lsp::PublishDiagnosticsParams {
2194 // uri: path_to_buffer_uri.clone(),
2195 // diagnostics: vec![diagnostic],
2196 // version: None,
2197 // },
2198 // None,
2199 // language::DiagnosticSourceKind::Pushed,
2200 // &[],
2201 // cx,
2202 // )
2203 // .unwrap();
2204 // });
2205 // });
2206
2207 // let buffer = project
2208 // .update(cx, |project, cx| {
2209 // let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2210 // project.open_buffer(path, cx)
2211 // })
2212 // .await
2213 // .unwrap();
2214
2215 // let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2216 // let position = snapshot.anchor_before(language::Point::new(0, 0));
2217
2218 // let _prediction_task = zeta.update(cx, |zeta, cx| {
2219 // zeta.request_prediction(&project, &buffer, position, cx)
2220 // });
2221
2222 // let (request, _respond_tx) = req_rx.next().await.unwrap();
2223
2224 // assert_eq!(request.diagnostic_groups.len(), 1);
2225 // let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
2226 // .unwrap();
2227 // // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
2228 // assert_eq!(
2229 // value,
2230 // json!({
2231 // "entries": [{
2232 // "range": {
2233 // "start": 8,
2234 // "end": 10
2235 // },
2236 // "diagnostic": {
2237 // "source": null,
2238 // "code": null,
2239 // "code_description": null,
2240 // "severity": 1,
2241 // "message": "\"Hello\" deprecated. Use \"Hi\" instead",
2242 // "markdown": null,
2243 // "group_id": 0,
2244 // "is_primary": true,
2245 // "is_disk_based": false,
2246 // "is_unnecessary": false,
2247 // "source_kind": "Pushed",
2248 // "data": null,
2249 // "underline": true
2250 // }
2251 // }],
2252 // "primary_ix": 0
2253 // })
2254 // );
2255 // }
2256
2257 fn model_response(text: &str) -> open_ai::Response {
2258 open_ai::Response {
2259 id: Uuid::new_v4().to_string(),
2260 object: "response".into(),
2261 created: 0,
2262 model: "model".into(),
2263 choices: vec![open_ai::Choice {
2264 index: 0,
2265 message: open_ai::RequestMessage::Assistant {
2266 content: Some(open_ai::MessageContent::Plain(text.to_string())),
2267 tool_calls: vec![],
2268 },
2269 finish_reason: None,
2270 }],
2271 usage: Usage {
2272 prompt_tokens: 0,
2273 completion_tokens: 0,
2274 total_tokens: 0,
2275 },
2276 }
2277 }
2278
2279 fn prompt_from_request(request: &open_ai::Request) -> &str {
2280 assert_eq!(request.messages.len(), 1);
2281 let open_ai::RequestMessage::User {
2282 content: open_ai::MessageContent::Plain(content),
2283 ..
2284 } = &request.messages[0]
2285 else {
2286 panic!(
2287 "Request does not have single user message of type Plain. {:#?}",
2288 request
2289 );
2290 };
2291 content
2292 }
2293
2294 fn init_test(
2295 cx: &mut TestAppContext,
2296 ) -> (
2297 Entity<Zeta>,
2298 mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender<open_ai::Response>)>,
2299 ) {
2300 cx.update(move |cx| {
2301 let settings_store = SettingsStore::test(cx);
2302 cx.set_global(settings_store);
2303 zlog::init_test();
2304
2305 let (req_tx, req_rx) = mpsc::unbounded();
2306
2307 let http_client = FakeHttpClient::create({
2308 move |req| {
2309 let uri = req.uri().path().to_string();
2310 let mut body = req.into_body();
2311 let req_tx = req_tx.clone();
2312 async move {
2313 let resp = match uri.as_str() {
2314 "/client/llm_tokens" => serde_json::to_string(&json!({
2315 "token": "test"
2316 }))
2317 .unwrap(),
2318 "/predict_edits/raw" => {
2319 let mut buf = Vec::new();
2320 body.read_to_end(&mut buf).await.ok();
2321 let req = serde_json::from_slice(&buf).unwrap();
2322
2323 let (res_tx, res_rx) = oneshot::channel();
2324 req_tx.unbounded_send((req, res_tx)).unwrap();
2325 serde_json::to_string(&res_rx.await?).unwrap()
2326 }
2327 _ => {
2328 panic!("Unexpected path: {}", uri)
2329 }
2330 };
2331
2332 Ok(Response::builder().body(resp.into()).unwrap())
2333 }
2334 }
2335 });
2336
2337 let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
2338 client.cloud_client().set_credentials(1, "test".into());
2339
2340 language_model::init(client.clone(), cx);
2341
2342 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2343 let zeta = Zeta::global(&client, &user_store, cx);
2344
2345 (zeta, req_rx)
2346 })
2347 }
2348}