1use anyhow::{Context as _, Result, anyhow};
2use chrono::TimeDelta;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
5use cloud_llm_client::{
6 AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
7 ZED_VERSION_HEADER_NAME,
8};
9use cloud_zeta2_prompt::DEFAULT_MAX_PROMPT_BYTES;
10use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
11use collections::HashMap;
12use edit_prediction_context::{
13 DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
14 EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionScoreOptions, Line,
15 SyntaxIndex, SyntaxIndexState,
16};
17use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
18use futures::AsyncReadExt as _;
19use futures::channel::{mpsc, oneshot};
20use gpui::http_client::{AsyncBody, Method};
21use gpui::{
22 App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
23 http_client, prelude::*,
24};
25use language::{Anchor, Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
26use language::{BufferSnapshot, OffsetRangeExt};
27use language_model::{LlmApiToken, RefreshLlmTokenListener};
28use open_ai::FunctionDefinition;
29use project::Project;
30use release_channel::AppVersion;
31use serde::de::DeserializeOwned;
32use std::collections::{VecDeque, hash_map};
33
34use std::env;
35use std::ops::Range;
36use std::path::Path;
37use std::str::FromStr as _;
38use std::sync::{Arc, LazyLock};
39use std::time::{Duration, Instant};
40use thiserror::Error;
41use util::rel_path::RelPathBuf;
42use util::{LogErrorFuture, TryFutureExt};
43use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
44
45pub mod merge_excerpts;
46mod prediction;
47mod provider;
48pub mod retrieval_search;
49pub mod udiff;
50
51use crate::merge_excerpts::merge_excerpts;
52use crate::prediction::EditPrediction;
53pub use crate::prediction::EditPredictionId;
54pub use provider::ZetaEditPredictionProvider;
55
56/// Maximum number of events to track.
57const MAX_EVENT_COUNT: usize = 16;
58
59pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
60 max_bytes: 512,
61 min_bytes: 128,
62 target_before_cursor_over_total_bytes: 0.5,
63};
64
65pub const DEFAULT_CONTEXT_OPTIONS: ContextMode =
66 ContextMode::Agentic(DEFAULT_AGENTIC_CONTEXT_OPTIONS);
67
68pub const DEFAULT_AGENTIC_CONTEXT_OPTIONS: AgenticContextOptions = AgenticContextOptions {
69 excerpt: DEFAULT_EXCERPT_OPTIONS,
70};
71
72pub const DEFAULT_SYNTAX_CONTEXT_OPTIONS: EditPredictionContextOptions =
73 EditPredictionContextOptions {
74 use_imports: true,
75 max_retrieved_declarations: 0,
76 excerpt: DEFAULT_EXCERPT_OPTIONS,
77 score: EditPredictionScoreOptions {
78 omit_excerpt_overlaps: true,
79 },
80 };
81
82pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
83 context: DEFAULT_CONTEXT_OPTIONS,
84 max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
85 max_diagnostic_bytes: 2048,
86 prompt_format: PromptFormat::DEFAULT,
87 file_indexing_parallelism: 1,
88 buffer_change_grouping_interval: Duration::from_secs(1),
89};
90
91static USE_OLLAMA: LazyLock<bool> =
92 LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
93static MODEL_ID: LazyLock<String> = LazyLock::new(|| {
94 env::var("ZED_ZETA2_MODEL").unwrap_or(if *USE_OLLAMA {
95 "qwen3-coder:30b".to_string()
96 } else {
97 "yqvev8r3".to_string()
98 })
99});
100static PREDICT_EDITS_URL: LazyLock<Option<String>> = LazyLock::new(|| {
101 env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| {
102 if *USE_OLLAMA {
103 Some("http://localhost:11434/v1/chat/completions".into())
104 } else {
105 None
106 }
107 })
108});
109
110pub struct Zeta2FeatureFlag;
111
112impl FeatureFlag for Zeta2FeatureFlag {
113 const NAME: &'static str = "zeta2";
114
115 fn enabled_for_staff() -> bool {
116 false
117 }
118}
119
120#[derive(Clone)]
121struct ZetaGlobal(Entity<Zeta>);
122
123impl Global for ZetaGlobal {}
124
125pub struct Zeta {
126 client: Arc<Client>,
127 user_store: Entity<UserStore>,
128 llm_token: LlmApiToken,
129 _llm_token_subscription: Subscription,
130 projects: HashMap<EntityId, ZetaProject>,
131 options: ZetaOptions,
132 update_required: bool,
133 debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
134}
135
136#[derive(Debug, Clone, PartialEq)]
137pub struct ZetaOptions {
138 pub context: ContextMode,
139 pub max_prompt_bytes: usize,
140 pub max_diagnostic_bytes: usize,
141 pub prompt_format: predict_edits_v3::PromptFormat,
142 pub file_indexing_parallelism: usize,
143 pub buffer_change_grouping_interval: Duration,
144}
145
146#[derive(Debug, Clone, PartialEq)]
147pub enum ContextMode {
148 Agentic(AgenticContextOptions),
149 Syntax(EditPredictionContextOptions),
150}
151
152#[derive(Debug, Clone, PartialEq)]
153pub struct AgenticContextOptions {
154 pub excerpt: EditPredictionExcerptOptions,
155}
156
157impl ContextMode {
158 pub fn excerpt(&self) -> &EditPredictionExcerptOptions {
159 match self {
160 ContextMode::Agentic(options) => &options.excerpt,
161 ContextMode::Syntax(options) => &options.excerpt,
162 }
163 }
164}
165
166#[derive(Debug)]
167pub enum ZetaDebugInfo {
168 ContextRetrievalStarted(ZetaContextRetrievalStartedDebugInfo),
169 SearchQueriesGenerated(ZetaSearchQueryDebugInfo),
170 SearchQueriesExecuted(ZetaContextRetrievalDebugInfo),
171 ContextRetrievalFinished(ZetaContextRetrievalDebugInfo),
172 EditPredictionRequested(ZetaEditPredictionDebugInfo),
173}
174
175#[derive(Debug)]
176pub struct ZetaContextRetrievalStartedDebugInfo {
177 pub project: Entity<Project>,
178 pub timestamp: Instant,
179 pub search_prompt: String,
180}
181
182#[derive(Debug)]
183pub struct ZetaContextRetrievalDebugInfo {
184 pub project: Entity<Project>,
185 pub timestamp: Instant,
186}
187
188#[derive(Debug)]
189pub struct ZetaEditPredictionDebugInfo {
190 pub request: predict_edits_v3::PredictEditsRequest,
191 pub retrieval_time: TimeDelta,
192 pub buffer: WeakEntity<Buffer>,
193 pub position: language::Anchor,
194 pub local_prompt: Result<String, String>,
195 pub response_rx: oneshot::Receiver<(Result<open_ai::Response, String>, TimeDelta)>,
196}
197
198#[derive(Debug)]
199pub struct ZetaSearchQueryDebugInfo {
200 pub project: Entity<Project>,
201 pub timestamp: Instant,
202 pub search_queries: Vec<SearchToolQuery>,
203}
204
205pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
206
207struct ZetaProject {
208 syntax_index: Entity<SyntaxIndex>,
209 events: VecDeque<Event>,
210 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
211 current_prediction: Option<CurrentEditPrediction>,
212 context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
213 refresh_context_task: Option<LogErrorFuture<Task<Result<()>>>>,
214 refresh_context_debounce_task: Option<Task<Option<()>>>,
215 refresh_context_timestamp: Option<Instant>,
216}
217
218#[derive(Debug, Clone)]
219struct CurrentEditPrediction {
220 pub requested_by_buffer_id: EntityId,
221 pub prediction: EditPrediction,
222}
223
224impl CurrentEditPrediction {
225 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
226 let Some(new_edits) = self
227 .prediction
228 .interpolate(&self.prediction.buffer.read(cx))
229 else {
230 return false;
231 };
232
233 if self.prediction.buffer != old_prediction.prediction.buffer {
234 return true;
235 }
236
237 let Some(old_edits) = old_prediction
238 .prediction
239 .interpolate(&old_prediction.prediction.buffer.read(cx))
240 else {
241 return true;
242 };
243
244 // This reduces the occurrence of UI thrash from replacing edits
245 //
246 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
247 if self.requested_by_buffer_id == self.prediction.buffer.entity_id()
248 && self.requested_by_buffer_id == old_prediction.prediction.buffer.entity_id()
249 && old_edits.len() == 1
250 && new_edits.len() == 1
251 {
252 let (old_range, old_text) = &old_edits[0];
253 let (new_range, new_text) = &new_edits[0];
254 new_range == old_range && new_text.starts_with(old_text.as_ref())
255 } else {
256 true
257 }
258 }
259}
260
261/// A prediction from the perspective of a buffer.
262#[derive(Debug)]
263enum BufferEditPrediction<'a> {
264 Local { prediction: &'a EditPrediction },
265 Jump { prediction: &'a EditPrediction },
266}
267
268struct RegisteredBuffer {
269 snapshot: BufferSnapshot,
270 _subscriptions: [gpui::Subscription; 2],
271}
272
273#[derive(Clone)]
274pub enum Event {
275 BufferChange {
276 old_snapshot: BufferSnapshot,
277 new_snapshot: BufferSnapshot,
278 timestamp: Instant,
279 },
280}
281
282impl Event {
283 pub fn to_request_event(&self, cx: &App) -> Option<predict_edits_v3::Event> {
284 match self {
285 Event::BufferChange {
286 old_snapshot,
287 new_snapshot,
288 ..
289 } => {
290 let path = new_snapshot.file().map(|f| f.full_path(cx));
291
292 let old_path = old_snapshot.file().and_then(|f| {
293 let old_path = f.full_path(cx);
294 if Some(&old_path) != path.as_ref() {
295 Some(old_path)
296 } else {
297 None
298 }
299 });
300
301 // TODO [zeta2] move to bg?
302 let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
303
304 if path == old_path && diff.is_empty() {
305 None
306 } else {
307 Some(predict_edits_v3::Event::BufferChange {
308 old_path,
309 path,
310 diff,
311 //todo: Actually detect if this edit was predicted or not
312 predicted: false,
313 })
314 }
315 }
316 }
317 }
318}
319
320impl Zeta {
321 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
322 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
323 }
324
325 pub fn global(
326 client: &Arc<Client>,
327 user_store: &Entity<UserStore>,
328 cx: &mut App,
329 ) -> Entity<Self> {
330 cx.try_global::<ZetaGlobal>()
331 .map(|global| global.0.clone())
332 .unwrap_or_else(|| {
333 let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
334 cx.set_global(ZetaGlobal(zeta.clone()));
335 zeta
336 })
337 }
338
339 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
340 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
341
342 Self {
343 projects: HashMap::default(),
344 client,
345 user_store,
346 options: DEFAULT_OPTIONS,
347 llm_token: LlmApiToken::default(),
348 _llm_token_subscription: cx.subscribe(
349 &refresh_llm_token_listener,
350 |this, _listener, _event, cx| {
351 let client = this.client.clone();
352 let llm_token = this.llm_token.clone();
353 cx.spawn(async move |_this, _cx| {
354 llm_token.refresh(&client).await?;
355 anyhow::Ok(())
356 })
357 .detach_and_log_err(cx);
358 },
359 ),
360 update_required: false,
361 debug_tx: None,
362 }
363 }
364
365 pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<ZetaDebugInfo> {
366 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
367 self.debug_tx = Some(debug_watch_tx);
368 debug_watch_rx
369 }
370
371 pub fn options(&self) -> &ZetaOptions {
372 &self.options
373 }
374
375 pub fn set_options(&mut self, options: ZetaOptions) {
376 self.options = options;
377 }
378
379 pub fn clear_history(&mut self) {
380 for zeta_project in self.projects.values_mut() {
381 zeta_project.events.clear();
382 }
383 }
384
385 pub fn history_for_project(&self, project: &Entity<Project>) -> impl Iterator<Item = &Event> {
386 self.projects
387 .get(&project.entity_id())
388 .map(|project| project.events.iter())
389 .into_iter()
390 .flatten()
391 }
392
393 pub fn context_for_project(
394 &self,
395 project: &Entity<Project>,
396 ) -> impl Iterator<Item = (Entity<Buffer>, &[Range<Anchor>])> {
397 self.projects
398 .get(&project.entity_id())
399 .and_then(|project| {
400 Some(
401 project
402 .context
403 .as_ref()?
404 .iter()
405 .map(|(buffer, ranges)| (buffer.clone(), ranges.as_slice())),
406 )
407 })
408 .into_iter()
409 .flatten()
410 }
411
412 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
413 self.user_store.read(cx).edit_prediction_usage()
414 }
415
416 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
417 self.get_or_init_zeta_project(project, cx);
418 }
419
420 pub fn register_buffer(
421 &mut self,
422 buffer: &Entity<Buffer>,
423 project: &Entity<Project>,
424 cx: &mut Context<Self>,
425 ) {
426 let zeta_project = self.get_or_init_zeta_project(project, cx);
427 Self::register_buffer_impl(zeta_project, buffer, project, cx);
428 }
429
430 fn get_or_init_zeta_project(
431 &mut self,
432 project: &Entity<Project>,
433 cx: &mut App,
434 ) -> &mut ZetaProject {
435 self.projects
436 .entry(project.entity_id())
437 .or_insert_with(|| ZetaProject {
438 syntax_index: cx.new(|cx| {
439 SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
440 }),
441 events: VecDeque::new(),
442 registered_buffers: HashMap::default(),
443 current_prediction: None,
444 context: None,
445 refresh_context_task: None,
446 refresh_context_debounce_task: None,
447 refresh_context_timestamp: None,
448 })
449 }
450
451 fn register_buffer_impl<'a>(
452 zeta_project: &'a mut ZetaProject,
453 buffer: &Entity<Buffer>,
454 project: &Entity<Project>,
455 cx: &mut Context<Self>,
456 ) -> &'a mut RegisteredBuffer {
457 let buffer_id = buffer.entity_id();
458 match zeta_project.registered_buffers.entry(buffer_id) {
459 hash_map::Entry::Occupied(entry) => entry.into_mut(),
460 hash_map::Entry::Vacant(entry) => {
461 let snapshot = buffer.read(cx).snapshot();
462 let project_entity_id = project.entity_id();
463 entry.insert(RegisteredBuffer {
464 snapshot,
465 _subscriptions: [
466 cx.subscribe(buffer, {
467 let project = project.downgrade();
468 move |this, buffer, event, cx| {
469 if let language::BufferEvent::Edited = event
470 && let Some(project) = project.upgrade()
471 {
472 this.report_changes_for_buffer(&buffer, &project, cx);
473 }
474 }
475 }),
476 cx.observe_release(buffer, move |this, _buffer, _cx| {
477 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
478 else {
479 return;
480 };
481 zeta_project.registered_buffers.remove(&buffer_id);
482 }),
483 ],
484 })
485 }
486 }
487 }
488
489 fn report_changes_for_buffer(
490 &mut self,
491 buffer: &Entity<Buffer>,
492 project: &Entity<Project>,
493 cx: &mut Context<Self>,
494 ) -> BufferSnapshot {
495 let buffer_change_grouping_interval = self.options.buffer_change_grouping_interval;
496 let zeta_project = self.get_or_init_zeta_project(project, cx);
497 let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
498
499 let new_snapshot = buffer.read(cx).snapshot();
500 if new_snapshot.version != registered_buffer.snapshot.version {
501 let old_snapshot =
502 std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
503 Self::push_event(
504 zeta_project,
505 buffer_change_grouping_interval,
506 Event::BufferChange {
507 old_snapshot,
508 new_snapshot: new_snapshot.clone(),
509 timestamp: Instant::now(),
510 },
511 );
512 }
513
514 new_snapshot
515 }
516
517 fn push_event(
518 zeta_project: &mut ZetaProject,
519 buffer_change_grouping_interval: Duration,
520 event: Event,
521 ) {
522 let events = &mut zeta_project.events;
523
524 if buffer_change_grouping_interval > Duration::ZERO
525 && let Some(Event::BufferChange {
526 new_snapshot: last_new_snapshot,
527 timestamp: last_timestamp,
528 ..
529 }) = events.back_mut()
530 {
531 // Coalesce edits for the same buffer when they happen one after the other.
532 let Event::BufferChange {
533 old_snapshot,
534 new_snapshot,
535 timestamp,
536 } = &event;
537
538 if timestamp.duration_since(*last_timestamp) <= buffer_change_grouping_interval
539 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
540 && old_snapshot.version == last_new_snapshot.version
541 {
542 *last_new_snapshot = new_snapshot.clone();
543 *last_timestamp = *timestamp;
544 return;
545 }
546 }
547
548 if events.len() >= MAX_EVENT_COUNT {
549 // These are halved instead of popping to improve prompt caching.
550 events.drain(..MAX_EVENT_COUNT / 2);
551 }
552
553 events.push_back(event);
554 }
555
556 fn current_prediction_for_buffer(
557 &self,
558 buffer: &Entity<Buffer>,
559 project: &Entity<Project>,
560 cx: &App,
561 ) -> Option<BufferEditPrediction<'_>> {
562 let project_state = self.projects.get(&project.entity_id())?;
563
564 let CurrentEditPrediction {
565 requested_by_buffer_id,
566 prediction,
567 } = project_state.current_prediction.as_ref()?;
568
569 if prediction.targets_buffer(buffer.read(cx)) {
570 Some(BufferEditPrediction::Local { prediction })
571 } else if *requested_by_buffer_id == buffer.entity_id() {
572 Some(BufferEditPrediction::Jump { prediction })
573 } else {
574 None
575 }
576 }
577
578 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
579 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
580 return;
581 };
582
583 let Some(prediction) = project_state.current_prediction.take() else {
584 return;
585 };
586 let request_id = prediction.prediction.id.to_string();
587
588 let client = self.client.clone();
589 let llm_token = self.llm_token.clone();
590 let app_version = AppVersion::global(cx);
591 cx.spawn(async move |this, cx| {
592 let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") {
593 http_client::Url::parse(&predict_edits_url)?
594 } else {
595 client
596 .http_client()
597 .build_zed_llm_url("/predict_edits/accept", &[])?
598 };
599
600 let response = cx
601 .background_spawn(Self::send_api_request::<()>(
602 move |builder| {
603 let req = builder.uri(url.as_ref()).body(
604 serde_json::to_string(&AcceptEditPredictionBody {
605 request_id: request_id.clone(),
606 })?
607 .into(),
608 );
609 Ok(req?)
610 },
611 client,
612 llm_token,
613 app_version,
614 ))
615 .await;
616
617 Self::handle_api_response(&this, response, cx)?;
618 anyhow::Ok(())
619 })
620 .detach_and_log_err(cx);
621 }
622
623 fn discard_current_prediction(&mut self, project: &Entity<Project>) {
624 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
625 project_state.current_prediction.take();
626 };
627 }
628
629 pub fn refresh_prediction(
630 &mut self,
631 project: &Entity<Project>,
632 buffer: &Entity<Buffer>,
633 position: language::Anchor,
634 cx: &mut Context<Self>,
635 ) -> Task<Result<()>> {
636 let request_task = self.request_prediction(project, buffer, position, cx);
637 let buffer = buffer.clone();
638 let project = project.clone();
639
640 cx.spawn(async move |this, cx| {
641 if let Some(prediction) = request_task.await? {
642 this.update(cx, |this, cx| {
643 let project_state = this
644 .projects
645 .get_mut(&project.entity_id())
646 .context("Project not found")?;
647
648 let new_prediction = CurrentEditPrediction {
649 requested_by_buffer_id: buffer.entity_id(),
650 prediction: prediction,
651 };
652
653 if project_state
654 .current_prediction
655 .as_ref()
656 .is_none_or(|old_prediction| {
657 new_prediction.should_replace_prediction(&old_prediction, cx)
658 })
659 {
660 project_state.current_prediction = Some(new_prediction);
661 }
662 anyhow::Ok(())
663 })??;
664 }
665 Ok(())
666 })
667 }
668
669 pub fn request_prediction(
670 &mut self,
671 project: &Entity<Project>,
672 active_buffer: &Entity<Buffer>,
673 position: language::Anchor,
674 cx: &mut Context<Self>,
675 ) -> Task<Result<Option<EditPrediction>>> {
676 let project_state = self.projects.get(&project.entity_id());
677
678 let index_state = project_state.map(|state| {
679 state
680 .syntax_index
681 .read_with(cx, |index, _cx| index.state().clone())
682 });
683 let options = self.options.clone();
684 let active_snapshot = active_buffer.read(cx).snapshot();
685 let Some(excerpt_path) = active_snapshot
686 .file()
687 .map(|path| -> Arc<Path> { path.full_path(cx).into() })
688 else {
689 return Task::ready(Err(anyhow!("No file path for excerpt")));
690 };
691 let client = self.client.clone();
692 let llm_token = self.llm_token.clone();
693 let app_version = AppVersion::global(cx);
694 let worktree_snapshots = project
695 .read(cx)
696 .worktrees(cx)
697 .map(|worktree| worktree.read(cx).snapshot())
698 .collect::<Vec<_>>();
699 let debug_tx = self.debug_tx.clone();
700
701 let events = project_state
702 .map(|state| {
703 state
704 .events
705 .iter()
706 .filter_map(|event| event.to_request_event(cx))
707 .collect::<Vec<_>>()
708 })
709 .unwrap_or_default();
710
711 let diagnostics = active_snapshot.diagnostic_sets().clone();
712
713 let parent_abs_path =
714 project::File::from_dyn(active_buffer.read(cx).file()).and_then(|f| {
715 let mut path = f.worktree.read(cx).absolutize(&f.path);
716 if path.pop() { Some(path) } else { None }
717 });
718
719 // TODO data collection
720 let can_collect_data = cx.is_staff();
721
722 let mut included_files = project_state
723 .and_then(|project_state| project_state.context.as_ref())
724 .unwrap_or(&HashMap::default())
725 .iter()
726 .filter_map(|(buffer_entity, ranges)| {
727 let buffer = buffer_entity.read(cx);
728 Some((
729 buffer_entity.clone(),
730 buffer.snapshot(),
731 buffer.file()?.full_path(cx).into(),
732 ranges.clone(),
733 ))
734 })
735 .collect::<Vec<_>>();
736
737 let request_task = cx.background_spawn({
738 let active_buffer = active_buffer.clone();
739 async move {
740 let index_state = if let Some(index_state) = index_state {
741 Some(index_state.lock_owned().await)
742 } else {
743 None
744 };
745
746 let cursor_offset = position.to_offset(&active_snapshot);
747 let cursor_point = cursor_offset.to_point(&active_snapshot);
748
749 let before_retrieval = chrono::Utc::now();
750
751 let (diagnostic_groups, diagnostic_groups_truncated) =
752 Self::gather_nearby_diagnostics(
753 cursor_offset,
754 &diagnostics,
755 &active_snapshot,
756 options.max_diagnostic_bytes,
757 );
758
759 let cloud_request = match options.context {
760 ContextMode::Agentic(context_options) => {
761 let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
762 cursor_point,
763 &active_snapshot,
764 &context_options.excerpt,
765 index_state.as_deref(),
766 ) else {
767 return Ok((None, None));
768 };
769
770 let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start)
771 ..active_snapshot.anchor_before(excerpt.range.end);
772
773 if let Some(buffer_ix) =
774 included_files.iter().position(|(_, snapshot, _, _)| {
775 snapshot.remote_id() == active_snapshot.remote_id()
776 })
777 {
778 let (_, buffer, _, ranges) = &mut included_files[buffer_ix];
779 let range_ix = ranges
780 .binary_search_by(|probe| {
781 probe
782 .start
783 .cmp(&excerpt_anchor_range.start, buffer)
784 .then(excerpt_anchor_range.end.cmp(&probe.end, buffer))
785 })
786 .unwrap_or_else(|ix| ix);
787
788 ranges.insert(range_ix, excerpt_anchor_range);
789 let last_ix = included_files.len() - 1;
790 included_files.swap(buffer_ix, last_ix);
791 } else {
792 included_files.push((
793 active_buffer.clone(),
794 active_snapshot,
795 excerpt_path.clone(),
796 vec![excerpt_anchor_range],
797 ));
798 }
799
800 let included_files = included_files
801 .iter()
802 .map(|(_, buffer, path, ranges)| {
803 let excerpts = merge_excerpts(
804 &buffer,
805 ranges.iter().map(|range| {
806 let point_range = range.to_point(&buffer);
807 Line(point_range.start.row)..Line(point_range.end.row)
808 }),
809 );
810 predict_edits_v3::IncludedFile {
811 path: path.clone(),
812 max_row: Line(buffer.max_point().row),
813 excerpts,
814 }
815 })
816 .collect::<Vec<_>>();
817
818 predict_edits_v3::PredictEditsRequest {
819 excerpt_path,
820 excerpt: String::new(),
821 excerpt_line_range: Line(0)..Line(0),
822 excerpt_range: 0..0,
823 cursor_point: predict_edits_v3::Point {
824 line: predict_edits_v3::Line(cursor_point.row),
825 column: cursor_point.column,
826 },
827 included_files,
828 referenced_declarations: vec![],
829 events,
830 can_collect_data,
831 diagnostic_groups,
832 diagnostic_groups_truncated,
833 debug_info: debug_tx.is_some(),
834 prompt_max_bytes: Some(options.max_prompt_bytes),
835 prompt_format: options.prompt_format,
836 // TODO [zeta2]
837 signatures: vec![],
838 excerpt_parent: None,
839 git_info: None,
840 }
841 }
842 ContextMode::Syntax(context_options) => {
843 let Some(context) = EditPredictionContext::gather_context(
844 cursor_point,
845 &active_snapshot,
846 parent_abs_path.as_deref(),
847 &context_options,
848 index_state.as_deref(),
849 ) else {
850 return Ok((None, None));
851 };
852
853 make_syntax_context_cloud_request(
854 excerpt_path,
855 context,
856 events,
857 can_collect_data,
858 diagnostic_groups,
859 diagnostic_groups_truncated,
860 None,
861 debug_tx.is_some(),
862 &worktree_snapshots,
863 index_state.as_deref(),
864 Some(options.max_prompt_bytes),
865 options.prompt_format,
866 )
867 }
868 };
869
870 let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request);
871
872 let retrieval_time = chrono::Utc::now() - before_retrieval;
873
874 let debug_response_tx = if let Some(debug_tx) = &debug_tx {
875 let (response_tx, response_rx) = oneshot::channel();
876
877 debug_tx
878 .unbounded_send(ZetaDebugInfo::EditPredictionRequested(
879 ZetaEditPredictionDebugInfo {
880 request: cloud_request.clone(),
881 retrieval_time,
882 buffer: active_buffer.downgrade(),
883 local_prompt: match prompt_result.as_ref() {
884 Ok((prompt, _)) => Ok(prompt.clone()),
885 Err(err) => Err(err.to_string()),
886 },
887 position,
888 response_rx,
889 },
890 ))
891 .ok();
892 Some(response_tx)
893 } else {
894 None
895 };
896
897 if cfg!(debug_assertions) && env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
898 if let Some(debug_response_tx) = debug_response_tx {
899 debug_response_tx
900 .send((Err("Request skipped".to_string()), TimeDelta::zero()))
901 .ok();
902 }
903 anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
904 }
905
906 let (prompt, _) = prompt_result?;
907 let request = open_ai::Request {
908 model: MODEL_ID.clone(),
909 messages: vec![open_ai::RequestMessage::User {
910 content: open_ai::MessageContent::Plain(prompt),
911 }],
912 stream: false,
913 max_completion_tokens: None,
914 stop: Default::default(),
915 temperature: 0.7,
916 tool_choice: None,
917 parallel_tool_calls: None,
918 tools: vec![],
919 prompt_cache_key: None,
920 reasoning_effort: None,
921 };
922
923 log::trace!("Sending edit prediction request");
924
925 let before_request = chrono::Utc::now();
926 let response =
927 Self::send_raw_llm_request(client, llm_token, app_version, request).await;
928 let request_time = chrono::Utc::now() - before_request;
929
930 log::trace!("Got edit prediction response");
931
932 if let Some(debug_response_tx) = debug_response_tx {
933 debug_response_tx
934 .send((
935 response
936 .as_ref()
937 .map_err(|err| err.to_string())
938 .map(|response| response.0.clone()),
939 request_time,
940 ))
941 .ok();
942 }
943
944 let (res, usage) = response?;
945 let request_id = EditPredictionId(res.id.clone().into());
946 let Some(output_text) = text_from_response(res) else {
947 return Ok((None, usage))
948 };
949
950 let (edited_buffer_snapshot, edits) =
951 crate::udiff::parse_diff(&output_text, |path| {
952 included_files
953 .iter()
954 .find_map(|(_, buffer, probe_path, ranges)| {
955 if probe_path.as_ref() == path {
956 Some((buffer, ranges.as_slice()))
957 } else {
958 None
959 }
960 })
961 })
962 .await?;
963
964 let edited_buffer = included_files
965 .iter()
966 .find_map(|(buffer, snapshot, _, _)| {
967 if snapshot.remote_id() == edited_buffer_snapshot.remote_id() {
968 Some(buffer.clone())
969 } else {
970 None
971 }
972 })
973 .context("Failed to find buffer in included_buffers, even though we just found the snapshot")?;
974
975 anyhow::Ok((Some((request_id, edited_buffer, edited_buffer_snapshot.clone(), edits)), usage))
976 }
977 });
978
979 cx.spawn({
980 async move |this, cx| {
981 let Some((id, edited_buffer, edited_buffer_snapshot, edits)) =
982 Self::handle_api_response(&this, request_task.await, cx)?
983 else {
984 return Ok(None);
985 };
986
987 // TODO telemetry: duration, etc
988 Ok(
989 EditPrediction::new(id, &edited_buffer, &edited_buffer_snapshot, edits, cx)
990 .await,
991 )
992 }
993 })
994 }
995
996 async fn send_raw_llm_request(
997 client: Arc<Client>,
998 llm_token: LlmApiToken,
999 app_version: SemanticVersion,
1000 request: open_ai::Request,
1001 ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
1002 let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
1003 http_client::Url::parse(&predict_edits_url)?
1004 } else {
1005 client
1006 .http_client()
1007 .build_zed_llm_url("/predict_edits/raw", &[])?
1008 };
1009
1010 Self::send_api_request(
1011 |builder| {
1012 let req = builder
1013 .uri(url.as_ref())
1014 .body(serde_json::to_string(&request)?.into());
1015 Ok(req?)
1016 },
1017 client,
1018 llm_token,
1019 app_version,
1020 )
1021 .await
1022 }
1023
1024 fn handle_api_response<T>(
1025 this: &WeakEntity<Self>,
1026 response: Result<(T, Option<EditPredictionUsage>)>,
1027 cx: &mut gpui::AsyncApp,
1028 ) -> Result<T> {
1029 match response {
1030 Ok((data, usage)) => {
1031 if let Some(usage) = usage {
1032 this.update(cx, |this, cx| {
1033 this.user_store.update(cx, |user_store, cx| {
1034 user_store.update_edit_prediction_usage(usage, cx);
1035 });
1036 })
1037 .ok();
1038 }
1039 Ok(data)
1040 }
1041 Err(err) => {
1042 if err.is::<ZedUpdateRequiredError>() {
1043 cx.update(|cx| {
1044 this.update(cx, |this, _cx| {
1045 this.update_required = true;
1046 })
1047 .ok();
1048
1049 let error_message: SharedString = err.to_string().into();
1050 show_app_notification(
1051 NotificationId::unique::<ZedUpdateRequiredError>(),
1052 cx,
1053 move |cx| {
1054 cx.new(|cx| {
1055 ErrorMessagePrompt::new(error_message.clone(), cx)
1056 .with_link_button("Update Zed", "https://zed.dev/releases")
1057 })
1058 },
1059 );
1060 })
1061 .ok();
1062 }
1063 Err(err)
1064 }
1065 }
1066 }
1067
1068 async fn send_api_request<Res>(
1069 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1070 client: Arc<Client>,
1071 llm_token: LlmApiToken,
1072 app_version: SemanticVersion,
1073 ) -> Result<(Res, Option<EditPredictionUsage>)>
1074 where
1075 Res: DeserializeOwned,
1076 {
1077 let http_client = client.http_client();
1078 let mut token = llm_token.acquire(&client).await?;
1079 let mut did_retry = false;
1080
1081 loop {
1082 let request_builder = http_client::Request::builder().method(Method::POST);
1083
1084 let request = build(
1085 request_builder
1086 .header("Content-Type", "application/json")
1087 .header("Authorization", format!("Bearer {}", token))
1088 .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
1089 )?;
1090
1091 let mut response = http_client.send(request).await?;
1092
1093 if let Some(minimum_required_version) = response
1094 .headers()
1095 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
1096 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
1097 {
1098 anyhow::ensure!(
1099 app_version >= minimum_required_version,
1100 ZedUpdateRequiredError {
1101 minimum_version: minimum_required_version
1102 }
1103 );
1104 }
1105
1106 if response.status().is_success() {
1107 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
1108
1109 let mut body = Vec::new();
1110 response.body_mut().read_to_end(&mut body).await?;
1111 return Ok((serde_json::from_slice(&body)?, usage));
1112 } else if !did_retry
1113 && response
1114 .headers()
1115 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1116 .is_some()
1117 {
1118 did_retry = true;
1119 token = llm_token.refresh(&client).await?;
1120 } else {
1121 let mut body = String::new();
1122 response.body_mut().read_to_string(&mut body).await?;
1123 anyhow::bail!(
1124 "Request failed with status: {:?}\nBody: {}",
1125 response.status(),
1126 body
1127 );
1128 }
1129 }
1130 }
1131
1132 pub const CONTEXT_RETRIEVAL_IDLE_DURATION: Duration = Duration::from_secs(10);
1133 pub const CONTEXT_RETRIEVAL_DEBOUNCE_DURATION: Duration = Duration::from_secs(3);
1134
1135 // Refresh the related excerpts when the user just beguns editing after
1136 // an idle period, and after they pause editing.
1137 fn refresh_context_if_needed(
1138 &mut self,
1139 project: &Entity<Project>,
1140 buffer: &Entity<language::Buffer>,
1141 cursor_position: language::Anchor,
1142 cx: &mut Context<Self>,
1143 ) {
1144 if !matches!(&self.options().context, ContextMode::Agentic { .. }) {
1145 return;
1146 }
1147
1148 let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1149 return;
1150 };
1151
1152 let now = Instant::now();
1153 let was_idle = zeta_project
1154 .refresh_context_timestamp
1155 .map_or(true, |timestamp| {
1156 now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION
1157 });
1158 zeta_project.refresh_context_timestamp = Some(now);
1159 zeta_project.refresh_context_debounce_task = Some(cx.spawn({
1160 let buffer = buffer.clone();
1161 let project = project.clone();
1162 async move |this, cx| {
1163 if was_idle {
1164 log::debug!("refetching edit prediction context after idle");
1165 } else {
1166 cx.background_executor()
1167 .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION)
1168 .await;
1169 log::debug!("refetching edit prediction context after pause");
1170 }
1171 this.update(cx, |this, cx| {
1172 let task = this.refresh_context(project.clone(), buffer, cursor_position, cx);
1173
1174 if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
1175 zeta_project.refresh_context_task = Some(task.log_err());
1176 };
1177 })
1178 .ok()
1179 }
1180 }));
1181 }
1182
1183 // Refresh the related excerpts asynchronously. Ensure the task runs to completion,
1184 // and avoid spawning more than one concurrent task.
1185 pub fn refresh_context(
1186 &mut self,
1187 project: Entity<Project>,
1188 buffer: Entity<language::Buffer>,
1189 cursor_position: language::Anchor,
1190 cx: &mut Context<Self>,
1191 ) -> Task<Result<()>> {
1192 let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
1193 return Task::ready(anyhow::Ok(()));
1194 };
1195
1196 let ContextMode::Agentic(options) = &self.options().context else {
1197 return Task::ready(anyhow::Ok(()));
1198 };
1199
1200 let snapshot = buffer.read(cx).snapshot();
1201 let cursor_point = cursor_position.to_point(&snapshot);
1202 let Some(cursor_excerpt) = EditPredictionExcerpt::select_from_buffer(
1203 cursor_point,
1204 &snapshot,
1205 &options.excerpt,
1206 None,
1207 ) else {
1208 return Task::ready(Ok(()));
1209 };
1210
1211 let app_version = AppVersion::global(cx);
1212 let client = self.client.clone();
1213 let llm_token = self.llm_token.clone();
1214 let debug_tx = self.debug_tx.clone();
1215 let current_file_path: Arc<Path> = snapshot
1216 .file()
1217 .map(|f| f.full_path(cx).into())
1218 .unwrap_or_else(|| Path::new("untitled").into());
1219
1220 let prompt = match cloud_zeta2_prompt::retrieval_prompt::build_prompt(
1221 predict_edits_v3::PlanContextRetrievalRequest {
1222 excerpt: cursor_excerpt.text(&snapshot).body,
1223 excerpt_path: current_file_path,
1224 excerpt_line_range: cursor_excerpt.line_range,
1225 cursor_file_max_row: Line(snapshot.max_point().row),
1226 events: zeta_project
1227 .events
1228 .iter()
1229 .filter_map(|ev| ev.to_request_event(cx))
1230 .collect(),
1231 },
1232 ) {
1233 Ok(prompt) => prompt,
1234 Err(err) => {
1235 return Task::ready(Err(err));
1236 }
1237 };
1238
1239 if let Some(debug_tx) = &debug_tx {
1240 debug_tx
1241 .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted(
1242 ZetaContextRetrievalStartedDebugInfo {
1243 project: project.clone(),
1244 timestamp: Instant::now(),
1245 search_prompt: prompt.clone(),
1246 },
1247 ))
1248 .ok();
1249 }
1250
1251 pub static TOOL_SCHEMA: LazyLock<(serde_json::Value, String)> = LazyLock::new(|| {
1252 let schema = language_model::tool_schema::root_schema_for::<SearchToolInput>(
1253 language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset,
1254 );
1255
1256 let description = schema
1257 .get("description")
1258 .and_then(|description| description.as_str())
1259 .unwrap()
1260 .to_string();
1261
1262 (schema.into(), description)
1263 });
1264
1265 let (tool_schema, tool_description) = TOOL_SCHEMA.clone();
1266
1267 let request = open_ai::Request {
1268 model: MODEL_ID.clone(),
1269 messages: vec![open_ai::RequestMessage::User {
1270 content: open_ai::MessageContent::Plain(prompt),
1271 }],
1272 stream: false,
1273 max_completion_tokens: None,
1274 stop: Default::default(),
1275 temperature: 0.7,
1276 tool_choice: None,
1277 parallel_tool_calls: None,
1278 tools: vec![open_ai::ToolDefinition::Function {
1279 function: FunctionDefinition {
1280 name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME.to_string(),
1281 description: Some(tool_description),
1282 parameters: Some(tool_schema),
1283 },
1284 }],
1285 prompt_cache_key: None,
1286 reasoning_effort: None,
1287 };
1288
1289 cx.spawn(async move |this, cx| {
1290 log::trace!("Sending search planning request");
1291 let response =
1292 Self::send_raw_llm_request(client, llm_token, app_version, request).await;
1293 let mut response = Self::handle_api_response(&this, response, cx)?;
1294 log::trace!("Got search planning response");
1295
1296 let choice = response
1297 .choices
1298 .pop()
1299 .context("No choices in retrieval response")?;
1300 let open_ai::RequestMessage::Assistant {
1301 content: _,
1302 tool_calls,
1303 } = choice.message
1304 else {
1305 anyhow::bail!("Retrieval response didn't include an assistant message");
1306 };
1307
1308 let mut queries: Vec<SearchToolQuery> = Vec::new();
1309 for tool_call in tool_calls {
1310 let open_ai::ToolCallContent::Function { function } = tool_call.content;
1311 if function.name != cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME {
1312 log::warn!(
1313 "Context retrieval response tried to call an unknown tool: {}",
1314 function.name
1315 );
1316
1317 continue;
1318 }
1319
1320 let input: SearchToolInput = serde_json::from_str(&function.arguments)?;
1321 queries.extend(input.queries);
1322 }
1323
1324 if let Some(debug_tx) = &debug_tx {
1325 debug_tx
1326 .unbounded_send(ZetaDebugInfo::SearchQueriesGenerated(
1327 ZetaSearchQueryDebugInfo {
1328 project: project.clone(),
1329 timestamp: Instant::now(),
1330 search_queries: queries.clone(),
1331 },
1332 ))
1333 .ok();
1334 }
1335
1336 log::trace!("Running retrieval search: {queries:#?}");
1337
1338 let related_excerpts_result =
1339 retrieval_search::run_retrieval_searches(project.clone(), queries, cx).await;
1340
1341 log::trace!("Search queries executed");
1342
1343 if let Some(debug_tx) = &debug_tx {
1344 debug_tx
1345 .unbounded_send(ZetaDebugInfo::SearchQueriesExecuted(
1346 ZetaContextRetrievalDebugInfo {
1347 project: project.clone(),
1348 timestamp: Instant::now(),
1349 },
1350 ))
1351 .ok();
1352 }
1353
1354 this.update(cx, |this, _cx| {
1355 let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
1356 return Ok(());
1357 };
1358 zeta_project.refresh_context_task.take();
1359 if let Some(debug_tx) = &this.debug_tx {
1360 debug_tx
1361 .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished(
1362 ZetaContextRetrievalDebugInfo {
1363 project,
1364 timestamp: Instant::now(),
1365 },
1366 ))
1367 .ok();
1368 }
1369 match related_excerpts_result {
1370 Ok(excerpts) => {
1371 zeta_project.context = Some(excerpts);
1372 Ok(())
1373 }
1374 Err(error) => Err(error),
1375 }
1376 })?
1377 })
1378 }
1379
1380 fn gather_nearby_diagnostics(
1381 cursor_offset: usize,
1382 diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
1383 snapshot: &BufferSnapshot,
1384 max_diagnostics_bytes: usize,
1385 ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
1386 // TODO: Could make this more efficient
1387 let mut diagnostic_groups = Vec::new();
1388 for (language_server_id, diagnostics) in diagnostic_sets {
1389 let mut groups = Vec::new();
1390 diagnostics.groups(*language_server_id, &mut groups, &snapshot);
1391 diagnostic_groups.extend(
1392 groups
1393 .into_iter()
1394 .map(|(_, group)| group.resolve::<usize>(&snapshot)),
1395 );
1396 }
1397
1398 // sort by proximity to cursor
1399 diagnostic_groups.sort_by_key(|group| {
1400 let range = &group.entries[group.primary_ix].range;
1401 if range.start >= cursor_offset {
1402 range.start - cursor_offset
1403 } else if cursor_offset >= range.end {
1404 cursor_offset - range.end
1405 } else {
1406 (cursor_offset - range.start).min(range.end - cursor_offset)
1407 }
1408 });
1409
1410 let mut results = Vec::new();
1411 let mut diagnostic_groups_truncated = false;
1412 let mut diagnostics_byte_count = 0;
1413 for group in diagnostic_groups {
1414 let raw_value = serde_json::value::to_raw_value(&group).unwrap();
1415 diagnostics_byte_count += raw_value.get().len();
1416 if diagnostics_byte_count > max_diagnostics_bytes {
1417 diagnostic_groups_truncated = true;
1418 break;
1419 }
1420 results.push(predict_edits_v3::DiagnosticGroup(raw_value));
1421 }
1422
1423 (results, diagnostic_groups_truncated)
1424 }
1425
1426 // TODO: Dedupe with similar code in request_prediction?
1427 pub fn cloud_request_for_zeta_cli(
1428 &mut self,
1429 project: &Entity<Project>,
1430 buffer: &Entity<Buffer>,
1431 position: language::Anchor,
1432 cx: &mut Context<Self>,
1433 ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
1434 let project_state = self.projects.get(&project.entity_id());
1435
1436 let index_state = project_state.map(|state| {
1437 state
1438 .syntax_index
1439 .read_with(cx, |index, _cx| index.state().clone())
1440 });
1441 let options = self.options.clone();
1442 let snapshot = buffer.read(cx).snapshot();
1443 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
1444 return Task::ready(Err(anyhow!("No file path for excerpt")));
1445 };
1446 let worktree_snapshots = project
1447 .read(cx)
1448 .worktrees(cx)
1449 .map(|worktree| worktree.read(cx).snapshot())
1450 .collect::<Vec<_>>();
1451
1452 let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
1453 let mut path = f.worktree.read(cx).absolutize(&f.path);
1454 if path.pop() { Some(path) } else { None }
1455 });
1456
1457 cx.background_spawn(async move {
1458 let index_state = if let Some(index_state) = index_state {
1459 Some(index_state.lock_owned().await)
1460 } else {
1461 None
1462 };
1463
1464 let cursor_point = position.to_point(&snapshot);
1465
1466 let debug_info = true;
1467 EditPredictionContext::gather_context(
1468 cursor_point,
1469 &snapshot,
1470 parent_abs_path.as_deref(),
1471 match &options.context {
1472 ContextMode::Agentic(_) => {
1473 // TODO
1474 panic!("Llm mode not supported in zeta cli yet");
1475 }
1476 ContextMode::Syntax(edit_prediction_context_options) => {
1477 edit_prediction_context_options
1478 }
1479 },
1480 index_state.as_deref(),
1481 )
1482 .context("Failed to select excerpt")
1483 .map(|context| {
1484 make_syntax_context_cloud_request(
1485 excerpt_path.into(),
1486 context,
1487 // TODO pass everything
1488 Vec::new(),
1489 false,
1490 Vec::new(),
1491 false,
1492 None,
1493 debug_info,
1494 &worktree_snapshots,
1495 index_state.as_deref(),
1496 Some(options.max_prompt_bytes),
1497 options.prompt_format,
1498 )
1499 })
1500 })
1501 }
1502
1503 pub fn wait_for_initial_indexing(
1504 &mut self,
1505 project: &Entity<Project>,
1506 cx: &mut App,
1507 ) -> Task<Result<()>> {
1508 let zeta_project = self.get_or_init_zeta_project(project, cx);
1509 zeta_project
1510 .syntax_index
1511 .read(cx)
1512 .wait_for_initial_file_indexing(cx)
1513 }
1514}
1515
1516pub fn text_from_response(mut res: open_ai::Response) -> Option<String> {
1517 let choice = res.choices.pop()?;
1518 let output_text = match choice.message {
1519 open_ai::RequestMessage::Assistant {
1520 content: Some(open_ai::MessageContent::Plain(content)),
1521 ..
1522 } => content,
1523 open_ai::RequestMessage::Assistant {
1524 content: Some(open_ai::MessageContent::Multipart(mut content)),
1525 ..
1526 } => {
1527 if content.is_empty() {
1528 log::error!("No output from Baseten completion response");
1529 return None;
1530 }
1531
1532 match content.remove(0) {
1533 open_ai::MessagePart::Text { text } => text,
1534 open_ai::MessagePart::Image { .. } => {
1535 log::error!("Expected text, got an image");
1536 return None;
1537 }
1538 }
1539 }
1540 _ => {
1541 log::error!("Invalid response message: {:?}", choice.message);
1542 return None;
1543 }
1544 };
1545 Some(output_text)
1546}
1547
1548#[derive(Error, Debug)]
1549#[error(
1550 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1551)]
1552pub struct ZedUpdateRequiredError {
1553 minimum_version: SemanticVersion,
1554}
1555
1556fn make_syntax_context_cloud_request(
1557 excerpt_path: Arc<Path>,
1558 context: EditPredictionContext,
1559 events: Vec<predict_edits_v3::Event>,
1560 can_collect_data: bool,
1561 diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
1562 diagnostic_groups_truncated: bool,
1563 git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
1564 debug_info: bool,
1565 worktrees: &Vec<worktree::Snapshot>,
1566 index_state: Option<&SyntaxIndexState>,
1567 prompt_max_bytes: Option<usize>,
1568 prompt_format: PromptFormat,
1569) -> predict_edits_v3::PredictEditsRequest {
1570 let mut signatures = Vec::new();
1571 let mut declaration_to_signature_index = HashMap::default();
1572 let mut referenced_declarations = Vec::new();
1573
1574 for snippet in context.declarations {
1575 let project_entry_id = snippet.declaration.project_entry_id();
1576 let Some(path) = worktrees.iter().find_map(|worktree| {
1577 worktree.entry_for_id(project_entry_id).map(|entry| {
1578 let mut full_path = RelPathBuf::new();
1579 full_path.push(worktree.root_name());
1580 full_path.push(&entry.path);
1581 full_path
1582 })
1583 }) else {
1584 continue;
1585 };
1586
1587 let parent_index = index_state.and_then(|index_state| {
1588 snippet.declaration.parent().and_then(|parent| {
1589 add_signature(
1590 parent,
1591 &mut declaration_to_signature_index,
1592 &mut signatures,
1593 index_state,
1594 )
1595 })
1596 });
1597
1598 let (text, text_is_truncated) = snippet.declaration.item_text();
1599 referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
1600 path: path.as_std_path().into(),
1601 text: text.into(),
1602 range: snippet.declaration.item_line_range(),
1603 text_is_truncated,
1604 signature_range: snippet.declaration.signature_range_in_item_text(),
1605 parent_index,
1606 signature_score: snippet.score(DeclarationStyle::Signature),
1607 declaration_score: snippet.score(DeclarationStyle::Declaration),
1608 score_components: snippet.components,
1609 });
1610 }
1611
1612 let excerpt_parent = index_state.and_then(|index_state| {
1613 context
1614 .excerpt
1615 .parent_declarations
1616 .last()
1617 .and_then(|(parent, _)| {
1618 add_signature(
1619 *parent,
1620 &mut declaration_to_signature_index,
1621 &mut signatures,
1622 index_state,
1623 )
1624 })
1625 });
1626
1627 predict_edits_v3::PredictEditsRequest {
1628 excerpt_path,
1629 excerpt: context.excerpt_text.body,
1630 excerpt_line_range: context.excerpt.line_range,
1631 excerpt_range: context.excerpt.range,
1632 cursor_point: predict_edits_v3::Point {
1633 line: predict_edits_v3::Line(context.cursor_point.row),
1634 column: context.cursor_point.column,
1635 },
1636 referenced_declarations,
1637 included_files: vec![],
1638 signatures,
1639 excerpt_parent,
1640 events,
1641 can_collect_data,
1642 diagnostic_groups,
1643 diagnostic_groups_truncated,
1644 git_info,
1645 debug_info,
1646 prompt_max_bytes,
1647 prompt_format,
1648 }
1649}
1650
1651fn add_signature(
1652 declaration_id: DeclarationId,
1653 declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
1654 signatures: &mut Vec<Signature>,
1655 index: &SyntaxIndexState,
1656) -> Option<usize> {
1657 if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
1658 return Some(*signature_index);
1659 }
1660 let Some(parent_declaration) = index.declaration(declaration_id) else {
1661 log::error!("bug: missing parent declaration");
1662 return None;
1663 };
1664 let parent_index = parent_declaration.parent().and_then(|parent| {
1665 add_signature(parent, declaration_to_signature_index, signatures, index)
1666 });
1667 let (text, text_is_truncated) = parent_declaration.signature_text();
1668 let signature_index = signatures.len();
1669 signatures.push(Signature {
1670 text: text.into(),
1671 text_is_truncated,
1672 parent_index,
1673 range: parent_declaration.signature_line_range(),
1674 });
1675 declaration_to_signature_index.insert(declaration_id, signature_index);
1676 Some(signature_index)
1677}
1678
1679#[cfg(test)]
1680mod tests {
1681 use std::{path::Path, sync::Arc};
1682
1683 use client::UserStore;
1684 use clock::FakeSystemClock;
1685 use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
1686 use futures::{
1687 AsyncReadExt, StreamExt,
1688 channel::{mpsc, oneshot},
1689 };
1690 use gpui::{
1691 Entity, TestAppContext,
1692 http_client::{FakeHttpClient, Response},
1693 prelude::*,
1694 };
1695 use indoc::indoc;
1696 use language::OffsetRangeExt as _;
1697 use open_ai::Usage;
1698 use pretty_assertions::{assert_eq, assert_matches};
1699 use project::{FakeFs, Project};
1700 use serde_json::json;
1701 use settings::SettingsStore;
1702 use util::path;
1703 use uuid::Uuid;
1704
1705 use crate::{BufferEditPrediction, Zeta};
1706
1707 #[gpui::test]
1708 async fn test_current_state(cx: &mut TestAppContext) {
1709 let (zeta, mut req_rx) = init_test(cx);
1710 let fs = FakeFs::new(cx.executor());
1711 fs.insert_tree(
1712 "/root",
1713 json!({
1714 "1.txt": "Hello!\nHow\nBye\n",
1715 "2.txt": "Hola!\nComo\nAdios\n"
1716 }),
1717 )
1718 .await;
1719 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1720
1721 zeta.update(cx, |zeta, cx| {
1722 zeta.register_project(&project, cx);
1723 });
1724
1725 let buffer1 = project
1726 .update(cx, |project, cx| {
1727 let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
1728 project.open_buffer(path, cx)
1729 })
1730 .await
1731 .unwrap();
1732 let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
1733 let position = snapshot1.anchor_before(language::Point::new(1, 3));
1734
1735 // Prediction for current file
1736
1737 let prediction_task = zeta.update(cx, |zeta, cx| {
1738 zeta.refresh_prediction(&project, &buffer1, position, cx)
1739 });
1740 let (_request, respond_tx) = req_rx.next().await.unwrap();
1741
1742 respond_tx
1743 .send(model_response(indoc! {r"
1744 --- a/root/1.txt
1745 +++ b/root/1.txt
1746 @@ ... @@
1747 Hello!
1748 -How
1749 +How are you?
1750 Bye
1751 "}))
1752 .unwrap();
1753 prediction_task.await.unwrap();
1754
1755 zeta.read_with(cx, |zeta, cx| {
1756 let prediction = zeta
1757 .current_prediction_for_buffer(&buffer1, &project, cx)
1758 .unwrap();
1759 assert_matches!(prediction, BufferEditPrediction::Local { .. });
1760 });
1761
1762 // Context refresh
1763 let refresh_task = zeta.update(cx, |zeta, cx| {
1764 zeta.refresh_context(project.clone(), buffer1.clone(), position, cx)
1765 });
1766 let (_request, respond_tx) = req_rx.next().await.unwrap();
1767 respond_tx
1768 .send(open_ai::Response {
1769 id: Uuid::new_v4().to_string(),
1770 object: "response".into(),
1771 created: 0,
1772 model: "model".into(),
1773 choices: vec![open_ai::Choice {
1774 index: 0,
1775 message: open_ai::RequestMessage::Assistant {
1776 content: None,
1777 tool_calls: vec![open_ai::ToolCall {
1778 id: "search".into(),
1779 content: open_ai::ToolCallContent::Function {
1780 function: open_ai::FunctionContent {
1781 name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME
1782 .to_string(),
1783 arguments: serde_json::to_string(&SearchToolInput {
1784 queries: Box::new([SearchToolQuery {
1785 glob: "root/2.txt".to_string(),
1786 syntax_node: vec![],
1787 content: Some(".".into()),
1788 }]),
1789 })
1790 .unwrap(),
1791 },
1792 },
1793 }],
1794 },
1795 finish_reason: None,
1796 }],
1797 usage: Usage {
1798 prompt_tokens: 0,
1799 completion_tokens: 0,
1800 total_tokens: 0,
1801 },
1802 })
1803 .unwrap();
1804 refresh_task.await.unwrap();
1805
1806 zeta.update(cx, |zeta, _cx| {
1807 zeta.discard_current_prediction(&project);
1808 });
1809
1810 // Prediction for another file
1811 let prediction_task = zeta.update(cx, |zeta, cx| {
1812 zeta.refresh_prediction(&project, &buffer1, position, cx)
1813 });
1814 let (_request, respond_tx) = req_rx.next().await.unwrap();
1815 respond_tx
1816 .send(model_response(indoc! {r#"
1817 --- a/root/2.txt
1818 +++ b/root/2.txt
1819 Hola!
1820 -Como
1821 +Como estas?
1822 Adios
1823 "#}))
1824 .unwrap();
1825 prediction_task.await.unwrap();
1826 zeta.read_with(cx, |zeta, cx| {
1827 let prediction = zeta
1828 .current_prediction_for_buffer(&buffer1, &project, cx)
1829 .unwrap();
1830 assert_matches!(
1831 prediction,
1832 BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
1833 );
1834 });
1835
1836 let buffer2 = project
1837 .update(cx, |project, cx| {
1838 let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
1839 project.open_buffer(path, cx)
1840 })
1841 .await
1842 .unwrap();
1843
1844 zeta.read_with(cx, |zeta, cx| {
1845 let prediction = zeta
1846 .current_prediction_for_buffer(&buffer2, &project, cx)
1847 .unwrap();
1848 assert_matches!(prediction, BufferEditPrediction::Local { .. });
1849 });
1850 }
1851
1852 #[gpui::test]
1853 async fn test_simple_request(cx: &mut TestAppContext) {
1854 let (zeta, mut req_rx) = init_test(cx);
1855 let fs = FakeFs::new(cx.executor());
1856 fs.insert_tree(
1857 "/root",
1858 json!({
1859 "foo.md": "Hello!\nHow\nBye\n"
1860 }),
1861 )
1862 .await;
1863 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1864
1865 let buffer = project
1866 .update(cx, |project, cx| {
1867 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1868 project.open_buffer(path, cx)
1869 })
1870 .await
1871 .unwrap();
1872 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1873 let position = snapshot.anchor_before(language::Point::new(1, 3));
1874
1875 let prediction_task = zeta.update(cx, |zeta, cx| {
1876 zeta.request_prediction(&project, &buffer, position, cx)
1877 });
1878
1879 let (_, respond_tx) = req_rx.next().await.unwrap();
1880
1881 // TODO Put back when we have a structured request again
1882 // assert_eq!(
1883 // request.excerpt_path.as_ref(),
1884 // Path::new(path!("root/foo.md"))
1885 // );
1886 // assert_eq!(
1887 // request.cursor_point,
1888 // Point {
1889 // line: Line(1),
1890 // column: 3
1891 // }
1892 // );
1893
1894 respond_tx
1895 .send(model_response(indoc! { r"
1896 --- a/root/foo.md
1897 +++ b/root/foo.md
1898 @@ ... @@
1899 Hello!
1900 -How
1901 +How are you?
1902 Bye
1903 "}))
1904 .unwrap();
1905
1906 let prediction = prediction_task.await.unwrap().unwrap();
1907
1908 assert_eq!(prediction.edits.len(), 1);
1909 assert_eq!(
1910 prediction.edits[0].0.to_point(&snapshot).start,
1911 language::Point::new(1, 3)
1912 );
1913 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
1914 }
1915
1916 #[gpui::test]
1917 async fn test_request_events(cx: &mut TestAppContext) {
1918 let (zeta, mut req_rx) = init_test(cx);
1919 let fs = FakeFs::new(cx.executor());
1920 fs.insert_tree(
1921 "/root",
1922 json!({
1923 "foo.md": "Hello!\n\nBye\n"
1924 }),
1925 )
1926 .await;
1927 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1928
1929 let buffer = project
1930 .update(cx, |project, cx| {
1931 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1932 project.open_buffer(path, cx)
1933 })
1934 .await
1935 .unwrap();
1936
1937 zeta.update(cx, |zeta, cx| {
1938 zeta.register_buffer(&buffer, &project, cx);
1939 });
1940
1941 buffer.update(cx, |buffer, cx| {
1942 buffer.edit(vec![(7..7, "How")], None, cx);
1943 });
1944
1945 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1946 let position = snapshot.anchor_before(language::Point::new(1, 3));
1947
1948 let prediction_task = zeta.update(cx, |zeta, cx| {
1949 zeta.request_prediction(&project, &buffer, position, cx)
1950 });
1951
1952 let (request, respond_tx) = req_rx.next().await.unwrap();
1953
1954 let prompt = prompt_from_request(&request);
1955 assert!(
1956 prompt.contains(indoc! {"
1957 --- a/root/foo.md
1958 +++ b/root/foo.md
1959 @@ -1,3 +1,3 @@
1960 Hello!
1961 -
1962 +How
1963 Bye
1964 "}),
1965 "{prompt}"
1966 );
1967
1968 respond_tx
1969 .send(model_response(indoc! {r#"
1970 --- a/root/foo.md
1971 +++ b/root/foo.md
1972 @@ ... @@
1973 Hello!
1974 -How
1975 +How are you?
1976 Bye
1977 "#}))
1978 .unwrap();
1979
1980 let prediction = prediction_task.await.unwrap().unwrap();
1981
1982 assert_eq!(prediction.edits.len(), 1);
1983 assert_eq!(
1984 prediction.edits[0].0.to_point(&snapshot).start,
1985 language::Point::new(1, 3)
1986 );
1987 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
1988 }
1989
1990 // Skipped until we start including diagnostics in prompt
1991 // #[gpui::test]
1992 // async fn test_request_diagnostics(cx: &mut TestAppContext) {
1993 // let (zeta, mut req_rx) = init_test(cx);
1994 // let fs = FakeFs::new(cx.executor());
1995 // fs.insert_tree(
1996 // "/root",
1997 // json!({
1998 // "foo.md": "Hello!\nBye"
1999 // }),
2000 // )
2001 // .await;
2002 // let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2003
2004 // let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
2005 // let diagnostic = lsp::Diagnostic {
2006 // range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
2007 // severity: Some(lsp::DiagnosticSeverity::ERROR),
2008 // message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
2009 // ..Default::default()
2010 // };
2011
2012 // project.update(cx, |project, cx| {
2013 // project.lsp_store().update(cx, |lsp_store, cx| {
2014 // // Create some diagnostics
2015 // lsp_store
2016 // .update_diagnostics(
2017 // LanguageServerId(0),
2018 // lsp::PublishDiagnosticsParams {
2019 // uri: path_to_buffer_uri.clone(),
2020 // diagnostics: vec![diagnostic],
2021 // version: None,
2022 // },
2023 // None,
2024 // language::DiagnosticSourceKind::Pushed,
2025 // &[],
2026 // cx,
2027 // )
2028 // .unwrap();
2029 // });
2030 // });
2031
2032 // let buffer = project
2033 // .update(cx, |project, cx| {
2034 // let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2035 // project.open_buffer(path, cx)
2036 // })
2037 // .await
2038 // .unwrap();
2039
2040 // let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2041 // let position = snapshot.anchor_before(language::Point::new(0, 0));
2042
2043 // let _prediction_task = zeta.update(cx, |zeta, cx| {
2044 // zeta.request_prediction(&project, &buffer, position, cx)
2045 // });
2046
2047 // let (request, _respond_tx) = req_rx.next().await.unwrap();
2048
2049 // assert_eq!(request.diagnostic_groups.len(), 1);
2050 // let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
2051 // .unwrap();
2052 // // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
2053 // assert_eq!(
2054 // value,
2055 // json!({
2056 // "entries": [{
2057 // "range": {
2058 // "start": 8,
2059 // "end": 10
2060 // },
2061 // "diagnostic": {
2062 // "source": null,
2063 // "code": null,
2064 // "code_description": null,
2065 // "severity": 1,
2066 // "message": "\"Hello\" deprecated. Use \"Hi\" instead",
2067 // "markdown": null,
2068 // "group_id": 0,
2069 // "is_primary": true,
2070 // "is_disk_based": false,
2071 // "is_unnecessary": false,
2072 // "source_kind": "Pushed",
2073 // "data": null,
2074 // "underline": true
2075 // }
2076 // }],
2077 // "primary_ix": 0
2078 // })
2079 // );
2080 // }
2081
2082 fn model_response(text: &str) -> open_ai::Response {
2083 open_ai::Response {
2084 id: Uuid::new_v4().to_string(),
2085 object: "response".into(),
2086 created: 0,
2087 model: "model".into(),
2088 choices: vec![open_ai::Choice {
2089 index: 0,
2090 message: open_ai::RequestMessage::Assistant {
2091 content: Some(open_ai::MessageContent::Plain(text.to_string())),
2092 tool_calls: vec![],
2093 },
2094 finish_reason: None,
2095 }],
2096 usage: Usage {
2097 prompt_tokens: 0,
2098 completion_tokens: 0,
2099 total_tokens: 0,
2100 },
2101 }
2102 }
2103
2104 fn prompt_from_request(request: &open_ai::Request) -> &str {
2105 assert_eq!(request.messages.len(), 1);
2106 let open_ai::RequestMessage::User {
2107 content: open_ai::MessageContent::Plain(content),
2108 ..
2109 } = &request.messages[0]
2110 else {
2111 panic!(
2112 "Request does not have single user message of type Plain. {:#?}",
2113 request
2114 );
2115 };
2116 content
2117 }
2118
2119 fn init_test(
2120 cx: &mut TestAppContext,
2121 ) -> (
2122 Entity<Zeta>,
2123 mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender<open_ai::Response>)>,
2124 ) {
2125 cx.update(move |cx| {
2126 let settings_store = SettingsStore::test(cx);
2127 cx.set_global(settings_store);
2128 zlog::init_test();
2129
2130 let (req_tx, req_rx) = mpsc::unbounded();
2131
2132 let http_client = FakeHttpClient::create({
2133 move |req| {
2134 let uri = req.uri().path().to_string();
2135 let mut body = req.into_body();
2136 let req_tx = req_tx.clone();
2137 async move {
2138 let resp = match uri.as_str() {
2139 "/client/llm_tokens" => serde_json::to_string(&json!({
2140 "token": "test"
2141 }))
2142 .unwrap(),
2143 "/predict_edits/raw" => {
2144 let mut buf = Vec::new();
2145 body.read_to_end(&mut buf).await.ok();
2146 let req = serde_json::from_slice(&buf).unwrap();
2147
2148 let (res_tx, res_rx) = oneshot::channel();
2149 req_tx.unbounded_send((req, res_tx)).unwrap();
2150 serde_json::to_string(&res_rx.await?).unwrap()
2151 }
2152 _ => {
2153 panic!("Unexpected path: {}", uri)
2154 }
2155 };
2156
2157 Ok(Response::builder().body(resp.into()).unwrap())
2158 }
2159 }
2160 });
2161
2162 let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
2163 client.cloud_client().set_credentials(1, "test".into());
2164
2165 language_model::init(client.clone(), cx);
2166
2167 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2168 let zeta = Zeta::global(&client, &user_store, cx);
2169
2170 (zeta, req_rx)
2171 })
2172 }
2173}