1use anyhow::{Context as _, Result, anyhow};
2use chrono::TimeDelta;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
5use cloud_llm_client::{
6 AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
7 ZED_VERSION_HEADER_NAME,
8};
9use cloud_zeta2_prompt::{DEFAULT_MAX_PROMPT_BYTES, build_prompt};
10use collections::HashMap;
11use edit_prediction_context::{
12 DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
13 EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionScoreOptions, Line,
14 SyntaxIndex, SyntaxIndexState,
15};
16use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
17use futures::AsyncReadExt as _;
18use futures::channel::{mpsc, oneshot};
19use gpui::http_client::{AsyncBody, Method};
20use gpui::{
21 App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
22 http_client, prelude::*,
23};
24use language::{Anchor, Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
25use language::{BufferSnapshot, OffsetRangeExt};
26use language_model::{LlmApiToken, RefreshLlmTokenListener};
27use project::Project;
28use release_channel::AppVersion;
29use serde::de::DeserializeOwned;
30use std::collections::{VecDeque, hash_map};
31use std::fmt::Write;
32use std::ops::Range;
33use std::path::Path;
34use std::str::FromStr as _;
35use std::sync::Arc;
36use std::time::{Duration, Instant};
37use thiserror::Error;
38use util::ResultExt as _;
39use util::rel_path::RelPathBuf;
40use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
41
42pub mod merge_excerpts;
43mod prediction;
44mod provider;
45pub mod related_excerpts;
46
47use crate::merge_excerpts::merge_excerpts;
48use crate::prediction::EditPrediction;
49use crate::related_excerpts::find_related_excerpts;
50pub use crate::related_excerpts::{LlmContextOptions, SearchToolQuery};
51pub use provider::ZetaEditPredictionProvider;
52
53/// Maximum number of events to track.
54const MAX_EVENT_COUNT: usize = 16;
55
56pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
57 max_bytes: 512,
58 min_bytes: 128,
59 target_before_cursor_over_total_bytes: 0.5,
60};
61
62pub const DEFAULT_CONTEXT_OPTIONS: ContextMode = ContextMode::Llm(DEFAULT_LLM_CONTEXT_OPTIONS);
63
64pub const DEFAULT_LLM_CONTEXT_OPTIONS: LlmContextOptions = LlmContextOptions {
65 excerpt: DEFAULT_EXCERPT_OPTIONS,
66};
67
68pub const DEFAULT_SYNTAX_CONTEXT_OPTIONS: EditPredictionContextOptions =
69 EditPredictionContextOptions {
70 use_imports: true,
71 max_retrieved_declarations: 0,
72 excerpt: DEFAULT_EXCERPT_OPTIONS,
73 score: EditPredictionScoreOptions {
74 omit_excerpt_overlaps: true,
75 },
76 };
77
78pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
79 context: DEFAULT_CONTEXT_OPTIONS,
80 max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
81 max_diagnostic_bytes: 2048,
82 prompt_format: PromptFormat::DEFAULT,
83 file_indexing_parallelism: 1,
84 buffer_change_grouping_interval: Duration::from_secs(1),
85};
86
87pub struct Zeta2FeatureFlag;
88
89impl FeatureFlag for Zeta2FeatureFlag {
90 const NAME: &'static str = "zeta2";
91
92 fn enabled_for_staff() -> bool {
93 false
94 }
95}
96
97#[derive(Clone)]
98struct ZetaGlobal(Entity<Zeta>);
99
100impl Global for ZetaGlobal {}
101
102pub struct Zeta {
103 client: Arc<Client>,
104 user_store: Entity<UserStore>,
105 llm_token: LlmApiToken,
106 _llm_token_subscription: Subscription,
107 projects: HashMap<EntityId, ZetaProject>,
108 options: ZetaOptions,
109 update_required: bool,
110 debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
111}
112
113#[derive(Debug, Clone, PartialEq)]
114pub struct ZetaOptions {
115 pub context: ContextMode,
116 pub max_prompt_bytes: usize,
117 pub max_diagnostic_bytes: usize,
118 pub prompt_format: predict_edits_v3::PromptFormat,
119 pub file_indexing_parallelism: usize,
120 pub buffer_change_grouping_interval: Duration,
121}
122
123#[derive(Debug, Clone, PartialEq)]
124pub enum ContextMode {
125 Llm(LlmContextOptions),
126 Syntax(EditPredictionContextOptions),
127}
128
129impl ContextMode {
130 pub fn excerpt(&self) -> &EditPredictionExcerptOptions {
131 match self {
132 ContextMode::Llm(options) => &options.excerpt,
133 ContextMode::Syntax(options) => &options.excerpt,
134 }
135 }
136}
137
138pub enum ZetaDebugInfo {
139 ContextRetrievalStarted(ZetaContextRetrievalStartedDebugInfo),
140 SearchQueriesGenerated(ZetaSearchQueryDebugInfo),
141 SearchQueriesExecuted(ZetaContextRetrievalDebugInfo),
142 SearchResultsFiltered(ZetaContextRetrievalDebugInfo),
143 ContextRetrievalFinished(ZetaContextRetrievalDebugInfo),
144 EditPredicted(ZetaEditPredictionDebugInfo),
145}
146
147pub struct ZetaContextRetrievalStartedDebugInfo {
148 pub project: Entity<Project>,
149 pub timestamp: Instant,
150 pub search_prompt: String,
151}
152
153pub struct ZetaContextRetrievalDebugInfo {
154 pub project: Entity<Project>,
155 pub timestamp: Instant,
156}
157
158pub struct ZetaEditPredictionDebugInfo {
159 pub request: predict_edits_v3::PredictEditsRequest,
160 pub retrieval_time: TimeDelta,
161 pub buffer: WeakEntity<Buffer>,
162 pub position: language::Anchor,
163 pub local_prompt: Result<String, String>,
164 pub response_rx: oneshot::Receiver<Result<predict_edits_v3::PredictEditsResponse, String>>,
165}
166
167pub struct ZetaSearchQueryDebugInfo {
168 pub project: Entity<Project>,
169 pub timestamp: Instant,
170 pub queries: Vec<SearchToolQuery>,
171}
172
173pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
174
175struct ZetaProject {
176 syntax_index: Entity<SyntaxIndex>,
177 events: VecDeque<Event>,
178 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
179 current_prediction: Option<CurrentEditPrediction>,
180 context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
181 refresh_context_task: Option<Task<Option<()>>>,
182 refresh_context_debounce_task: Option<Task<Option<()>>>,
183 refresh_context_timestamp: Option<Instant>,
184}
185
186#[derive(Debug, Clone)]
187struct CurrentEditPrediction {
188 pub requested_by_buffer_id: EntityId,
189 pub prediction: EditPrediction,
190}
191
192impl CurrentEditPrediction {
193 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
194 let Some(new_edits) = self
195 .prediction
196 .interpolate(&self.prediction.buffer.read(cx))
197 else {
198 return false;
199 };
200
201 if self.prediction.buffer != old_prediction.prediction.buffer {
202 return true;
203 }
204
205 let Some(old_edits) = old_prediction
206 .prediction
207 .interpolate(&old_prediction.prediction.buffer.read(cx))
208 else {
209 return true;
210 };
211
212 // This reduces the occurrence of UI thrash from replacing edits
213 //
214 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
215 if self.requested_by_buffer_id == self.prediction.buffer.entity_id()
216 && self.requested_by_buffer_id == old_prediction.prediction.buffer.entity_id()
217 && old_edits.len() == 1
218 && new_edits.len() == 1
219 {
220 let (old_range, old_text) = &old_edits[0];
221 let (new_range, new_text) = &new_edits[0];
222 new_range == old_range && new_text.starts_with(old_text)
223 } else {
224 true
225 }
226 }
227}
228
229/// A prediction from the perspective of a buffer.
230#[derive(Debug)]
231enum BufferEditPrediction<'a> {
232 Local { prediction: &'a EditPrediction },
233 Jump { prediction: &'a EditPrediction },
234}
235
236struct RegisteredBuffer {
237 snapshot: BufferSnapshot,
238 _subscriptions: [gpui::Subscription; 2],
239}
240
241#[derive(Clone)]
242pub enum Event {
243 BufferChange {
244 old_snapshot: BufferSnapshot,
245 new_snapshot: BufferSnapshot,
246 timestamp: Instant,
247 },
248}
249
250impl Event {
251 pub fn to_request_event(&self, cx: &App) -> Option<predict_edits_v3::Event> {
252 match self {
253 Event::BufferChange {
254 old_snapshot,
255 new_snapshot,
256 ..
257 } => {
258 let path = new_snapshot.file().map(|f| f.full_path(cx));
259
260 let old_path = old_snapshot.file().and_then(|f| {
261 let old_path = f.full_path(cx);
262 if Some(&old_path) != path.as_ref() {
263 Some(old_path)
264 } else {
265 None
266 }
267 });
268
269 // TODO [zeta2] move to bg?
270 let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
271
272 if path == old_path && diff.is_empty() {
273 None
274 } else {
275 Some(predict_edits_v3::Event::BufferChange {
276 old_path,
277 path,
278 diff,
279 //todo: Actually detect if this edit was predicted or not
280 predicted: false,
281 })
282 }
283 }
284 }
285 }
286}
287
288impl Zeta {
289 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
290 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
291 }
292
293 pub fn global(
294 client: &Arc<Client>,
295 user_store: &Entity<UserStore>,
296 cx: &mut App,
297 ) -> Entity<Self> {
298 cx.try_global::<ZetaGlobal>()
299 .map(|global| global.0.clone())
300 .unwrap_or_else(|| {
301 let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
302 cx.set_global(ZetaGlobal(zeta.clone()));
303 zeta
304 })
305 }
306
307 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
308 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
309
310 Self {
311 projects: HashMap::default(),
312 client,
313 user_store,
314 options: DEFAULT_OPTIONS,
315 llm_token: LlmApiToken::default(),
316 _llm_token_subscription: cx.subscribe(
317 &refresh_llm_token_listener,
318 |this, _listener, _event, cx| {
319 let client = this.client.clone();
320 let llm_token = this.llm_token.clone();
321 cx.spawn(async move |_this, _cx| {
322 llm_token.refresh(&client).await?;
323 anyhow::Ok(())
324 })
325 .detach_and_log_err(cx);
326 },
327 ),
328 update_required: false,
329 debug_tx: None,
330 }
331 }
332
333 pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<ZetaDebugInfo> {
334 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
335 self.debug_tx = Some(debug_watch_tx);
336 debug_watch_rx
337 }
338
339 pub fn options(&self) -> &ZetaOptions {
340 &self.options
341 }
342
343 pub fn set_options(&mut self, options: ZetaOptions) {
344 self.options = options;
345 }
346
347 pub fn clear_history(&mut self) {
348 for zeta_project in self.projects.values_mut() {
349 zeta_project.events.clear();
350 }
351 }
352
353 pub fn history_for_project(&self, project: &Entity<Project>) -> impl Iterator<Item = &Event> {
354 self.projects
355 .get(&project.entity_id())
356 .map(|project| project.events.iter())
357 .into_iter()
358 .flatten()
359 }
360
361 pub fn context_for_project(
362 &self,
363 project: &Entity<Project>,
364 ) -> impl Iterator<Item = (Entity<Buffer>, &[Range<Anchor>])> {
365 self.projects
366 .get(&project.entity_id())
367 .and_then(|project| {
368 Some(
369 project
370 .context
371 .as_ref()?
372 .iter()
373 .map(|(buffer, ranges)| (buffer.clone(), ranges.as_slice())),
374 )
375 })
376 .into_iter()
377 .flatten()
378 }
379
380 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
381 self.user_store.read(cx).edit_prediction_usage()
382 }
383
384 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
385 self.get_or_init_zeta_project(project, cx);
386 }
387
388 pub fn register_buffer(
389 &mut self,
390 buffer: &Entity<Buffer>,
391 project: &Entity<Project>,
392 cx: &mut Context<Self>,
393 ) {
394 let zeta_project = self.get_or_init_zeta_project(project, cx);
395 Self::register_buffer_impl(zeta_project, buffer, project, cx);
396 }
397
398 fn get_or_init_zeta_project(
399 &mut self,
400 project: &Entity<Project>,
401 cx: &mut App,
402 ) -> &mut ZetaProject {
403 self.projects
404 .entry(project.entity_id())
405 .or_insert_with(|| ZetaProject {
406 syntax_index: cx.new(|cx| {
407 SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
408 }),
409 events: VecDeque::new(),
410 registered_buffers: HashMap::default(),
411 current_prediction: None,
412 context: None,
413 refresh_context_task: None,
414 refresh_context_debounce_task: None,
415 refresh_context_timestamp: None,
416 })
417 }
418
419 fn register_buffer_impl<'a>(
420 zeta_project: &'a mut ZetaProject,
421 buffer: &Entity<Buffer>,
422 project: &Entity<Project>,
423 cx: &mut Context<Self>,
424 ) -> &'a mut RegisteredBuffer {
425 let buffer_id = buffer.entity_id();
426 match zeta_project.registered_buffers.entry(buffer_id) {
427 hash_map::Entry::Occupied(entry) => entry.into_mut(),
428 hash_map::Entry::Vacant(entry) => {
429 let snapshot = buffer.read(cx).snapshot();
430 let project_entity_id = project.entity_id();
431 entry.insert(RegisteredBuffer {
432 snapshot,
433 _subscriptions: [
434 cx.subscribe(buffer, {
435 let project = project.downgrade();
436 move |this, buffer, event, cx| {
437 if let language::BufferEvent::Edited = event
438 && let Some(project) = project.upgrade()
439 {
440 this.report_changes_for_buffer(&buffer, &project, cx);
441 }
442 }
443 }),
444 cx.observe_release(buffer, move |this, _buffer, _cx| {
445 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
446 else {
447 return;
448 };
449 zeta_project.registered_buffers.remove(&buffer_id);
450 }),
451 ],
452 })
453 }
454 }
455 }
456
457 fn report_changes_for_buffer(
458 &mut self,
459 buffer: &Entity<Buffer>,
460 project: &Entity<Project>,
461 cx: &mut Context<Self>,
462 ) -> BufferSnapshot {
463 let buffer_change_grouping_interval = self.options.buffer_change_grouping_interval;
464 let zeta_project = self.get_or_init_zeta_project(project, cx);
465 let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
466
467 let new_snapshot = buffer.read(cx).snapshot();
468 if new_snapshot.version != registered_buffer.snapshot.version {
469 let old_snapshot =
470 std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
471 Self::push_event(
472 zeta_project,
473 buffer_change_grouping_interval,
474 Event::BufferChange {
475 old_snapshot,
476 new_snapshot: new_snapshot.clone(),
477 timestamp: Instant::now(),
478 },
479 );
480 }
481
482 new_snapshot
483 }
484
485 fn push_event(
486 zeta_project: &mut ZetaProject,
487 buffer_change_grouping_interval: Duration,
488 event: Event,
489 ) {
490 let events = &mut zeta_project.events;
491
492 if buffer_change_grouping_interval > Duration::ZERO
493 && let Some(Event::BufferChange {
494 new_snapshot: last_new_snapshot,
495 timestamp: last_timestamp,
496 ..
497 }) = events.back_mut()
498 {
499 // Coalesce edits for the same buffer when they happen one after the other.
500 let Event::BufferChange {
501 old_snapshot,
502 new_snapshot,
503 timestamp,
504 } = &event;
505
506 if timestamp.duration_since(*last_timestamp) <= buffer_change_grouping_interval
507 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
508 && old_snapshot.version == last_new_snapshot.version
509 {
510 *last_new_snapshot = new_snapshot.clone();
511 *last_timestamp = *timestamp;
512 return;
513 }
514 }
515
516 if events.len() >= MAX_EVENT_COUNT {
517 // These are halved instead of popping to improve prompt caching.
518 events.drain(..MAX_EVENT_COUNT / 2);
519 }
520
521 events.push_back(event);
522 }
523
524 fn current_prediction_for_buffer(
525 &self,
526 buffer: &Entity<Buffer>,
527 project: &Entity<Project>,
528 cx: &App,
529 ) -> Option<BufferEditPrediction<'_>> {
530 let project_state = self.projects.get(&project.entity_id())?;
531
532 let CurrentEditPrediction {
533 requested_by_buffer_id,
534 prediction,
535 } = project_state.current_prediction.as_ref()?;
536
537 if prediction.targets_buffer(buffer.read(cx), cx) {
538 Some(BufferEditPrediction::Local { prediction })
539 } else if *requested_by_buffer_id == buffer.entity_id() {
540 Some(BufferEditPrediction::Jump { prediction })
541 } else {
542 None
543 }
544 }
545
546 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
547 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
548 return;
549 };
550
551 let Some(prediction) = project_state.current_prediction.take() else {
552 return;
553 };
554 let request_id = prediction.prediction.id.into();
555
556 let client = self.client.clone();
557 let llm_token = self.llm_token.clone();
558 let app_version = AppVersion::global(cx);
559 cx.spawn(async move |this, cx| {
560 let url = if let Ok(predict_edits_url) = std::env::var("ZED_ACCEPT_PREDICTION_URL") {
561 http_client::Url::parse(&predict_edits_url)?
562 } else {
563 client
564 .http_client()
565 .build_zed_llm_url("/predict_edits/accept", &[])?
566 };
567
568 let response = cx
569 .background_spawn(Self::send_api_request::<()>(
570 move |builder| {
571 let req = builder.uri(url.as_ref()).body(
572 serde_json::to_string(&AcceptEditPredictionBody { request_id })?.into(),
573 );
574 Ok(req?)
575 },
576 client,
577 llm_token,
578 app_version,
579 ))
580 .await;
581
582 Self::handle_api_response(&this, response, cx)?;
583 anyhow::Ok(())
584 })
585 .detach_and_log_err(cx);
586 }
587
588 fn discard_current_prediction(&mut self, project: &Entity<Project>) {
589 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
590 project_state.current_prediction.take();
591 };
592 }
593
594 pub fn refresh_prediction(
595 &mut self,
596 project: &Entity<Project>,
597 buffer: &Entity<Buffer>,
598 position: language::Anchor,
599 cx: &mut Context<Self>,
600 ) -> Task<Result<()>> {
601 let request_task = self.request_prediction(project, buffer, position, cx);
602 let buffer = buffer.clone();
603 let project = project.clone();
604
605 cx.spawn(async move |this, cx| {
606 if let Some(prediction) = request_task.await? {
607 this.update(cx, |this, cx| {
608 let project_state = this
609 .projects
610 .get_mut(&project.entity_id())
611 .context("Project not found")?;
612
613 let new_prediction = CurrentEditPrediction {
614 requested_by_buffer_id: buffer.entity_id(),
615 prediction: prediction,
616 };
617
618 if project_state
619 .current_prediction
620 .as_ref()
621 .is_none_or(|old_prediction| {
622 new_prediction.should_replace_prediction(&old_prediction, cx)
623 })
624 {
625 project_state.current_prediction = Some(new_prediction);
626 }
627 anyhow::Ok(())
628 })??;
629 }
630 Ok(())
631 })
632 }
633
634 pub fn request_prediction(
635 &mut self,
636 project: &Entity<Project>,
637 buffer: &Entity<Buffer>,
638 position: language::Anchor,
639 cx: &mut Context<Self>,
640 ) -> Task<Result<Option<EditPrediction>>> {
641 let project_state = self.projects.get(&project.entity_id());
642
643 let index_state = project_state.map(|state| {
644 state
645 .syntax_index
646 .read_with(cx, |index, _cx| index.state().clone())
647 });
648 let options = self.options.clone();
649 let snapshot = buffer.read(cx).snapshot();
650 let Some(excerpt_path) = snapshot
651 .file()
652 .map(|path| -> Arc<Path> { path.full_path(cx).into() })
653 else {
654 return Task::ready(Err(anyhow!("No file path for excerpt")));
655 };
656 let client = self.client.clone();
657 let llm_token = self.llm_token.clone();
658 let app_version = AppVersion::global(cx);
659 let worktree_snapshots = project
660 .read(cx)
661 .worktrees(cx)
662 .map(|worktree| worktree.read(cx).snapshot())
663 .collect::<Vec<_>>();
664 let debug_tx = self.debug_tx.clone();
665
666 let events = project_state
667 .map(|state| {
668 state
669 .events
670 .iter()
671 .filter_map(|event| event.to_request_event(cx))
672 .collect::<Vec<_>>()
673 })
674 .unwrap_or_default();
675
676 let diagnostics = snapshot.diagnostic_sets().clone();
677
678 let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
679 let mut path = f.worktree.read(cx).absolutize(&f.path);
680 if path.pop() { Some(path) } else { None }
681 });
682
683 // TODO data collection
684 let can_collect_data = cx.is_staff();
685
686 let mut included_files = project_state
687 .and_then(|project_state| project_state.context.as_ref())
688 .unwrap_or(&HashMap::default())
689 .iter()
690 .filter_map(|(buffer, ranges)| {
691 let buffer = buffer.read(cx);
692 Some((
693 buffer.snapshot(),
694 buffer.file()?.full_path(cx).into(),
695 ranges.clone(),
696 ))
697 })
698 .collect::<Vec<_>>();
699
700 let request_task = cx.background_spawn({
701 let snapshot = snapshot.clone();
702 let buffer = buffer.clone();
703 async move {
704 let index_state = if let Some(index_state) = index_state {
705 Some(index_state.lock_owned().await)
706 } else {
707 None
708 };
709
710 let cursor_offset = position.to_offset(&snapshot);
711 let cursor_point = cursor_offset.to_point(&snapshot);
712
713 let before_retrieval = chrono::Utc::now();
714
715 let (diagnostic_groups, diagnostic_groups_truncated) =
716 Self::gather_nearby_diagnostics(
717 cursor_offset,
718 &diagnostics,
719 &snapshot,
720 options.max_diagnostic_bytes,
721 );
722
723 let request = match options.context {
724 ContextMode::Llm(context_options) => {
725 let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
726 cursor_point,
727 &snapshot,
728 &context_options.excerpt,
729 index_state.as_deref(),
730 ) else {
731 return Ok((None, None));
732 };
733
734 let excerpt_anchor_range = snapshot.anchor_after(excerpt.range.start)
735 ..snapshot.anchor_before(excerpt.range.end);
736
737 if let Some(buffer_ix) = included_files
738 .iter()
739 .position(|(buffer, _, _)| buffer.remote_id() == snapshot.remote_id())
740 {
741 let (buffer, _, ranges) = &mut included_files[buffer_ix];
742 let range_ix = ranges
743 .binary_search_by(|probe| {
744 probe
745 .start
746 .cmp(&excerpt_anchor_range.start, buffer)
747 .then(excerpt_anchor_range.end.cmp(&probe.end, buffer))
748 })
749 .unwrap_or_else(|ix| ix);
750
751 ranges.insert(range_ix, excerpt_anchor_range);
752 let last_ix = included_files.len() - 1;
753 included_files.swap(buffer_ix, last_ix);
754 } else {
755 included_files.push((
756 snapshot,
757 excerpt_path.clone(),
758 vec![excerpt_anchor_range],
759 ));
760 }
761
762 let included_files = included_files
763 .into_iter()
764 .map(|(buffer, path, ranges)| {
765 let excerpts = merge_excerpts(
766 &buffer,
767 ranges.iter().map(|range| {
768 let point_range = range.to_point(&buffer);
769 Line(point_range.start.row)..Line(point_range.end.row)
770 }),
771 );
772 predict_edits_v3::IncludedFile {
773 path,
774 max_row: Line(buffer.max_point().row),
775 excerpts,
776 }
777 })
778 .collect::<Vec<_>>();
779
780 predict_edits_v3::PredictEditsRequest {
781 excerpt_path,
782 excerpt: String::new(),
783 excerpt_line_range: Line(0)..Line(0),
784 excerpt_range: 0..0,
785 cursor_point: predict_edits_v3::Point {
786 line: predict_edits_v3::Line(cursor_point.row),
787 column: cursor_point.column,
788 },
789 included_files,
790 referenced_declarations: vec![],
791 events,
792 can_collect_data,
793 diagnostic_groups,
794 diagnostic_groups_truncated,
795 debug_info: debug_tx.is_some(),
796 prompt_max_bytes: Some(options.max_prompt_bytes),
797 prompt_format: options.prompt_format,
798 // TODO [zeta2]
799 signatures: vec![],
800 excerpt_parent: None,
801 git_info: None,
802 }
803 }
804 ContextMode::Syntax(context_options) => {
805 let Some(context) = EditPredictionContext::gather_context(
806 cursor_point,
807 &snapshot,
808 parent_abs_path.as_deref(),
809 &context_options,
810 index_state.as_deref(),
811 ) else {
812 return Ok((None, None));
813 };
814
815 make_syntax_context_cloud_request(
816 excerpt_path,
817 context,
818 events,
819 can_collect_data,
820 diagnostic_groups,
821 diagnostic_groups_truncated,
822 None,
823 debug_tx.is_some(),
824 &worktree_snapshots,
825 index_state.as_deref(),
826 Some(options.max_prompt_bytes),
827 options.prompt_format,
828 )
829 }
830 };
831
832 let retrieval_time = chrono::Utc::now() - before_retrieval;
833
834 let debug_response_tx = if let Some(debug_tx) = &debug_tx {
835 let (response_tx, response_rx) = oneshot::channel();
836
837 let local_prompt = build_prompt(&request)
838 .map(|(prompt, _)| prompt)
839 .map_err(|err| err.to_string());
840
841 debug_tx
842 .unbounded_send(ZetaDebugInfo::EditPredicted(ZetaEditPredictionDebugInfo {
843 request: request.clone(),
844 retrieval_time,
845 buffer: buffer.downgrade(),
846 local_prompt,
847 position,
848 response_rx,
849 }))
850 .ok();
851 Some(response_tx)
852 } else {
853 None
854 };
855
856 if cfg!(debug_assertions) && std::env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
857 if let Some(debug_response_tx) = debug_response_tx {
858 debug_response_tx
859 .send(Err("Request skipped".to_string()))
860 .ok();
861 }
862 anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
863 }
864
865 let response =
866 Self::send_prediction_request(client, llm_token, app_version, request).await;
867
868 if let Some(debug_response_tx) = debug_response_tx {
869 debug_response_tx
870 .send(
871 response
872 .as_ref()
873 .map_err(|err| err.to_string())
874 .map(|response| response.0.clone()),
875 )
876 .ok();
877 }
878
879 response.map(|(res, usage)| (Some(res), usage))
880 }
881 });
882
883 let buffer = buffer.clone();
884
885 cx.spawn({
886 let project = project.clone();
887 async move |this, cx| {
888 let Some(response) = Self::handle_api_response(&this, request_task.await, cx)?
889 else {
890 return Ok(None);
891 };
892
893 // TODO telemetry: duration, etc
894 Ok(EditPrediction::from_response(response, &snapshot, &buffer, &project, cx).await)
895 }
896 })
897 }
898
899 async fn send_prediction_request(
900 client: Arc<Client>,
901 llm_token: LlmApiToken,
902 app_version: SemanticVersion,
903 request: predict_edits_v3::PredictEditsRequest,
904 ) -> Result<(
905 predict_edits_v3::PredictEditsResponse,
906 Option<EditPredictionUsage>,
907 )> {
908 let url = if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
909 http_client::Url::parse(&predict_edits_url)?
910 } else {
911 client
912 .http_client()
913 .build_zed_llm_url("/predict_edits/v3", &[])?
914 };
915
916 Self::send_api_request(
917 |builder| {
918 let req = builder
919 .uri(url.as_ref())
920 .body(serde_json::to_string(&request)?.into());
921 Ok(req?)
922 },
923 client,
924 llm_token,
925 app_version,
926 )
927 .await
928 }
929
930 fn handle_api_response<T>(
931 this: &WeakEntity<Self>,
932 response: Result<(T, Option<EditPredictionUsage>)>,
933 cx: &mut gpui::AsyncApp,
934 ) -> Result<T> {
935 match response {
936 Ok((data, usage)) => {
937 if let Some(usage) = usage {
938 this.update(cx, |this, cx| {
939 this.user_store.update(cx, |user_store, cx| {
940 user_store.update_edit_prediction_usage(usage, cx);
941 });
942 })
943 .ok();
944 }
945 Ok(data)
946 }
947 Err(err) => {
948 if err.is::<ZedUpdateRequiredError>() {
949 cx.update(|cx| {
950 this.update(cx, |this, _cx| {
951 this.update_required = true;
952 })
953 .ok();
954
955 let error_message: SharedString = err.to_string().into();
956 show_app_notification(
957 NotificationId::unique::<ZedUpdateRequiredError>(),
958 cx,
959 move |cx| {
960 cx.new(|cx| {
961 ErrorMessagePrompt::new(error_message.clone(), cx)
962 .with_link_button("Update Zed", "https://zed.dev/releases")
963 })
964 },
965 );
966 })
967 .ok();
968 }
969 Err(err)
970 }
971 }
972 }
973
974 async fn send_api_request<Res>(
975 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
976 client: Arc<Client>,
977 llm_token: LlmApiToken,
978 app_version: SemanticVersion,
979 ) -> Result<(Res, Option<EditPredictionUsage>)>
980 where
981 Res: DeserializeOwned,
982 {
983 let http_client = client.http_client();
984 let mut token = llm_token.acquire(&client).await?;
985 let mut did_retry = false;
986
987 loop {
988 let request_builder = http_client::Request::builder().method(Method::POST);
989
990 let request = build(
991 request_builder
992 .header("Content-Type", "application/json")
993 .header("Authorization", format!("Bearer {}", token))
994 .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
995 )?;
996
997 let mut response = http_client.send(request).await?;
998
999 if let Some(minimum_required_version) = response
1000 .headers()
1001 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
1002 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
1003 {
1004 anyhow::ensure!(
1005 app_version >= minimum_required_version,
1006 ZedUpdateRequiredError {
1007 minimum_version: minimum_required_version
1008 }
1009 );
1010 }
1011
1012 if response.status().is_success() {
1013 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
1014
1015 let mut body = Vec::new();
1016 response.body_mut().read_to_end(&mut body).await?;
1017 return Ok((serde_json::from_slice(&body)?, usage));
1018 } else if !did_retry
1019 && response
1020 .headers()
1021 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1022 .is_some()
1023 {
1024 did_retry = true;
1025 token = llm_token.refresh(&client).await?;
1026 } else {
1027 let mut body = String::new();
1028 response.body_mut().read_to_string(&mut body).await?;
1029 anyhow::bail!(
1030 "Request failed with status: {:?}\nBody: {}",
1031 response.status(),
1032 body
1033 );
1034 }
1035 }
1036 }
1037
1038 pub const CONTEXT_RETRIEVAL_IDLE_DURATION: Duration = Duration::from_secs(10);
1039 pub const CONTEXT_RETRIEVAL_DEBOUNCE_DURATION: Duration = Duration::from_secs(3);
1040
1041 // Refresh the related excerpts when the user just beguns editing after
1042 // an idle period, and after they pause editing.
1043 fn refresh_context_if_needed(
1044 &mut self,
1045 project: &Entity<Project>,
1046 buffer: &Entity<language::Buffer>,
1047 cursor_position: language::Anchor,
1048 cx: &mut Context<Self>,
1049 ) {
1050 if !matches!(&self.options().context, ContextMode::Llm { .. }) {
1051 return;
1052 }
1053
1054 let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1055 return;
1056 };
1057
1058 let now = Instant::now();
1059 let was_idle = zeta_project
1060 .refresh_context_timestamp
1061 .map_or(true, |timestamp| {
1062 now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION
1063 });
1064 zeta_project.refresh_context_timestamp = Some(now);
1065 zeta_project.refresh_context_debounce_task = Some(cx.spawn({
1066 let buffer = buffer.clone();
1067 let project = project.clone();
1068 async move |this, cx| {
1069 if was_idle {
1070 log::debug!("refetching edit prediction context after idle");
1071 } else {
1072 cx.background_executor()
1073 .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION)
1074 .await;
1075 log::debug!("refetching edit prediction context after pause");
1076 }
1077 this.update(cx, |this, cx| {
1078 this.refresh_context(project, buffer, cursor_position, cx);
1079 })
1080 .ok()
1081 }
1082 }));
1083 }
1084
1085 // Refresh the related excerpts asynchronously. Ensure the task runs to completion,
1086 // and avoid spawning more than one concurrent task.
1087 fn refresh_context(
1088 &mut self,
1089 project: Entity<Project>,
1090 buffer: Entity<language::Buffer>,
1091 cursor_position: language::Anchor,
1092 cx: &mut Context<Self>,
1093 ) {
1094 let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1095 return;
1096 };
1097
1098 let debug_tx = self.debug_tx.clone();
1099
1100 zeta_project
1101 .refresh_context_task
1102 .get_or_insert(cx.spawn(async move |this, cx| {
1103 let related_excerpts = this
1104 .update(cx, |this, cx| {
1105 let Some(zeta_project) = this.projects.get(&project.entity_id()) else {
1106 return Task::ready(anyhow::Ok(HashMap::default()));
1107 };
1108
1109 let ContextMode::Llm(options) = &this.options().context else {
1110 return Task::ready(anyhow::Ok(HashMap::default()));
1111 };
1112
1113 let mut edit_history_unified_diff = String::new();
1114
1115 for event in zeta_project.events.iter() {
1116 if let Some(event) = event.to_request_event(cx) {
1117 writeln!(&mut edit_history_unified_diff, "{event}").ok();
1118 }
1119 }
1120
1121 find_related_excerpts(
1122 buffer.clone(),
1123 cursor_position,
1124 &project,
1125 edit_history_unified_diff,
1126 options,
1127 debug_tx,
1128 cx,
1129 )
1130 })
1131 .ok()?
1132 .await
1133 .log_err()
1134 .unwrap_or_default();
1135 this.update(cx, |this, _cx| {
1136 let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
1137 return;
1138 };
1139 zeta_project.context = Some(related_excerpts);
1140 zeta_project.refresh_context_task.take();
1141 if let Some(debug_tx) = &this.debug_tx {
1142 debug_tx
1143 .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished(
1144 ZetaContextRetrievalDebugInfo {
1145 project,
1146 timestamp: Instant::now(),
1147 },
1148 ))
1149 .ok();
1150 }
1151 })
1152 .ok()
1153 }));
1154 }
1155
1156 fn gather_nearby_diagnostics(
1157 cursor_offset: usize,
1158 diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
1159 snapshot: &BufferSnapshot,
1160 max_diagnostics_bytes: usize,
1161 ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
1162 // TODO: Could make this more efficient
1163 let mut diagnostic_groups = Vec::new();
1164 for (language_server_id, diagnostics) in diagnostic_sets {
1165 let mut groups = Vec::new();
1166 diagnostics.groups(*language_server_id, &mut groups, &snapshot);
1167 diagnostic_groups.extend(
1168 groups
1169 .into_iter()
1170 .map(|(_, group)| group.resolve::<usize>(&snapshot)),
1171 );
1172 }
1173
1174 // sort by proximity to cursor
1175 diagnostic_groups.sort_by_key(|group| {
1176 let range = &group.entries[group.primary_ix].range;
1177 if range.start >= cursor_offset {
1178 range.start - cursor_offset
1179 } else if cursor_offset >= range.end {
1180 cursor_offset - range.end
1181 } else {
1182 (cursor_offset - range.start).min(range.end - cursor_offset)
1183 }
1184 });
1185
1186 let mut results = Vec::new();
1187 let mut diagnostic_groups_truncated = false;
1188 let mut diagnostics_byte_count = 0;
1189 for group in diagnostic_groups {
1190 let raw_value = serde_json::value::to_raw_value(&group).unwrap();
1191 diagnostics_byte_count += raw_value.get().len();
1192 if diagnostics_byte_count > max_diagnostics_bytes {
1193 diagnostic_groups_truncated = true;
1194 break;
1195 }
1196 results.push(predict_edits_v3::DiagnosticGroup(raw_value));
1197 }
1198
1199 (results, diagnostic_groups_truncated)
1200 }
1201
1202 // TODO: Dedupe with similar code in request_prediction?
1203 pub fn cloud_request_for_zeta_cli(
1204 &mut self,
1205 project: &Entity<Project>,
1206 buffer: &Entity<Buffer>,
1207 position: language::Anchor,
1208 cx: &mut Context<Self>,
1209 ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
1210 let project_state = self.projects.get(&project.entity_id());
1211
1212 let index_state = project_state.map(|state| {
1213 state
1214 .syntax_index
1215 .read_with(cx, |index, _cx| index.state().clone())
1216 });
1217 let options = self.options.clone();
1218 let snapshot = buffer.read(cx).snapshot();
1219 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
1220 return Task::ready(Err(anyhow!("No file path for excerpt")));
1221 };
1222 let worktree_snapshots = project
1223 .read(cx)
1224 .worktrees(cx)
1225 .map(|worktree| worktree.read(cx).snapshot())
1226 .collect::<Vec<_>>();
1227
1228 let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
1229 let mut path = f.worktree.read(cx).absolutize(&f.path);
1230 if path.pop() { Some(path) } else { None }
1231 });
1232
1233 cx.background_spawn(async move {
1234 let index_state = if let Some(index_state) = index_state {
1235 Some(index_state.lock_owned().await)
1236 } else {
1237 None
1238 };
1239
1240 let cursor_point = position.to_point(&snapshot);
1241
1242 let debug_info = true;
1243 EditPredictionContext::gather_context(
1244 cursor_point,
1245 &snapshot,
1246 parent_abs_path.as_deref(),
1247 match &options.context {
1248 ContextMode::Llm(_) => {
1249 // TODO
1250 panic!("Llm mode not supported in zeta cli yet");
1251 }
1252 ContextMode::Syntax(edit_prediction_context_options) => {
1253 edit_prediction_context_options
1254 }
1255 },
1256 index_state.as_deref(),
1257 )
1258 .context("Failed to select excerpt")
1259 .map(|context| {
1260 make_syntax_context_cloud_request(
1261 excerpt_path.into(),
1262 context,
1263 // TODO pass everything
1264 Vec::new(),
1265 false,
1266 Vec::new(),
1267 false,
1268 None,
1269 debug_info,
1270 &worktree_snapshots,
1271 index_state.as_deref(),
1272 Some(options.max_prompt_bytes),
1273 options.prompt_format,
1274 )
1275 })
1276 })
1277 }
1278
1279 pub fn wait_for_initial_indexing(
1280 &mut self,
1281 project: &Entity<Project>,
1282 cx: &mut App,
1283 ) -> Task<Result<()>> {
1284 let zeta_project = self.get_or_init_zeta_project(project, cx);
1285 zeta_project
1286 .syntax_index
1287 .read(cx)
1288 .wait_for_initial_file_indexing(cx)
1289 }
1290}
1291
1292#[derive(Error, Debug)]
1293#[error(
1294 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1295)]
1296pub struct ZedUpdateRequiredError {
1297 minimum_version: SemanticVersion,
1298}
1299
1300fn make_syntax_context_cloud_request(
1301 excerpt_path: Arc<Path>,
1302 context: EditPredictionContext,
1303 events: Vec<predict_edits_v3::Event>,
1304 can_collect_data: bool,
1305 diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
1306 diagnostic_groups_truncated: bool,
1307 git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
1308 debug_info: bool,
1309 worktrees: &Vec<worktree::Snapshot>,
1310 index_state: Option<&SyntaxIndexState>,
1311 prompt_max_bytes: Option<usize>,
1312 prompt_format: PromptFormat,
1313) -> predict_edits_v3::PredictEditsRequest {
1314 let mut signatures = Vec::new();
1315 let mut declaration_to_signature_index = HashMap::default();
1316 let mut referenced_declarations = Vec::new();
1317
1318 for snippet in context.declarations {
1319 let project_entry_id = snippet.declaration.project_entry_id();
1320 let Some(path) = worktrees.iter().find_map(|worktree| {
1321 worktree.entry_for_id(project_entry_id).map(|entry| {
1322 let mut full_path = RelPathBuf::new();
1323 full_path.push(worktree.root_name());
1324 full_path.push(&entry.path);
1325 full_path
1326 })
1327 }) else {
1328 continue;
1329 };
1330
1331 let parent_index = index_state.and_then(|index_state| {
1332 snippet.declaration.parent().and_then(|parent| {
1333 add_signature(
1334 parent,
1335 &mut declaration_to_signature_index,
1336 &mut signatures,
1337 index_state,
1338 )
1339 })
1340 });
1341
1342 let (text, text_is_truncated) = snippet.declaration.item_text();
1343 referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
1344 path: path.as_std_path().into(),
1345 text: text.into(),
1346 range: snippet.declaration.item_line_range(),
1347 text_is_truncated,
1348 signature_range: snippet.declaration.signature_range_in_item_text(),
1349 parent_index,
1350 signature_score: snippet.score(DeclarationStyle::Signature),
1351 declaration_score: snippet.score(DeclarationStyle::Declaration),
1352 score_components: snippet.components,
1353 });
1354 }
1355
1356 let excerpt_parent = index_state.and_then(|index_state| {
1357 context
1358 .excerpt
1359 .parent_declarations
1360 .last()
1361 .and_then(|(parent, _)| {
1362 add_signature(
1363 *parent,
1364 &mut declaration_to_signature_index,
1365 &mut signatures,
1366 index_state,
1367 )
1368 })
1369 });
1370
1371 predict_edits_v3::PredictEditsRequest {
1372 excerpt_path,
1373 excerpt: context.excerpt_text.body,
1374 excerpt_line_range: context.excerpt.line_range,
1375 excerpt_range: context.excerpt.range,
1376 cursor_point: predict_edits_v3::Point {
1377 line: predict_edits_v3::Line(context.cursor_point.row),
1378 column: context.cursor_point.column,
1379 },
1380 referenced_declarations,
1381 included_files: vec![],
1382 signatures,
1383 excerpt_parent,
1384 events,
1385 can_collect_data,
1386 diagnostic_groups,
1387 diagnostic_groups_truncated,
1388 git_info,
1389 debug_info,
1390 prompt_max_bytes,
1391 prompt_format,
1392 }
1393}
1394
1395fn add_signature(
1396 declaration_id: DeclarationId,
1397 declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
1398 signatures: &mut Vec<Signature>,
1399 index: &SyntaxIndexState,
1400) -> Option<usize> {
1401 if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
1402 return Some(*signature_index);
1403 }
1404 let Some(parent_declaration) = index.declaration(declaration_id) else {
1405 log::error!("bug: missing parent declaration");
1406 return None;
1407 };
1408 let parent_index = parent_declaration.parent().and_then(|parent| {
1409 add_signature(parent, declaration_to_signature_index, signatures, index)
1410 });
1411 let (text, text_is_truncated) = parent_declaration.signature_text();
1412 let signature_index = signatures.len();
1413 signatures.push(Signature {
1414 text: text.into(),
1415 text_is_truncated,
1416 parent_index,
1417 range: parent_declaration.signature_line_range(),
1418 });
1419 declaration_to_signature_index.insert(declaration_id, signature_index);
1420 Some(signature_index)
1421}
1422
1423#[cfg(test)]
1424mod tests {
1425 use std::{
1426 path::{Path, PathBuf},
1427 sync::Arc,
1428 };
1429
1430 use client::UserStore;
1431 use clock::FakeSystemClock;
1432 use cloud_llm_client::predict_edits_v3::{self, Point};
1433 use edit_prediction_context::Line;
1434 use futures::{
1435 AsyncReadExt, StreamExt,
1436 channel::{mpsc, oneshot},
1437 };
1438 use gpui::{
1439 Entity, TestAppContext,
1440 http_client::{FakeHttpClient, Response},
1441 prelude::*,
1442 };
1443 use indoc::indoc;
1444 use language::{LanguageServerId, OffsetRangeExt as _};
1445 use pretty_assertions::{assert_eq, assert_matches};
1446 use project::{FakeFs, Project};
1447 use serde_json::json;
1448 use settings::SettingsStore;
1449 use util::path;
1450 use uuid::Uuid;
1451
1452 use crate::{BufferEditPrediction, Zeta};
1453
1454 #[gpui::test]
1455 async fn test_current_state(cx: &mut TestAppContext) {
1456 let (zeta, mut req_rx) = init_test(cx);
1457 let fs = FakeFs::new(cx.executor());
1458 fs.insert_tree(
1459 "/root",
1460 json!({
1461 "1.txt": "Hello!\nHow\nBye",
1462 "2.txt": "Hola!\nComo\nAdios"
1463 }),
1464 )
1465 .await;
1466 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1467
1468 zeta.update(cx, |zeta, cx| {
1469 zeta.register_project(&project, cx);
1470 });
1471
1472 let buffer1 = project
1473 .update(cx, |project, cx| {
1474 let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
1475 project.open_buffer(path, cx)
1476 })
1477 .await
1478 .unwrap();
1479 let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
1480 let position = snapshot1.anchor_before(language::Point::new(1, 3));
1481
1482 // Prediction for current file
1483
1484 let prediction_task = zeta.update(cx, |zeta, cx| {
1485 zeta.refresh_prediction(&project, &buffer1, position, cx)
1486 });
1487 let (_request, respond_tx) = req_rx.next().await.unwrap();
1488 respond_tx
1489 .send(predict_edits_v3::PredictEditsResponse {
1490 request_id: Uuid::new_v4(),
1491 edits: vec![predict_edits_v3::Edit {
1492 path: Path::new(path!("root/1.txt")).into(),
1493 range: Line(0)..Line(snapshot1.max_point().row + 1),
1494 content: "Hello!\nHow are you?\nBye".into(),
1495 }],
1496 debug_info: None,
1497 })
1498 .unwrap();
1499 prediction_task.await.unwrap();
1500
1501 zeta.read_with(cx, |zeta, cx| {
1502 let prediction = zeta
1503 .current_prediction_for_buffer(&buffer1, &project, cx)
1504 .unwrap();
1505 assert_matches!(prediction, BufferEditPrediction::Local { .. });
1506 });
1507
1508 // Prediction for another file
1509 let prediction_task = zeta.update(cx, |zeta, cx| {
1510 zeta.refresh_prediction(&project, &buffer1, position, cx)
1511 });
1512 let (_request, respond_tx) = req_rx.next().await.unwrap();
1513 respond_tx
1514 .send(predict_edits_v3::PredictEditsResponse {
1515 request_id: Uuid::new_v4(),
1516 edits: vec![predict_edits_v3::Edit {
1517 path: Path::new(path!("root/2.txt")).into(),
1518 range: Line(0)..Line(snapshot1.max_point().row + 1),
1519 content: "Hola!\nComo estas?\nAdios".into(),
1520 }],
1521 debug_info: None,
1522 })
1523 .unwrap();
1524 prediction_task.await.unwrap();
1525 zeta.read_with(cx, |zeta, cx| {
1526 let prediction = zeta
1527 .current_prediction_for_buffer(&buffer1, &project, cx)
1528 .unwrap();
1529 assert_matches!(
1530 prediction,
1531 BufferEditPrediction::Jump { prediction } if prediction.path.as_ref() == Path::new(path!("root/2.txt"))
1532 );
1533 });
1534
1535 let buffer2 = project
1536 .update(cx, |project, cx| {
1537 let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
1538 project.open_buffer(path, cx)
1539 })
1540 .await
1541 .unwrap();
1542
1543 zeta.read_with(cx, |zeta, cx| {
1544 let prediction = zeta
1545 .current_prediction_for_buffer(&buffer2, &project, cx)
1546 .unwrap();
1547 assert_matches!(prediction, BufferEditPrediction::Local { .. });
1548 });
1549 }
1550
1551 #[gpui::test]
1552 async fn test_simple_request(cx: &mut TestAppContext) {
1553 let (zeta, mut req_rx) = init_test(cx);
1554 let fs = FakeFs::new(cx.executor());
1555 fs.insert_tree(
1556 "/root",
1557 json!({
1558 "foo.md": "Hello!\nHow\nBye"
1559 }),
1560 )
1561 .await;
1562 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1563
1564 let buffer = project
1565 .update(cx, |project, cx| {
1566 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1567 project.open_buffer(path, cx)
1568 })
1569 .await
1570 .unwrap();
1571 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1572 let position = snapshot.anchor_before(language::Point::new(1, 3));
1573
1574 let prediction_task = zeta.update(cx, |zeta, cx| {
1575 zeta.request_prediction(&project, &buffer, position, cx)
1576 });
1577
1578 let (request, respond_tx) = req_rx.next().await.unwrap();
1579 assert_eq!(
1580 request.excerpt_path.as_ref(),
1581 Path::new(path!("root/foo.md"))
1582 );
1583 assert_eq!(
1584 request.cursor_point,
1585 Point {
1586 line: Line(1),
1587 column: 3
1588 }
1589 );
1590
1591 respond_tx
1592 .send(predict_edits_v3::PredictEditsResponse {
1593 request_id: Uuid::new_v4(),
1594 edits: vec![predict_edits_v3::Edit {
1595 path: Path::new(path!("root/foo.md")).into(),
1596 range: Line(0)..Line(snapshot.max_point().row + 1),
1597 content: "Hello!\nHow are you?\nBye".into(),
1598 }],
1599 debug_info: None,
1600 })
1601 .unwrap();
1602
1603 let prediction = prediction_task.await.unwrap().unwrap();
1604
1605 assert_eq!(prediction.edits.len(), 1);
1606 assert_eq!(
1607 prediction.edits[0].0.to_point(&snapshot).start,
1608 language::Point::new(1, 3)
1609 );
1610 assert_eq!(prediction.edits[0].1, " are you?");
1611 }
1612
1613 #[gpui::test]
1614 async fn test_request_events(cx: &mut TestAppContext) {
1615 let (zeta, mut req_rx) = init_test(cx);
1616 let fs = FakeFs::new(cx.executor());
1617 fs.insert_tree(
1618 "/root",
1619 json!({
1620 "foo.md": "Hello!\n\nBye"
1621 }),
1622 )
1623 .await;
1624 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1625
1626 let buffer = project
1627 .update(cx, |project, cx| {
1628 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1629 project.open_buffer(path, cx)
1630 })
1631 .await
1632 .unwrap();
1633
1634 zeta.update(cx, |zeta, cx| {
1635 zeta.register_buffer(&buffer, &project, cx);
1636 });
1637
1638 buffer.update(cx, |buffer, cx| {
1639 buffer.edit(vec![(7..7, "How")], None, cx);
1640 });
1641
1642 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1643 let position = snapshot.anchor_before(language::Point::new(1, 3));
1644
1645 let prediction_task = zeta.update(cx, |zeta, cx| {
1646 zeta.request_prediction(&project, &buffer, position, cx)
1647 });
1648
1649 let (request, respond_tx) = req_rx.next().await.unwrap();
1650
1651 assert_eq!(request.events.len(), 1);
1652 assert_eq!(
1653 request.events[0],
1654 predict_edits_v3::Event::BufferChange {
1655 path: Some(PathBuf::from(path!("root/foo.md"))),
1656 old_path: None,
1657 diff: indoc! {"
1658 @@ -1,3 +1,3 @@
1659 Hello!
1660 -
1661 +How
1662 Bye
1663 "}
1664 .to_string(),
1665 predicted: false
1666 }
1667 );
1668
1669 respond_tx
1670 .send(predict_edits_v3::PredictEditsResponse {
1671 request_id: Uuid::new_v4(),
1672 edits: vec![predict_edits_v3::Edit {
1673 path: Path::new(path!("root/foo.md")).into(),
1674 range: Line(0)..Line(snapshot.max_point().row + 1),
1675 content: "Hello!\nHow are you?\nBye".into(),
1676 }],
1677 debug_info: None,
1678 })
1679 .unwrap();
1680
1681 let prediction = prediction_task.await.unwrap().unwrap();
1682
1683 assert_eq!(prediction.edits.len(), 1);
1684 assert_eq!(
1685 prediction.edits[0].0.to_point(&snapshot).start,
1686 language::Point::new(1, 3)
1687 );
1688 assert_eq!(prediction.edits[0].1, " are you?");
1689 }
1690
1691 #[gpui::test]
1692 async fn test_request_diagnostics(cx: &mut TestAppContext) {
1693 let (zeta, mut req_rx) = init_test(cx);
1694 let fs = FakeFs::new(cx.executor());
1695 fs.insert_tree(
1696 "/root",
1697 json!({
1698 "foo.md": "Hello!\nBye"
1699 }),
1700 )
1701 .await;
1702 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1703
1704 let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1705 let diagnostic = lsp::Diagnostic {
1706 range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1707 severity: Some(lsp::DiagnosticSeverity::ERROR),
1708 message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1709 ..Default::default()
1710 };
1711
1712 project.update(cx, |project, cx| {
1713 project.lsp_store().update(cx, |lsp_store, cx| {
1714 // Create some diagnostics
1715 lsp_store
1716 .update_diagnostics(
1717 LanguageServerId(0),
1718 lsp::PublishDiagnosticsParams {
1719 uri: path_to_buffer_uri.clone(),
1720 diagnostics: vec![diagnostic],
1721 version: None,
1722 },
1723 None,
1724 language::DiagnosticSourceKind::Pushed,
1725 &[],
1726 cx,
1727 )
1728 .unwrap();
1729 });
1730 });
1731
1732 let buffer = project
1733 .update(cx, |project, cx| {
1734 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1735 project.open_buffer(path, cx)
1736 })
1737 .await
1738 .unwrap();
1739
1740 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1741 let position = snapshot.anchor_before(language::Point::new(0, 0));
1742
1743 let _prediction_task = zeta.update(cx, |zeta, cx| {
1744 zeta.request_prediction(&project, &buffer, position, cx)
1745 });
1746
1747 let (request, _respond_tx) = req_rx.next().await.unwrap();
1748
1749 assert_eq!(request.diagnostic_groups.len(), 1);
1750 let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1751 .unwrap();
1752 // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1753 assert_eq!(
1754 value,
1755 json!({
1756 "entries": [{
1757 "range": {
1758 "start": 8,
1759 "end": 10
1760 },
1761 "diagnostic": {
1762 "source": null,
1763 "code": null,
1764 "code_description": null,
1765 "severity": 1,
1766 "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1767 "markdown": null,
1768 "group_id": 0,
1769 "is_primary": true,
1770 "is_disk_based": false,
1771 "is_unnecessary": false,
1772 "source_kind": "Pushed",
1773 "data": null,
1774 "underline": true
1775 }
1776 }],
1777 "primary_ix": 0
1778 })
1779 );
1780 }
1781
1782 fn init_test(
1783 cx: &mut TestAppContext,
1784 ) -> (
1785 Entity<Zeta>,
1786 mpsc::UnboundedReceiver<(
1787 predict_edits_v3::PredictEditsRequest,
1788 oneshot::Sender<predict_edits_v3::PredictEditsResponse>,
1789 )>,
1790 ) {
1791 cx.update(move |cx| {
1792 let settings_store = SettingsStore::test(cx);
1793 cx.set_global(settings_store);
1794 language::init(cx);
1795 Project::init_settings(cx);
1796
1797 let (req_tx, req_rx) = mpsc::unbounded();
1798
1799 let http_client = FakeHttpClient::create({
1800 move |req| {
1801 let uri = req.uri().path().to_string();
1802 let mut body = req.into_body();
1803 let req_tx = req_tx.clone();
1804 async move {
1805 let resp = match uri.as_str() {
1806 "/client/llm_tokens" => serde_json::to_string(&json!({
1807 "token": "test"
1808 }))
1809 .unwrap(),
1810 "/predict_edits/v3" => {
1811 let mut buf = Vec::new();
1812 body.read_to_end(&mut buf).await.ok();
1813 let req = serde_json::from_slice(&buf).unwrap();
1814
1815 let (res_tx, res_rx) = oneshot::channel();
1816 req_tx.unbounded_send((req, res_tx)).unwrap();
1817 serde_json::to_string(&res_rx.await?).unwrap()
1818 }
1819 _ => {
1820 panic!("Unexpected path: {}", uri)
1821 }
1822 };
1823
1824 Ok(Response::builder().body(resp.into()).unwrap())
1825 }
1826 }
1827 });
1828
1829 let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1830 client.cloud_client().set_credentials(1, "test".into());
1831
1832 language_model::init(client.clone(), cx);
1833
1834 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1835 let zeta = Zeta::global(&client, &user_store, cx);
1836
1837 (zeta, req_rx)
1838 })
1839 }
1840}