1use anyhow::{Context as _, Result, anyhow};
2use chrono::TimeDelta;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
5use cloud_llm_client::{
6 AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
7 ZED_VERSION_HEADER_NAME,
8};
9use cloud_zeta2_prompt::DEFAULT_MAX_PROMPT_BYTES;
10use cloud_zeta2_prompt::retrieval_prompt::SearchToolInput;
11use collections::HashMap;
12use edit_prediction_context::{
13 DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
14 EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionScoreOptions, Line,
15 SyntaxIndex, SyntaxIndexState,
16};
17use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
18use futures::AsyncReadExt as _;
19use futures::channel::{mpsc, oneshot};
20use gpui::http_client::{AsyncBody, Method};
21use gpui::{
22 App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
23 http_client, prelude::*,
24};
25use language::{Anchor, Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
26use language::{BufferSnapshot, OffsetRangeExt};
27use language_model::{LlmApiToken, RefreshLlmTokenListener};
28use open_ai::FunctionDefinition;
29use project::Project;
30use release_channel::AppVersion;
31use serde::de::DeserializeOwned;
32use std::collections::{VecDeque, hash_map};
33use uuid::Uuid;
34
35use std::ops::Range;
36use std::path::Path;
37use std::str::FromStr as _;
38use std::sync::Arc;
39use std::time::{Duration, Instant};
40use thiserror::Error;
41use util::rel_path::RelPathBuf;
42use util::{LogErrorFuture, TryFutureExt};
43use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
44
45pub mod merge_excerpts;
46mod prediction;
47mod provider;
48pub mod retrieval_search;
49pub mod udiff;
50
51use crate::merge_excerpts::merge_excerpts;
52use crate::prediction::EditPrediction;
53pub use crate::prediction::EditPredictionId;
54pub use provider::ZetaEditPredictionProvider;
55
56/// Maximum number of events to track.
57const MAX_EVENT_COUNT: usize = 16;
58
59pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
60 max_bytes: 512,
61 min_bytes: 128,
62 target_before_cursor_over_total_bytes: 0.5,
63};
64
65pub const DEFAULT_CONTEXT_OPTIONS: ContextMode =
66 ContextMode::Agentic(DEFAULT_AGENTIC_CONTEXT_OPTIONS);
67
68pub const DEFAULT_AGENTIC_CONTEXT_OPTIONS: AgenticContextOptions = AgenticContextOptions {
69 excerpt: DEFAULT_EXCERPT_OPTIONS,
70};
71
72pub const DEFAULT_SYNTAX_CONTEXT_OPTIONS: EditPredictionContextOptions =
73 EditPredictionContextOptions {
74 use_imports: true,
75 max_retrieved_declarations: 0,
76 excerpt: DEFAULT_EXCERPT_OPTIONS,
77 score: EditPredictionScoreOptions {
78 omit_excerpt_overlaps: true,
79 },
80 };
81
82pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
83 context: DEFAULT_CONTEXT_OPTIONS,
84 max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
85 max_diagnostic_bytes: 2048,
86 prompt_format: PromptFormat::DEFAULT,
87 file_indexing_parallelism: 1,
88 buffer_change_grouping_interval: Duration::from_secs(1),
89};
90
91pub struct Zeta2FeatureFlag;
92
93impl FeatureFlag for Zeta2FeatureFlag {
94 const NAME: &'static str = "zeta2";
95
96 fn enabled_for_staff() -> bool {
97 false
98 }
99}
100
101#[derive(Clone)]
102struct ZetaGlobal(Entity<Zeta>);
103
104impl Global for ZetaGlobal {}
105
106pub struct Zeta {
107 client: Arc<Client>,
108 user_store: Entity<UserStore>,
109 llm_token: LlmApiToken,
110 _llm_token_subscription: Subscription,
111 projects: HashMap<EntityId, ZetaProject>,
112 options: ZetaOptions,
113 update_required: bool,
114 debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
115}
116
117#[derive(Debug, Clone, PartialEq)]
118pub struct ZetaOptions {
119 pub context: ContextMode,
120 pub max_prompt_bytes: usize,
121 pub max_diagnostic_bytes: usize,
122 pub prompt_format: predict_edits_v3::PromptFormat,
123 pub file_indexing_parallelism: usize,
124 pub buffer_change_grouping_interval: Duration,
125}
126
127#[derive(Debug, Clone, PartialEq)]
128pub enum ContextMode {
129 Agentic(AgenticContextOptions),
130 Syntax(EditPredictionContextOptions),
131}
132
133#[derive(Debug, Clone, PartialEq)]
134pub struct AgenticContextOptions {
135 pub excerpt: EditPredictionExcerptOptions,
136}
137
138impl ContextMode {
139 pub fn excerpt(&self) -> &EditPredictionExcerptOptions {
140 match self {
141 ContextMode::Agentic(options) => &options.excerpt,
142 ContextMode::Syntax(options) => &options.excerpt,
143 }
144 }
145}
146
147#[derive(Debug)]
148pub enum ZetaDebugInfo {
149 ContextRetrievalStarted(ZetaContextRetrievalStartedDebugInfo),
150 SearchQueriesGenerated(ZetaSearchQueryDebugInfo),
151 SearchQueriesExecuted(ZetaContextRetrievalDebugInfo),
152 ContextRetrievalFinished(ZetaContextRetrievalDebugInfo),
153 EditPredictionRequested(ZetaEditPredictionDebugInfo),
154}
155
156#[derive(Debug)]
157pub struct ZetaContextRetrievalStartedDebugInfo {
158 pub project: Entity<Project>,
159 pub timestamp: Instant,
160 pub search_prompt: String,
161}
162
163#[derive(Debug)]
164pub struct ZetaContextRetrievalDebugInfo {
165 pub project: Entity<Project>,
166 pub timestamp: Instant,
167}
168
169#[derive(Debug)]
170pub struct ZetaEditPredictionDebugInfo {
171 pub request: predict_edits_v3::PredictEditsRequest,
172 pub retrieval_time: TimeDelta,
173 pub buffer: WeakEntity<Buffer>,
174 pub position: language::Anchor,
175 pub local_prompt: Result<String, String>,
176 pub response_rx: oneshot::Receiver<(Result<open_ai::Response, String>, TimeDelta)>,
177}
178
179#[derive(Debug)]
180pub struct ZetaSearchQueryDebugInfo {
181 pub project: Entity<Project>,
182 pub timestamp: Instant,
183 pub regex_by_glob: HashMap<String, String>,
184}
185
186pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
187
188struct ZetaProject {
189 syntax_index: Entity<SyntaxIndex>,
190 events: VecDeque<Event>,
191 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
192 current_prediction: Option<CurrentEditPrediction>,
193 context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
194 refresh_context_task: Option<LogErrorFuture<Task<Result<()>>>>,
195 refresh_context_debounce_task: Option<Task<Option<()>>>,
196 refresh_context_timestamp: Option<Instant>,
197}
198
199#[derive(Debug, Clone)]
200struct CurrentEditPrediction {
201 pub requested_by_buffer_id: EntityId,
202 pub prediction: EditPrediction,
203}
204
205impl CurrentEditPrediction {
206 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
207 let Some(new_edits) = self
208 .prediction
209 .interpolate(&self.prediction.buffer.read(cx))
210 else {
211 return false;
212 };
213
214 if self.prediction.buffer != old_prediction.prediction.buffer {
215 return true;
216 }
217
218 let Some(old_edits) = old_prediction
219 .prediction
220 .interpolate(&old_prediction.prediction.buffer.read(cx))
221 else {
222 return true;
223 };
224
225 // This reduces the occurrence of UI thrash from replacing edits
226 //
227 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
228 if self.requested_by_buffer_id == self.prediction.buffer.entity_id()
229 && self.requested_by_buffer_id == old_prediction.prediction.buffer.entity_id()
230 && old_edits.len() == 1
231 && new_edits.len() == 1
232 {
233 let (old_range, old_text) = &old_edits[0];
234 let (new_range, new_text) = &new_edits[0];
235 new_range == old_range && new_text.starts_with(old_text.as_ref())
236 } else {
237 true
238 }
239 }
240}
241
242/// A prediction from the perspective of a buffer.
243#[derive(Debug)]
244enum BufferEditPrediction<'a> {
245 Local { prediction: &'a EditPrediction },
246 Jump { prediction: &'a EditPrediction },
247}
248
249struct RegisteredBuffer {
250 snapshot: BufferSnapshot,
251 _subscriptions: [gpui::Subscription; 2],
252}
253
254#[derive(Clone)]
255pub enum Event {
256 BufferChange {
257 old_snapshot: BufferSnapshot,
258 new_snapshot: BufferSnapshot,
259 timestamp: Instant,
260 },
261}
262
263impl Event {
264 pub fn to_request_event(&self, cx: &App) -> Option<predict_edits_v3::Event> {
265 match self {
266 Event::BufferChange {
267 old_snapshot,
268 new_snapshot,
269 ..
270 } => {
271 let path = new_snapshot.file().map(|f| f.full_path(cx));
272
273 let old_path = old_snapshot.file().and_then(|f| {
274 let old_path = f.full_path(cx);
275 if Some(&old_path) != path.as_ref() {
276 Some(old_path)
277 } else {
278 None
279 }
280 });
281
282 // TODO [zeta2] move to bg?
283 let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
284
285 if path == old_path && diff.is_empty() {
286 None
287 } else {
288 Some(predict_edits_v3::Event::BufferChange {
289 old_path,
290 path,
291 diff,
292 //todo: Actually detect if this edit was predicted or not
293 predicted: false,
294 })
295 }
296 }
297 }
298 }
299}
300
301impl Zeta {
302 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
303 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
304 }
305
306 pub fn global(
307 client: &Arc<Client>,
308 user_store: &Entity<UserStore>,
309 cx: &mut App,
310 ) -> Entity<Self> {
311 cx.try_global::<ZetaGlobal>()
312 .map(|global| global.0.clone())
313 .unwrap_or_else(|| {
314 let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
315 cx.set_global(ZetaGlobal(zeta.clone()));
316 zeta
317 })
318 }
319
320 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
321 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
322
323 Self {
324 projects: HashMap::default(),
325 client,
326 user_store,
327 options: DEFAULT_OPTIONS,
328 llm_token: LlmApiToken::default(),
329 _llm_token_subscription: cx.subscribe(
330 &refresh_llm_token_listener,
331 |this, _listener, _event, cx| {
332 let client = this.client.clone();
333 let llm_token = this.llm_token.clone();
334 cx.spawn(async move |_this, _cx| {
335 llm_token.refresh(&client).await?;
336 anyhow::Ok(())
337 })
338 .detach_and_log_err(cx);
339 },
340 ),
341 update_required: false,
342 debug_tx: None,
343 }
344 }
345
346 pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<ZetaDebugInfo> {
347 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
348 self.debug_tx = Some(debug_watch_tx);
349 debug_watch_rx
350 }
351
352 pub fn options(&self) -> &ZetaOptions {
353 &self.options
354 }
355
356 pub fn set_options(&mut self, options: ZetaOptions) {
357 self.options = options;
358 }
359
360 pub fn clear_history(&mut self) {
361 for zeta_project in self.projects.values_mut() {
362 zeta_project.events.clear();
363 }
364 }
365
366 pub fn history_for_project(&self, project: &Entity<Project>) -> impl Iterator<Item = &Event> {
367 self.projects
368 .get(&project.entity_id())
369 .map(|project| project.events.iter())
370 .into_iter()
371 .flatten()
372 }
373
374 pub fn context_for_project(
375 &self,
376 project: &Entity<Project>,
377 ) -> impl Iterator<Item = (Entity<Buffer>, &[Range<Anchor>])> {
378 self.projects
379 .get(&project.entity_id())
380 .and_then(|project| {
381 Some(
382 project
383 .context
384 .as_ref()?
385 .iter()
386 .map(|(buffer, ranges)| (buffer.clone(), ranges.as_slice())),
387 )
388 })
389 .into_iter()
390 .flatten()
391 }
392
393 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
394 self.user_store.read(cx).edit_prediction_usage()
395 }
396
397 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
398 self.get_or_init_zeta_project(project, cx);
399 }
400
401 pub fn register_buffer(
402 &mut self,
403 buffer: &Entity<Buffer>,
404 project: &Entity<Project>,
405 cx: &mut Context<Self>,
406 ) {
407 let zeta_project = self.get_or_init_zeta_project(project, cx);
408 Self::register_buffer_impl(zeta_project, buffer, project, cx);
409 }
410
411 fn get_or_init_zeta_project(
412 &mut self,
413 project: &Entity<Project>,
414 cx: &mut App,
415 ) -> &mut ZetaProject {
416 self.projects
417 .entry(project.entity_id())
418 .or_insert_with(|| ZetaProject {
419 syntax_index: cx.new(|cx| {
420 SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
421 }),
422 events: VecDeque::new(),
423 registered_buffers: HashMap::default(),
424 current_prediction: None,
425 context: None,
426 refresh_context_task: None,
427 refresh_context_debounce_task: None,
428 refresh_context_timestamp: None,
429 })
430 }
431
432 fn register_buffer_impl<'a>(
433 zeta_project: &'a mut ZetaProject,
434 buffer: &Entity<Buffer>,
435 project: &Entity<Project>,
436 cx: &mut Context<Self>,
437 ) -> &'a mut RegisteredBuffer {
438 let buffer_id = buffer.entity_id();
439 match zeta_project.registered_buffers.entry(buffer_id) {
440 hash_map::Entry::Occupied(entry) => entry.into_mut(),
441 hash_map::Entry::Vacant(entry) => {
442 let snapshot = buffer.read(cx).snapshot();
443 let project_entity_id = project.entity_id();
444 entry.insert(RegisteredBuffer {
445 snapshot,
446 _subscriptions: [
447 cx.subscribe(buffer, {
448 let project = project.downgrade();
449 move |this, buffer, event, cx| {
450 if let language::BufferEvent::Edited = event
451 && let Some(project) = project.upgrade()
452 {
453 this.report_changes_for_buffer(&buffer, &project, cx);
454 }
455 }
456 }),
457 cx.observe_release(buffer, move |this, _buffer, _cx| {
458 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
459 else {
460 return;
461 };
462 zeta_project.registered_buffers.remove(&buffer_id);
463 }),
464 ],
465 })
466 }
467 }
468 }
469
470 fn report_changes_for_buffer(
471 &mut self,
472 buffer: &Entity<Buffer>,
473 project: &Entity<Project>,
474 cx: &mut Context<Self>,
475 ) -> BufferSnapshot {
476 let buffer_change_grouping_interval = self.options.buffer_change_grouping_interval;
477 let zeta_project = self.get_or_init_zeta_project(project, cx);
478 let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
479
480 let new_snapshot = buffer.read(cx).snapshot();
481 if new_snapshot.version != registered_buffer.snapshot.version {
482 let old_snapshot =
483 std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
484 Self::push_event(
485 zeta_project,
486 buffer_change_grouping_interval,
487 Event::BufferChange {
488 old_snapshot,
489 new_snapshot: new_snapshot.clone(),
490 timestamp: Instant::now(),
491 },
492 );
493 }
494
495 new_snapshot
496 }
497
498 fn push_event(
499 zeta_project: &mut ZetaProject,
500 buffer_change_grouping_interval: Duration,
501 event: Event,
502 ) {
503 let events = &mut zeta_project.events;
504
505 if buffer_change_grouping_interval > Duration::ZERO
506 && let Some(Event::BufferChange {
507 new_snapshot: last_new_snapshot,
508 timestamp: last_timestamp,
509 ..
510 }) = events.back_mut()
511 {
512 // Coalesce edits for the same buffer when they happen one after the other.
513 let Event::BufferChange {
514 old_snapshot,
515 new_snapshot,
516 timestamp,
517 } = &event;
518
519 if timestamp.duration_since(*last_timestamp) <= buffer_change_grouping_interval
520 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
521 && old_snapshot.version == last_new_snapshot.version
522 {
523 *last_new_snapshot = new_snapshot.clone();
524 *last_timestamp = *timestamp;
525 return;
526 }
527 }
528
529 if events.len() >= MAX_EVENT_COUNT {
530 // These are halved instead of popping to improve prompt caching.
531 events.drain(..MAX_EVENT_COUNT / 2);
532 }
533
534 events.push_back(event);
535 }
536
537 fn current_prediction_for_buffer(
538 &self,
539 buffer: &Entity<Buffer>,
540 project: &Entity<Project>,
541 cx: &App,
542 ) -> Option<BufferEditPrediction<'_>> {
543 let project_state = self.projects.get(&project.entity_id())?;
544
545 let CurrentEditPrediction {
546 requested_by_buffer_id,
547 prediction,
548 } = project_state.current_prediction.as_ref()?;
549
550 if prediction.targets_buffer(buffer.read(cx)) {
551 Some(BufferEditPrediction::Local { prediction })
552 } else if *requested_by_buffer_id == buffer.entity_id() {
553 Some(BufferEditPrediction::Jump { prediction })
554 } else {
555 None
556 }
557 }
558
559 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
560 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
561 return;
562 };
563
564 let Some(prediction) = project_state.current_prediction.take() else {
565 return;
566 };
567 let request_id = prediction.prediction.id.into();
568
569 let client = self.client.clone();
570 let llm_token = self.llm_token.clone();
571 let app_version = AppVersion::global(cx);
572 cx.spawn(async move |this, cx| {
573 let url = if let Ok(predict_edits_url) = std::env::var("ZED_ACCEPT_PREDICTION_URL") {
574 http_client::Url::parse(&predict_edits_url)?
575 } else {
576 client
577 .http_client()
578 .build_zed_llm_url("/predict_edits/accept", &[])?
579 };
580
581 let response = cx
582 .background_spawn(Self::send_api_request::<()>(
583 move |builder| {
584 let req = builder.uri(url.as_ref()).body(
585 serde_json::to_string(&AcceptEditPredictionBody { request_id })?.into(),
586 );
587 Ok(req?)
588 },
589 client,
590 llm_token,
591 app_version,
592 ))
593 .await;
594
595 Self::handle_api_response(&this, response, cx)?;
596 anyhow::Ok(())
597 })
598 .detach_and_log_err(cx);
599 }
600
601 fn discard_current_prediction(&mut self, project: &Entity<Project>) {
602 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
603 project_state.current_prediction.take();
604 };
605 }
606
607 pub fn refresh_prediction(
608 &mut self,
609 project: &Entity<Project>,
610 buffer: &Entity<Buffer>,
611 position: language::Anchor,
612 cx: &mut Context<Self>,
613 ) -> Task<Result<()>> {
614 let request_task = self.request_prediction(project, buffer, position, cx);
615 let buffer = buffer.clone();
616 let project = project.clone();
617
618 cx.spawn(async move |this, cx| {
619 if let Some(prediction) = request_task.await? {
620 this.update(cx, |this, cx| {
621 let project_state = this
622 .projects
623 .get_mut(&project.entity_id())
624 .context("Project not found")?;
625
626 let new_prediction = CurrentEditPrediction {
627 requested_by_buffer_id: buffer.entity_id(),
628 prediction: prediction,
629 };
630
631 if project_state
632 .current_prediction
633 .as_ref()
634 .is_none_or(|old_prediction| {
635 new_prediction.should_replace_prediction(&old_prediction, cx)
636 })
637 {
638 project_state.current_prediction = Some(new_prediction);
639 }
640 anyhow::Ok(())
641 })??;
642 }
643 Ok(())
644 })
645 }
646
647 pub fn request_prediction(
648 &mut self,
649 project: &Entity<Project>,
650 active_buffer: &Entity<Buffer>,
651 position: language::Anchor,
652 cx: &mut Context<Self>,
653 ) -> Task<Result<Option<EditPrediction>>> {
654 let project_state = self.projects.get(&project.entity_id());
655
656 let index_state = project_state.map(|state| {
657 state
658 .syntax_index
659 .read_with(cx, |index, _cx| index.state().clone())
660 });
661 let options = self.options.clone();
662 let active_snapshot = active_buffer.read(cx).snapshot();
663 let Some(excerpt_path) = active_snapshot
664 .file()
665 .map(|path| -> Arc<Path> { path.full_path(cx).into() })
666 else {
667 return Task::ready(Err(anyhow!("No file path for excerpt")));
668 };
669 let client = self.client.clone();
670 let llm_token = self.llm_token.clone();
671 let app_version = AppVersion::global(cx);
672 let worktree_snapshots = project
673 .read(cx)
674 .worktrees(cx)
675 .map(|worktree| worktree.read(cx).snapshot())
676 .collect::<Vec<_>>();
677 let debug_tx = self.debug_tx.clone();
678
679 let events = project_state
680 .map(|state| {
681 state
682 .events
683 .iter()
684 .filter_map(|event| event.to_request_event(cx))
685 .collect::<Vec<_>>()
686 })
687 .unwrap_or_default();
688
689 let diagnostics = active_snapshot.diagnostic_sets().clone();
690
691 let parent_abs_path =
692 project::File::from_dyn(active_buffer.read(cx).file()).and_then(|f| {
693 let mut path = f.worktree.read(cx).absolutize(&f.path);
694 if path.pop() { Some(path) } else { None }
695 });
696
697 // TODO data collection
698 let can_collect_data = cx.is_staff();
699
700 let mut included_files = project_state
701 .and_then(|project_state| project_state.context.as_ref())
702 .unwrap_or(&HashMap::default())
703 .iter()
704 .filter_map(|(buffer_entity, ranges)| {
705 let buffer = buffer_entity.read(cx);
706 Some((
707 buffer_entity.clone(),
708 buffer.snapshot(),
709 buffer.file()?.full_path(cx).into(),
710 ranges.clone(),
711 ))
712 })
713 .collect::<Vec<_>>();
714
715 let request_task = cx.background_spawn({
716 let active_buffer = active_buffer.clone();
717 async move {
718 let index_state = if let Some(index_state) = index_state {
719 Some(index_state.lock_owned().await)
720 } else {
721 None
722 };
723
724 let cursor_offset = position.to_offset(&active_snapshot);
725 let cursor_point = cursor_offset.to_point(&active_snapshot);
726
727 let before_retrieval = chrono::Utc::now();
728
729 let (diagnostic_groups, diagnostic_groups_truncated) =
730 Self::gather_nearby_diagnostics(
731 cursor_offset,
732 &diagnostics,
733 &active_snapshot,
734 options.max_diagnostic_bytes,
735 );
736
737 let cloud_request = match options.context {
738 ContextMode::Agentic(context_options) => {
739 let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
740 cursor_point,
741 &active_snapshot,
742 &context_options.excerpt,
743 index_state.as_deref(),
744 ) else {
745 return Ok((None, None));
746 };
747
748 let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start)
749 ..active_snapshot.anchor_before(excerpt.range.end);
750
751 if let Some(buffer_ix) =
752 included_files.iter().position(|(_, snapshot, _, _)| {
753 snapshot.remote_id() == active_snapshot.remote_id()
754 })
755 {
756 let (_, buffer, _, ranges) = &mut included_files[buffer_ix];
757 let range_ix = ranges
758 .binary_search_by(|probe| {
759 probe
760 .start
761 .cmp(&excerpt_anchor_range.start, buffer)
762 .then(excerpt_anchor_range.end.cmp(&probe.end, buffer))
763 })
764 .unwrap_or_else(|ix| ix);
765
766 ranges.insert(range_ix, excerpt_anchor_range);
767 let last_ix = included_files.len() - 1;
768 included_files.swap(buffer_ix, last_ix);
769 } else {
770 included_files.push((
771 active_buffer.clone(),
772 active_snapshot,
773 excerpt_path.clone(),
774 vec![excerpt_anchor_range],
775 ));
776 }
777
778 let included_files = included_files
779 .iter()
780 .map(|(_, buffer, path, ranges)| {
781 let excerpts = merge_excerpts(
782 &buffer,
783 ranges.iter().map(|range| {
784 let point_range = range.to_point(&buffer);
785 Line(point_range.start.row)..Line(point_range.end.row)
786 }),
787 );
788 predict_edits_v3::IncludedFile {
789 path: path.clone(),
790 max_row: Line(buffer.max_point().row),
791 excerpts,
792 }
793 })
794 .collect::<Vec<_>>();
795
796 predict_edits_v3::PredictEditsRequest {
797 excerpt_path,
798 excerpt: String::new(),
799 excerpt_line_range: Line(0)..Line(0),
800 excerpt_range: 0..0,
801 cursor_point: predict_edits_v3::Point {
802 line: predict_edits_v3::Line(cursor_point.row),
803 column: cursor_point.column,
804 },
805 included_files,
806 referenced_declarations: vec![],
807 events,
808 can_collect_data,
809 diagnostic_groups,
810 diagnostic_groups_truncated,
811 debug_info: debug_tx.is_some(),
812 prompt_max_bytes: Some(options.max_prompt_bytes),
813 prompt_format: options.prompt_format,
814 // TODO [zeta2]
815 signatures: vec![],
816 excerpt_parent: None,
817 git_info: None,
818 }
819 }
820 ContextMode::Syntax(context_options) => {
821 let Some(context) = EditPredictionContext::gather_context(
822 cursor_point,
823 &active_snapshot,
824 parent_abs_path.as_deref(),
825 &context_options,
826 index_state.as_deref(),
827 ) else {
828 return Ok((None, None));
829 };
830
831 make_syntax_context_cloud_request(
832 excerpt_path,
833 context,
834 events,
835 can_collect_data,
836 diagnostic_groups,
837 diagnostic_groups_truncated,
838 None,
839 debug_tx.is_some(),
840 &worktree_snapshots,
841 index_state.as_deref(),
842 Some(options.max_prompt_bytes),
843 options.prompt_format,
844 )
845 }
846 };
847
848 let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request);
849
850 let retrieval_time = chrono::Utc::now() - before_retrieval;
851
852 let debug_response_tx = if let Some(debug_tx) = &debug_tx {
853 let (response_tx, response_rx) = oneshot::channel();
854
855 debug_tx
856 .unbounded_send(ZetaDebugInfo::EditPredictionRequested(
857 ZetaEditPredictionDebugInfo {
858 request: cloud_request.clone(),
859 retrieval_time,
860 buffer: active_buffer.downgrade(),
861 local_prompt: match prompt_result.as_ref() {
862 Ok((prompt, _)) => Ok(prompt.clone()),
863 Err(err) => Err(err.to_string()),
864 },
865 position,
866 response_rx,
867 },
868 ))
869 .ok();
870 Some(response_tx)
871 } else {
872 None
873 };
874
875 if cfg!(debug_assertions) && std::env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
876 if let Some(debug_response_tx) = debug_response_tx {
877 debug_response_tx
878 .send((Err("Request skipped".to_string()), TimeDelta::zero()))
879 .ok();
880 }
881 anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
882 }
883
884 let (prompt, _) = prompt_result?;
885 let request = open_ai::Request {
886 model: std::env::var("ZED_ZETA2_MODEL").unwrap_or("yqvev8r3".to_string()),
887 messages: vec![open_ai::RequestMessage::User {
888 content: open_ai::MessageContent::Plain(prompt),
889 }],
890 stream: false,
891 max_completion_tokens: None,
892 stop: Default::default(),
893 temperature: 0.7,
894 tool_choice: None,
895 parallel_tool_calls: None,
896 tools: vec![],
897 prompt_cache_key: None,
898 reasoning_effort: None,
899 };
900
901 log::trace!("Sending edit prediction request");
902
903 let before_request = chrono::Utc::now();
904 let response =
905 Self::send_raw_llm_request(client, llm_token, app_version, request).await;
906 let request_time = chrono::Utc::now() - before_request;
907
908 log::trace!("Got edit prediction response");
909
910 if let Some(debug_response_tx) = debug_response_tx {
911 debug_response_tx
912 .send((
913 response
914 .as_ref()
915 .map_err(|err| err.to_string())
916 .map(|response| response.0.clone()),
917 request_time,
918 ))
919 .ok();
920 }
921
922 let (mut res, usage) = response?;
923
924 let request_id = EditPredictionId(Uuid::from_str(&res.id)?);
925
926 let Some(choice) = res.choices.pop() else {
927 return Ok((None, usage));
928 };
929
930 let output_text = match choice.message {
931 open_ai::RequestMessage::Assistant {
932 content: Some(open_ai::MessageContent::Plain(content)),
933 ..
934 } => content,
935 open_ai::RequestMessage::Assistant {
936 content: Some(open_ai::MessageContent::Multipart(mut content)),
937 ..
938 } => {
939 if content.is_empty() {
940 log::error!("No output from Baseten completion response");
941 return Ok((None, usage));
942 }
943
944 match content.remove(0) {
945 open_ai::MessagePart::Text { text } => text,
946 open_ai::MessagePart::Image { .. } => {
947 log::error!("Expected text, got an image");
948 return Ok((None, usage));
949 }
950 }
951 }
952 _ => {
953 log::error!("Invalid response message: {:?}", choice.message);
954 return Ok((None, usage));
955 }
956 };
957
958 let (edited_buffer_snapshot, edits) =
959 crate::udiff::parse_diff(&output_text, |path| {
960 included_files
961 .iter()
962 .find_map(|(_, buffer, probe_path, ranges)| {
963 if probe_path.as_ref() == path {
964 Some((buffer, ranges.as_slice()))
965 } else {
966 None
967 }
968 })
969 })
970 .await?;
971
972 let edited_buffer = included_files
973 .iter()
974 .find_map(|(buffer, snapshot, _, _)| {
975 if snapshot.remote_id() == edited_buffer_snapshot.remote_id() {
976 Some(buffer.clone())
977 } else {
978 None
979 }
980 })
981 .context("Failed to find buffer in included_buffers, even though we just found the snapshot")?;
982
983 anyhow::Ok((Some((request_id, edited_buffer, edited_buffer_snapshot.clone(), edits)), usage))
984 }
985 });
986
987 cx.spawn({
988 async move |this, cx| {
989 let Some((id, edited_buffer, edited_buffer_snapshot, edits)) =
990 Self::handle_api_response(&this, request_task.await, cx)?
991 else {
992 return Ok(None);
993 };
994
995 // TODO telemetry: duration, etc
996 Ok(
997 EditPrediction::new(id, &edited_buffer, &edited_buffer_snapshot, edits, cx)
998 .await,
999 )
1000 }
1001 })
1002 }
1003
1004 async fn send_raw_llm_request(
1005 client: Arc<Client>,
1006 llm_token: LlmApiToken,
1007 app_version: SemanticVersion,
1008 request: open_ai::Request,
1009 ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
1010 let url = if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
1011 http_client::Url::parse(&predict_edits_url)?
1012 } else {
1013 client
1014 .http_client()
1015 .build_zed_llm_url("/predict_edits/raw", &[])?
1016 };
1017
1018 Self::send_api_request(
1019 |builder| {
1020 let req = builder
1021 .uri(url.as_ref())
1022 .body(serde_json::to_string(&request)?.into());
1023 Ok(req?)
1024 },
1025 client,
1026 llm_token,
1027 app_version,
1028 )
1029 .await
1030 }
1031
1032 fn handle_api_response<T>(
1033 this: &WeakEntity<Self>,
1034 response: Result<(T, Option<EditPredictionUsage>)>,
1035 cx: &mut gpui::AsyncApp,
1036 ) -> Result<T> {
1037 match response {
1038 Ok((data, usage)) => {
1039 if let Some(usage) = usage {
1040 this.update(cx, |this, cx| {
1041 this.user_store.update(cx, |user_store, cx| {
1042 user_store.update_edit_prediction_usage(usage, cx);
1043 });
1044 })
1045 .ok();
1046 }
1047 Ok(data)
1048 }
1049 Err(err) => {
1050 if err.is::<ZedUpdateRequiredError>() {
1051 cx.update(|cx| {
1052 this.update(cx, |this, _cx| {
1053 this.update_required = true;
1054 })
1055 .ok();
1056
1057 let error_message: SharedString = err.to_string().into();
1058 show_app_notification(
1059 NotificationId::unique::<ZedUpdateRequiredError>(),
1060 cx,
1061 move |cx| {
1062 cx.new(|cx| {
1063 ErrorMessagePrompt::new(error_message.clone(), cx)
1064 .with_link_button("Update Zed", "https://zed.dev/releases")
1065 })
1066 },
1067 );
1068 })
1069 .ok();
1070 }
1071 Err(err)
1072 }
1073 }
1074 }
1075
1076 async fn send_api_request<Res>(
1077 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1078 client: Arc<Client>,
1079 llm_token: LlmApiToken,
1080 app_version: SemanticVersion,
1081 ) -> Result<(Res, Option<EditPredictionUsage>)>
1082 where
1083 Res: DeserializeOwned,
1084 {
1085 let http_client = client.http_client();
1086 let mut token = llm_token.acquire(&client).await?;
1087 let mut did_retry = false;
1088
1089 loop {
1090 let request_builder = http_client::Request::builder().method(Method::POST);
1091
1092 let request = build(
1093 request_builder
1094 .header("Content-Type", "application/json")
1095 .header("Authorization", format!("Bearer {}", token))
1096 .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
1097 )?;
1098
1099 let mut response = http_client.send(request).await?;
1100
1101 if let Some(minimum_required_version) = response
1102 .headers()
1103 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
1104 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
1105 {
1106 anyhow::ensure!(
1107 app_version >= minimum_required_version,
1108 ZedUpdateRequiredError {
1109 minimum_version: minimum_required_version
1110 }
1111 );
1112 }
1113
1114 if response.status().is_success() {
1115 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
1116
1117 let mut body = Vec::new();
1118 response.body_mut().read_to_end(&mut body).await?;
1119 return Ok((serde_json::from_slice(&body)?, usage));
1120 } else if !did_retry
1121 && response
1122 .headers()
1123 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1124 .is_some()
1125 {
1126 did_retry = true;
1127 token = llm_token.refresh(&client).await?;
1128 } else {
1129 let mut body = String::new();
1130 response.body_mut().read_to_string(&mut body).await?;
1131 anyhow::bail!(
1132 "Request failed with status: {:?}\nBody: {}",
1133 response.status(),
1134 body
1135 );
1136 }
1137 }
1138 }
1139
1140 pub const CONTEXT_RETRIEVAL_IDLE_DURATION: Duration = Duration::from_secs(10);
1141 pub const CONTEXT_RETRIEVAL_DEBOUNCE_DURATION: Duration = Duration::from_secs(3);
1142
1143 // Refresh the related excerpts when the user just beguns editing after
1144 // an idle period, and after they pause editing.
1145 fn refresh_context_if_needed(
1146 &mut self,
1147 project: &Entity<Project>,
1148 buffer: &Entity<language::Buffer>,
1149 cursor_position: language::Anchor,
1150 cx: &mut Context<Self>,
1151 ) {
1152 if !matches!(&self.options().context, ContextMode::Agentic { .. }) {
1153 return;
1154 }
1155
1156 let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1157 return;
1158 };
1159
1160 let now = Instant::now();
1161 let was_idle = zeta_project
1162 .refresh_context_timestamp
1163 .map_or(true, |timestamp| {
1164 now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION
1165 });
1166 zeta_project.refresh_context_timestamp = Some(now);
1167 zeta_project.refresh_context_debounce_task = Some(cx.spawn({
1168 let buffer = buffer.clone();
1169 let project = project.clone();
1170 async move |this, cx| {
1171 if was_idle {
1172 log::debug!("refetching edit prediction context after idle");
1173 } else {
1174 cx.background_executor()
1175 .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION)
1176 .await;
1177 log::debug!("refetching edit prediction context after pause");
1178 }
1179 this.update(cx, |this, cx| {
1180 let task = this.refresh_context(project.clone(), buffer, cursor_position, cx);
1181
1182 if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
1183 zeta_project.refresh_context_task = Some(task.log_err());
1184 };
1185 })
1186 .ok()
1187 }
1188 }));
1189 }
1190
1191 // Refresh the related excerpts asynchronously. Ensure the task runs to completion,
1192 // and avoid spawning more than one concurrent task.
1193 pub fn refresh_context(
1194 &mut self,
1195 project: Entity<Project>,
1196 buffer: Entity<language::Buffer>,
1197 cursor_position: language::Anchor,
1198 cx: &mut Context<Self>,
1199 ) -> Task<Result<()>> {
1200 let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
1201 return Task::ready(anyhow::Ok(()));
1202 };
1203
1204 let ContextMode::Agentic(options) = &self.options().context else {
1205 return Task::ready(anyhow::Ok(()));
1206 };
1207
1208 let snapshot = buffer.read(cx).snapshot();
1209 let cursor_point = cursor_position.to_point(&snapshot);
1210 let Some(cursor_excerpt) = EditPredictionExcerpt::select_from_buffer(
1211 cursor_point,
1212 &snapshot,
1213 &options.excerpt,
1214 None,
1215 ) else {
1216 return Task::ready(Ok(()));
1217 };
1218
1219 let current_file_path: Arc<Path> = snapshot
1220 .file()
1221 .map(|f| f.full_path(cx).into())
1222 .unwrap_or_else(|| Path::new("untitled").into());
1223
1224 let prompt = match cloud_zeta2_prompt::retrieval_prompt::build_prompt(
1225 predict_edits_v3::PlanContextRetrievalRequest {
1226 excerpt: cursor_excerpt.text(&snapshot).body,
1227 excerpt_path: current_file_path,
1228 excerpt_line_range: cursor_excerpt.line_range,
1229 cursor_file_max_row: Line(snapshot.max_point().row),
1230 events: zeta_project
1231 .events
1232 .iter()
1233 .filter_map(|ev| ev.to_request_event(cx))
1234 .collect(),
1235 },
1236 ) {
1237 Ok(prompt) => prompt,
1238 Err(err) => {
1239 return Task::ready(Err(err));
1240 }
1241 };
1242
1243 let app_version = AppVersion::global(cx);
1244 let client = self.client.clone();
1245 let llm_token = self.llm_token.clone();
1246 let debug_tx = self.debug_tx.clone();
1247
1248 let (tool_schema, tool_description) = &*cloud_zeta2_prompt::retrieval_prompt::TOOL_SCHEMA;
1249
1250 let request = open_ai::Request {
1251 model: std::env::var("ZED_ZETA2_MODEL").unwrap_or("2327jz9q".to_string()),
1252 messages: vec![open_ai::RequestMessage::User {
1253 content: open_ai::MessageContent::Plain(prompt),
1254 }],
1255 stream: false,
1256 max_completion_tokens: None,
1257 stop: Default::default(),
1258 temperature: 0.7,
1259 tool_choice: None,
1260 parallel_tool_calls: None,
1261 tools: vec![open_ai::ToolDefinition::Function {
1262 function: FunctionDefinition {
1263 name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME.to_string(),
1264 description: Some(tool_description.clone()),
1265 parameters: Some(tool_schema.clone()),
1266 },
1267 }],
1268 prompt_cache_key: None,
1269 reasoning_effort: None,
1270 };
1271
1272 cx.spawn(async move |this, cx| {
1273 log::trace!("Sending search planning request");
1274 let response =
1275 Self::send_raw_llm_request(client, llm_token, app_version, request).await;
1276 let mut response = Self::handle_api_response(&this, response, cx)?;
1277
1278 log::trace!("Got search planning response");
1279
1280 let choice = response
1281 .choices
1282 .pop()
1283 .context("No choices in retrieval response")?;
1284 let open_ai::RequestMessage::Assistant {
1285 content: _,
1286 tool_calls,
1287 } = choice.message
1288 else {
1289 anyhow::bail!("Retrieval response didn't include an assistant message");
1290 };
1291
1292 let mut regex_by_glob: HashMap<String, String> = HashMap::default();
1293 for tool_call in tool_calls {
1294 let open_ai::ToolCallContent::Function { function } = tool_call.content;
1295 if function.name != cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME {
1296 log::warn!(
1297 "Context retrieval response tried to call an unknown tool: {}",
1298 function.name
1299 );
1300
1301 continue;
1302 }
1303
1304 let input: SearchToolInput = serde_json::from_str(&function.arguments)?;
1305 for query in input.queries {
1306 let regex = regex_by_glob.entry(query.glob).or_default();
1307 if !regex.is_empty() {
1308 regex.push('|');
1309 }
1310 regex.push_str(&query.regex);
1311 }
1312 }
1313
1314 if let Some(debug_tx) = &debug_tx {
1315 debug_tx
1316 .unbounded_send(ZetaDebugInfo::SearchQueriesGenerated(
1317 ZetaSearchQueryDebugInfo {
1318 project: project.clone(),
1319 timestamp: Instant::now(),
1320 regex_by_glob: regex_by_glob.clone(),
1321 },
1322 ))
1323 .ok();
1324 }
1325
1326 log::trace!("Running retrieval search: {regex_by_glob:#?}");
1327
1328 let related_excerpts_result =
1329 retrieval_search::run_retrieval_searches(project.clone(), regex_by_glob, cx).await;
1330
1331 log::trace!("Search queries executed");
1332
1333 if let Some(debug_tx) = &debug_tx {
1334 debug_tx
1335 .unbounded_send(ZetaDebugInfo::SearchQueriesExecuted(
1336 ZetaContextRetrievalDebugInfo {
1337 project: project.clone(),
1338 timestamp: Instant::now(),
1339 },
1340 ))
1341 .ok();
1342 }
1343
1344 this.update(cx, |this, _cx| {
1345 let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
1346 return Ok(());
1347 };
1348 zeta_project.refresh_context_task.take();
1349 if let Some(debug_tx) = &this.debug_tx {
1350 debug_tx
1351 .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished(
1352 ZetaContextRetrievalDebugInfo {
1353 project,
1354 timestamp: Instant::now(),
1355 },
1356 ))
1357 .ok();
1358 }
1359 match related_excerpts_result {
1360 Ok(excerpts) => {
1361 zeta_project.context = Some(excerpts);
1362 Ok(())
1363 }
1364 Err(error) => Err(error),
1365 }
1366 })?
1367 })
1368 }
1369
1370 fn gather_nearby_diagnostics(
1371 cursor_offset: usize,
1372 diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
1373 snapshot: &BufferSnapshot,
1374 max_diagnostics_bytes: usize,
1375 ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
1376 // TODO: Could make this more efficient
1377 let mut diagnostic_groups = Vec::new();
1378 for (language_server_id, diagnostics) in diagnostic_sets {
1379 let mut groups = Vec::new();
1380 diagnostics.groups(*language_server_id, &mut groups, &snapshot);
1381 diagnostic_groups.extend(
1382 groups
1383 .into_iter()
1384 .map(|(_, group)| group.resolve::<usize>(&snapshot)),
1385 );
1386 }
1387
1388 // sort by proximity to cursor
1389 diagnostic_groups.sort_by_key(|group| {
1390 let range = &group.entries[group.primary_ix].range;
1391 if range.start >= cursor_offset {
1392 range.start - cursor_offset
1393 } else if cursor_offset >= range.end {
1394 cursor_offset - range.end
1395 } else {
1396 (cursor_offset - range.start).min(range.end - cursor_offset)
1397 }
1398 });
1399
1400 let mut results = Vec::new();
1401 let mut diagnostic_groups_truncated = false;
1402 let mut diagnostics_byte_count = 0;
1403 for group in diagnostic_groups {
1404 let raw_value = serde_json::value::to_raw_value(&group).unwrap();
1405 diagnostics_byte_count += raw_value.get().len();
1406 if diagnostics_byte_count > max_diagnostics_bytes {
1407 diagnostic_groups_truncated = true;
1408 break;
1409 }
1410 results.push(predict_edits_v3::DiagnosticGroup(raw_value));
1411 }
1412
1413 (results, diagnostic_groups_truncated)
1414 }
1415
1416 // TODO: Dedupe with similar code in request_prediction?
1417 pub fn cloud_request_for_zeta_cli(
1418 &mut self,
1419 project: &Entity<Project>,
1420 buffer: &Entity<Buffer>,
1421 position: language::Anchor,
1422 cx: &mut Context<Self>,
1423 ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
1424 let project_state = self.projects.get(&project.entity_id());
1425
1426 let index_state = project_state.map(|state| {
1427 state
1428 .syntax_index
1429 .read_with(cx, |index, _cx| index.state().clone())
1430 });
1431 let options = self.options.clone();
1432 let snapshot = buffer.read(cx).snapshot();
1433 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
1434 return Task::ready(Err(anyhow!("No file path for excerpt")));
1435 };
1436 let worktree_snapshots = project
1437 .read(cx)
1438 .worktrees(cx)
1439 .map(|worktree| worktree.read(cx).snapshot())
1440 .collect::<Vec<_>>();
1441
1442 let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
1443 let mut path = f.worktree.read(cx).absolutize(&f.path);
1444 if path.pop() { Some(path) } else { None }
1445 });
1446
1447 cx.background_spawn(async move {
1448 let index_state = if let Some(index_state) = index_state {
1449 Some(index_state.lock_owned().await)
1450 } else {
1451 None
1452 };
1453
1454 let cursor_point = position.to_point(&snapshot);
1455
1456 let debug_info = true;
1457 EditPredictionContext::gather_context(
1458 cursor_point,
1459 &snapshot,
1460 parent_abs_path.as_deref(),
1461 match &options.context {
1462 ContextMode::Agentic(_) => {
1463 // TODO
1464 panic!("Llm mode not supported in zeta cli yet");
1465 }
1466 ContextMode::Syntax(edit_prediction_context_options) => {
1467 edit_prediction_context_options
1468 }
1469 },
1470 index_state.as_deref(),
1471 )
1472 .context("Failed to select excerpt")
1473 .map(|context| {
1474 make_syntax_context_cloud_request(
1475 excerpt_path.into(),
1476 context,
1477 // TODO pass everything
1478 Vec::new(),
1479 false,
1480 Vec::new(),
1481 false,
1482 None,
1483 debug_info,
1484 &worktree_snapshots,
1485 index_state.as_deref(),
1486 Some(options.max_prompt_bytes),
1487 options.prompt_format,
1488 )
1489 })
1490 })
1491 }
1492
1493 pub fn wait_for_initial_indexing(
1494 &mut self,
1495 project: &Entity<Project>,
1496 cx: &mut App,
1497 ) -> Task<Result<()>> {
1498 let zeta_project = self.get_or_init_zeta_project(project, cx);
1499 zeta_project
1500 .syntax_index
1501 .read(cx)
1502 .wait_for_initial_file_indexing(cx)
1503 }
1504}
1505
1506#[derive(Error, Debug)]
1507#[error(
1508 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1509)]
1510pub struct ZedUpdateRequiredError {
1511 minimum_version: SemanticVersion,
1512}
1513
1514fn make_syntax_context_cloud_request(
1515 excerpt_path: Arc<Path>,
1516 context: EditPredictionContext,
1517 events: Vec<predict_edits_v3::Event>,
1518 can_collect_data: bool,
1519 diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
1520 diagnostic_groups_truncated: bool,
1521 git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
1522 debug_info: bool,
1523 worktrees: &Vec<worktree::Snapshot>,
1524 index_state: Option<&SyntaxIndexState>,
1525 prompt_max_bytes: Option<usize>,
1526 prompt_format: PromptFormat,
1527) -> predict_edits_v3::PredictEditsRequest {
1528 let mut signatures = Vec::new();
1529 let mut declaration_to_signature_index = HashMap::default();
1530 let mut referenced_declarations = Vec::new();
1531
1532 for snippet in context.declarations {
1533 let project_entry_id = snippet.declaration.project_entry_id();
1534 let Some(path) = worktrees.iter().find_map(|worktree| {
1535 worktree.entry_for_id(project_entry_id).map(|entry| {
1536 let mut full_path = RelPathBuf::new();
1537 full_path.push(worktree.root_name());
1538 full_path.push(&entry.path);
1539 full_path
1540 })
1541 }) else {
1542 continue;
1543 };
1544
1545 let parent_index = index_state.and_then(|index_state| {
1546 snippet.declaration.parent().and_then(|parent| {
1547 add_signature(
1548 parent,
1549 &mut declaration_to_signature_index,
1550 &mut signatures,
1551 index_state,
1552 )
1553 })
1554 });
1555
1556 let (text, text_is_truncated) = snippet.declaration.item_text();
1557 referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
1558 path: path.as_std_path().into(),
1559 text: text.into(),
1560 range: snippet.declaration.item_line_range(),
1561 text_is_truncated,
1562 signature_range: snippet.declaration.signature_range_in_item_text(),
1563 parent_index,
1564 signature_score: snippet.score(DeclarationStyle::Signature),
1565 declaration_score: snippet.score(DeclarationStyle::Declaration),
1566 score_components: snippet.components,
1567 });
1568 }
1569
1570 let excerpt_parent = index_state.and_then(|index_state| {
1571 context
1572 .excerpt
1573 .parent_declarations
1574 .last()
1575 .and_then(|(parent, _)| {
1576 add_signature(
1577 *parent,
1578 &mut declaration_to_signature_index,
1579 &mut signatures,
1580 index_state,
1581 )
1582 })
1583 });
1584
1585 predict_edits_v3::PredictEditsRequest {
1586 excerpt_path,
1587 excerpt: context.excerpt_text.body,
1588 excerpt_line_range: context.excerpt.line_range,
1589 excerpt_range: context.excerpt.range,
1590 cursor_point: predict_edits_v3::Point {
1591 line: predict_edits_v3::Line(context.cursor_point.row),
1592 column: context.cursor_point.column,
1593 },
1594 referenced_declarations,
1595 included_files: vec![],
1596 signatures,
1597 excerpt_parent,
1598 events,
1599 can_collect_data,
1600 diagnostic_groups,
1601 diagnostic_groups_truncated,
1602 git_info,
1603 debug_info,
1604 prompt_max_bytes,
1605 prompt_format,
1606 }
1607}
1608
1609fn add_signature(
1610 declaration_id: DeclarationId,
1611 declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
1612 signatures: &mut Vec<Signature>,
1613 index: &SyntaxIndexState,
1614) -> Option<usize> {
1615 if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
1616 return Some(*signature_index);
1617 }
1618 let Some(parent_declaration) = index.declaration(declaration_id) else {
1619 log::error!("bug: missing parent declaration");
1620 return None;
1621 };
1622 let parent_index = parent_declaration.parent().and_then(|parent| {
1623 add_signature(parent, declaration_to_signature_index, signatures, index)
1624 });
1625 let (text, text_is_truncated) = parent_declaration.signature_text();
1626 let signature_index = signatures.len();
1627 signatures.push(Signature {
1628 text: text.into(),
1629 text_is_truncated,
1630 parent_index,
1631 range: parent_declaration.signature_line_range(),
1632 });
1633 declaration_to_signature_index.insert(declaration_id, signature_index);
1634 Some(signature_index)
1635}
1636
1637#[cfg(test)]
1638mod tests {
1639 use std::{path::Path, sync::Arc};
1640
1641 use client::UserStore;
1642 use clock::FakeSystemClock;
1643 use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
1644 use futures::{
1645 AsyncReadExt, StreamExt,
1646 channel::{mpsc, oneshot},
1647 };
1648 use gpui::{
1649 Entity, TestAppContext,
1650 http_client::{FakeHttpClient, Response},
1651 prelude::*,
1652 };
1653 use indoc::indoc;
1654 use language::OffsetRangeExt as _;
1655 use open_ai::Usage;
1656 use pretty_assertions::{assert_eq, assert_matches};
1657 use project::{FakeFs, Project};
1658 use serde_json::json;
1659 use settings::SettingsStore;
1660 use util::path;
1661 use uuid::Uuid;
1662
1663 use crate::{BufferEditPrediction, Zeta};
1664
1665 #[gpui::test]
1666 async fn test_current_state(cx: &mut TestAppContext) {
1667 let (zeta, mut req_rx) = init_test(cx);
1668 let fs = FakeFs::new(cx.executor());
1669 fs.insert_tree(
1670 "/root",
1671 json!({
1672 "1.txt": "Hello!\nHow\nBye\n",
1673 "2.txt": "Hola!\nComo\nAdios\n"
1674 }),
1675 )
1676 .await;
1677 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1678
1679 zeta.update(cx, |zeta, cx| {
1680 zeta.register_project(&project, cx);
1681 });
1682
1683 let buffer1 = project
1684 .update(cx, |project, cx| {
1685 let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
1686 project.open_buffer(path, cx)
1687 })
1688 .await
1689 .unwrap();
1690 let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
1691 let position = snapshot1.anchor_before(language::Point::new(1, 3));
1692
1693 // Prediction for current file
1694
1695 let prediction_task = zeta.update(cx, |zeta, cx| {
1696 zeta.refresh_prediction(&project, &buffer1, position, cx)
1697 });
1698 let (_request, respond_tx) = req_rx.next().await.unwrap();
1699
1700 respond_tx
1701 .send(model_response(indoc! {r"
1702 --- a/root/1.txt
1703 +++ b/root/1.txt
1704 @@ ... @@
1705 Hello!
1706 -How
1707 +How are you?
1708 Bye
1709 "}))
1710 .unwrap();
1711 prediction_task.await.unwrap();
1712
1713 zeta.read_with(cx, |zeta, cx| {
1714 let prediction = zeta
1715 .current_prediction_for_buffer(&buffer1, &project, cx)
1716 .unwrap();
1717 assert_matches!(prediction, BufferEditPrediction::Local { .. });
1718 });
1719
1720 // Context refresh
1721 let refresh_task = zeta.update(cx, |zeta, cx| {
1722 zeta.refresh_context(project.clone(), buffer1.clone(), position, cx)
1723 });
1724 let (_request, respond_tx) = req_rx.next().await.unwrap();
1725 respond_tx
1726 .send(open_ai::Response {
1727 id: Uuid::new_v4().to_string(),
1728 object: "response".into(),
1729 created: 0,
1730 model: "model".into(),
1731 choices: vec![open_ai::Choice {
1732 index: 0,
1733 message: open_ai::RequestMessage::Assistant {
1734 content: None,
1735 tool_calls: vec![open_ai::ToolCall {
1736 id: "search".into(),
1737 content: open_ai::ToolCallContent::Function {
1738 function: open_ai::FunctionContent {
1739 name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME
1740 .to_string(),
1741 arguments: serde_json::to_string(&SearchToolInput {
1742 queries: Box::new([SearchToolQuery {
1743 glob: "root/2.txt".to_string(),
1744 regex: ".".to_string(),
1745 }]),
1746 })
1747 .unwrap(),
1748 },
1749 },
1750 }],
1751 },
1752 finish_reason: None,
1753 }],
1754 usage: Usage {
1755 prompt_tokens: 0,
1756 completion_tokens: 0,
1757 total_tokens: 0,
1758 },
1759 })
1760 .unwrap();
1761 refresh_task.await.unwrap();
1762
1763 zeta.update(cx, |zeta, _cx| {
1764 zeta.discard_current_prediction(&project);
1765 });
1766
1767 // Prediction for another file
1768 let prediction_task = zeta.update(cx, |zeta, cx| {
1769 zeta.refresh_prediction(&project, &buffer1, position, cx)
1770 });
1771 let (_request, respond_tx) = req_rx.next().await.unwrap();
1772 respond_tx
1773 .send(model_response(indoc! {r#"
1774 --- a/root/2.txt
1775 +++ b/root/2.txt
1776 Hola!
1777 -Como
1778 +Como estas?
1779 Adios
1780 "#}))
1781 .unwrap();
1782 prediction_task.await.unwrap();
1783 zeta.read_with(cx, |zeta, cx| {
1784 let prediction = zeta
1785 .current_prediction_for_buffer(&buffer1, &project, cx)
1786 .unwrap();
1787 assert_matches!(
1788 prediction,
1789 BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
1790 );
1791 });
1792
1793 let buffer2 = project
1794 .update(cx, |project, cx| {
1795 let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
1796 project.open_buffer(path, cx)
1797 })
1798 .await
1799 .unwrap();
1800
1801 zeta.read_with(cx, |zeta, cx| {
1802 let prediction = zeta
1803 .current_prediction_for_buffer(&buffer2, &project, cx)
1804 .unwrap();
1805 assert_matches!(prediction, BufferEditPrediction::Local { .. });
1806 });
1807 }
1808
1809 #[gpui::test]
1810 async fn test_simple_request(cx: &mut TestAppContext) {
1811 let (zeta, mut req_rx) = init_test(cx);
1812 let fs = FakeFs::new(cx.executor());
1813 fs.insert_tree(
1814 "/root",
1815 json!({
1816 "foo.md": "Hello!\nHow\nBye\n"
1817 }),
1818 )
1819 .await;
1820 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1821
1822 let buffer = project
1823 .update(cx, |project, cx| {
1824 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1825 project.open_buffer(path, cx)
1826 })
1827 .await
1828 .unwrap();
1829 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1830 let position = snapshot.anchor_before(language::Point::new(1, 3));
1831
1832 let prediction_task = zeta.update(cx, |zeta, cx| {
1833 zeta.request_prediction(&project, &buffer, position, cx)
1834 });
1835
1836 let (_, respond_tx) = req_rx.next().await.unwrap();
1837
1838 // TODO Put back when we have a structured request again
1839 // assert_eq!(
1840 // request.excerpt_path.as_ref(),
1841 // Path::new(path!("root/foo.md"))
1842 // );
1843 // assert_eq!(
1844 // request.cursor_point,
1845 // Point {
1846 // line: Line(1),
1847 // column: 3
1848 // }
1849 // );
1850
1851 respond_tx
1852 .send(model_response(indoc! { r"
1853 --- a/root/foo.md
1854 +++ b/root/foo.md
1855 @@ ... @@
1856 Hello!
1857 -How
1858 +How are you?
1859 Bye
1860 "}))
1861 .unwrap();
1862
1863 let prediction = prediction_task.await.unwrap().unwrap();
1864
1865 assert_eq!(prediction.edits.len(), 1);
1866 assert_eq!(
1867 prediction.edits[0].0.to_point(&snapshot).start,
1868 language::Point::new(1, 3)
1869 );
1870 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
1871 }
1872
1873 #[gpui::test]
1874 async fn test_request_events(cx: &mut TestAppContext) {
1875 let (zeta, mut req_rx) = init_test(cx);
1876 let fs = FakeFs::new(cx.executor());
1877 fs.insert_tree(
1878 "/root",
1879 json!({
1880 "foo.md": "Hello!\n\nBye\n"
1881 }),
1882 )
1883 .await;
1884 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1885
1886 let buffer = project
1887 .update(cx, |project, cx| {
1888 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1889 project.open_buffer(path, cx)
1890 })
1891 .await
1892 .unwrap();
1893
1894 zeta.update(cx, |zeta, cx| {
1895 zeta.register_buffer(&buffer, &project, cx);
1896 });
1897
1898 buffer.update(cx, |buffer, cx| {
1899 buffer.edit(vec![(7..7, "How")], None, cx);
1900 });
1901
1902 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1903 let position = snapshot.anchor_before(language::Point::new(1, 3));
1904
1905 let prediction_task = zeta.update(cx, |zeta, cx| {
1906 zeta.request_prediction(&project, &buffer, position, cx)
1907 });
1908
1909 let (request, respond_tx) = req_rx.next().await.unwrap();
1910
1911 let prompt = prompt_from_request(&request);
1912 assert!(
1913 prompt.contains(indoc! {"
1914 --- a/root/foo.md
1915 +++ b/root/foo.md
1916 @@ -1,3 +1,3 @@
1917 Hello!
1918 -
1919 +How
1920 Bye
1921 "}),
1922 "{prompt}"
1923 );
1924
1925 respond_tx
1926 .send(model_response(indoc! {r#"
1927 --- a/root/foo.md
1928 +++ b/root/foo.md
1929 @@ ... @@
1930 Hello!
1931 -How
1932 +How are you?
1933 Bye
1934 "#}))
1935 .unwrap();
1936
1937 let prediction = prediction_task.await.unwrap().unwrap();
1938
1939 assert_eq!(prediction.edits.len(), 1);
1940 assert_eq!(
1941 prediction.edits[0].0.to_point(&snapshot).start,
1942 language::Point::new(1, 3)
1943 );
1944 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
1945 }
1946
1947 // Skipped until we start including diagnostics in prompt
1948 // #[gpui::test]
1949 // async fn test_request_diagnostics(cx: &mut TestAppContext) {
1950 // let (zeta, mut req_rx) = init_test(cx);
1951 // let fs = FakeFs::new(cx.executor());
1952 // fs.insert_tree(
1953 // "/root",
1954 // json!({
1955 // "foo.md": "Hello!\nBye"
1956 // }),
1957 // )
1958 // .await;
1959 // let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1960
1961 // let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1962 // let diagnostic = lsp::Diagnostic {
1963 // range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1964 // severity: Some(lsp::DiagnosticSeverity::ERROR),
1965 // message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1966 // ..Default::default()
1967 // };
1968
1969 // project.update(cx, |project, cx| {
1970 // project.lsp_store().update(cx, |lsp_store, cx| {
1971 // // Create some diagnostics
1972 // lsp_store
1973 // .update_diagnostics(
1974 // LanguageServerId(0),
1975 // lsp::PublishDiagnosticsParams {
1976 // uri: path_to_buffer_uri.clone(),
1977 // diagnostics: vec![diagnostic],
1978 // version: None,
1979 // },
1980 // None,
1981 // language::DiagnosticSourceKind::Pushed,
1982 // &[],
1983 // cx,
1984 // )
1985 // .unwrap();
1986 // });
1987 // });
1988
1989 // let buffer = project
1990 // .update(cx, |project, cx| {
1991 // let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1992 // project.open_buffer(path, cx)
1993 // })
1994 // .await
1995 // .unwrap();
1996
1997 // let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1998 // let position = snapshot.anchor_before(language::Point::new(0, 0));
1999
2000 // let _prediction_task = zeta.update(cx, |zeta, cx| {
2001 // zeta.request_prediction(&project, &buffer, position, cx)
2002 // });
2003
2004 // let (request, _respond_tx) = req_rx.next().await.unwrap();
2005
2006 // assert_eq!(request.diagnostic_groups.len(), 1);
2007 // let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
2008 // .unwrap();
2009 // // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
2010 // assert_eq!(
2011 // value,
2012 // json!({
2013 // "entries": [{
2014 // "range": {
2015 // "start": 8,
2016 // "end": 10
2017 // },
2018 // "diagnostic": {
2019 // "source": null,
2020 // "code": null,
2021 // "code_description": null,
2022 // "severity": 1,
2023 // "message": "\"Hello\" deprecated. Use \"Hi\" instead",
2024 // "markdown": null,
2025 // "group_id": 0,
2026 // "is_primary": true,
2027 // "is_disk_based": false,
2028 // "is_unnecessary": false,
2029 // "source_kind": "Pushed",
2030 // "data": null,
2031 // "underline": true
2032 // }
2033 // }],
2034 // "primary_ix": 0
2035 // })
2036 // );
2037 // }
2038
2039 fn model_response(text: &str) -> open_ai::Response {
2040 open_ai::Response {
2041 id: Uuid::new_v4().to_string(),
2042 object: "response".into(),
2043 created: 0,
2044 model: "model".into(),
2045 choices: vec![open_ai::Choice {
2046 index: 0,
2047 message: open_ai::RequestMessage::Assistant {
2048 content: Some(open_ai::MessageContent::Plain(text.to_string())),
2049 tool_calls: vec![],
2050 },
2051 finish_reason: None,
2052 }],
2053 usage: Usage {
2054 prompt_tokens: 0,
2055 completion_tokens: 0,
2056 total_tokens: 0,
2057 },
2058 }
2059 }
2060
2061 fn prompt_from_request(request: &open_ai::Request) -> &str {
2062 assert_eq!(request.messages.len(), 1);
2063 let open_ai::RequestMessage::User {
2064 content: open_ai::MessageContent::Plain(content),
2065 ..
2066 } = &request.messages[0]
2067 else {
2068 panic!(
2069 "Request does not have single user message of type Plain. {:#?}",
2070 request
2071 );
2072 };
2073 content
2074 }
2075
2076 fn init_test(
2077 cx: &mut TestAppContext,
2078 ) -> (
2079 Entity<Zeta>,
2080 mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender<open_ai::Response>)>,
2081 ) {
2082 cx.update(move |cx| {
2083 let settings_store = SettingsStore::test(cx);
2084 cx.set_global(settings_store);
2085 language::init(cx);
2086 Project::init_settings(cx);
2087 zlog::init_test();
2088
2089 let (req_tx, req_rx) = mpsc::unbounded();
2090
2091 let http_client = FakeHttpClient::create({
2092 move |req| {
2093 let uri = req.uri().path().to_string();
2094 let mut body = req.into_body();
2095 let req_tx = req_tx.clone();
2096 async move {
2097 let resp = match uri.as_str() {
2098 "/client/llm_tokens" => serde_json::to_string(&json!({
2099 "token": "test"
2100 }))
2101 .unwrap(),
2102 "/predict_edits/raw" => {
2103 let mut buf = Vec::new();
2104 body.read_to_end(&mut buf).await.ok();
2105 let req = serde_json::from_slice(&buf).unwrap();
2106
2107 let (res_tx, res_rx) = oneshot::channel();
2108 req_tx.unbounded_send((req, res_tx)).unwrap();
2109 serde_json::to_string(&res_rx.await?).unwrap()
2110 }
2111 _ => {
2112 panic!("Unexpected path: {}", uri)
2113 }
2114 };
2115
2116 Ok(Response::builder().body(resp.into()).unwrap())
2117 }
2118 }
2119 });
2120
2121 let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
2122 client.cloud_client().set_credentials(1, "test".into());
2123
2124 language_model::init(client.clone(), cx);
2125
2126 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2127 let zeta = Zeta::global(&client, &user_store, cx);
2128
2129 (zeta, req_rx)
2130 })
2131 }
2132}