1use anyhow::{Context as _, Result, anyhow};
2use chrono::TimeDelta;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
5use cloud_llm_client::{
6 AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
7 ZED_VERSION_HEADER_NAME,
8};
9use cloud_zeta2_prompt::DEFAULT_MAX_PROMPT_BYTES;
10use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
11use collections::HashMap;
12use edit_prediction_context::{
13 DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
14 EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionScoreOptions, Line,
15 SyntaxIndex, SyntaxIndexState,
16};
17use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
18use futures::AsyncReadExt as _;
19use futures::channel::{mpsc, oneshot};
20use gpui::http_client::{AsyncBody, Method};
21use gpui::{
22 App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
23 http_client, prelude::*,
24};
25use language::{Anchor, Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
26use language::{BufferSnapshot, OffsetRangeExt};
27use language_model::{LlmApiToken, RefreshLlmTokenListener};
28use open_ai::FunctionDefinition;
29use project::Project;
30use release_channel::AppVersion;
31use serde::de::DeserializeOwned;
32use std::collections::{VecDeque, hash_map};
33use uuid::Uuid;
34
35use std::ops::Range;
36use std::path::Path;
37use std::str::FromStr as _;
38use std::sync::{Arc, LazyLock};
39use std::time::{Duration, Instant};
40use thiserror::Error;
41use util::rel_path::RelPathBuf;
42use util::{LogErrorFuture, TryFutureExt};
43use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
44
45pub mod merge_excerpts;
46mod prediction;
47mod provider;
48pub mod retrieval_search;
49pub mod udiff;
50
51use crate::merge_excerpts::merge_excerpts;
52use crate::prediction::EditPrediction;
53pub use crate::prediction::EditPredictionId;
54pub use provider::ZetaEditPredictionProvider;
55
56/// Maximum number of events to track.
57const MAX_EVENT_COUNT: usize = 16;
58
59pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
60 max_bytes: 512,
61 min_bytes: 128,
62 target_before_cursor_over_total_bytes: 0.5,
63};
64
65pub const DEFAULT_CONTEXT_OPTIONS: ContextMode =
66 ContextMode::Agentic(DEFAULT_AGENTIC_CONTEXT_OPTIONS);
67
68pub const DEFAULT_AGENTIC_CONTEXT_OPTIONS: AgenticContextOptions = AgenticContextOptions {
69 excerpt: DEFAULT_EXCERPT_OPTIONS,
70};
71
72pub const DEFAULT_SYNTAX_CONTEXT_OPTIONS: EditPredictionContextOptions =
73 EditPredictionContextOptions {
74 use_imports: true,
75 max_retrieved_declarations: 0,
76 excerpt: DEFAULT_EXCERPT_OPTIONS,
77 score: EditPredictionScoreOptions {
78 omit_excerpt_overlaps: true,
79 },
80 };
81
82pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
83 context: DEFAULT_CONTEXT_OPTIONS,
84 max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
85 max_diagnostic_bytes: 2048,
86 prompt_format: PromptFormat::DEFAULT,
87 file_indexing_parallelism: 1,
88 buffer_change_grouping_interval: Duration::from_secs(1),
89};
90
91static MODEL_ID: LazyLock<String> =
92 LazyLock::new(|| std::env::var("ZED_ZETA2_MODEL").unwrap_or("yqvev8r3".to_string()));
93
94pub struct Zeta2FeatureFlag;
95
96impl FeatureFlag for Zeta2FeatureFlag {
97 const NAME: &'static str = "zeta2";
98
99 fn enabled_for_staff() -> bool {
100 false
101 }
102}
103
104#[derive(Clone)]
105struct ZetaGlobal(Entity<Zeta>);
106
107impl Global for ZetaGlobal {}
108
109pub struct Zeta {
110 client: Arc<Client>,
111 user_store: Entity<UserStore>,
112 llm_token: LlmApiToken,
113 _llm_token_subscription: Subscription,
114 projects: HashMap<EntityId, ZetaProject>,
115 options: ZetaOptions,
116 update_required: bool,
117 debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
118}
119
120#[derive(Debug, Clone, PartialEq)]
121pub struct ZetaOptions {
122 pub context: ContextMode,
123 pub max_prompt_bytes: usize,
124 pub max_diagnostic_bytes: usize,
125 pub prompt_format: predict_edits_v3::PromptFormat,
126 pub file_indexing_parallelism: usize,
127 pub buffer_change_grouping_interval: Duration,
128}
129
130#[derive(Debug, Clone, PartialEq)]
131pub enum ContextMode {
132 Agentic(AgenticContextOptions),
133 Syntax(EditPredictionContextOptions),
134}
135
136#[derive(Debug, Clone, PartialEq)]
137pub struct AgenticContextOptions {
138 pub excerpt: EditPredictionExcerptOptions,
139}
140
141impl ContextMode {
142 pub fn excerpt(&self) -> &EditPredictionExcerptOptions {
143 match self {
144 ContextMode::Agentic(options) => &options.excerpt,
145 ContextMode::Syntax(options) => &options.excerpt,
146 }
147 }
148}
149
150#[derive(Debug)]
151pub enum ZetaDebugInfo {
152 ContextRetrievalStarted(ZetaContextRetrievalStartedDebugInfo),
153 SearchQueriesGenerated(ZetaSearchQueryDebugInfo),
154 SearchQueriesExecuted(ZetaContextRetrievalDebugInfo),
155 ContextRetrievalFinished(ZetaContextRetrievalDebugInfo),
156 EditPredictionRequested(ZetaEditPredictionDebugInfo),
157}
158
159#[derive(Debug)]
160pub struct ZetaContextRetrievalStartedDebugInfo {
161 pub project: Entity<Project>,
162 pub timestamp: Instant,
163 pub search_prompt: String,
164}
165
166#[derive(Debug)]
167pub struct ZetaContextRetrievalDebugInfo {
168 pub project: Entity<Project>,
169 pub timestamp: Instant,
170}
171
172#[derive(Debug)]
173pub struct ZetaEditPredictionDebugInfo {
174 pub request: predict_edits_v3::PredictEditsRequest,
175 pub retrieval_time: TimeDelta,
176 pub buffer: WeakEntity<Buffer>,
177 pub position: language::Anchor,
178 pub local_prompt: Result<String, String>,
179 pub response_rx: oneshot::Receiver<(Result<open_ai::Response, String>, TimeDelta)>,
180}
181
182#[derive(Debug)]
183pub struct ZetaSearchQueryDebugInfo {
184 pub project: Entity<Project>,
185 pub timestamp: Instant,
186 pub search_queries: Vec<SearchToolQuery>,
187}
188
189pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
190
191struct ZetaProject {
192 syntax_index: Entity<SyntaxIndex>,
193 events: VecDeque<Event>,
194 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
195 current_prediction: Option<CurrentEditPrediction>,
196 context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
197 refresh_context_task: Option<LogErrorFuture<Task<Result<()>>>>,
198 refresh_context_debounce_task: Option<Task<Option<()>>>,
199 refresh_context_timestamp: Option<Instant>,
200}
201
202#[derive(Debug, Clone)]
203struct CurrentEditPrediction {
204 pub requested_by_buffer_id: EntityId,
205 pub prediction: EditPrediction,
206}
207
208impl CurrentEditPrediction {
209 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
210 let Some(new_edits) = self
211 .prediction
212 .interpolate(&self.prediction.buffer.read(cx))
213 else {
214 return false;
215 };
216
217 if self.prediction.buffer != old_prediction.prediction.buffer {
218 return true;
219 }
220
221 let Some(old_edits) = old_prediction
222 .prediction
223 .interpolate(&old_prediction.prediction.buffer.read(cx))
224 else {
225 return true;
226 };
227
228 // This reduces the occurrence of UI thrash from replacing edits
229 //
230 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
231 if self.requested_by_buffer_id == self.prediction.buffer.entity_id()
232 && self.requested_by_buffer_id == old_prediction.prediction.buffer.entity_id()
233 && old_edits.len() == 1
234 && new_edits.len() == 1
235 {
236 let (old_range, old_text) = &old_edits[0];
237 let (new_range, new_text) = &new_edits[0];
238 new_range == old_range && new_text.starts_with(old_text.as_ref())
239 } else {
240 true
241 }
242 }
243}
244
245/// A prediction from the perspective of a buffer.
246#[derive(Debug)]
247enum BufferEditPrediction<'a> {
248 Local { prediction: &'a EditPrediction },
249 Jump { prediction: &'a EditPrediction },
250}
251
252struct RegisteredBuffer {
253 snapshot: BufferSnapshot,
254 _subscriptions: [gpui::Subscription; 2],
255}
256
257#[derive(Clone)]
258pub enum Event {
259 BufferChange {
260 old_snapshot: BufferSnapshot,
261 new_snapshot: BufferSnapshot,
262 timestamp: Instant,
263 },
264}
265
266impl Event {
267 pub fn to_request_event(&self, cx: &App) -> Option<predict_edits_v3::Event> {
268 match self {
269 Event::BufferChange {
270 old_snapshot,
271 new_snapshot,
272 ..
273 } => {
274 let path = new_snapshot.file().map(|f| f.full_path(cx));
275
276 let old_path = old_snapshot.file().and_then(|f| {
277 let old_path = f.full_path(cx);
278 if Some(&old_path) != path.as_ref() {
279 Some(old_path)
280 } else {
281 None
282 }
283 });
284
285 // TODO [zeta2] move to bg?
286 let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
287
288 if path == old_path && diff.is_empty() {
289 None
290 } else {
291 Some(predict_edits_v3::Event::BufferChange {
292 old_path,
293 path,
294 diff,
295 //todo: Actually detect if this edit was predicted or not
296 predicted: false,
297 })
298 }
299 }
300 }
301 }
302}
303
304impl Zeta {
305 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
306 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
307 }
308
309 pub fn global(
310 client: &Arc<Client>,
311 user_store: &Entity<UserStore>,
312 cx: &mut App,
313 ) -> Entity<Self> {
314 cx.try_global::<ZetaGlobal>()
315 .map(|global| global.0.clone())
316 .unwrap_or_else(|| {
317 let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
318 cx.set_global(ZetaGlobal(zeta.clone()));
319 zeta
320 })
321 }
322
323 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
324 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
325
326 Self {
327 projects: HashMap::default(),
328 client,
329 user_store,
330 options: DEFAULT_OPTIONS,
331 llm_token: LlmApiToken::default(),
332 _llm_token_subscription: cx.subscribe(
333 &refresh_llm_token_listener,
334 |this, _listener, _event, cx| {
335 let client = this.client.clone();
336 let llm_token = this.llm_token.clone();
337 cx.spawn(async move |_this, _cx| {
338 llm_token.refresh(&client).await?;
339 anyhow::Ok(())
340 })
341 .detach_and_log_err(cx);
342 },
343 ),
344 update_required: false,
345 debug_tx: None,
346 }
347 }
348
349 pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<ZetaDebugInfo> {
350 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
351 self.debug_tx = Some(debug_watch_tx);
352 debug_watch_rx
353 }
354
355 pub fn options(&self) -> &ZetaOptions {
356 &self.options
357 }
358
359 pub fn set_options(&mut self, options: ZetaOptions) {
360 self.options = options;
361 }
362
363 pub fn clear_history(&mut self) {
364 for zeta_project in self.projects.values_mut() {
365 zeta_project.events.clear();
366 }
367 }
368
369 pub fn history_for_project(&self, project: &Entity<Project>) -> impl Iterator<Item = &Event> {
370 self.projects
371 .get(&project.entity_id())
372 .map(|project| project.events.iter())
373 .into_iter()
374 .flatten()
375 }
376
377 pub fn context_for_project(
378 &self,
379 project: &Entity<Project>,
380 ) -> impl Iterator<Item = (Entity<Buffer>, &[Range<Anchor>])> {
381 self.projects
382 .get(&project.entity_id())
383 .and_then(|project| {
384 Some(
385 project
386 .context
387 .as_ref()?
388 .iter()
389 .map(|(buffer, ranges)| (buffer.clone(), ranges.as_slice())),
390 )
391 })
392 .into_iter()
393 .flatten()
394 }
395
396 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
397 self.user_store.read(cx).edit_prediction_usage()
398 }
399
400 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
401 self.get_or_init_zeta_project(project, cx);
402 }
403
404 pub fn register_buffer(
405 &mut self,
406 buffer: &Entity<Buffer>,
407 project: &Entity<Project>,
408 cx: &mut Context<Self>,
409 ) {
410 let zeta_project = self.get_or_init_zeta_project(project, cx);
411 Self::register_buffer_impl(zeta_project, buffer, project, cx);
412 }
413
414 fn get_or_init_zeta_project(
415 &mut self,
416 project: &Entity<Project>,
417 cx: &mut App,
418 ) -> &mut ZetaProject {
419 self.projects
420 .entry(project.entity_id())
421 .or_insert_with(|| ZetaProject {
422 syntax_index: cx.new(|cx| {
423 SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
424 }),
425 events: VecDeque::new(),
426 registered_buffers: HashMap::default(),
427 current_prediction: None,
428 context: None,
429 refresh_context_task: None,
430 refresh_context_debounce_task: None,
431 refresh_context_timestamp: None,
432 })
433 }
434
435 fn register_buffer_impl<'a>(
436 zeta_project: &'a mut ZetaProject,
437 buffer: &Entity<Buffer>,
438 project: &Entity<Project>,
439 cx: &mut Context<Self>,
440 ) -> &'a mut RegisteredBuffer {
441 let buffer_id = buffer.entity_id();
442 match zeta_project.registered_buffers.entry(buffer_id) {
443 hash_map::Entry::Occupied(entry) => entry.into_mut(),
444 hash_map::Entry::Vacant(entry) => {
445 let snapshot = buffer.read(cx).snapshot();
446 let project_entity_id = project.entity_id();
447 entry.insert(RegisteredBuffer {
448 snapshot,
449 _subscriptions: [
450 cx.subscribe(buffer, {
451 let project = project.downgrade();
452 move |this, buffer, event, cx| {
453 if let language::BufferEvent::Edited = event
454 && let Some(project) = project.upgrade()
455 {
456 this.report_changes_for_buffer(&buffer, &project, cx);
457 }
458 }
459 }),
460 cx.observe_release(buffer, move |this, _buffer, _cx| {
461 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
462 else {
463 return;
464 };
465 zeta_project.registered_buffers.remove(&buffer_id);
466 }),
467 ],
468 })
469 }
470 }
471 }
472
473 fn report_changes_for_buffer(
474 &mut self,
475 buffer: &Entity<Buffer>,
476 project: &Entity<Project>,
477 cx: &mut Context<Self>,
478 ) -> BufferSnapshot {
479 let buffer_change_grouping_interval = self.options.buffer_change_grouping_interval;
480 let zeta_project = self.get_or_init_zeta_project(project, cx);
481 let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
482
483 let new_snapshot = buffer.read(cx).snapshot();
484 if new_snapshot.version != registered_buffer.snapshot.version {
485 let old_snapshot =
486 std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
487 Self::push_event(
488 zeta_project,
489 buffer_change_grouping_interval,
490 Event::BufferChange {
491 old_snapshot,
492 new_snapshot: new_snapshot.clone(),
493 timestamp: Instant::now(),
494 },
495 );
496 }
497
498 new_snapshot
499 }
500
501 fn push_event(
502 zeta_project: &mut ZetaProject,
503 buffer_change_grouping_interval: Duration,
504 event: Event,
505 ) {
506 let events = &mut zeta_project.events;
507
508 if buffer_change_grouping_interval > Duration::ZERO
509 && let Some(Event::BufferChange {
510 new_snapshot: last_new_snapshot,
511 timestamp: last_timestamp,
512 ..
513 }) = events.back_mut()
514 {
515 // Coalesce edits for the same buffer when they happen one after the other.
516 let Event::BufferChange {
517 old_snapshot,
518 new_snapshot,
519 timestamp,
520 } = &event;
521
522 if timestamp.duration_since(*last_timestamp) <= buffer_change_grouping_interval
523 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
524 && old_snapshot.version == last_new_snapshot.version
525 {
526 *last_new_snapshot = new_snapshot.clone();
527 *last_timestamp = *timestamp;
528 return;
529 }
530 }
531
532 if events.len() >= MAX_EVENT_COUNT {
533 // These are halved instead of popping to improve prompt caching.
534 events.drain(..MAX_EVENT_COUNT / 2);
535 }
536
537 events.push_back(event);
538 }
539
540 fn current_prediction_for_buffer(
541 &self,
542 buffer: &Entity<Buffer>,
543 project: &Entity<Project>,
544 cx: &App,
545 ) -> Option<BufferEditPrediction<'_>> {
546 let project_state = self.projects.get(&project.entity_id())?;
547
548 let CurrentEditPrediction {
549 requested_by_buffer_id,
550 prediction,
551 } = project_state.current_prediction.as_ref()?;
552
553 if prediction.targets_buffer(buffer.read(cx)) {
554 Some(BufferEditPrediction::Local { prediction })
555 } else if *requested_by_buffer_id == buffer.entity_id() {
556 Some(BufferEditPrediction::Jump { prediction })
557 } else {
558 None
559 }
560 }
561
562 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
563 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
564 return;
565 };
566
567 let Some(prediction) = project_state.current_prediction.take() else {
568 return;
569 };
570 let request_id = prediction.prediction.id.into();
571
572 let client = self.client.clone();
573 let llm_token = self.llm_token.clone();
574 let app_version = AppVersion::global(cx);
575 cx.spawn(async move |this, cx| {
576 let url = if let Ok(predict_edits_url) = std::env::var("ZED_ACCEPT_PREDICTION_URL") {
577 http_client::Url::parse(&predict_edits_url)?
578 } else {
579 client
580 .http_client()
581 .build_zed_llm_url("/predict_edits/accept", &[])?
582 };
583
584 let response = cx
585 .background_spawn(Self::send_api_request::<()>(
586 move |builder| {
587 let req = builder.uri(url.as_ref()).body(
588 serde_json::to_string(&AcceptEditPredictionBody { request_id })?.into(),
589 );
590 Ok(req?)
591 },
592 client,
593 llm_token,
594 app_version,
595 ))
596 .await;
597
598 Self::handle_api_response(&this, response, cx)?;
599 anyhow::Ok(())
600 })
601 .detach_and_log_err(cx);
602 }
603
604 fn discard_current_prediction(&mut self, project: &Entity<Project>) {
605 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
606 project_state.current_prediction.take();
607 };
608 }
609
610 pub fn refresh_prediction(
611 &mut self,
612 project: &Entity<Project>,
613 buffer: &Entity<Buffer>,
614 position: language::Anchor,
615 cx: &mut Context<Self>,
616 ) -> Task<Result<()>> {
617 let request_task = self.request_prediction(project, buffer, position, cx);
618 let buffer = buffer.clone();
619 let project = project.clone();
620
621 cx.spawn(async move |this, cx| {
622 if let Some(prediction) = request_task.await? {
623 this.update(cx, |this, cx| {
624 let project_state = this
625 .projects
626 .get_mut(&project.entity_id())
627 .context("Project not found")?;
628
629 let new_prediction = CurrentEditPrediction {
630 requested_by_buffer_id: buffer.entity_id(),
631 prediction: prediction,
632 };
633
634 if project_state
635 .current_prediction
636 .as_ref()
637 .is_none_or(|old_prediction| {
638 new_prediction.should_replace_prediction(&old_prediction, cx)
639 })
640 {
641 project_state.current_prediction = Some(new_prediction);
642 }
643 anyhow::Ok(())
644 })??;
645 }
646 Ok(())
647 })
648 }
649
650 pub fn request_prediction(
651 &mut self,
652 project: &Entity<Project>,
653 active_buffer: &Entity<Buffer>,
654 position: language::Anchor,
655 cx: &mut Context<Self>,
656 ) -> Task<Result<Option<EditPrediction>>> {
657 let project_state = self.projects.get(&project.entity_id());
658
659 let index_state = project_state.map(|state| {
660 state
661 .syntax_index
662 .read_with(cx, |index, _cx| index.state().clone())
663 });
664 let options = self.options.clone();
665 let active_snapshot = active_buffer.read(cx).snapshot();
666 let Some(excerpt_path) = active_snapshot
667 .file()
668 .map(|path| -> Arc<Path> { path.full_path(cx).into() })
669 else {
670 return Task::ready(Err(anyhow!("No file path for excerpt")));
671 };
672 let client = self.client.clone();
673 let llm_token = self.llm_token.clone();
674 let app_version = AppVersion::global(cx);
675 let worktree_snapshots = project
676 .read(cx)
677 .worktrees(cx)
678 .map(|worktree| worktree.read(cx).snapshot())
679 .collect::<Vec<_>>();
680 let debug_tx = self.debug_tx.clone();
681
682 let events = project_state
683 .map(|state| {
684 state
685 .events
686 .iter()
687 .filter_map(|event| event.to_request_event(cx))
688 .collect::<Vec<_>>()
689 })
690 .unwrap_or_default();
691
692 let diagnostics = active_snapshot.diagnostic_sets().clone();
693
694 let parent_abs_path =
695 project::File::from_dyn(active_buffer.read(cx).file()).and_then(|f| {
696 let mut path = f.worktree.read(cx).absolutize(&f.path);
697 if path.pop() { Some(path) } else { None }
698 });
699
700 // TODO data collection
701 let can_collect_data = cx.is_staff();
702
703 let mut included_files = project_state
704 .and_then(|project_state| project_state.context.as_ref())
705 .unwrap_or(&HashMap::default())
706 .iter()
707 .filter_map(|(buffer_entity, ranges)| {
708 let buffer = buffer_entity.read(cx);
709 Some((
710 buffer_entity.clone(),
711 buffer.snapshot(),
712 buffer.file()?.full_path(cx).into(),
713 ranges.clone(),
714 ))
715 })
716 .collect::<Vec<_>>();
717
718 let request_task = cx.background_spawn({
719 let active_buffer = active_buffer.clone();
720 async move {
721 let index_state = if let Some(index_state) = index_state {
722 Some(index_state.lock_owned().await)
723 } else {
724 None
725 };
726
727 let cursor_offset = position.to_offset(&active_snapshot);
728 let cursor_point = cursor_offset.to_point(&active_snapshot);
729
730 let before_retrieval = chrono::Utc::now();
731
732 let (diagnostic_groups, diagnostic_groups_truncated) =
733 Self::gather_nearby_diagnostics(
734 cursor_offset,
735 &diagnostics,
736 &active_snapshot,
737 options.max_diagnostic_bytes,
738 );
739
740 let cloud_request = match options.context {
741 ContextMode::Agentic(context_options) => {
742 let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
743 cursor_point,
744 &active_snapshot,
745 &context_options.excerpt,
746 index_state.as_deref(),
747 ) else {
748 return Ok((None, None));
749 };
750
751 let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start)
752 ..active_snapshot.anchor_before(excerpt.range.end);
753
754 if let Some(buffer_ix) =
755 included_files.iter().position(|(_, snapshot, _, _)| {
756 snapshot.remote_id() == active_snapshot.remote_id()
757 })
758 {
759 let (_, buffer, _, ranges) = &mut included_files[buffer_ix];
760 let range_ix = ranges
761 .binary_search_by(|probe| {
762 probe
763 .start
764 .cmp(&excerpt_anchor_range.start, buffer)
765 .then(excerpt_anchor_range.end.cmp(&probe.end, buffer))
766 })
767 .unwrap_or_else(|ix| ix);
768
769 ranges.insert(range_ix, excerpt_anchor_range);
770 let last_ix = included_files.len() - 1;
771 included_files.swap(buffer_ix, last_ix);
772 } else {
773 included_files.push((
774 active_buffer.clone(),
775 active_snapshot,
776 excerpt_path.clone(),
777 vec![excerpt_anchor_range],
778 ));
779 }
780
781 let included_files = included_files
782 .iter()
783 .map(|(_, buffer, path, ranges)| {
784 let excerpts = merge_excerpts(
785 &buffer,
786 ranges.iter().map(|range| {
787 let point_range = range.to_point(&buffer);
788 Line(point_range.start.row)..Line(point_range.end.row)
789 }),
790 );
791 predict_edits_v3::IncludedFile {
792 path: path.clone(),
793 max_row: Line(buffer.max_point().row),
794 excerpts,
795 }
796 })
797 .collect::<Vec<_>>();
798
799 predict_edits_v3::PredictEditsRequest {
800 excerpt_path,
801 excerpt: String::new(),
802 excerpt_line_range: Line(0)..Line(0),
803 excerpt_range: 0..0,
804 cursor_point: predict_edits_v3::Point {
805 line: predict_edits_v3::Line(cursor_point.row),
806 column: cursor_point.column,
807 },
808 included_files,
809 referenced_declarations: vec![],
810 events,
811 can_collect_data,
812 diagnostic_groups,
813 diagnostic_groups_truncated,
814 debug_info: debug_tx.is_some(),
815 prompt_max_bytes: Some(options.max_prompt_bytes),
816 prompt_format: options.prompt_format,
817 // TODO [zeta2]
818 signatures: vec![],
819 excerpt_parent: None,
820 git_info: None,
821 }
822 }
823 ContextMode::Syntax(context_options) => {
824 let Some(context) = EditPredictionContext::gather_context(
825 cursor_point,
826 &active_snapshot,
827 parent_abs_path.as_deref(),
828 &context_options,
829 index_state.as_deref(),
830 ) else {
831 return Ok((None, None));
832 };
833
834 make_syntax_context_cloud_request(
835 excerpt_path,
836 context,
837 events,
838 can_collect_data,
839 diagnostic_groups,
840 diagnostic_groups_truncated,
841 None,
842 debug_tx.is_some(),
843 &worktree_snapshots,
844 index_state.as_deref(),
845 Some(options.max_prompt_bytes),
846 options.prompt_format,
847 )
848 }
849 };
850
851 let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request);
852
853 let retrieval_time = chrono::Utc::now() - before_retrieval;
854
855 let debug_response_tx = if let Some(debug_tx) = &debug_tx {
856 let (response_tx, response_rx) = oneshot::channel();
857
858 debug_tx
859 .unbounded_send(ZetaDebugInfo::EditPredictionRequested(
860 ZetaEditPredictionDebugInfo {
861 request: cloud_request.clone(),
862 retrieval_time,
863 buffer: active_buffer.downgrade(),
864 local_prompt: match prompt_result.as_ref() {
865 Ok((prompt, _)) => Ok(prompt.clone()),
866 Err(err) => Err(err.to_string()),
867 },
868 position,
869 response_rx,
870 },
871 ))
872 .ok();
873 Some(response_tx)
874 } else {
875 None
876 };
877
878 if cfg!(debug_assertions) && std::env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
879 if let Some(debug_response_tx) = debug_response_tx {
880 debug_response_tx
881 .send((Err("Request skipped".to_string()), TimeDelta::zero()))
882 .ok();
883 }
884 anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
885 }
886
887 let (prompt, _) = prompt_result?;
888 let request = open_ai::Request {
889 model: MODEL_ID.clone(),
890 messages: vec![open_ai::RequestMessage::User {
891 content: open_ai::MessageContent::Plain(prompt),
892 }],
893 stream: false,
894 max_completion_tokens: None,
895 stop: Default::default(),
896 temperature: 0.7,
897 tool_choice: None,
898 parallel_tool_calls: None,
899 tools: vec![],
900 prompt_cache_key: None,
901 reasoning_effort: None,
902 };
903
904 log::trace!("Sending edit prediction request");
905
906 let before_request = chrono::Utc::now();
907 let response =
908 Self::send_raw_llm_request(client, llm_token, app_version, request).await;
909 let request_time = chrono::Utc::now() - before_request;
910
911 log::trace!("Got edit prediction response");
912
913 if let Some(debug_response_tx) = debug_response_tx {
914 debug_response_tx
915 .send((
916 response
917 .as_ref()
918 .map_err(|err| err.to_string())
919 .map(|response| response.0.clone()),
920 request_time,
921 ))
922 .ok();
923 }
924
925 let (res, usage) = response?;
926 let request_id = EditPredictionId(Uuid::from_str(&res.id)?);
927 let Some(output_text) = text_from_response(res) else {
928 return Ok((None, usage))
929 };
930
931 let (edited_buffer_snapshot, edits) =
932 crate::udiff::parse_diff(&output_text, |path| {
933 included_files
934 .iter()
935 .find_map(|(_, buffer, probe_path, ranges)| {
936 if probe_path.as_ref() == path {
937 Some((buffer, ranges.as_slice()))
938 } else {
939 None
940 }
941 })
942 })
943 .await?;
944
945 let edited_buffer = included_files
946 .iter()
947 .find_map(|(buffer, snapshot, _, _)| {
948 if snapshot.remote_id() == edited_buffer_snapshot.remote_id() {
949 Some(buffer.clone())
950 } else {
951 None
952 }
953 })
954 .context("Failed to find buffer in included_buffers, even though we just found the snapshot")?;
955
956 anyhow::Ok((Some((request_id, edited_buffer, edited_buffer_snapshot.clone(), edits)), usage))
957 }
958 });
959
960 cx.spawn({
961 async move |this, cx| {
962 let Some((id, edited_buffer, edited_buffer_snapshot, edits)) =
963 Self::handle_api_response(&this, request_task.await, cx)?
964 else {
965 return Ok(None);
966 };
967
968 // TODO telemetry: duration, etc
969 Ok(
970 EditPrediction::new(id, &edited_buffer, &edited_buffer_snapshot, edits, cx)
971 .await,
972 )
973 }
974 })
975 }
976
977 async fn send_raw_llm_request(
978 client: Arc<Client>,
979 llm_token: LlmApiToken,
980 app_version: SemanticVersion,
981 request: open_ai::Request,
982 ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
983 let url = if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
984 http_client::Url::parse(&predict_edits_url)?
985 } else {
986 client
987 .http_client()
988 .build_zed_llm_url("/predict_edits/raw", &[])?
989 };
990
991 Self::send_api_request(
992 |builder| {
993 let req = builder
994 .uri(url.as_ref())
995 .body(serde_json::to_string(&request)?.into());
996 Ok(req?)
997 },
998 client,
999 llm_token,
1000 app_version,
1001 )
1002 .await
1003 }
1004
1005 fn handle_api_response<T>(
1006 this: &WeakEntity<Self>,
1007 response: Result<(T, Option<EditPredictionUsage>)>,
1008 cx: &mut gpui::AsyncApp,
1009 ) -> Result<T> {
1010 match response {
1011 Ok((data, usage)) => {
1012 if let Some(usage) = usage {
1013 this.update(cx, |this, cx| {
1014 this.user_store.update(cx, |user_store, cx| {
1015 user_store.update_edit_prediction_usage(usage, cx);
1016 });
1017 })
1018 .ok();
1019 }
1020 Ok(data)
1021 }
1022 Err(err) => {
1023 if err.is::<ZedUpdateRequiredError>() {
1024 cx.update(|cx| {
1025 this.update(cx, |this, _cx| {
1026 this.update_required = true;
1027 })
1028 .ok();
1029
1030 let error_message: SharedString = err.to_string().into();
1031 show_app_notification(
1032 NotificationId::unique::<ZedUpdateRequiredError>(),
1033 cx,
1034 move |cx| {
1035 cx.new(|cx| {
1036 ErrorMessagePrompt::new(error_message.clone(), cx)
1037 .with_link_button("Update Zed", "https://zed.dev/releases")
1038 })
1039 },
1040 );
1041 })
1042 .ok();
1043 }
1044 Err(err)
1045 }
1046 }
1047 }
1048
1049 async fn send_api_request<Res>(
1050 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1051 client: Arc<Client>,
1052 llm_token: LlmApiToken,
1053 app_version: SemanticVersion,
1054 ) -> Result<(Res, Option<EditPredictionUsage>)>
1055 where
1056 Res: DeserializeOwned,
1057 {
1058 let http_client = client.http_client();
1059 let mut token = llm_token.acquire(&client).await?;
1060 let mut did_retry = false;
1061
1062 loop {
1063 let request_builder = http_client::Request::builder().method(Method::POST);
1064
1065 let request = build(
1066 request_builder
1067 .header("Content-Type", "application/json")
1068 .header("Authorization", format!("Bearer {}", token))
1069 .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
1070 )?;
1071
1072 let mut response = http_client.send(request).await?;
1073
1074 if let Some(minimum_required_version) = response
1075 .headers()
1076 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
1077 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
1078 {
1079 anyhow::ensure!(
1080 app_version >= minimum_required_version,
1081 ZedUpdateRequiredError {
1082 minimum_version: minimum_required_version
1083 }
1084 );
1085 }
1086
1087 if response.status().is_success() {
1088 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
1089
1090 let mut body = Vec::new();
1091 response.body_mut().read_to_end(&mut body).await?;
1092 return Ok((serde_json::from_slice(&body)?, usage));
1093 } else if !did_retry
1094 && response
1095 .headers()
1096 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1097 .is_some()
1098 {
1099 did_retry = true;
1100 token = llm_token.refresh(&client).await?;
1101 } else {
1102 let mut body = String::new();
1103 response.body_mut().read_to_string(&mut body).await?;
1104 anyhow::bail!(
1105 "Request failed with status: {:?}\nBody: {}",
1106 response.status(),
1107 body
1108 );
1109 }
1110 }
1111 }
1112
1113 pub const CONTEXT_RETRIEVAL_IDLE_DURATION: Duration = Duration::from_secs(10);
1114 pub const CONTEXT_RETRIEVAL_DEBOUNCE_DURATION: Duration = Duration::from_secs(3);
1115
1116 // Refresh the related excerpts when the user just beguns editing after
1117 // an idle period, and after they pause editing.
1118 fn refresh_context_if_needed(
1119 &mut self,
1120 project: &Entity<Project>,
1121 buffer: &Entity<language::Buffer>,
1122 cursor_position: language::Anchor,
1123 cx: &mut Context<Self>,
1124 ) {
1125 if !matches!(&self.options().context, ContextMode::Agentic { .. }) {
1126 return;
1127 }
1128
1129 let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1130 return;
1131 };
1132
1133 let now = Instant::now();
1134 let was_idle = zeta_project
1135 .refresh_context_timestamp
1136 .map_or(true, |timestamp| {
1137 now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION
1138 });
1139 zeta_project.refresh_context_timestamp = Some(now);
1140 zeta_project.refresh_context_debounce_task = Some(cx.spawn({
1141 let buffer = buffer.clone();
1142 let project = project.clone();
1143 async move |this, cx| {
1144 if was_idle {
1145 log::debug!("refetching edit prediction context after idle");
1146 } else {
1147 cx.background_executor()
1148 .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION)
1149 .await;
1150 log::debug!("refetching edit prediction context after pause");
1151 }
1152 this.update(cx, |this, cx| {
1153 let task = this.refresh_context(project.clone(), buffer, cursor_position, cx);
1154
1155 if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
1156 zeta_project.refresh_context_task = Some(task.log_err());
1157 };
1158 })
1159 .ok()
1160 }
1161 }));
1162 }
1163
1164 // Refresh the related excerpts asynchronously. Ensure the task runs to completion,
1165 // and avoid spawning more than one concurrent task.
1166 pub fn refresh_context(
1167 &mut self,
1168 project: Entity<Project>,
1169 buffer: Entity<language::Buffer>,
1170 cursor_position: language::Anchor,
1171 cx: &mut Context<Self>,
1172 ) -> Task<Result<()>> {
1173 let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
1174 return Task::ready(anyhow::Ok(()));
1175 };
1176
1177 let ContextMode::Agentic(options) = &self.options().context else {
1178 return Task::ready(anyhow::Ok(()));
1179 };
1180
1181 let snapshot = buffer.read(cx).snapshot();
1182 let cursor_point = cursor_position.to_point(&snapshot);
1183 let Some(cursor_excerpt) = EditPredictionExcerpt::select_from_buffer(
1184 cursor_point,
1185 &snapshot,
1186 &options.excerpt,
1187 None,
1188 ) else {
1189 return Task::ready(Ok(()));
1190 };
1191
1192 let app_version = AppVersion::global(cx);
1193 let client = self.client.clone();
1194 let llm_token = self.llm_token.clone();
1195 let debug_tx = self.debug_tx.clone();
1196 let current_file_path: Arc<Path> = snapshot
1197 .file()
1198 .map(|f| f.full_path(cx).into())
1199 .unwrap_or_else(|| Path::new("untitled").into());
1200
1201 let prompt = match cloud_zeta2_prompt::retrieval_prompt::build_prompt(
1202 predict_edits_v3::PlanContextRetrievalRequest {
1203 excerpt: cursor_excerpt.text(&snapshot).body,
1204 excerpt_path: current_file_path,
1205 excerpt_line_range: cursor_excerpt.line_range,
1206 cursor_file_max_row: Line(snapshot.max_point().row),
1207 events: zeta_project
1208 .events
1209 .iter()
1210 .filter_map(|ev| ev.to_request_event(cx))
1211 .collect(),
1212 },
1213 ) {
1214 Ok(prompt) => prompt,
1215 Err(err) => {
1216 return Task::ready(Err(err));
1217 }
1218 };
1219
1220 if let Some(debug_tx) = &debug_tx {
1221 debug_tx
1222 .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted(
1223 ZetaContextRetrievalStartedDebugInfo {
1224 project: project.clone(),
1225 timestamp: Instant::now(),
1226 search_prompt: prompt.clone(),
1227 },
1228 ))
1229 .ok();
1230 }
1231
1232 pub static TOOL_SCHEMA: LazyLock<(serde_json::Value, String)> = LazyLock::new(|| {
1233 let schema = language_model::tool_schema::root_schema_for::<SearchToolInput>(
1234 language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset,
1235 );
1236
1237 let description = schema
1238 .get("description")
1239 .and_then(|description| description.as_str())
1240 .unwrap()
1241 .to_string();
1242
1243 (schema.into(), description)
1244 });
1245
1246 let (tool_schema, tool_description) = TOOL_SCHEMA.clone();
1247
1248 let request = open_ai::Request {
1249 model: MODEL_ID.clone(),
1250 messages: vec![open_ai::RequestMessage::User {
1251 content: open_ai::MessageContent::Plain(prompt),
1252 }],
1253 stream: false,
1254 max_completion_tokens: None,
1255 stop: Default::default(),
1256 temperature: 0.7,
1257 tool_choice: None,
1258 parallel_tool_calls: None,
1259 tools: vec![open_ai::ToolDefinition::Function {
1260 function: FunctionDefinition {
1261 name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME.to_string(),
1262 description: Some(tool_description),
1263 parameters: Some(tool_schema),
1264 },
1265 }],
1266 prompt_cache_key: None,
1267 reasoning_effort: None,
1268 };
1269
1270 cx.spawn(async move |this, cx| {
1271 log::trace!("Sending search planning request");
1272 let response =
1273 Self::send_raw_llm_request(client, llm_token, app_version, request).await;
1274 let mut response = Self::handle_api_response(&this, response, cx)?;
1275 log::trace!("Got search planning response");
1276
1277 let choice = response
1278 .choices
1279 .pop()
1280 .context("No choices in retrieval response")?;
1281 let open_ai::RequestMessage::Assistant {
1282 content: _,
1283 tool_calls,
1284 } = choice.message
1285 else {
1286 anyhow::bail!("Retrieval response didn't include an assistant message");
1287 };
1288
1289 let mut queries: Vec<SearchToolQuery> = Vec::new();
1290 for tool_call in tool_calls {
1291 let open_ai::ToolCallContent::Function { function } = tool_call.content;
1292 if function.name != cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME {
1293 log::warn!(
1294 "Context retrieval response tried to call an unknown tool: {}",
1295 function.name
1296 );
1297
1298 continue;
1299 }
1300
1301 let input: SearchToolInput = serde_json::from_str(&function.arguments)?;
1302 queries.extend(input.queries);
1303 }
1304
1305 if let Some(debug_tx) = &debug_tx {
1306 debug_tx
1307 .unbounded_send(ZetaDebugInfo::SearchQueriesGenerated(
1308 ZetaSearchQueryDebugInfo {
1309 project: project.clone(),
1310 timestamp: Instant::now(),
1311 search_queries: queries.clone(),
1312 },
1313 ))
1314 .ok();
1315 }
1316
1317 log::trace!("Running retrieval search: {queries:#?}");
1318
1319 let related_excerpts_result =
1320 retrieval_search::run_retrieval_searches(project.clone(), queries, cx).await;
1321
1322 log::trace!("Search queries executed");
1323
1324 if let Some(debug_tx) = &debug_tx {
1325 debug_tx
1326 .unbounded_send(ZetaDebugInfo::SearchQueriesExecuted(
1327 ZetaContextRetrievalDebugInfo {
1328 project: project.clone(),
1329 timestamp: Instant::now(),
1330 },
1331 ))
1332 .ok();
1333 }
1334
1335 this.update(cx, |this, _cx| {
1336 let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
1337 return Ok(());
1338 };
1339 zeta_project.refresh_context_task.take();
1340 if let Some(debug_tx) = &this.debug_tx {
1341 debug_tx
1342 .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished(
1343 ZetaContextRetrievalDebugInfo {
1344 project,
1345 timestamp: Instant::now(),
1346 },
1347 ))
1348 .ok();
1349 }
1350 match related_excerpts_result {
1351 Ok(excerpts) => {
1352 zeta_project.context = Some(excerpts);
1353 Ok(())
1354 }
1355 Err(error) => Err(error),
1356 }
1357 })?
1358 })
1359 }
1360
1361 fn gather_nearby_diagnostics(
1362 cursor_offset: usize,
1363 diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
1364 snapshot: &BufferSnapshot,
1365 max_diagnostics_bytes: usize,
1366 ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
1367 // TODO: Could make this more efficient
1368 let mut diagnostic_groups = Vec::new();
1369 for (language_server_id, diagnostics) in diagnostic_sets {
1370 let mut groups = Vec::new();
1371 diagnostics.groups(*language_server_id, &mut groups, &snapshot);
1372 diagnostic_groups.extend(
1373 groups
1374 .into_iter()
1375 .map(|(_, group)| group.resolve::<usize>(&snapshot)),
1376 );
1377 }
1378
1379 // sort by proximity to cursor
1380 diagnostic_groups.sort_by_key(|group| {
1381 let range = &group.entries[group.primary_ix].range;
1382 if range.start >= cursor_offset {
1383 range.start - cursor_offset
1384 } else if cursor_offset >= range.end {
1385 cursor_offset - range.end
1386 } else {
1387 (cursor_offset - range.start).min(range.end - cursor_offset)
1388 }
1389 });
1390
1391 let mut results = Vec::new();
1392 let mut diagnostic_groups_truncated = false;
1393 let mut diagnostics_byte_count = 0;
1394 for group in diagnostic_groups {
1395 let raw_value = serde_json::value::to_raw_value(&group).unwrap();
1396 diagnostics_byte_count += raw_value.get().len();
1397 if diagnostics_byte_count > max_diagnostics_bytes {
1398 diagnostic_groups_truncated = true;
1399 break;
1400 }
1401 results.push(predict_edits_v3::DiagnosticGroup(raw_value));
1402 }
1403
1404 (results, diagnostic_groups_truncated)
1405 }
1406
1407 // TODO: Dedupe with similar code in request_prediction?
1408 pub fn cloud_request_for_zeta_cli(
1409 &mut self,
1410 project: &Entity<Project>,
1411 buffer: &Entity<Buffer>,
1412 position: language::Anchor,
1413 cx: &mut Context<Self>,
1414 ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
1415 let project_state = self.projects.get(&project.entity_id());
1416
1417 let index_state = project_state.map(|state| {
1418 state
1419 .syntax_index
1420 .read_with(cx, |index, _cx| index.state().clone())
1421 });
1422 let options = self.options.clone();
1423 let snapshot = buffer.read(cx).snapshot();
1424 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
1425 return Task::ready(Err(anyhow!("No file path for excerpt")));
1426 };
1427 let worktree_snapshots = project
1428 .read(cx)
1429 .worktrees(cx)
1430 .map(|worktree| worktree.read(cx).snapshot())
1431 .collect::<Vec<_>>();
1432
1433 let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
1434 let mut path = f.worktree.read(cx).absolutize(&f.path);
1435 if path.pop() { Some(path) } else { None }
1436 });
1437
1438 cx.background_spawn(async move {
1439 let index_state = if let Some(index_state) = index_state {
1440 Some(index_state.lock_owned().await)
1441 } else {
1442 None
1443 };
1444
1445 let cursor_point = position.to_point(&snapshot);
1446
1447 let debug_info = true;
1448 EditPredictionContext::gather_context(
1449 cursor_point,
1450 &snapshot,
1451 parent_abs_path.as_deref(),
1452 match &options.context {
1453 ContextMode::Agentic(_) => {
1454 // TODO
1455 panic!("Llm mode not supported in zeta cli yet");
1456 }
1457 ContextMode::Syntax(edit_prediction_context_options) => {
1458 edit_prediction_context_options
1459 }
1460 },
1461 index_state.as_deref(),
1462 )
1463 .context("Failed to select excerpt")
1464 .map(|context| {
1465 make_syntax_context_cloud_request(
1466 excerpt_path.into(),
1467 context,
1468 // TODO pass everything
1469 Vec::new(),
1470 false,
1471 Vec::new(),
1472 false,
1473 None,
1474 debug_info,
1475 &worktree_snapshots,
1476 index_state.as_deref(),
1477 Some(options.max_prompt_bytes),
1478 options.prompt_format,
1479 )
1480 })
1481 })
1482 }
1483
1484 pub fn wait_for_initial_indexing(
1485 &mut self,
1486 project: &Entity<Project>,
1487 cx: &mut App,
1488 ) -> Task<Result<()>> {
1489 let zeta_project = self.get_or_init_zeta_project(project, cx);
1490 zeta_project
1491 .syntax_index
1492 .read(cx)
1493 .wait_for_initial_file_indexing(cx)
1494 }
1495}
1496
1497pub fn text_from_response(mut res: open_ai::Response) -> Option<String> {
1498 let choice = res.choices.pop()?;
1499 let output_text = match choice.message {
1500 open_ai::RequestMessage::Assistant {
1501 content: Some(open_ai::MessageContent::Plain(content)),
1502 ..
1503 } => content,
1504 open_ai::RequestMessage::Assistant {
1505 content: Some(open_ai::MessageContent::Multipart(mut content)),
1506 ..
1507 } => {
1508 if content.is_empty() {
1509 log::error!("No output from Baseten completion response");
1510 return None;
1511 }
1512
1513 match content.remove(0) {
1514 open_ai::MessagePart::Text { text } => text,
1515 open_ai::MessagePart::Image { .. } => {
1516 log::error!("Expected text, got an image");
1517 return None;
1518 }
1519 }
1520 }
1521 _ => {
1522 log::error!("Invalid response message: {:?}", choice.message);
1523 return None;
1524 }
1525 };
1526 Some(output_text)
1527}
1528
1529#[derive(Error, Debug)]
1530#[error(
1531 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1532)]
1533pub struct ZedUpdateRequiredError {
1534 minimum_version: SemanticVersion,
1535}
1536
1537fn make_syntax_context_cloud_request(
1538 excerpt_path: Arc<Path>,
1539 context: EditPredictionContext,
1540 events: Vec<predict_edits_v3::Event>,
1541 can_collect_data: bool,
1542 diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
1543 diagnostic_groups_truncated: bool,
1544 git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
1545 debug_info: bool,
1546 worktrees: &Vec<worktree::Snapshot>,
1547 index_state: Option<&SyntaxIndexState>,
1548 prompt_max_bytes: Option<usize>,
1549 prompt_format: PromptFormat,
1550) -> predict_edits_v3::PredictEditsRequest {
1551 let mut signatures = Vec::new();
1552 let mut declaration_to_signature_index = HashMap::default();
1553 let mut referenced_declarations = Vec::new();
1554
1555 for snippet in context.declarations {
1556 let project_entry_id = snippet.declaration.project_entry_id();
1557 let Some(path) = worktrees.iter().find_map(|worktree| {
1558 worktree.entry_for_id(project_entry_id).map(|entry| {
1559 let mut full_path = RelPathBuf::new();
1560 full_path.push(worktree.root_name());
1561 full_path.push(&entry.path);
1562 full_path
1563 })
1564 }) else {
1565 continue;
1566 };
1567
1568 let parent_index = index_state.and_then(|index_state| {
1569 snippet.declaration.parent().and_then(|parent| {
1570 add_signature(
1571 parent,
1572 &mut declaration_to_signature_index,
1573 &mut signatures,
1574 index_state,
1575 )
1576 })
1577 });
1578
1579 let (text, text_is_truncated) = snippet.declaration.item_text();
1580 referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
1581 path: path.as_std_path().into(),
1582 text: text.into(),
1583 range: snippet.declaration.item_line_range(),
1584 text_is_truncated,
1585 signature_range: snippet.declaration.signature_range_in_item_text(),
1586 parent_index,
1587 signature_score: snippet.score(DeclarationStyle::Signature),
1588 declaration_score: snippet.score(DeclarationStyle::Declaration),
1589 score_components: snippet.components,
1590 });
1591 }
1592
1593 let excerpt_parent = index_state.and_then(|index_state| {
1594 context
1595 .excerpt
1596 .parent_declarations
1597 .last()
1598 .and_then(|(parent, _)| {
1599 add_signature(
1600 *parent,
1601 &mut declaration_to_signature_index,
1602 &mut signatures,
1603 index_state,
1604 )
1605 })
1606 });
1607
1608 predict_edits_v3::PredictEditsRequest {
1609 excerpt_path,
1610 excerpt: context.excerpt_text.body,
1611 excerpt_line_range: context.excerpt.line_range,
1612 excerpt_range: context.excerpt.range,
1613 cursor_point: predict_edits_v3::Point {
1614 line: predict_edits_v3::Line(context.cursor_point.row),
1615 column: context.cursor_point.column,
1616 },
1617 referenced_declarations,
1618 included_files: vec![],
1619 signatures,
1620 excerpt_parent,
1621 events,
1622 can_collect_data,
1623 diagnostic_groups,
1624 diagnostic_groups_truncated,
1625 git_info,
1626 debug_info,
1627 prompt_max_bytes,
1628 prompt_format,
1629 }
1630}
1631
1632fn add_signature(
1633 declaration_id: DeclarationId,
1634 declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
1635 signatures: &mut Vec<Signature>,
1636 index: &SyntaxIndexState,
1637) -> Option<usize> {
1638 if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
1639 return Some(*signature_index);
1640 }
1641 let Some(parent_declaration) = index.declaration(declaration_id) else {
1642 log::error!("bug: missing parent declaration");
1643 return None;
1644 };
1645 let parent_index = parent_declaration.parent().and_then(|parent| {
1646 add_signature(parent, declaration_to_signature_index, signatures, index)
1647 });
1648 let (text, text_is_truncated) = parent_declaration.signature_text();
1649 let signature_index = signatures.len();
1650 signatures.push(Signature {
1651 text: text.into(),
1652 text_is_truncated,
1653 parent_index,
1654 range: parent_declaration.signature_line_range(),
1655 });
1656 declaration_to_signature_index.insert(declaration_id, signature_index);
1657 Some(signature_index)
1658}
1659
1660#[cfg(test)]
1661mod tests {
1662 use std::{path::Path, sync::Arc};
1663
1664 use client::UserStore;
1665 use clock::FakeSystemClock;
1666 use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
1667 use futures::{
1668 AsyncReadExt, StreamExt,
1669 channel::{mpsc, oneshot},
1670 };
1671 use gpui::{
1672 Entity, TestAppContext,
1673 http_client::{FakeHttpClient, Response},
1674 prelude::*,
1675 };
1676 use indoc::indoc;
1677 use language::OffsetRangeExt as _;
1678 use open_ai::Usage;
1679 use pretty_assertions::{assert_eq, assert_matches};
1680 use project::{FakeFs, Project};
1681 use serde_json::json;
1682 use settings::SettingsStore;
1683 use util::path;
1684 use uuid::Uuid;
1685
1686 use crate::{BufferEditPrediction, Zeta};
1687
1688 #[gpui::test]
1689 async fn test_current_state(cx: &mut TestAppContext) {
1690 let (zeta, mut req_rx) = init_test(cx);
1691 let fs = FakeFs::new(cx.executor());
1692 fs.insert_tree(
1693 "/root",
1694 json!({
1695 "1.txt": "Hello!\nHow\nBye\n",
1696 "2.txt": "Hola!\nComo\nAdios\n"
1697 }),
1698 )
1699 .await;
1700 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1701
1702 zeta.update(cx, |zeta, cx| {
1703 zeta.register_project(&project, cx);
1704 });
1705
1706 let buffer1 = project
1707 .update(cx, |project, cx| {
1708 let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
1709 project.open_buffer(path, cx)
1710 })
1711 .await
1712 .unwrap();
1713 let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
1714 let position = snapshot1.anchor_before(language::Point::new(1, 3));
1715
1716 // Prediction for current file
1717
1718 let prediction_task = zeta.update(cx, |zeta, cx| {
1719 zeta.refresh_prediction(&project, &buffer1, position, cx)
1720 });
1721 let (_request, respond_tx) = req_rx.next().await.unwrap();
1722
1723 respond_tx
1724 .send(model_response(indoc! {r"
1725 --- a/root/1.txt
1726 +++ b/root/1.txt
1727 @@ ... @@
1728 Hello!
1729 -How
1730 +How are you?
1731 Bye
1732 "}))
1733 .unwrap();
1734 prediction_task.await.unwrap();
1735
1736 zeta.read_with(cx, |zeta, cx| {
1737 let prediction = zeta
1738 .current_prediction_for_buffer(&buffer1, &project, cx)
1739 .unwrap();
1740 assert_matches!(prediction, BufferEditPrediction::Local { .. });
1741 });
1742
1743 // Context refresh
1744 let refresh_task = zeta.update(cx, |zeta, cx| {
1745 zeta.refresh_context(project.clone(), buffer1.clone(), position, cx)
1746 });
1747 let (_request, respond_tx) = req_rx.next().await.unwrap();
1748 respond_tx
1749 .send(open_ai::Response {
1750 id: Uuid::new_v4().to_string(),
1751 object: "response".into(),
1752 created: 0,
1753 model: "model".into(),
1754 choices: vec![open_ai::Choice {
1755 index: 0,
1756 message: open_ai::RequestMessage::Assistant {
1757 content: None,
1758 tool_calls: vec![open_ai::ToolCall {
1759 id: "search".into(),
1760 content: open_ai::ToolCallContent::Function {
1761 function: open_ai::FunctionContent {
1762 name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME
1763 .to_string(),
1764 arguments: serde_json::to_string(&SearchToolInput {
1765 queries: Box::new([SearchToolQuery {
1766 glob: "root/2.txt".to_string(),
1767 syntax_node: vec![],
1768 content: Some(".".into()),
1769 }]),
1770 })
1771 .unwrap(),
1772 },
1773 },
1774 }],
1775 },
1776 finish_reason: None,
1777 }],
1778 usage: Usage {
1779 prompt_tokens: 0,
1780 completion_tokens: 0,
1781 total_tokens: 0,
1782 },
1783 })
1784 .unwrap();
1785 refresh_task.await.unwrap();
1786
1787 zeta.update(cx, |zeta, _cx| {
1788 zeta.discard_current_prediction(&project);
1789 });
1790
1791 // Prediction for another file
1792 let prediction_task = zeta.update(cx, |zeta, cx| {
1793 zeta.refresh_prediction(&project, &buffer1, position, cx)
1794 });
1795 let (_request, respond_tx) = req_rx.next().await.unwrap();
1796 respond_tx
1797 .send(model_response(indoc! {r#"
1798 --- a/root/2.txt
1799 +++ b/root/2.txt
1800 Hola!
1801 -Como
1802 +Como estas?
1803 Adios
1804 "#}))
1805 .unwrap();
1806 prediction_task.await.unwrap();
1807 zeta.read_with(cx, |zeta, cx| {
1808 let prediction = zeta
1809 .current_prediction_for_buffer(&buffer1, &project, cx)
1810 .unwrap();
1811 assert_matches!(
1812 prediction,
1813 BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
1814 );
1815 });
1816
1817 let buffer2 = project
1818 .update(cx, |project, cx| {
1819 let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
1820 project.open_buffer(path, cx)
1821 })
1822 .await
1823 .unwrap();
1824
1825 zeta.read_with(cx, |zeta, cx| {
1826 let prediction = zeta
1827 .current_prediction_for_buffer(&buffer2, &project, cx)
1828 .unwrap();
1829 assert_matches!(prediction, BufferEditPrediction::Local { .. });
1830 });
1831 }
1832
1833 #[gpui::test]
1834 async fn test_simple_request(cx: &mut TestAppContext) {
1835 let (zeta, mut req_rx) = init_test(cx);
1836 let fs = FakeFs::new(cx.executor());
1837 fs.insert_tree(
1838 "/root",
1839 json!({
1840 "foo.md": "Hello!\nHow\nBye\n"
1841 }),
1842 )
1843 .await;
1844 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1845
1846 let buffer = project
1847 .update(cx, |project, cx| {
1848 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1849 project.open_buffer(path, cx)
1850 })
1851 .await
1852 .unwrap();
1853 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1854 let position = snapshot.anchor_before(language::Point::new(1, 3));
1855
1856 let prediction_task = zeta.update(cx, |zeta, cx| {
1857 zeta.request_prediction(&project, &buffer, position, cx)
1858 });
1859
1860 let (_, respond_tx) = req_rx.next().await.unwrap();
1861
1862 // TODO Put back when we have a structured request again
1863 // assert_eq!(
1864 // request.excerpt_path.as_ref(),
1865 // Path::new(path!("root/foo.md"))
1866 // );
1867 // assert_eq!(
1868 // request.cursor_point,
1869 // Point {
1870 // line: Line(1),
1871 // column: 3
1872 // }
1873 // );
1874
1875 respond_tx
1876 .send(model_response(indoc! { r"
1877 --- a/root/foo.md
1878 +++ b/root/foo.md
1879 @@ ... @@
1880 Hello!
1881 -How
1882 +How are you?
1883 Bye
1884 "}))
1885 .unwrap();
1886
1887 let prediction = prediction_task.await.unwrap().unwrap();
1888
1889 assert_eq!(prediction.edits.len(), 1);
1890 assert_eq!(
1891 prediction.edits[0].0.to_point(&snapshot).start,
1892 language::Point::new(1, 3)
1893 );
1894 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
1895 }
1896
1897 #[gpui::test]
1898 async fn test_request_events(cx: &mut TestAppContext) {
1899 let (zeta, mut req_rx) = init_test(cx);
1900 let fs = FakeFs::new(cx.executor());
1901 fs.insert_tree(
1902 "/root",
1903 json!({
1904 "foo.md": "Hello!\n\nBye\n"
1905 }),
1906 )
1907 .await;
1908 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1909
1910 let buffer = project
1911 .update(cx, |project, cx| {
1912 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1913 project.open_buffer(path, cx)
1914 })
1915 .await
1916 .unwrap();
1917
1918 zeta.update(cx, |zeta, cx| {
1919 zeta.register_buffer(&buffer, &project, cx);
1920 });
1921
1922 buffer.update(cx, |buffer, cx| {
1923 buffer.edit(vec![(7..7, "How")], None, cx);
1924 });
1925
1926 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1927 let position = snapshot.anchor_before(language::Point::new(1, 3));
1928
1929 let prediction_task = zeta.update(cx, |zeta, cx| {
1930 zeta.request_prediction(&project, &buffer, position, cx)
1931 });
1932
1933 let (request, respond_tx) = req_rx.next().await.unwrap();
1934
1935 let prompt = prompt_from_request(&request);
1936 assert!(
1937 prompt.contains(indoc! {"
1938 --- a/root/foo.md
1939 +++ b/root/foo.md
1940 @@ -1,3 +1,3 @@
1941 Hello!
1942 -
1943 +How
1944 Bye
1945 "}),
1946 "{prompt}"
1947 );
1948
1949 respond_tx
1950 .send(model_response(indoc! {r#"
1951 --- a/root/foo.md
1952 +++ b/root/foo.md
1953 @@ ... @@
1954 Hello!
1955 -How
1956 +How are you?
1957 Bye
1958 "#}))
1959 .unwrap();
1960
1961 let prediction = prediction_task.await.unwrap().unwrap();
1962
1963 assert_eq!(prediction.edits.len(), 1);
1964 assert_eq!(
1965 prediction.edits[0].0.to_point(&snapshot).start,
1966 language::Point::new(1, 3)
1967 );
1968 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
1969 }
1970
1971 // Skipped until we start including diagnostics in prompt
1972 // #[gpui::test]
1973 // async fn test_request_diagnostics(cx: &mut TestAppContext) {
1974 // let (zeta, mut req_rx) = init_test(cx);
1975 // let fs = FakeFs::new(cx.executor());
1976 // fs.insert_tree(
1977 // "/root",
1978 // json!({
1979 // "foo.md": "Hello!\nBye"
1980 // }),
1981 // )
1982 // .await;
1983 // let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1984
1985 // let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1986 // let diagnostic = lsp::Diagnostic {
1987 // range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1988 // severity: Some(lsp::DiagnosticSeverity::ERROR),
1989 // message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1990 // ..Default::default()
1991 // };
1992
1993 // project.update(cx, |project, cx| {
1994 // project.lsp_store().update(cx, |lsp_store, cx| {
1995 // // Create some diagnostics
1996 // lsp_store
1997 // .update_diagnostics(
1998 // LanguageServerId(0),
1999 // lsp::PublishDiagnosticsParams {
2000 // uri: path_to_buffer_uri.clone(),
2001 // diagnostics: vec![diagnostic],
2002 // version: None,
2003 // },
2004 // None,
2005 // language::DiagnosticSourceKind::Pushed,
2006 // &[],
2007 // cx,
2008 // )
2009 // .unwrap();
2010 // });
2011 // });
2012
2013 // let buffer = project
2014 // .update(cx, |project, cx| {
2015 // let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2016 // project.open_buffer(path, cx)
2017 // })
2018 // .await
2019 // .unwrap();
2020
2021 // let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2022 // let position = snapshot.anchor_before(language::Point::new(0, 0));
2023
2024 // let _prediction_task = zeta.update(cx, |zeta, cx| {
2025 // zeta.request_prediction(&project, &buffer, position, cx)
2026 // });
2027
2028 // let (request, _respond_tx) = req_rx.next().await.unwrap();
2029
2030 // assert_eq!(request.diagnostic_groups.len(), 1);
2031 // let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
2032 // .unwrap();
2033 // // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
2034 // assert_eq!(
2035 // value,
2036 // json!({
2037 // "entries": [{
2038 // "range": {
2039 // "start": 8,
2040 // "end": 10
2041 // },
2042 // "diagnostic": {
2043 // "source": null,
2044 // "code": null,
2045 // "code_description": null,
2046 // "severity": 1,
2047 // "message": "\"Hello\" deprecated. Use \"Hi\" instead",
2048 // "markdown": null,
2049 // "group_id": 0,
2050 // "is_primary": true,
2051 // "is_disk_based": false,
2052 // "is_unnecessary": false,
2053 // "source_kind": "Pushed",
2054 // "data": null,
2055 // "underline": true
2056 // }
2057 // }],
2058 // "primary_ix": 0
2059 // })
2060 // );
2061 // }
2062
2063 fn model_response(text: &str) -> open_ai::Response {
2064 open_ai::Response {
2065 id: Uuid::new_v4().to_string(),
2066 object: "response".into(),
2067 created: 0,
2068 model: "model".into(),
2069 choices: vec![open_ai::Choice {
2070 index: 0,
2071 message: open_ai::RequestMessage::Assistant {
2072 content: Some(open_ai::MessageContent::Plain(text.to_string())),
2073 tool_calls: vec![],
2074 },
2075 finish_reason: None,
2076 }],
2077 usage: Usage {
2078 prompt_tokens: 0,
2079 completion_tokens: 0,
2080 total_tokens: 0,
2081 },
2082 }
2083 }
2084
2085 fn prompt_from_request(request: &open_ai::Request) -> &str {
2086 assert_eq!(request.messages.len(), 1);
2087 let open_ai::RequestMessage::User {
2088 content: open_ai::MessageContent::Plain(content),
2089 ..
2090 } = &request.messages[0]
2091 else {
2092 panic!(
2093 "Request does not have single user message of type Plain. {:#?}",
2094 request
2095 );
2096 };
2097 content
2098 }
2099
2100 fn init_test(
2101 cx: &mut TestAppContext,
2102 ) -> (
2103 Entity<Zeta>,
2104 mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender<open_ai::Response>)>,
2105 ) {
2106 cx.update(move |cx| {
2107 let settings_store = SettingsStore::test(cx);
2108 cx.set_global(settings_store);
2109 zlog::init_test();
2110
2111 let (req_tx, req_rx) = mpsc::unbounded();
2112
2113 let http_client = FakeHttpClient::create({
2114 move |req| {
2115 let uri = req.uri().path().to_string();
2116 let mut body = req.into_body();
2117 let req_tx = req_tx.clone();
2118 async move {
2119 let resp = match uri.as_str() {
2120 "/client/llm_tokens" => serde_json::to_string(&json!({
2121 "token": "test"
2122 }))
2123 .unwrap(),
2124 "/predict_edits/raw" => {
2125 let mut buf = Vec::new();
2126 body.read_to_end(&mut buf).await.ok();
2127 let req = serde_json::from_slice(&buf).unwrap();
2128
2129 let (res_tx, res_rx) = oneshot::channel();
2130 req_tx.unbounded_send((req, res_tx)).unwrap();
2131 serde_json::to_string(&res_rx.await?).unwrap()
2132 }
2133 _ => {
2134 panic!("Unexpected path: {}", uri)
2135 }
2136 };
2137
2138 Ok(Response::builder().body(resp.into()).unwrap())
2139 }
2140 }
2141 });
2142
2143 let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
2144 client.cloud_client().set_credentials(1, "test".into());
2145
2146 language_model::init(client.clone(), cx);
2147
2148 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2149 let zeta = Zeta::global(&client, &user_store, cx);
2150
2151 (zeta, req_rx)
2152 })
2153 }
2154}