1use anyhow::{Context as _, Result, anyhow};
2use chrono::TimeDelta;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
5use cloud_llm_client::{
6 AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
7 ZED_VERSION_HEADER_NAME,
8};
9use cloud_zeta2_prompt::DEFAULT_MAX_PROMPT_BYTES;
10use cloud_zeta2_prompt::retrieval_prompt::SearchToolInput;
11use collections::HashMap;
12use edit_prediction_context::{
13 DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
14 EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionScoreOptions, Line,
15 SyntaxIndex, SyntaxIndexState,
16};
17use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
18use futures::AsyncReadExt as _;
19use futures::channel::{mpsc, oneshot};
20use gpui::http_client::{AsyncBody, Method};
21use gpui::{
22 App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
23 http_client, prelude::*,
24};
25use language::{Anchor, Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
26use language::{BufferSnapshot, OffsetRangeExt};
27use language_model::{LlmApiToken, RefreshLlmTokenListener};
28use open_ai::FunctionDefinition;
29use project::Project;
30use release_channel::AppVersion;
31use serde::de::DeserializeOwned;
32use std::collections::{VecDeque, hash_map};
33use uuid::Uuid;
34
35use std::ops::Range;
36use std::path::Path;
37use std::str::FromStr as _;
38use std::sync::Arc;
39use std::time::{Duration, Instant};
40use thiserror::Error;
41use util::rel_path::RelPathBuf;
42use util::{LogErrorFuture, TryFutureExt};
43use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
44
45pub mod merge_excerpts;
46mod prediction;
47mod provider;
48pub mod retrieval_search;
49pub mod udiff;
50
51use crate::merge_excerpts::merge_excerpts;
52use crate::prediction::EditPrediction;
53pub use crate::prediction::EditPredictionId;
54pub use provider::ZetaEditPredictionProvider;
55
56/// Maximum number of events to track.
57const MAX_EVENT_COUNT: usize = 16;
58
59pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
60 max_bytes: 512,
61 min_bytes: 128,
62 target_before_cursor_over_total_bytes: 0.5,
63};
64
65pub const DEFAULT_CONTEXT_OPTIONS: ContextMode =
66 ContextMode::Agentic(DEFAULT_AGENTIC_CONTEXT_OPTIONS);
67
68pub const DEFAULT_AGENTIC_CONTEXT_OPTIONS: AgenticContextOptions = AgenticContextOptions {
69 excerpt: DEFAULT_EXCERPT_OPTIONS,
70};
71
72pub const DEFAULT_SYNTAX_CONTEXT_OPTIONS: EditPredictionContextOptions =
73 EditPredictionContextOptions {
74 use_imports: true,
75 max_retrieved_declarations: 0,
76 excerpt: DEFAULT_EXCERPT_OPTIONS,
77 score: EditPredictionScoreOptions {
78 omit_excerpt_overlaps: true,
79 },
80 };
81
82pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
83 context: DEFAULT_CONTEXT_OPTIONS,
84 max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
85 max_diagnostic_bytes: 2048,
86 prompt_format: PromptFormat::DEFAULT,
87 file_indexing_parallelism: 1,
88 buffer_change_grouping_interval: Duration::from_secs(1),
89};
90
91pub struct Zeta2FeatureFlag;
92
93impl FeatureFlag for Zeta2FeatureFlag {
94 const NAME: &'static str = "zeta2";
95
96 fn enabled_for_staff() -> bool {
97 false
98 }
99}
100
101#[derive(Clone)]
102struct ZetaGlobal(Entity<Zeta>);
103
104impl Global for ZetaGlobal {}
105
106pub struct Zeta {
107 client: Arc<Client>,
108 user_store: Entity<UserStore>,
109 llm_token: LlmApiToken,
110 _llm_token_subscription: Subscription,
111 projects: HashMap<EntityId, ZetaProject>,
112 options: ZetaOptions,
113 update_required: bool,
114 debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
115}
116
117#[derive(Debug, Clone, PartialEq)]
118pub struct ZetaOptions {
119 pub context: ContextMode,
120 pub max_prompt_bytes: usize,
121 pub max_diagnostic_bytes: usize,
122 pub prompt_format: predict_edits_v3::PromptFormat,
123 pub file_indexing_parallelism: usize,
124 pub buffer_change_grouping_interval: Duration,
125}
126
127#[derive(Debug, Clone, PartialEq)]
128pub enum ContextMode {
129 Agentic(AgenticContextOptions),
130 Syntax(EditPredictionContextOptions),
131}
132
133#[derive(Debug, Clone, PartialEq)]
134pub struct AgenticContextOptions {
135 pub excerpt: EditPredictionExcerptOptions,
136}
137
138impl ContextMode {
139 pub fn excerpt(&self) -> &EditPredictionExcerptOptions {
140 match self {
141 ContextMode::Agentic(options) => &options.excerpt,
142 ContextMode::Syntax(options) => &options.excerpt,
143 }
144 }
145}
146
147#[derive(Debug)]
148pub enum ZetaDebugInfo {
149 ContextRetrievalStarted(ZetaContextRetrievalStartedDebugInfo),
150 SearchQueriesGenerated(ZetaSearchQueryDebugInfo),
151 SearchQueriesExecuted(ZetaContextRetrievalDebugInfo),
152 ContextRetrievalFinished(ZetaContextRetrievalDebugInfo),
153 EditPredictionRequested(ZetaEditPredictionDebugInfo),
154}
155
156#[derive(Debug)]
157pub struct ZetaContextRetrievalStartedDebugInfo {
158 pub project: Entity<Project>,
159 pub timestamp: Instant,
160 pub search_prompt: String,
161}
162
163#[derive(Debug)]
164pub struct ZetaContextRetrievalDebugInfo {
165 pub project: Entity<Project>,
166 pub timestamp: Instant,
167}
168
169#[derive(Debug)]
170pub struct ZetaEditPredictionDebugInfo {
171 pub request: predict_edits_v3::PredictEditsRequest,
172 pub retrieval_time: TimeDelta,
173 pub buffer: WeakEntity<Buffer>,
174 pub position: language::Anchor,
175 pub local_prompt: Result<String, String>,
176 pub response_rx: oneshot::Receiver<(Result<open_ai::Response, String>, TimeDelta)>,
177}
178
179#[derive(Debug)]
180pub struct ZetaSearchQueryDebugInfo {
181 pub project: Entity<Project>,
182 pub timestamp: Instant,
183 pub regex_by_glob: HashMap<String, String>,
184}
185
186pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
187
188struct ZetaProject {
189 syntax_index: Entity<SyntaxIndex>,
190 events: VecDeque<Event>,
191 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
192 current_prediction: Option<CurrentEditPrediction>,
193 context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
194 refresh_context_task: Option<LogErrorFuture<Task<Result<()>>>>,
195 refresh_context_debounce_task: Option<Task<Option<()>>>,
196 refresh_context_timestamp: Option<Instant>,
197}
198
199#[derive(Debug, Clone)]
200struct CurrentEditPrediction {
201 pub requested_by_buffer_id: EntityId,
202 pub prediction: EditPrediction,
203}
204
205impl CurrentEditPrediction {
206 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
207 let Some(new_edits) = self
208 .prediction
209 .interpolate(&self.prediction.buffer.read(cx))
210 else {
211 return false;
212 };
213
214 if self.prediction.buffer != old_prediction.prediction.buffer {
215 return true;
216 }
217
218 let Some(old_edits) = old_prediction
219 .prediction
220 .interpolate(&old_prediction.prediction.buffer.read(cx))
221 else {
222 return true;
223 };
224
225 // This reduces the occurrence of UI thrash from replacing edits
226 //
227 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
228 if self.requested_by_buffer_id == self.prediction.buffer.entity_id()
229 && self.requested_by_buffer_id == old_prediction.prediction.buffer.entity_id()
230 && old_edits.len() == 1
231 && new_edits.len() == 1
232 {
233 let (old_range, old_text) = &old_edits[0];
234 let (new_range, new_text) = &new_edits[0];
235 new_range == old_range && new_text.starts_with(old_text.as_ref())
236 } else {
237 true
238 }
239 }
240}
241
242/// A prediction from the perspective of a buffer.
243#[derive(Debug)]
244enum BufferEditPrediction<'a> {
245 Local { prediction: &'a EditPrediction },
246 Jump { prediction: &'a EditPrediction },
247}
248
249struct RegisteredBuffer {
250 snapshot: BufferSnapshot,
251 _subscriptions: [gpui::Subscription; 2],
252}
253
254#[derive(Clone)]
255pub enum Event {
256 BufferChange {
257 old_snapshot: BufferSnapshot,
258 new_snapshot: BufferSnapshot,
259 timestamp: Instant,
260 },
261}
262
263impl Event {
264 pub fn to_request_event(&self, cx: &App) -> Option<predict_edits_v3::Event> {
265 match self {
266 Event::BufferChange {
267 old_snapshot,
268 new_snapshot,
269 ..
270 } => {
271 let path = new_snapshot.file().map(|f| f.full_path(cx));
272
273 let old_path = old_snapshot.file().and_then(|f| {
274 let old_path = f.full_path(cx);
275 if Some(&old_path) != path.as_ref() {
276 Some(old_path)
277 } else {
278 None
279 }
280 });
281
282 // TODO [zeta2] move to bg?
283 let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
284
285 if path == old_path && diff.is_empty() {
286 None
287 } else {
288 Some(predict_edits_v3::Event::BufferChange {
289 old_path,
290 path,
291 diff,
292 //todo: Actually detect if this edit was predicted or not
293 predicted: false,
294 })
295 }
296 }
297 }
298 }
299}
300
301impl Zeta {
302 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
303 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
304 }
305
306 pub fn global(
307 client: &Arc<Client>,
308 user_store: &Entity<UserStore>,
309 cx: &mut App,
310 ) -> Entity<Self> {
311 cx.try_global::<ZetaGlobal>()
312 .map(|global| global.0.clone())
313 .unwrap_or_else(|| {
314 let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
315 cx.set_global(ZetaGlobal(zeta.clone()));
316 zeta
317 })
318 }
319
320 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
321 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
322
323 Self {
324 projects: HashMap::default(),
325 client,
326 user_store,
327 options: DEFAULT_OPTIONS,
328 llm_token: LlmApiToken::default(),
329 _llm_token_subscription: cx.subscribe(
330 &refresh_llm_token_listener,
331 |this, _listener, _event, cx| {
332 let client = this.client.clone();
333 let llm_token = this.llm_token.clone();
334 cx.spawn(async move |_this, _cx| {
335 llm_token.refresh(&client).await?;
336 anyhow::Ok(())
337 })
338 .detach_and_log_err(cx);
339 },
340 ),
341 update_required: false,
342 debug_tx: None,
343 }
344 }
345
346 pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<ZetaDebugInfo> {
347 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
348 self.debug_tx = Some(debug_watch_tx);
349 debug_watch_rx
350 }
351
352 pub fn options(&self) -> &ZetaOptions {
353 &self.options
354 }
355
356 pub fn set_options(&mut self, options: ZetaOptions) {
357 self.options = options;
358 }
359
360 pub fn clear_history(&mut self) {
361 for zeta_project in self.projects.values_mut() {
362 zeta_project.events.clear();
363 }
364 }
365
366 pub fn history_for_project(&self, project: &Entity<Project>) -> impl Iterator<Item = &Event> {
367 self.projects
368 .get(&project.entity_id())
369 .map(|project| project.events.iter())
370 .into_iter()
371 .flatten()
372 }
373
374 pub fn context_for_project(
375 &self,
376 project: &Entity<Project>,
377 ) -> impl Iterator<Item = (Entity<Buffer>, &[Range<Anchor>])> {
378 self.projects
379 .get(&project.entity_id())
380 .and_then(|project| {
381 Some(
382 project
383 .context
384 .as_ref()?
385 .iter()
386 .map(|(buffer, ranges)| (buffer.clone(), ranges.as_slice())),
387 )
388 })
389 .into_iter()
390 .flatten()
391 }
392
393 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
394 self.user_store.read(cx).edit_prediction_usage()
395 }
396
397 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
398 self.get_or_init_zeta_project(project, cx);
399 }
400
401 pub fn register_buffer(
402 &mut self,
403 buffer: &Entity<Buffer>,
404 project: &Entity<Project>,
405 cx: &mut Context<Self>,
406 ) {
407 let zeta_project = self.get_or_init_zeta_project(project, cx);
408 Self::register_buffer_impl(zeta_project, buffer, project, cx);
409 }
410
411 fn get_or_init_zeta_project(
412 &mut self,
413 project: &Entity<Project>,
414 cx: &mut App,
415 ) -> &mut ZetaProject {
416 self.projects
417 .entry(project.entity_id())
418 .or_insert_with(|| ZetaProject {
419 syntax_index: cx.new(|cx| {
420 SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
421 }),
422 events: VecDeque::new(),
423 registered_buffers: HashMap::default(),
424 current_prediction: None,
425 context: None,
426 refresh_context_task: None,
427 refresh_context_debounce_task: None,
428 refresh_context_timestamp: None,
429 })
430 }
431
432 fn register_buffer_impl<'a>(
433 zeta_project: &'a mut ZetaProject,
434 buffer: &Entity<Buffer>,
435 project: &Entity<Project>,
436 cx: &mut Context<Self>,
437 ) -> &'a mut RegisteredBuffer {
438 let buffer_id = buffer.entity_id();
439 match zeta_project.registered_buffers.entry(buffer_id) {
440 hash_map::Entry::Occupied(entry) => entry.into_mut(),
441 hash_map::Entry::Vacant(entry) => {
442 let snapshot = buffer.read(cx).snapshot();
443 let project_entity_id = project.entity_id();
444 entry.insert(RegisteredBuffer {
445 snapshot,
446 _subscriptions: [
447 cx.subscribe(buffer, {
448 let project = project.downgrade();
449 move |this, buffer, event, cx| {
450 if let language::BufferEvent::Edited = event
451 && let Some(project) = project.upgrade()
452 {
453 this.report_changes_for_buffer(&buffer, &project, cx);
454 }
455 }
456 }),
457 cx.observe_release(buffer, move |this, _buffer, _cx| {
458 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
459 else {
460 return;
461 };
462 zeta_project.registered_buffers.remove(&buffer_id);
463 }),
464 ],
465 })
466 }
467 }
468 }
469
470 fn report_changes_for_buffer(
471 &mut self,
472 buffer: &Entity<Buffer>,
473 project: &Entity<Project>,
474 cx: &mut Context<Self>,
475 ) -> BufferSnapshot {
476 let buffer_change_grouping_interval = self.options.buffer_change_grouping_interval;
477 let zeta_project = self.get_or_init_zeta_project(project, cx);
478 let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
479
480 let new_snapshot = buffer.read(cx).snapshot();
481 if new_snapshot.version != registered_buffer.snapshot.version {
482 let old_snapshot =
483 std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
484 Self::push_event(
485 zeta_project,
486 buffer_change_grouping_interval,
487 Event::BufferChange {
488 old_snapshot,
489 new_snapshot: new_snapshot.clone(),
490 timestamp: Instant::now(),
491 },
492 );
493 }
494
495 new_snapshot
496 }
497
498 fn push_event(
499 zeta_project: &mut ZetaProject,
500 buffer_change_grouping_interval: Duration,
501 event: Event,
502 ) {
503 let events = &mut zeta_project.events;
504
505 if buffer_change_grouping_interval > Duration::ZERO
506 && let Some(Event::BufferChange {
507 new_snapshot: last_new_snapshot,
508 timestamp: last_timestamp,
509 ..
510 }) = events.back_mut()
511 {
512 // Coalesce edits for the same buffer when they happen one after the other.
513 let Event::BufferChange {
514 old_snapshot,
515 new_snapshot,
516 timestamp,
517 } = &event;
518
519 if timestamp.duration_since(*last_timestamp) <= buffer_change_grouping_interval
520 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
521 && old_snapshot.version == last_new_snapshot.version
522 {
523 *last_new_snapshot = new_snapshot.clone();
524 *last_timestamp = *timestamp;
525 return;
526 }
527 }
528
529 if events.len() >= MAX_EVENT_COUNT {
530 // These are halved instead of popping to improve prompt caching.
531 events.drain(..MAX_EVENT_COUNT / 2);
532 }
533
534 events.push_back(event);
535 }
536
537 fn current_prediction_for_buffer(
538 &self,
539 buffer: &Entity<Buffer>,
540 project: &Entity<Project>,
541 cx: &App,
542 ) -> Option<BufferEditPrediction<'_>> {
543 let project_state = self.projects.get(&project.entity_id())?;
544
545 let CurrentEditPrediction {
546 requested_by_buffer_id,
547 prediction,
548 } = project_state.current_prediction.as_ref()?;
549
550 if prediction.targets_buffer(buffer.read(cx)) {
551 Some(BufferEditPrediction::Local { prediction })
552 } else if *requested_by_buffer_id == buffer.entity_id() {
553 Some(BufferEditPrediction::Jump { prediction })
554 } else {
555 None
556 }
557 }
558
559 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
560 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
561 return;
562 };
563
564 let Some(prediction) = project_state.current_prediction.take() else {
565 return;
566 };
567 let request_id = prediction.prediction.id.into();
568
569 let client = self.client.clone();
570 let llm_token = self.llm_token.clone();
571 let app_version = AppVersion::global(cx);
572 cx.spawn(async move |this, cx| {
573 let url = if let Ok(predict_edits_url) = std::env::var("ZED_ACCEPT_PREDICTION_URL") {
574 http_client::Url::parse(&predict_edits_url)?
575 } else {
576 client
577 .http_client()
578 .build_zed_llm_url("/predict_edits/accept", &[])?
579 };
580
581 let response = cx
582 .background_spawn(Self::send_api_request::<()>(
583 move |builder| {
584 let req = builder.uri(url.as_ref()).body(
585 serde_json::to_string(&AcceptEditPredictionBody { request_id })?.into(),
586 );
587 Ok(req?)
588 },
589 client,
590 llm_token,
591 app_version,
592 ))
593 .await;
594
595 Self::handle_api_response(&this, response, cx)?;
596 anyhow::Ok(())
597 })
598 .detach_and_log_err(cx);
599 }
600
601 fn discard_current_prediction(&mut self, project: &Entity<Project>) {
602 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
603 project_state.current_prediction.take();
604 };
605 }
606
607 pub fn refresh_prediction(
608 &mut self,
609 project: &Entity<Project>,
610 buffer: &Entity<Buffer>,
611 position: language::Anchor,
612 cx: &mut Context<Self>,
613 ) -> Task<Result<()>> {
614 let request_task = self.request_prediction(project, buffer, position, cx);
615 let buffer = buffer.clone();
616 let project = project.clone();
617
618 cx.spawn(async move |this, cx| {
619 if let Some(prediction) = request_task.await? {
620 this.update(cx, |this, cx| {
621 let project_state = this
622 .projects
623 .get_mut(&project.entity_id())
624 .context("Project not found")?;
625
626 let new_prediction = CurrentEditPrediction {
627 requested_by_buffer_id: buffer.entity_id(),
628 prediction: prediction,
629 };
630
631 if project_state
632 .current_prediction
633 .as_ref()
634 .is_none_or(|old_prediction| {
635 new_prediction.should_replace_prediction(&old_prediction, cx)
636 })
637 {
638 project_state.current_prediction = Some(new_prediction);
639 }
640 anyhow::Ok(())
641 })??;
642 }
643 Ok(())
644 })
645 }
646
647 pub fn request_prediction(
648 &mut self,
649 project: &Entity<Project>,
650 active_buffer: &Entity<Buffer>,
651 position: language::Anchor,
652 cx: &mut Context<Self>,
653 ) -> Task<Result<Option<EditPrediction>>> {
654 let project_state = self.projects.get(&project.entity_id());
655
656 let index_state = project_state.map(|state| {
657 state
658 .syntax_index
659 .read_with(cx, |index, _cx| index.state().clone())
660 });
661 let options = self.options.clone();
662 let active_snapshot = active_buffer.read(cx).snapshot();
663 let Some(excerpt_path) = active_snapshot
664 .file()
665 .map(|path| -> Arc<Path> { path.full_path(cx).into() })
666 else {
667 return Task::ready(Err(anyhow!("No file path for excerpt")));
668 };
669 let client = self.client.clone();
670 let llm_token = self.llm_token.clone();
671 let app_version = AppVersion::global(cx);
672 let worktree_snapshots = project
673 .read(cx)
674 .worktrees(cx)
675 .map(|worktree| worktree.read(cx).snapshot())
676 .collect::<Vec<_>>();
677 let debug_tx = self.debug_tx.clone();
678
679 let events = project_state
680 .map(|state| {
681 state
682 .events
683 .iter()
684 .filter_map(|event| event.to_request_event(cx))
685 .collect::<Vec<_>>()
686 })
687 .unwrap_or_default();
688
689 let diagnostics = active_snapshot.diagnostic_sets().clone();
690
691 let parent_abs_path =
692 project::File::from_dyn(active_buffer.read(cx).file()).and_then(|f| {
693 let mut path = f.worktree.read(cx).absolutize(&f.path);
694 if path.pop() { Some(path) } else { None }
695 });
696
697 // TODO data collection
698 let can_collect_data = cx.is_staff();
699
700 let mut included_files = project_state
701 .and_then(|project_state| project_state.context.as_ref())
702 .unwrap_or(&HashMap::default())
703 .iter()
704 .filter_map(|(buffer_entity, ranges)| {
705 let buffer = buffer_entity.read(cx);
706 Some((
707 buffer_entity.clone(),
708 buffer.snapshot(),
709 buffer.file()?.full_path(cx).into(),
710 ranges.clone(),
711 ))
712 })
713 .collect::<Vec<_>>();
714
715 let request_task = cx.background_spawn({
716 let active_buffer = active_buffer.clone();
717 async move {
718 let index_state = if let Some(index_state) = index_state {
719 Some(index_state.lock_owned().await)
720 } else {
721 None
722 };
723
724 let cursor_offset = position.to_offset(&active_snapshot);
725 let cursor_point = cursor_offset.to_point(&active_snapshot);
726
727 let before_retrieval = chrono::Utc::now();
728
729 let (diagnostic_groups, diagnostic_groups_truncated) =
730 Self::gather_nearby_diagnostics(
731 cursor_offset,
732 &diagnostics,
733 &active_snapshot,
734 options.max_diagnostic_bytes,
735 );
736
737 let cloud_request = match options.context {
738 ContextMode::Agentic(context_options) => {
739 let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
740 cursor_point,
741 &active_snapshot,
742 &context_options.excerpt,
743 index_state.as_deref(),
744 ) else {
745 return Ok((None, None));
746 };
747
748 let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start)
749 ..active_snapshot.anchor_before(excerpt.range.end);
750
751 if let Some(buffer_ix) =
752 included_files.iter().position(|(_, snapshot, _, _)| {
753 snapshot.remote_id() == active_snapshot.remote_id()
754 })
755 {
756 let (_, buffer, _, ranges) = &mut included_files[buffer_ix];
757 let range_ix = ranges
758 .binary_search_by(|probe| {
759 probe
760 .start
761 .cmp(&excerpt_anchor_range.start, buffer)
762 .then(excerpt_anchor_range.end.cmp(&probe.end, buffer))
763 })
764 .unwrap_or_else(|ix| ix);
765
766 ranges.insert(range_ix, excerpt_anchor_range);
767 let last_ix = included_files.len() - 1;
768 included_files.swap(buffer_ix, last_ix);
769 } else {
770 included_files.push((
771 active_buffer.clone(),
772 active_snapshot,
773 excerpt_path.clone(),
774 vec![excerpt_anchor_range],
775 ));
776 }
777
778 let included_files = included_files
779 .iter()
780 .map(|(_, buffer, path, ranges)| {
781 let excerpts = merge_excerpts(
782 &buffer,
783 ranges.iter().map(|range| {
784 let point_range = range.to_point(&buffer);
785 Line(point_range.start.row)..Line(point_range.end.row)
786 }),
787 );
788 predict_edits_v3::IncludedFile {
789 path: path.clone(),
790 max_row: Line(buffer.max_point().row),
791 excerpts,
792 }
793 })
794 .collect::<Vec<_>>();
795
796 predict_edits_v3::PredictEditsRequest {
797 excerpt_path,
798 excerpt: String::new(),
799 excerpt_line_range: Line(0)..Line(0),
800 excerpt_range: 0..0,
801 cursor_point: predict_edits_v3::Point {
802 line: predict_edits_v3::Line(cursor_point.row),
803 column: cursor_point.column,
804 },
805 included_files,
806 referenced_declarations: vec![],
807 events,
808 can_collect_data,
809 diagnostic_groups,
810 diagnostic_groups_truncated,
811 debug_info: debug_tx.is_some(),
812 prompt_max_bytes: Some(options.max_prompt_bytes),
813 prompt_format: options.prompt_format,
814 // TODO [zeta2]
815 signatures: vec![],
816 excerpt_parent: None,
817 git_info: None,
818 }
819 }
820 ContextMode::Syntax(context_options) => {
821 let Some(context) = EditPredictionContext::gather_context(
822 cursor_point,
823 &active_snapshot,
824 parent_abs_path.as_deref(),
825 &context_options,
826 index_state.as_deref(),
827 ) else {
828 return Ok((None, None));
829 };
830
831 make_syntax_context_cloud_request(
832 excerpt_path,
833 context,
834 events,
835 can_collect_data,
836 diagnostic_groups,
837 diagnostic_groups_truncated,
838 None,
839 debug_tx.is_some(),
840 &worktree_snapshots,
841 index_state.as_deref(),
842 Some(options.max_prompt_bytes),
843 options.prompt_format,
844 )
845 }
846 };
847
848 let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request);
849
850 let retrieval_time = chrono::Utc::now() - before_retrieval;
851
852 let debug_response_tx = if let Some(debug_tx) = &debug_tx {
853 let (response_tx, response_rx) = oneshot::channel();
854
855 debug_tx
856 .unbounded_send(ZetaDebugInfo::EditPredictionRequested(
857 ZetaEditPredictionDebugInfo {
858 request: cloud_request.clone(),
859 retrieval_time,
860 buffer: active_buffer.downgrade(),
861 local_prompt: match prompt_result.as_ref() {
862 Ok((prompt, _)) => Ok(prompt.clone()),
863 Err(err) => Err(err.to_string()),
864 },
865 position,
866 response_rx,
867 },
868 ))
869 .ok();
870 Some(response_tx)
871 } else {
872 None
873 };
874
875 if cfg!(debug_assertions) && std::env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
876 if let Some(debug_response_tx) = debug_response_tx {
877 debug_response_tx
878 .send((Err("Request skipped".to_string()), TimeDelta::zero()))
879 .ok();
880 }
881 anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
882 }
883
884 let (prompt, _) = prompt_result?;
885 let request = open_ai::Request {
886 model: std::env::var("ZED_ZETA2_MODEL").unwrap_or("yqvev8r3".to_string()),
887 messages: vec![open_ai::RequestMessage::User {
888 content: open_ai::MessageContent::Plain(prompt),
889 }],
890 stream: false,
891 max_completion_tokens: None,
892 stop: Default::default(),
893 temperature: 0.7,
894 tool_choice: None,
895 parallel_tool_calls: None,
896 tools: vec![],
897 prompt_cache_key: None,
898 reasoning_effort: None,
899 };
900
901 log::trace!("Sending edit prediction request");
902
903 let before_request = chrono::Utc::now();
904 let response =
905 Self::send_raw_llm_request(client, llm_token, app_version, request).await;
906 let request_time = chrono::Utc::now() - before_request;
907
908 log::trace!("Got edit prediction response");
909
910 if let Some(debug_response_tx) = debug_response_tx {
911 debug_response_tx
912 .send((
913 response
914 .as_ref()
915 .map_err(|err| err.to_string())
916 .map(|response| response.0.clone()),
917 request_time,
918 ))
919 .ok();
920 }
921
922 let (res, usage) = response?;
923 let request_id = EditPredictionId(Uuid::from_str(&res.id)?);
924 let Some(output_text) = text_from_response(res) else {
925 return Ok((None, usage))
926 };
927
928 let (edited_buffer_snapshot, edits) =
929 crate::udiff::parse_diff(&output_text, |path| {
930 included_files
931 .iter()
932 .find_map(|(_, buffer, probe_path, ranges)| {
933 if probe_path.as_ref() == path {
934 Some((buffer, ranges.as_slice()))
935 } else {
936 None
937 }
938 })
939 })
940 .await?;
941
942 let edited_buffer = included_files
943 .iter()
944 .find_map(|(buffer, snapshot, _, _)| {
945 if snapshot.remote_id() == edited_buffer_snapshot.remote_id() {
946 Some(buffer.clone())
947 } else {
948 None
949 }
950 })
951 .context("Failed to find buffer in included_buffers, even though we just found the snapshot")?;
952
953 anyhow::Ok((Some((request_id, edited_buffer, edited_buffer_snapshot.clone(), edits)), usage))
954 }
955 });
956
957 cx.spawn({
958 async move |this, cx| {
959 let Some((id, edited_buffer, edited_buffer_snapshot, edits)) =
960 Self::handle_api_response(&this, request_task.await, cx)?
961 else {
962 return Ok(None);
963 };
964
965 // TODO telemetry: duration, etc
966 Ok(
967 EditPrediction::new(id, &edited_buffer, &edited_buffer_snapshot, edits, cx)
968 .await,
969 )
970 }
971 })
972 }
973
974 async fn send_raw_llm_request(
975 client: Arc<Client>,
976 llm_token: LlmApiToken,
977 app_version: SemanticVersion,
978 request: open_ai::Request,
979 ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
980 let url = if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
981 http_client::Url::parse(&predict_edits_url)?
982 } else {
983 client
984 .http_client()
985 .build_zed_llm_url("/predict_edits/raw", &[])?
986 };
987
988 Self::send_api_request(
989 |builder| {
990 let req = builder
991 .uri(url.as_ref())
992 .body(serde_json::to_string(&request)?.into());
993 Ok(req?)
994 },
995 client,
996 llm_token,
997 app_version,
998 )
999 .await
1000 }
1001
1002 fn handle_api_response<T>(
1003 this: &WeakEntity<Self>,
1004 response: Result<(T, Option<EditPredictionUsage>)>,
1005 cx: &mut gpui::AsyncApp,
1006 ) -> Result<T> {
1007 match response {
1008 Ok((data, usage)) => {
1009 if let Some(usage) = usage {
1010 this.update(cx, |this, cx| {
1011 this.user_store.update(cx, |user_store, cx| {
1012 user_store.update_edit_prediction_usage(usage, cx);
1013 });
1014 })
1015 .ok();
1016 }
1017 Ok(data)
1018 }
1019 Err(err) => {
1020 if err.is::<ZedUpdateRequiredError>() {
1021 cx.update(|cx| {
1022 this.update(cx, |this, _cx| {
1023 this.update_required = true;
1024 })
1025 .ok();
1026
1027 let error_message: SharedString = err.to_string().into();
1028 show_app_notification(
1029 NotificationId::unique::<ZedUpdateRequiredError>(),
1030 cx,
1031 move |cx| {
1032 cx.new(|cx| {
1033 ErrorMessagePrompt::new(error_message.clone(), cx)
1034 .with_link_button("Update Zed", "https://zed.dev/releases")
1035 })
1036 },
1037 );
1038 })
1039 .ok();
1040 }
1041 Err(err)
1042 }
1043 }
1044 }
1045
1046 async fn send_api_request<Res>(
1047 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1048 client: Arc<Client>,
1049 llm_token: LlmApiToken,
1050 app_version: SemanticVersion,
1051 ) -> Result<(Res, Option<EditPredictionUsage>)>
1052 where
1053 Res: DeserializeOwned,
1054 {
1055 let http_client = client.http_client();
1056 let mut token = llm_token.acquire(&client).await?;
1057 let mut did_retry = false;
1058
1059 loop {
1060 let request_builder = http_client::Request::builder().method(Method::POST);
1061
1062 let request = build(
1063 request_builder
1064 .header("Content-Type", "application/json")
1065 .header("Authorization", format!("Bearer {}", token))
1066 .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
1067 )?;
1068
1069 let mut response = http_client.send(request).await?;
1070
1071 if let Some(minimum_required_version) = response
1072 .headers()
1073 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
1074 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
1075 {
1076 anyhow::ensure!(
1077 app_version >= minimum_required_version,
1078 ZedUpdateRequiredError {
1079 minimum_version: minimum_required_version
1080 }
1081 );
1082 }
1083
1084 if response.status().is_success() {
1085 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
1086
1087 let mut body = Vec::new();
1088 response.body_mut().read_to_end(&mut body).await?;
1089 return Ok((serde_json::from_slice(&body)?, usage));
1090 } else if !did_retry
1091 && response
1092 .headers()
1093 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1094 .is_some()
1095 {
1096 did_retry = true;
1097 token = llm_token.refresh(&client).await?;
1098 } else {
1099 let mut body = String::new();
1100 response.body_mut().read_to_string(&mut body).await?;
1101 anyhow::bail!(
1102 "Request failed with status: {:?}\nBody: {}",
1103 response.status(),
1104 body
1105 );
1106 }
1107 }
1108 }
1109
1110 pub const CONTEXT_RETRIEVAL_IDLE_DURATION: Duration = Duration::from_secs(10);
1111 pub const CONTEXT_RETRIEVAL_DEBOUNCE_DURATION: Duration = Duration::from_secs(3);
1112
1113 // Refresh the related excerpts when the user just beguns editing after
1114 // an idle period, and after they pause editing.
1115 fn refresh_context_if_needed(
1116 &mut self,
1117 project: &Entity<Project>,
1118 buffer: &Entity<language::Buffer>,
1119 cursor_position: language::Anchor,
1120 cx: &mut Context<Self>,
1121 ) {
1122 if !matches!(&self.options().context, ContextMode::Agentic { .. }) {
1123 return;
1124 }
1125
1126 let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1127 return;
1128 };
1129
1130 let now = Instant::now();
1131 let was_idle = zeta_project
1132 .refresh_context_timestamp
1133 .map_or(true, |timestamp| {
1134 now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION
1135 });
1136 zeta_project.refresh_context_timestamp = Some(now);
1137 zeta_project.refresh_context_debounce_task = Some(cx.spawn({
1138 let buffer = buffer.clone();
1139 let project = project.clone();
1140 async move |this, cx| {
1141 if was_idle {
1142 log::debug!("refetching edit prediction context after idle");
1143 } else {
1144 cx.background_executor()
1145 .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION)
1146 .await;
1147 log::debug!("refetching edit prediction context after pause");
1148 }
1149 this.update(cx, |this, cx| {
1150 let task = this.refresh_context(project.clone(), buffer, cursor_position, cx);
1151
1152 if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
1153 zeta_project.refresh_context_task = Some(task.log_err());
1154 };
1155 })
1156 .ok()
1157 }
1158 }));
1159 }
1160
1161 // Refresh the related excerpts asynchronously. Ensure the task runs to completion,
1162 // and avoid spawning more than one concurrent task.
1163 pub fn refresh_context(
1164 &mut self,
1165 project: Entity<Project>,
1166 buffer: Entity<language::Buffer>,
1167 cursor_position: language::Anchor,
1168 cx: &mut Context<Self>,
1169 ) -> Task<Result<()>> {
1170 let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
1171 return Task::ready(anyhow::Ok(()));
1172 };
1173
1174 let ContextMode::Agentic(options) = &self.options().context else {
1175 return Task::ready(anyhow::Ok(()));
1176 };
1177
1178 let snapshot = buffer.read(cx).snapshot();
1179 let cursor_point = cursor_position.to_point(&snapshot);
1180 let Some(cursor_excerpt) = EditPredictionExcerpt::select_from_buffer(
1181 cursor_point,
1182 &snapshot,
1183 &options.excerpt,
1184 None,
1185 ) else {
1186 return Task::ready(Ok(()));
1187 };
1188
1189 let app_version = AppVersion::global(cx);
1190 let client = self.client.clone();
1191 let llm_token = self.llm_token.clone();
1192 let debug_tx = self.debug_tx.clone();
1193 let current_file_path: Arc<Path> = snapshot
1194 .file()
1195 .map(|f| f.full_path(cx).into())
1196 .unwrap_or_else(|| Path::new("untitled").into());
1197
1198 let prompt = match cloud_zeta2_prompt::retrieval_prompt::build_prompt(
1199 predict_edits_v3::PlanContextRetrievalRequest {
1200 excerpt: cursor_excerpt.text(&snapshot).body,
1201 excerpt_path: current_file_path,
1202 excerpt_line_range: cursor_excerpt.line_range,
1203 cursor_file_max_row: Line(snapshot.max_point().row),
1204 events: zeta_project
1205 .events
1206 .iter()
1207 .filter_map(|ev| ev.to_request_event(cx))
1208 .collect(),
1209 },
1210 ) {
1211 Ok(prompt) => prompt,
1212 Err(err) => {
1213 return Task::ready(Err(err));
1214 }
1215 };
1216
1217 if let Some(debug_tx) = &debug_tx {
1218 debug_tx
1219 .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted(
1220 ZetaContextRetrievalStartedDebugInfo {
1221 project: project.clone(),
1222 timestamp: Instant::now(),
1223 search_prompt: prompt.clone(),
1224 },
1225 ))
1226 .ok();
1227 }
1228
1229 let (tool_schema, tool_description) = &*cloud_zeta2_prompt::retrieval_prompt::TOOL_SCHEMA;
1230
1231 let request = open_ai::Request {
1232 model: std::env::var("ZED_ZETA2_MODEL").unwrap_or("2327jz9q".to_string()),
1233 messages: vec![open_ai::RequestMessage::User {
1234 content: open_ai::MessageContent::Plain(prompt),
1235 }],
1236 stream: false,
1237 max_completion_tokens: None,
1238 stop: Default::default(),
1239 temperature: 0.7,
1240 tool_choice: None,
1241 parallel_tool_calls: None,
1242 tools: vec![open_ai::ToolDefinition::Function {
1243 function: FunctionDefinition {
1244 name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME.to_string(),
1245 description: Some(tool_description.clone()),
1246 parameters: Some(tool_schema.clone()),
1247 },
1248 }],
1249 prompt_cache_key: None,
1250 reasoning_effort: None,
1251 };
1252
1253 cx.spawn(async move |this, cx| {
1254 log::trace!("Sending search planning request");
1255 let response =
1256 Self::send_raw_llm_request(client, llm_token, app_version, request).await;
1257 let mut response = Self::handle_api_response(&this, response, cx)?;
1258
1259 log::trace!("Got search planning response");
1260
1261 let choice = response
1262 .choices
1263 .pop()
1264 .context("No choices in retrieval response")?;
1265 let open_ai::RequestMessage::Assistant {
1266 content: _,
1267 tool_calls,
1268 } = choice.message
1269 else {
1270 anyhow::bail!("Retrieval response didn't include an assistant message");
1271 };
1272
1273 let mut regex_by_glob: HashMap<String, String> = HashMap::default();
1274 for tool_call in tool_calls {
1275 let open_ai::ToolCallContent::Function { function } = tool_call.content;
1276 if function.name != cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME {
1277 log::warn!(
1278 "Context retrieval response tried to call an unknown tool: {}",
1279 function.name
1280 );
1281
1282 continue;
1283 }
1284
1285 let input: SearchToolInput = serde_json::from_str(&function.arguments)?;
1286 for query in input.queries {
1287 let regex = regex_by_glob.entry(query.glob).or_default();
1288 if !regex.is_empty() {
1289 regex.push('|');
1290 }
1291 regex.push_str(&query.regex);
1292 }
1293 }
1294
1295 if let Some(debug_tx) = &debug_tx {
1296 debug_tx
1297 .unbounded_send(ZetaDebugInfo::SearchQueriesGenerated(
1298 ZetaSearchQueryDebugInfo {
1299 project: project.clone(),
1300 timestamp: Instant::now(),
1301 regex_by_glob: regex_by_glob.clone(),
1302 },
1303 ))
1304 .ok();
1305 }
1306
1307 log::trace!("Running retrieval search: {regex_by_glob:#?}");
1308
1309 let related_excerpts_result =
1310 retrieval_search::run_retrieval_searches(project.clone(), regex_by_glob, cx).await;
1311
1312 log::trace!("Search queries executed");
1313
1314 if let Some(debug_tx) = &debug_tx {
1315 debug_tx
1316 .unbounded_send(ZetaDebugInfo::SearchQueriesExecuted(
1317 ZetaContextRetrievalDebugInfo {
1318 project: project.clone(),
1319 timestamp: Instant::now(),
1320 },
1321 ))
1322 .ok();
1323 }
1324
1325 this.update(cx, |this, _cx| {
1326 let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
1327 return Ok(());
1328 };
1329 zeta_project.refresh_context_task.take();
1330 if let Some(debug_tx) = &this.debug_tx {
1331 debug_tx
1332 .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished(
1333 ZetaContextRetrievalDebugInfo {
1334 project,
1335 timestamp: Instant::now(),
1336 },
1337 ))
1338 .ok();
1339 }
1340 match related_excerpts_result {
1341 Ok(excerpts) => {
1342 zeta_project.context = Some(excerpts);
1343 Ok(())
1344 }
1345 Err(error) => Err(error),
1346 }
1347 })?
1348 })
1349 }
1350
1351 fn gather_nearby_diagnostics(
1352 cursor_offset: usize,
1353 diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
1354 snapshot: &BufferSnapshot,
1355 max_diagnostics_bytes: usize,
1356 ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
1357 // TODO: Could make this more efficient
1358 let mut diagnostic_groups = Vec::new();
1359 for (language_server_id, diagnostics) in diagnostic_sets {
1360 let mut groups = Vec::new();
1361 diagnostics.groups(*language_server_id, &mut groups, &snapshot);
1362 diagnostic_groups.extend(
1363 groups
1364 .into_iter()
1365 .map(|(_, group)| group.resolve::<usize>(&snapshot)),
1366 );
1367 }
1368
1369 // sort by proximity to cursor
1370 diagnostic_groups.sort_by_key(|group| {
1371 let range = &group.entries[group.primary_ix].range;
1372 if range.start >= cursor_offset {
1373 range.start - cursor_offset
1374 } else if cursor_offset >= range.end {
1375 cursor_offset - range.end
1376 } else {
1377 (cursor_offset - range.start).min(range.end - cursor_offset)
1378 }
1379 });
1380
1381 let mut results = Vec::new();
1382 let mut diagnostic_groups_truncated = false;
1383 let mut diagnostics_byte_count = 0;
1384 for group in diagnostic_groups {
1385 let raw_value = serde_json::value::to_raw_value(&group).unwrap();
1386 diagnostics_byte_count += raw_value.get().len();
1387 if diagnostics_byte_count > max_diagnostics_bytes {
1388 diagnostic_groups_truncated = true;
1389 break;
1390 }
1391 results.push(predict_edits_v3::DiagnosticGroup(raw_value));
1392 }
1393
1394 (results, diagnostic_groups_truncated)
1395 }
1396
1397 // TODO: Dedupe with similar code in request_prediction?
1398 pub fn cloud_request_for_zeta_cli(
1399 &mut self,
1400 project: &Entity<Project>,
1401 buffer: &Entity<Buffer>,
1402 position: language::Anchor,
1403 cx: &mut Context<Self>,
1404 ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
1405 let project_state = self.projects.get(&project.entity_id());
1406
1407 let index_state = project_state.map(|state| {
1408 state
1409 .syntax_index
1410 .read_with(cx, |index, _cx| index.state().clone())
1411 });
1412 let options = self.options.clone();
1413 let snapshot = buffer.read(cx).snapshot();
1414 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
1415 return Task::ready(Err(anyhow!("No file path for excerpt")));
1416 };
1417 let worktree_snapshots = project
1418 .read(cx)
1419 .worktrees(cx)
1420 .map(|worktree| worktree.read(cx).snapshot())
1421 .collect::<Vec<_>>();
1422
1423 let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
1424 let mut path = f.worktree.read(cx).absolutize(&f.path);
1425 if path.pop() { Some(path) } else { None }
1426 });
1427
1428 cx.background_spawn(async move {
1429 let index_state = if let Some(index_state) = index_state {
1430 Some(index_state.lock_owned().await)
1431 } else {
1432 None
1433 };
1434
1435 let cursor_point = position.to_point(&snapshot);
1436
1437 let debug_info = true;
1438 EditPredictionContext::gather_context(
1439 cursor_point,
1440 &snapshot,
1441 parent_abs_path.as_deref(),
1442 match &options.context {
1443 ContextMode::Agentic(_) => {
1444 // TODO
1445 panic!("Llm mode not supported in zeta cli yet");
1446 }
1447 ContextMode::Syntax(edit_prediction_context_options) => {
1448 edit_prediction_context_options
1449 }
1450 },
1451 index_state.as_deref(),
1452 )
1453 .context("Failed to select excerpt")
1454 .map(|context| {
1455 make_syntax_context_cloud_request(
1456 excerpt_path.into(),
1457 context,
1458 // TODO pass everything
1459 Vec::new(),
1460 false,
1461 Vec::new(),
1462 false,
1463 None,
1464 debug_info,
1465 &worktree_snapshots,
1466 index_state.as_deref(),
1467 Some(options.max_prompt_bytes),
1468 options.prompt_format,
1469 )
1470 })
1471 })
1472 }
1473
1474 pub fn wait_for_initial_indexing(
1475 &mut self,
1476 project: &Entity<Project>,
1477 cx: &mut App,
1478 ) -> Task<Result<()>> {
1479 let zeta_project = self.get_or_init_zeta_project(project, cx);
1480 zeta_project
1481 .syntax_index
1482 .read(cx)
1483 .wait_for_initial_file_indexing(cx)
1484 }
1485}
1486
1487pub fn text_from_response(mut res: open_ai::Response) -> Option<String> {
1488 let choice = res.choices.pop()?;
1489 let output_text = match choice.message {
1490 open_ai::RequestMessage::Assistant {
1491 content: Some(open_ai::MessageContent::Plain(content)),
1492 ..
1493 } => content,
1494 open_ai::RequestMessage::Assistant {
1495 content: Some(open_ai::MessageContent::Multipart(mut content)),
1496 ..
1497 } => {
1498 if content.is_empty() {
1499 log::error!("No output from Baseten completion response");
1500 return None;
1501 }
1502
1503 match content.remove(0) {
1504 open_ai::MessagePart::Text { text } => text,
1505 open_ai::MessagePart::Image { .. } => {
1506 log::error!("Expected text, got an image");
1507 return None;
1508 }
1509 }
1510 }
1511 _ => {
1512 log::error!("Invalid response message: {:?}", choice.message);
1513 return None;
1514 }
1515 };
1516 Some(output_text)
1517}
1518
1519#[derive(Error, Debug)]
1520#[error(
1521 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1522)]
1523pub struct ZedUpdateRequiredError {
1524 minimum_version: SemanticVersion,
1525}
1526
1527fn make_syntax_context_cloud_request(
1528 excerpt_path: Arc<Path>,
1529 context: EditPredictionContext,
1530 events: Vec<predict_edits_v3::Event>,
1531 can_collect_data: bool,
1532 diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
1533 diagnostic_groups_truncated: bool,
1534 git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
1535 debug_info: bool,
1536 worktrees: &Vec<worktree::Snapshot>,
1537 index_state: Option<&SyntaxIndexState>,
1538 prompt_max_bytes: Option<usize>,
1539 prompt_format: PromptFormat,
1540) -> predict_edits_v3::PredictEditsRequest {
1541 let mut signatures = Vec::new();
1542 let mut declaration_to_signature_index = HashMap::default();
1543 let mut referenced_declarations = Vec::new();
1544
1545 for snippet in context.declarations {
1546 let project_entry_id = snippet.declaration.project_entry_id();
1547 let Some(path) = worktrees.iter().find_map(|worktree| {
1548 worktree.entry_for_id(project_entry_id).map(|entry| {
1549 let mut full_path = RelPathBuf::new();
1550 full_path.push(worktree.root_name());
1551 full_path.push(&entry.path);
1552 full_path
1553 })
1554 }) else {
1555 continue;
1556 };
1557
1558 let parent_index = index_state.and_then(|index_state| {
1559 snippet.declaration.parent().and_then(|parent| {
1560 add_signature(
1561 parent,
1562 &mut declaration_to_signature_index,
1563 &mut signatures,
1564 index_state,
1565 )
1566 })
1567 });
1568
1569 let (text, text_is_truncated) = snippet.declaration.item_text();
1570 referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
1571 path: path.as_std_path().into(),
1572 text: text.into(),
1573 range: snippet.declaration.item_line_range(),
1574 text_is_truncated,
1575 signature_range: snippet.declaration.signature_range_in_item_text(),
1576 parent_index,
1577 signature_score: snippet.score(DeclarationStyle::Signature),
1578 declaration_score: snippet.score(DeclarationStyle::Declaration),
1579 score_components: snippet.components,
1580 });
1581 }
1582
1583 let excerpt_parent = index_state.and_then(|index_state| {
1584 context
1585 .excerpt
1586 .parent_declarations
1587 .last()
1588 .and_then(|(parent, _)| {
1589 add_signature(
1590 *parent,
1591 &mut declaration_to_signature_index,
1592 &mut signatures,
1593 index_state,
1594 )
1595 })
1596 });
1597
1598 predict_edits_v3::PredictEditsRequest {
1599 excerpt_path,
1600 excerpt: context.excerpt_text.body,
1601 excerpt_line_range: context.excerpt.line_range,
1602 excerpt_range: context.excerpt.range,
1603 cursor_point: predict_edits_v3::Point {
1604 line: predict_edits_v3::Line(context.cursor_point.row),
1605 column: context.cursor_point.column,
1606 },
1607 referenced_declarations,
1608 included_files: vec![],
1609 signatures,
1610 excerpt_parent,
1611 events,
1612 can_collect_data,
1613 diagnostic_groups,
1614 diagnostic_groups_truncated,
1615 git_info,
1616 debug_info,
1617 prompt_max_bytes,
1618 prompt_format,
1619 }
1620}
1621
1622fn add_signature(
1623 declaration_id: DeclarationId,
1624 declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
1625 signatures: &mut Vec<Signature>,
1626 index: &SyntaxIndexState,
1627) -> Option<usize> {
1628 if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
1629 return Some(*signature_index);
1630 }
1631 let Some(parent_declaration) = index.declaration(declaration_id) else {
1632 log::error!("bug: missing parent declaration");
1633 return None;
1634 };
1635 let parent_index = parent_declaration.parent().and_then(|parent| {
1636 add_signature(parent, declaration_to_signature_index, signatures, index)
1637 });
1638 let (text, text_is_truncated) = parent_declaration.signature_text();
1639 let signature_index = signatures.len();
1640 signatures.push(Signature {
1641 text: text.into(),
1642 text_is_truncated,
1643 parent_index,
1644 range: parent_declaration.signature_line_range(),
1645 });
1646 declaration_to_signature_index.insert(declaration_id, signature_index);
1647 Some(signature_index)
1648}
1649
1650#[cfg(test)]
1651mod tests {
1652 use std::{path::Path, sync::Arc};
1653
1654 use client::UserStore;
1655 use clock::FakeSystemClock;
1656 use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
1657 use futures::{
1658 AsyncReadExt, StreamExt,
1659 channel::{mpsc, oneshot},
1660 };
1661 use gpui::{
1662 Entity, TestAppContext,
1663 http_client::{FakeHttpClient, Response},
1664 prelude::*,
1665 };
1666 use indoc::indoc;
1667 use language::OffsetRangeExt as _;
1668 use open_ai::Usage;
1669 use pretty_assertions::{assert_eq, assert_matches};
1670 use project::{FakeFs, Project};
1671 use serde_json::json;
1672 use settings::SettingsStore;
1673 use util::path;
1674 use uuid::Uuid;
1675
1676 use crate::{BufferEditPrediction, Zeta};
1677
1678 #[gpui::test]
1679 async fn test_current_state(cx: &mut TestAppContext) {
1680 let (zeta, mut req_rx) = init_test(cx);
1681 let fs = FakeFs::new(cx.executor());
1682 fs.insert_tree(
1683 "/root",
1684 json!({
1685 "1.txt": "Hello!\nHow\nBye\n",
1686 "2.txt": "Hola!\nComo\nAdios\n"
1687 }),
1688 )
1689 .await;
1690 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1691
1692 zeta.update(cx, |zeta, cx| {
1693 zeta.register_project(&project, cx);
1694 });
1695
1696 let buffer1 = project
1697 .update(cx, |project, cx| {
1698 let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
1699 project.open_buffer(path, cx)
1700 })
1701 .await
1702 .unwrap();
1703 let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
1704 let position = snapshot1.anchor_before(language::Point::new(1, 3));
1705
1706 // Prediction for current file
1707
1708 let prediction_task = zeta.update(cx, |zeta, cx| {
1709 zeta.refresh_prediction(&project, &buffer1, position, cx)
1710 });
1711 let (_request, respond_tx) = req_rx.next().await.unwrap();
1712
1713 respond_tx
1714 .send(model_response(indoc! {r"
1715 --- a/root/1.txt
1716 +++ b/root/1.txt
1717 @@ ... @@
1718 Hello!
1719 -How
1720 +How are you?
1721 Bye
1722 "}))
1723 .unwrap();
1724 prediction_task.await.unwrap();
1725
1726 zeta.read_with(cx, |zeta, cx| {
1727 let prediction = zeta
1728 .current_prediction_for_buffer(&buffer1, &project, cx)
1729 .unwrap();
1730 assert_matches!(prediction, BufferEditPrediction::Local { .. });
1731 });
1732
1733 // Context refresh
1734 let refresh_task = zeta.update(cx, |zeta, cx| {
1735 zeta.refresh_context(project.clone(), buffer1.clone(), position, cx)
1736 });
1737 let (_request, respond_tx) = req_rx.next().await.unwrap();
1738 respond_tx
1739 .send(open_ai::Response {
1740 id: Uuid::new_v4().to_string(),
1741 object: "response".into(),
1742 created: 0,
1743 model: "model".into(),
1744 choices: vec![open_ai::Choice {
1745 index: 0,
1746 message: open_ai::RequestMessage::Assistant {
1747 content: None,
1748 tool_calls: vec![open_ai::ToolCall {
1749 id: "search".into(),
1750 content: open_ai::ToolCallContent::Function {
1751 function: open_ai::FunctionContent {
1752 name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME
1753 .to_string(),
1754 arguments: serde_json::to_string(&SearchToolInput {
1755 queries: Box::new([SearchToolQuery {
1756 glob: "root/2.txt".to_string(),
1757 regex: ".".to_string(),
1758 }]),
1759 })
1760 .unwrap(),
1761 },
1762 },
1763 }],
1764 },
1765 finish_reason: None,
1766 }],
1767 usage: Usage {
1768 prompt_tokens: 0,
1769 completion_tokens: 0,
1770 total_tokens: 0,
1771 },
1772 })
1773 .unwrap();
1774 refresh_task.await.unwrap();
1775
1776 zeta.update(cx, |zeta, _cx| {
1777 zeta.discard_current_prediction(&project);
1778 });
1779
1780 // Prediction for another file
1781 let prediction_task = zeta.update(cx, |zeta, cx| {
1782 zeta.refresh_prediction(&project, &buffer1, position, cx)
1783 });
1784 let (_request, respond_tx) = req_rx.next().await.unwrap();
1785 respond_tx
1786 .send(model_response(indoc! {r#"
1787 --- a/root/2.txt
1788 +++ b/root/2.txt
1789 Hola!
1790 -Como
1791 +Como estas?
1792 Adios
1793 "#}))
1794 .unwrap();
1795 prediction_task.await.unwrap();
1796 zeta.read_with(cx, |zeta, cx| {
1797 let prediction = zeta
1798 .current_prediction_for_buffer(&buffer1, &project, cx)
1799 .unwrap();
1800 assert_matches!(
1801 prediction,
1802 BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
1803 );
1804 });
1805
1806 let buffer2 = project
1807 .update(cx, |project, cx| {
1808 let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
1809 project.open_buffer(path, cx)
1810 })
1811 .await
1812 .unwrap();
1813
1814 zeta.read_with(cx, |zeta, cx| {
1815 let prediction = zeta
1816 .current_prediction_for_buffer(&buffer2, &project, cx)
1817 .unwrap();
1818 assert_matches!(prediction, BufferEditPrediction::Local { .. });
1819 });
1820 }
1821
1822 #[gpui::test]
1823 async fn test_simple_request(cx: &mut TestAppContext) {
1824 let (zeta, mut req_rx) = init_test(cx);
1825 let fs = FakeFs::new(cx.executor());
1826 fs.insert_tree(
1827 "/root",
1828 json!({
1829 "foo.md": "Hello!\nHow\nBye\n"
1830 }),
1831 )
1832 .await;
1833 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1834
1835 let buffer = project
1836 .update(cx, |project, cx| {
1837 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1838 project.open_buffer(path, cx)
1839 })
1840 .await
1841 .unwrap();
1842 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1843 let position = snapshot.anchor_before(language::Point::new(1, 3));
1844
1845 let prediction_task = zeta.update(cx, |zeta, cx| {
1846 zeta.request_prediction(&project, &buffer, position, cx)
1847 });
1848
1849 let (_, respond_tx) = req_rx.next().await.unwrap();
1850
1851 // TODO Put back when we have a structured request again
1852 // assert_eq!(
1853 // request.excerpt_path.as_ref(),
1854 // Path::new(path!("root/foo.md"))
1855 // );
1856 // assert_eq!(
1857 // request.cursor_point,
1858 // Point {
1859 // line: Line(1),
1860 // column: 3
1861 // }
1862 // );
1863
1864 respond_tx
1865 .send(model_response(indoc! { r"
1866 --- a/root/foo.md
1867 +++ b/root/foo.md
1868 @@ ... @@
1869 Hello!
1870 -How
1871 +How are you?
1872 Bye
1873 "}))
1874 .unwrap();
1875
1876 let prediction = prediction_task.await.unwrap().unwrap();
1877
1878 assert_eq!(prediction.edits.len(), 1);
1879 assert_eq!(
1880 prediction.edits[0].0.to_point(&snapshot).start,
1881 language::Point::new(1, 3)
1882 );
1883 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
1884 }
1885
1886 #[gpui::test]
1887 async fn test_request_events(cx: &mut TestAppContext) {
1888 let (zeta, mut req_rx) = init_test(cx);
1889 let fs = FakeFs::new(cx.executor());
1890 fs.insert_tree(
1891 "/root",
1892 json!({
1893 "foo.md": "Hello!\n\nBye\n"
1894 }),
1895 )
1896 .await;
1897 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1898
1899 let buffer = project
1900 .update(cx, |project, cx| {
1901 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1902 project.open_buffer(path, cx)
1903 })
1904 .await
1905 .unwrap();
1906
1907 zeta.update(cx, |zeta, cx| {
1908 zeta.register_buffer(&buffer, &project, cx);
1909 });
1910
1911 buffer.update(cx, |buffer, cx| {
1912 buffer.edit(vec![(7..7, "How")], None, cx);
1913 });
1914
1915 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1916 let position = snapshot.anchor_before(language::Point::new(1, 3));
1917
1918 let prediction_task = zeta.update(cx, |zeta, cx| {
1919 zeta.request_prediction(&project, &buffer, position, cx)
1920 });
1921
1922 let (request, respond_tx) = req_rx.next().await.unwrap();
1923
1924 let prompt = prompt_from_request(&request);
1925 assert!(
1926 prompt.contains(indoc! {"
1927 --- a/root/foo.md
1928 +++ b/root/foo.md
1929 @@ -1,3 +1,3 @@
1930 Hello!
1931 -
1932 +How
1933 Bye
1934 "}),
1935 "{prompt}"
1936 );
1937
1938 respond_tx
1939 .send(model_response(indoc! {r#"
1940 --- a/root/foo.md
1941 +++ b/root/foo.md
1942 @@ ... @@
1943 Hello!
1944 -How
1945 +How are you?
1946 Bye
1947 "#}))
1948 .unwrap();
1949
1950 let prediction = prediction_task.await.unwrap().unwrap();
1951
1952 assert_eq!(prediction.edits.len(), 1);
1953 assert_eq!(
1954 prediction.edits[0].0.to_point(&snapshot).start,
1955 language::Point::new(1, 3)
1956 );
1957 assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
1958 }
1959
1960 // Skipped until we start including diagnostics in prompt
1961 // #[gpui::test]
1962 // async fn test_request_diagnostics(cx: &mut TestAppContext) {
1963 // let (zeta, mut req_rx) = init_test(cx);
1964 // let fs = FakeFs::new(cx.executor());
1965 // fs.insert_tree(
1966 // "/root",
1967 // json!({
1968 // "foo.md": "Hello!\nBye"
1969 // }),
1970 // )
1971 // .await;
1972 // let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1973
1974 // let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1975 // let diagnostic = lsp::Diagnostic {
1976 // range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1977 // severity: Some(lsp::DiagnosticSeverity::ERROR),
1978 // message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1979 // ..Default::default()
1980 // };
1981
1982 // project.update(cx, |project, cx| {
1983 // project.lsp_store().update(cx, |lsp_store, cx| {
1984 // // Create some diagnostics
1985 // lsp_store
1986 // .update_diagnostics(
1987 // LanguageServerId(0),
1988 // lsp::PublishDiagnosticsParams {
1989 // uri: path_to_buffer_uri.clone(),
1990 // diagnostics: vec![diagnostic],
1991 // version: None,
1992 // },
1993 // None,
1994 // language::DiagnosticSourceKind::Pushed,
1995 // &[],
1996 // cx,
1997 // )
1998 // .unwrap();
1999 // });
2000 // });
2001
2002 // let buffer = project
2003 // .update(cx, |project, cx| {
2004 // let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2005 // project.open_buffer(path, cx)
2006 // })
2007 // .await
2008 // .unwrap();
2009
2010 // let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2011 // let position = snapshot.anchor_before(language::Point::new(0, 0));
2012
2013 // let _prediction_task = zeta.update(cx, |zeta, cx| {
2014 // zeta.request_prediction(&project, &buffer, position, cx)
2015 // });
2016
2017 // let (request, _respond_tx) = req_rx.next().await.unwrap();
2018
2019 // assert_eq!(request.diagnostic_groups.len(), 1);
2020 // let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
2021 // .unwrap();
2022 // // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
2023 // assert_eq!(
2024 // value,
2025 // json!({
2026 // "entries": [{
2027 // "range": {
2028 // "start": 8,
2029 // "end": 10
2030 // },
2031 // "diagnostic": {
2032 // "source": null,
2033 // "code": null,
2034 // "code_description": null,
2035 // "severity": 1,
2036 // "message": "\"Hello\" deprecated. Use \"Hi\" instead",
2037 // "markdown": null,
2038 // "group_id": 0,
2039 // "is_primary": true,
2040 // "is_disk_based": false,
2041 // "is_unnecessary": false,
2042 // "source_kind": "Pushed",
2043 // "data": null,
2044 // "underline": true
2045 // }
2046 // }],
2047 // "primary_ix": 0
2048 // })
2049 // );
2050 // }
2051
2052 fn model_response(text: &str) -> open_ai::Response {
2053 open_ai::Response {
2054 id: Uuid::new_v4().to_string(),
2055 object: "response".into(),
2056 created: 0,
2057 model: "model".into(),
2058 choices: vec![open_ai::Choice {
2059 index: 0,
2060 message: open_ai::RequestMessage::Assistant {
2061 content: Some(open_ai::MessageContent::Plain(text.to_string())),
2062 tool_calls: vec![],
2063 },
2064 finish_reason: None,
2065 }],
2066 usage: Usage {
2067 prompt_tokens: 0,
2068 completion_tokens: 0,
2069 total_tokens: 0,
2070 },
2071 }
2072 }
2073
2074 fn prompt_from_request(request: &open_ai::Request) -> &str {
2075 assert_eq!(request.messages.len(), 1);
2076 let open_ai::RequestMessage::User {
2077 content: open_ai::MessageContent::Plain(content),
2078 ..
2079 } = &request.messages[0]
2080 else {
2081 panic!(
2082 "Request does not have single user message of type Plain. {:#?}",
2083 request
2084 );
2085 };
2086 content
2087 }
2088
2089 fn init_test(
2090 cx: &mut TestAppContext,
2091 ) -> (
2092 Entity<Zeta>,
2093 mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender<open_ai::Response>)>,
2094 ) {
2095 cx.update(move |cx| {
2096 let settings_store = SettingsStore::test(cx);
2097 cx.set_global(settings_store);
2098 zlog::init_test();
2099
2100 let (req_tx, req_rx) = mpsc::unbounded();
2101
2102 let http_client = FakeHttpClient::create({
2103 move |req| {
2104 let uri = req.uri().path().to_string();
2105 let mut body = req.into_body();
2106 let req_tx = req_tx.clone();
2107 async move {
2108 let resp = match uri.as_str() {
2109 "/client/llm_tokens" => serde_json::to_string(&json!({
2110 "token": "test"
2111 }))
2112 .unwrap(),
2113 "/predict_edits/raw" => {
2114 let mut buf = Vec::new();
2115 body.read_to_end(&mut buf).await.ok();
2116 let req = serde_json::from_slice(&buf).unwrap();
2117
2118 let (res_tx, res_rx) = oneshot::channel();
2119 req_tx.unbounded_send((req, res_tx)).unwrap();
2120 serde_json::to_string(&res_rx.await?).unwrap()
2121 }
2122 _ => {
2123 panic!("Unexpected path: {}", uri)
2124 }
2125 };
2126
2127 Ok(Response::builder().body(resp.into()).unwrap())
2128 }
2129 }
2130 });
2131
2132 let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
2133 client.cloud_client().set_credentials(1, "test".into());
2134
2135 language_model::init(client.clone(), cx);
2136
2137 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2138 let zeta = Zeta::global(&client, &user_store, cx);
2139
2140 (zeta, req_rx)
2141 })
2142 }
2143}