1use anyhow::{Context as _, Result, anyhow};
2use chrono::TimeDelta;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
5use cloud_llm_client::{
6 AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
7 ZED_VERSION_HEADER_NAME,
8};
9use cloud_zeta2_prompt::{DEFAULT_MAX_PROMPT_BYTES, build_prompt};
10use collections::HashMap;
11use edit_prediction_context::{
12 DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
13 EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionScoreOptions, Line,
14 SyntaxIndex, SyntaxIndexState,
15};
16use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
17use futures::AsyncReadExt as _;
18use futures::channel::{mpsc, oneshot};
19use gpui::http_client::{AsyncBody, Method};
20use gpui::{
21 App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
22 http_client, prelude::*,
23};
24use language::{Anchor, Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
25use language::{BufferSnapshot, OffsetRangeExt};
26use language_model::{LlmApiToken, RefreshLlmTokenListener};
27use project::Project;
28use release_channel::AppVersion;
29use serde::de::DeserializeOwned;
30use std::collections::{VecDeque, hash_map};
31use std::fmt::Write;
32use std::ops::Range;
33use std::path::Path;
34use std::str::FromStr as _;
35use std::sync::Arc;
36use std::time::{Duration, Instant};
37use thiserror::Error;
38use util::rel_path::RelPathBuf;
39use util::{LogErrorFuture, TryFutureExt};
40use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
41
42pub mod merge_excerpts;
43mod prediction;
44mod provider;
45pub mod related_excerpts;
46
47use crate::merge_excerpts::merge_excerpts;
48use crate::prediction::EditPrediction;
49use crate::related_excerpts::find_related_excerpts;
50pub use crate::related_excerpts::{LlmContextOptions, SearchToolQuery};
51pub use provider::ZetaEditPredictionProvider;
52
53/// Maximum number of events to track.
54const MAX_EVENT_COUNT: usize = 16;
55
56pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
57 max_bytes: 512,
58 min_bytes: 128,
59 target_before_cursor_over_total_bytes: 0.5,
60};
61
62pub const DEFAULT_CONTEXT_OPTIONS: ContextMode = ContextMode::Llm(DEFAULT_LLM_CONTEXT_OPTIONS);
63
64pub const DEFAULT_LLM_CONTEXT_OPTIONS: LlmContextOptions = LlmContextOptions {
65 excerpt: DEFAULT_EXCERPT_OPTIONS,
66};
67
68pub const DEFAULT_SYNTAX_CONTEXT_OPTIONS: EditPredictionContextOptions =
69 EditPredictionContextOptions {
70 use_imports: true,
71 max_retrieved_declarations: 0,
72 excerpt: DEFAULT_EXCERPT_OPTIONS,
73 score: EditPredictionScoreOptions {
74 omit_excerpt_overlaps: true,
75 },
76 };
77
78pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
79 context: DEFAULT_CONTEXT_OPTIONS,
80 max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
81 max_diagnostic_bytes: 2048,
82 prompt_format: PromptFormat::DEFAULT,
83 file_indexing_parallelism: 1,
84 buffer_change_grouping_interval: Duration::from_secs(1),
85};
86
87pub struct Zeta2FeatureFlag;
88
89impl FeatureFlag for Zeta2FeatureFlag {
90 const NAME: &'static str = "zeta2";
91
92 fn enabled_for_staff() -> bool {
93 false
94 }
95}
96
97#[derive(Clone)]
98struct ZetaGlobal(Entity<Zeta>);
99
100impl Global for ZetaGlobal {}
101
102pub struct Zeta {
103 client: Arc<Client>,
104 user_store: Entity<UserStore>,
105 llm_token: LlmApiToken,
106 _llm_token_subscription: Subscription,
107 projects: HashMap<EntityId, ZetaProject>,
108 options: ZetaOptions,
109 update_required: bool,
110 debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
111}
112
113#[derive(Debug, Clone, PartialEq)]
114pub struct ZetaOptions {
115 pub context: ContextMode,
116 pub max_prompt_bytes: usize,
117 pub max_diagnostic_bytes: usize,
118 pub prompt_format: predict_edits_v3::PromptFormat,
119 pub file_indexing_parallelism: usize,
120 pub buffer_change_grouping_interval: Duration,
121}
122
123#[derive(Debug, Clone, PartialEq)]
124pub enum ContextMode {
125 Llm(LlmContextOptions),
126 Syntax(EditPredictionContextOptions),
127}
128
129impl ContextMode {
130 pub fn excerpt(&self) -> &EditPredictionExcerptOptions {
131 match self {
132 ContextMode::Llm(options) => &options.excerpt,
133 ContextMode::Syntax(options) => &options.excerpt,
134 }
135 }
136}
137
138#[derive(Debug)]
139pub enum ZetaDebugInfo {
140 ContextRetrievalStarted(ZetaContextRetrievalStartedDebugInfo),
141 SearchQueriesGenerated(ZetaSearchQueryDebugInfo),
142 SearchQueriesExecuted(ZetaContextRetrievalDebugInfo),
143 SearchResultsFiltered(ZetaContextRetrievalDebugInfo),
144 ContextRetrievalFinished(ZetaContextRetrievalDebugInfo),
145 EditPredicted(ZetaEditPredictionDebugInfo),
146}
147
148#[derive(Debug)]
149pub struct ZetaContextRetrievalStartedDebugInfo {
150 pub project: Entity<Project>,
151 pub timestamp: Instant,
152 pub search_prompt: String,
153}
154
155#[derive(Debug)]
156pub struct ZetaContextRetrievalDebugInfo {
157 pub project: Entity<Project>,
158 pub timestamp: Instant,
159}
160
161#[derive(Debug)]
162pub struct ZetaEditPredictionDebugInfo {
163 pub request: predict_edits_v3::PredictEditsRequest,
164 pub retrieval_time: TimeDelta,
165 pub buffer: WeakEntity<Buffer>,
166 pub position: language::Anchor,
167 pub local_prompt: Result<String, String>,
168 pub response_rx: oneshot::Receiver<Result<predict_edits_v3::PredictEditsResponse, String>>,
169}
170
171#[derive(Debug)]
172pub struct ZetaSearchQueryDebugInfo {
173 pub project: Entity<Project>,
174 pub timestamp: Instant,
175 pub queries: Vec<SearchToolQuery>,
176}
177
178pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
179
180struct ZetaProject {
181 syntax_index: Entity<SyntaxIndex>,
182 events: VecDeque<Event>,
183 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
184 current_prediction: Option<CurrentEditPrediction>,
185 context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
186 refresh_context_task: Option<LogErrorFuture<Task<Result<()>>>>,
187 refresh_context_debounce_task: Option<Task<Option<()>>>,
188 refresh_context_timestamp: Option<Instant>,
189}
190
191#[derive(Debug, Clone)]
192struct CurrentEditPrediction {
193 pub requested_by_buffer_id: EntityId,
194 pub prediction: EditPrediction,
195}
196
197impl CurrentEditPrediction {
198 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
199 let Some(new_edits) = self
200 .prediction
201 .interpolate(&self.prediction.buffer.read(cx))
202 else {
203 return false;
204 };
205
206 if self.prediction.buffer != old_prediction.prediction.buffer {
207 return true;
208 }
209
210 let Some(old_edits) = old_prediction
211 .prediction
212 .interpolate(&old_prediction.prediction.buffer.read(cx))
213 else {
214 return true;
215 };
216
217 // This reduces the occurrence of UI thrash from replacing edits
218 //
219 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
220 if self.requested_by_buffer_id == self.prediction.buffer.entity_id()
221 && self.requested_by_buffer_id == old_prediction.prediction.buffer.entity_id()
222 && old_edits.len() == 1
223 && new_edits.len() == 1
224 {
225 let (old_range, old_text) = &old_edits[0];
226 let (new_range, new_text) = &new_edits[0];
227 new_range == old_range && new_text.starts_with(old_text)
228 } else {
229 true
230 }
231 }
232}
233
234/// A prediction from the perspective of a buffer.
235#[derive(Debug)]
236enum BufferEditPrediction<'a> {
237 Local { prediction: &'a EditPrediction },
238 Jump { prediction: &'a EditPrediction },
239}
240
241struct RegisteredBuffer {
242 snapshot: BufferSnapshot,
243 _subscriptions: [gpui::Subscription; 2],
244}
245
246#[derive(Clone)]
247pub enum Event {
248 BufferChange {
249 old_snapshot: BufferSnapshot,
250 new_snapshot: BufferSnapshot,
251 timestamp: Instant,
252 },
253}
254
255impl Event {
256 pub fn to_request_event(&self, cx: &App) -> Option<predict_edits_v3::Event> {
257 match self {
258 Event::BufferChange {
259 old_snapshot,
260 new_snapshot,
261 ..
262 } => {
263 let path = new_snapshot.file().map(|f| f.full_path(cx));
264
265 let old_path = old_snapshot.file().and_then(|f| {
266 let old_path = f.full_path(cx);
267 if Some(&old_path) != path.as_ref() {
268 Some(old_path)
269 } else {
270 None
271 }
272 });
273
274 // TODO [zeta2] move to bg?
275 let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
276
277 if path == old_path && diff.is_empty() {
278 None
279 } else {
280 Some(predict_edits_v3::Event::BufferChange {
281 old_path,
282 path,
283 diff,
284 //todo: Actually detect if this edit was predicted or not
285 predicted: false,
286 })
287 }
288 }
289 }
290 }
291}
292
293impl Zeta {
294 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
295 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
296 }
297
298 pub fn global(
299 client: &Arc<Client>,
300 user_store: &Entity<UserStore>,
301 cx: &mut App,
302 ) -> Entity<Self> {
303 cx.try_global::<ZetaGlobal>()
304 .map(|global| global.0.clone())
305 .unwrap_or_else(|| {
306 let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
307 cx.set_global(ZetaGlobal(zeta.clone()));
308 zeta
309 })
310 }
311
312 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
313 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
314
315 Self {
316 projects: HashMap::default(),
317 client,
318 user_store,
319 options: DEFAULT_OPTIONS,
320 llm_token: LlmApiToken::default(),
321 _llm_token_subscription: cx.subscribe(
322 &refresh_llm_token_listener,
323 |this, _listener, _event, cx| {
324 let client = this.client.clone();
325 let llm_token = this.llm_token.clone();
326 cx.spawn(async move |_this, _cx| {
327 llm_token.refresh(&client).await?;
328 anyhow::Ok(())
329 })
330 .detach_and_log_err(cx);
331 },
332 ),
333 update_required: false,
334 debug_tx: None,
335 }
336 }
337
338 pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<ZetaDebugInfo> {
339 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
340 self.debug_tx = Some(debug_watch_tx);
341 debug_watch_rx
342 }
343
344 pub fn options(&self) -> &ZetaOptions {
345 &self.options
346 }
347
348 pub fn set_options(&mut self, options: ZetaOptions) {
349 self.options = options;
350 }
351
352 pub fn clear_history(&mut self) {
353 for zeta_project in self.projects.values_mut() {
354 zeta_project.events.clear();
355 }
356 }
357
358 pub fn history_for_project(&self, project: &Entity<Project>) -> impl Iterator<Item = &Event> {
359 self.projects
360 .get(&project.entity_id())
361 .map(|project| project.events.iter())
362 .into_iter()
363 .flatten()
364 }
365
366 pub fn context_for_project(
367 &self,
368 project: &Entity<Project>,
369 ) -> impl Iterator<Item = (Entity<Buffer>, &[Range<Anchor>])> {
370 self.projects
371 .get(&project.entity_id())
372 .and_then(|project| {
373 Some(
374 project
375 .context
376 .as_ref()?
377 .iter()
378 .map(|(buffer, ranges)| (buffer.clone(), ranges.as_slice())),
379 )
380 })
381 .into_iter()
382 .flatten()
383 }
384
385 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
386 self.user_store.read(cx).edit_prediction_usage()
387 }
388
389 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
390 self.get_or_init_zeta_project(project, cx);
391 }
392
393 pub fn register_buffer(
394 &mut self,
395 buffer: &Entity<Buffer>,
396 project: &Entity<Project>,
397 cx: &mut Context<Self>,
398 ) {
399 let zeta_project = self.get_or_init_zeta_project(project, cx);
400 Self::register_buffer_impl(zeta_project, buffer, project, cx);
401 }
402
403 fn get_or_init_zeta_project(
404 &mut self,
405 project: &Entity<Project>,
406 cx: &mut App,
407 ) -> &mut ZetaProject {
408 self.projects
409 .entry(project.entity_id())
410 .or_insert_with(|| ZetaProject {
411 syntax_index: cx.new(|cx| {
412 SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
413 }),
414 events: VecDeque::new(),
415 registered_buffers: HashMap::default(),
416 current_prediction: None,
417 context: None,
418 refresh_context_task: None,
419 refresh_context_debounce_task: None,
420 refresh_context_timestamp: None,
421 })
422 }
423
424 fn register_buffer_impl<'a>(
425 zeta_project: &'a mut ZetaProject,
426 buffer: &Entity<Buffer>,
427 project: &Entity<Project>,
428 cx: &mut Context<Self>,
429 ) -> &'a mut RegisteredBuffer {
430 let buffer_id = buffer.entity_id();
431 match zeta_project.registered_buffers.entry(buffer_id) {
432 hash_map::Entry::Occupied(entry) => entry.into_mut(),
433 hash_map::Entry::Vacant(entry) => {
434 let snapshot = buffer.read(cx).snapshot();
435 let project_entity_id = project.entity_id();
436 entry.insert(RegisteredBuffer {
437 snapshot,
438 _subscriptions: [
439 cx.subscribe(buffer, {
440 let project = project.downgrade();
441 move |this, buffer, event, cx| {
442 if let language::BufferEvent::Edited = event
443 && let Some(project) = project.upgrade()
444 {
445 this.report_changes_for_buffer(&buffer, &project, cx);
446 }
447 }
448 }),
449 cx.observe_release(buffer, move |this, _buffer, _cx| {
450 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
451 else {
452 return;
453 };
454 zeta_project.registered_buffers.remove(&buffer_id);
455 }),
456 ],
457 })
458 }
459 }
460 }
461
462 fn report_changes_for_buffer(
463 &mut self,
464 buffer: &Entity<Buffer>,
465 project: &Entity<Project>,
466 cx: &mut Context<Self>,
467 ) -> BufferSnapshot {
468 let buffer_change_grouping_interval = self.options.buffer_change_grouping_interval;
469 let zeta_project = self.get_or_init_zeta_project(project, cx);
470 let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
471
472 let new_snapshot = buffer.read(cx).snapshot();
473 if new_snapshot.version != registered_buffer.snapshot.version {
474 let old_snapshot =
475 std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
476 Self::push_event(
477 zeta_project,
478 buffer_change_grouping_interval,
479 Event::BufferChange {
480 old_snapshot,
481 new_snapshot: new_snapshot.clone(),
482 timestamp: Instant::now(),
483 },
484 );
485 }
486
487 new_snapshot
488 }
489
490 fn push_event(
491 zeta_project: &mut ZetaProject,
492 buffer_change_grouping_interval: Duration,
493 event: Event,
494 ) {
495 let events = &mut zeta_project.events;
496
497 if buffer_change_grouping_interval > Duration::ZERO
498 && let Some(Event::BufferChange {
499 new_snapshot: last_new_snapshot,
500 timestamp: last_timestamp,
501 ..
502 }) = events.back_mut()
503 {
504 // Coalesce edits for the same buffer when they happen one after the other.
505 let Event::BufferChange {
506 old_snapshot,
507 new_snapshot,
508 timestamp,
509 } = &event;
510
511 if timestamp.duration_since(*last_timestamp) <= buffer_change_grouping_interval
512 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
513 && old_snapshot.version == last_new_snapshot.version
514 {
515 *last_new_snapshot = new_snapshot.clone();
516 *last_timestamp = *timestamp;
517 return;
518 }
519 }
520
521 if events.len() >= MAX_EVENT_COUNT {
522 // These are halved instead of popping to improve prompt caching.
523 events.drain(..MAX_EVENT_COUNT / 2);
524 }
525
526 events.push_back(event);
527 }
528
529 fn current_prediction_for_buffer(
530 &self,
531 buffer: &Entity<Buffer>,
532 project: &Entity<Project>,
533 cx: &App,
534 ) -> Option<BufferEditPrediction<'_>> {
535 let project_state = self.projects.get(&project.entity_id())?;
536
537 let CurrentEditPrediction {
538 requested_by_buffer_id,
539 prediction,
540 } = project_state.current_prediction.as_ref()?;
541
542 if prediction.targets_buffer(buffer.read(cx), cx) {
543 Some(BufferEditPrediction::Local { prediction })
544 } else if *requested_by_buffer_id == buffer.entity_id() {
545 Some(BufferEditPrediction::Jump { prediction })
546 } else {
547 None
548 }
549 }
550
551 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
552 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
553 return;
554 };
555
556 let Some(prediction) = project_state.current_prediction.take() else {
557 return;
558 };
559 let request_id = prediction.prediction.id.into();
560
561 let client = self.client.clone();
562 let llm_token = self.llm_token.clone();
563 let app_version = AppVersion::global(cx);
564 cx.spawn(async move |this, cx| {
565 let url = if let Ok(predict_edits_url) = std::env::var("ZED_ACCEPT_PREDICTION_URL") {
566 http_client::Url::parse(&predict_edits_url)?
567 } else {
568 client
569 .http_client()
570 .build_zed_llm_url("/predict_edits/accept", &[])?
571 };
572
573 let response = cx
574 .background_spawn(Self::send_api_request::<()>(
575 move |builder| {
576 let req = builder.uri(url.as_ref()).body(
577 serde_json::to_string(&AcceptEditPredictionBody { request_id })?.into(),
578 );
579 Ok(req?)
580 },
581 client,
582 llm_token,
583 app_version,
584 ))
585 .await;
586
587 Self::handle_api_response(&this, response, cx)?;
588 anyhow::Ok(())
589 })
590 .detach_and_log_err(cx);
591 }
592
593 fn discard_current_prediction(&mut self, project: &Entity<Project>) {
594 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
595 project_state.current_prediction.take();
596 };
597 }
598
599 pub fn refresh_prediction(
600 &mut self,
601 project: &Entity<Project>,
602 buffer: &Entity<Buffer>,
603 position: language::Anchor,
604 cx: &mut Context<Self>,
605 ) -> Task<Result<()>> {
606 let request_task = self.request_prediction(project, buffer, position, cx);
607 let buffer = buffer.clone();
608 let project = project.clone();
609
610 cx.spawn(async move |this, cx| {
611 if let Some(prediction) = request_task.await? {
612 this.update(cx, |this, cx| {
613 let project_state = this
614 .projects
615 .get_mut(&project.entity_id())
616 .context("Project not found")?;
617
618 let new_prediction = CurrentEditPrediction {
619 requested_by_buffer_id: buffer.entity_id(),
620 prediction: prediction,
621 };
622
623 if project_state
624 .current_prediction
625 .as_ref()
626 .is_none_or(|old_prediction| {
627 new_prediction.should_replace_prediction(&old_prediction, cx)
628 })
629 {
630 project_state.current_prediction = Some(new_prediction);
631 }
632 anyhow::Ok(())
633 })??;
634 }
635 Ok(())
636 })
637 }
638
639 pub fn request_prediction(
640 &mut self,
641 project: &Entity<Project>,
642 buffer: &Entity<Buffer>,
643 position: language::Anchor,
644 cx: &mut Context<Self>,
645 ) -> Task<Result<Option<EditPrediction>>> {
646 let project_state = self.projects.get(&project.entity_id());
647
648 let index_state = project_state.map(|state| {
649 state
650 .syntax_index
651 .read_with(cx, |index, _cx| index.state().clone())
652 });
653 let options = self.options.clone();
654 let snapshot = buffer.read(cx).snapshot();
655 let Some(excerpt_path) = snapshot
656 .file()
657 .map(|path| -> Arc<Path> { path.full_path(cx).into() })
658 else {
659 return Task::ready(Err(anyhow!("No file path for excerpt")));
660 };
661 let client = self.client.clone();
662 let llm_token = self.llm_token.clone();
663 let app_version = AppVersion::global(cx);
664 let worktree_snapshots = project
665 .read(cx)
666 .worktrees(cx)
667 .map(|worktree| worktree.read(cx).snapshot())
668 .collect::<Vec<_>>();
669 let debug_tx = self.debug_tx.clone();
670
671 let events = project_state
672 .map(|state| {
673 state
674 .events
675 .iter()
676 .filter_map(|event| event.to_request_event(cx))
677 .collect::<Vec<_>>()
678 })
679 .unwrap_or_default();
680
681 let diagnostics = snapshot.diagnostic_sets().clone();
682
683 let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
684 let mut path = f.worktree.read(cx).absolutize(&f.path);
685 if path.pop() { Some(path) } else { None }
686 });
687
688 // TODO data collection
689 let can_collect_data = cx.is_staff();
690
691 let mut included_files = project_state
692 .and_then(|project_state| project_state.context.as_ref())
693 .unwrap_or(&HashMap::default())
694 .iter()
695 .filter_map(|(buffer, ranges)| {
696 let buffer = buffer.read(cx);
697 Some((
698 buffer.snapshot(),
699 buffer.file()?.full_path(cx).into(),
700 ranges.clone(),
701 ))
702 })
703 .collect::<Vec<_>>();
704
705 let request_task = cx.background_spawn({
706 let snapshot = snapshot.clone();
707 let buffer = buffer.clone();
708 async move {
709 let index_state = if let Some(index_state) = index_state {
710 Some(index_state.lock_owned().await)
711 } else {
712 None
713 };
714
715 let cursor_offset = position.to_offset(&snapshot);
716 let cursor_point = cursor_offset.to_point(&snapshot);
717
718 let before_retrieval = chrono::Utc::now();
719
720 let (diagnostic_groups, diagnostic_groups_truncated) =
721 Self::gather_nearby_diagnostics(
722 cursor_offset,
723 &diagnostics,
724 &snapshot,
725 options.max_diagnostic_bytes,
726 );
727
728 let request = match options.context {
729 ContextMode::Llm(context_options) => {
730 let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
731 cursor_point,
732 &snapshot,
733 &context_options.excerpt,
734 index_state.as_deref(),
735 ) else {
736 return Ok((None, None));
737 };
738
739 let excerpt_anchor_range = snapshot.anchor_after(excerpt.range.start)
740 ..snapshot.anchor_before(excerpt.range.end);
741
742 if let Some(buffer_ix) = included_files
743 .iter()
744 .position(|(buffer, _, _)| buffer.remote_id() == snapshot.remote_id())
745 {
746 let (buffer, _, ranges) = &mut included_files[buffer_ix];
747 let range_ix = ranges
748 .binary_search_by(|probe| {
749 probe
750 .start
751 .cmp(&excerpt_anchor_range.start, buffer)
752 .then(excerpt_anchor_range.end.cmp(&probe.end, buffer))
753 })
754 .unwrap_or_else(|ix| ix);
755
756 ranges.insert(range_ix, excerpt_anchor_range);
757 let last_ix = included_files.len() - 1;
758 included_files.swap(buffer_ix, last_ix);
759 } else {
760 included_files.push((
761 snapshot,
762 excerpt_path.clone(),
763 vec![excerpt_anchor_range],
764 ));
765 }
766
767 let included_files = included_files
768 .into_iter()
769 .map(|(buffer, path, ranges)| {
770 let excerpts = merge_excerpts(
771 &buffer,
772 ranges.iter().map(|range| {
773 let point_range = range.to_point(&buffer);
774 Line(point_range.start.row)..Line(point_range.end.row)
775 }),
776 );
777 predict_edits_v3::IncludedFile {
778 path,
779 max_row: Line(buffer.max_point().row),
780 excerpts,
781 }
782 })
783 .collect::<Vec<_>>();
784
785 predict_edits_v3::PredictEditsRequest {
786 excerpt_path,
787 excerpt: String::new(),
788 excerpt_line_range: Line(0)..Line(0),
789 excerpt_range: 0..0,
790 cursor_point: predict_edits_v3::Point {
791 line: predict_edits_v3::Line(cursor_point.row),
792 column: cursor_point.column,
793 },
794 included_files,
795 referenced_declarations: vec![],
796 events,
797 can_collect_data,
798 diagnostic_groups,
799 diagnostic_groups_truncated,
800 debug_info: debug_tx.is_some(),
801 prompt_max_bytes: Some(options.max_prompt_bytes),
802 prompt_format: options.prompt_format,
803 // TODO [zeta2]
804 signatures: vec![],
805 excerpt_parent: None,
806 git_info: None,
807 }
808 }
809 ContextMode::Syntax(context_options) => {
810 let Some(context) = EditPredictionContext::gather_context(
811 cursor_point,
812 &snapshot,
813 parent_abs_path.as_deref(),
814 &context_options,
815 index_state.as_deref(),
816 ) else {
817 return Ok((None, None));
818 };
819
820 make_syntax_context_cloud_request(
821 excerpt_path,
822 context,
823 events,
824 can_collect_data,
825 diagnostic_groups,
826 diagnostic_groups_truncated,
827 None,
828 debug_tx.is_some(),
829 &worktree_snapshots,
830 index_state.as_deref(),
831 Some(options.max_prompt_bytes),
832 options.prompt_format,
833 )
834 }
835 };
836
837 let retrieval_time = chrono::Utc::now() - before_retrieval;
838
839 let debug_response_tx = if let Some(debug_tx) = &debug_tx {
840 let (response_tx, response_rx) = oneshot::channel();
841
842 let local_prompt = build_prompt(&request)
843 .map(|(prompt, _)| prompt)
844 .map_err(|err| err.to_string());
845
846 debug_tx
847 .unbounded_send(ZetaDebugInfo::EditPredicted(ZetaEditPredictionDebugInfo {
848 request: request.clone(),
849 retrieval_time,
850 buffer: buffer.downgrade(),
851 local_prompt,
852 position,
853 response_rx,
854 }))
855 .ok();
856 Some(response_tx)
857 } else {
858 None
859 };
860
861 if cfg!(debug_assertions) && std::env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
862 if let Some(debug_response_tx) = debug_response_tx {
863 debug_response_tx
864 .send(Err("Request skipped".to_string()))
865 .ok();
866 }
867 anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
868 }
869
870 let response =
871 Self::send_prediction_request(client, llm_token, app_version, request).await;
872
873 if let Some(debug_response_tx) = debug_response_tx {
874 debug_response_tx
875 .send(
876 response
877 .as_ref()
878 .map_err(|err| err.to_string())
879 .map(|response| response.0.clone()),
880 )
881 .ok();
882 }
883
884 response.map(|(res, usage)| (Some(res), usage))
885 }
886 });
887
888 let buffer = buffer.clone();
889
890 cx.spawn({
891 let project = project.clone();
892 async move |this, cx| {
893 let Some(response) = Self::handle_api_response(&this, request_task.await, cx)?
894 else {
895 return Ok(None);
896 };
897
898 // TODO telemetry: duration, etc
899 Ok(EditPrediction::from_response(response, &snapshot, &buffer, &project, cx).await)
900 }
901 })
902 }
903
904 async fn send_prediction_request(
905 client: Arc<Client>,
906 llm_token: LlmApiToken,
907 app_version: SemanticVersion,
908 request: predict_edits_v3::PredictEditsRequest,
909 ) -> Result<(
910 predict_edits_v3::PredictEditsResponse,
911 Option<EditPredictionUsage>,
912 )> {
913 let url = if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
914 http_client::Url::parse(&predict_edits_url)?
915 } else {
916 client
917 .http_client()
918 .build_zed_llm_url("/predict_edits/v3", &[])?
919 };
920
921 Self::send_api_request(
922 |builder| {
923 let req = builder
924 .uri(url.as_ref())
925 .body(serde_json::to_string(&request)?.into());
926 Ok(req?)
927 },
928 client,
929 llm_token,
930 app_version,
931 )
932 .await
933 }
934
935 fn handle_api_response<T>(
936 this: &WeakEntity<Self>,
937 response: Result<(T, Option<EditPredictionUsage>)>,
938 cx: &mut gpui::AsyncApp,
939 ) -> Result<T> {
940 match response {
941 Ok((data, usage)) => {
942 if let Some(usage) = usage {
943 this.update(cx, |this, cx| {
944 this.user_store.update(cx, |user_store, cx| {
945 user_store.update_edit_prediction_usage(usage, cx);
946 });
947 })
948 .ok();
949 }
950 Ok(data)
951 }
952 Err(err) => {
953 if err.is::<ZedUpdateRequiredError>() {
954 cx.update(|cx| {
955 this.update(cx, |this, _cx| {
956 this.update_required = true;
957 })
958 .ok();
959
960 let error_message: SharedString = err.to_string().into();
961 show_app_notification(
962 NotificationId::unique::<ZedUpdateRequiredError>(),
963 cx,
964 move |cx| {
965 cx.new(|cx| {
966 ErrorMessagePrompt::new(error_message.clone(), cx)
967 .with_link_button("Update Zed", "https://zed.dev/releases")
968 })
969 },
970 );
971 })
972 .ok();
973 }
974 Err(err)
975 }
976 }
977 }
978
979 async fn send_api_request<Res>(
980 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
981 client: Arc<Client>,
982 llm_token: LlmApiToken,
983 app_version: SemanticVersion,
984 ) -> Result<(Res, Option<EditPredictionUsage>)>
985 where
986 Res: DeserializeOwned,
987 {
988 let http_client = client.http_client();
989 let mut token = llm_token.acquire(&client).await?;
990 let mut did_retry = false;
991
992 loop {
993 let request_builder = http_client::Request::builder().method(Method::POST);
994
995 let request = build(
996 request_builder
997 .header("Content-Type", "application/json")
998 .header("Authorization", format!("Bearer {}", token))
999 .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
1000 )?;
1001
1002 let mut response = http_client.send(request).await?;
1003
1004 if let Some(minimum_required_version) = response
1005 .headers()
1006 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
1007 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
1008 {
1009 anyhow::ensure!(
1010 app_version >= minimum_required_version,
1011 ZedUpdateRequiredError {
1012 minimum_version: minimum_required_version
1013 }
1014 );
1015 }
1016
1017 if response.status().is_success() {
1018 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
1019
1020 let mut body = Vec::new();
1021 response.body_mut().read_to_end(&mut body).await?;
1022 return Ok((serde_json::from_slice(&body)?, usage));
1023 } else if !did_retry
1024 && response
1025 .headers()
1026 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1027 .is_some()
1028 {
1029 did_retry = true;
1030 token = llm_token.refresh(&client).await?;
1031 } else {
1032 let mut body = String::new();
1033 response.body_mut().read_to_string(&mut body).await?;
1034 anyhow::bail!(
1035 "Request failed with status: {:?}\nBody: {}",
1036 response.status(),
1037 body
1038 );
1039 }
1040 }
1041 }
1042
1043 pub const CONTEXT_RETRIEVAL_IDLE_DURATION: Duration = Duration::from_secs(10);
1044 pub const CONTEXT_RETRIEVAL_DEBOUNCE_DURATION: Duration = Duration::from_secs(3);
1045
1046 // Refresh the related excerpts when the user just beguns editing after
1047 // an idle period, and after they pause editing.
1048 fn refresh_context_if_needed(
1049 &mut self,
1050 project: &Entity<Project>,
1051 buffer: &Entity<language::Buffer>,
1052 cursor_position: language::Anchor,
1053 cx: &mut Context<Self>,
1054 ) {
1055 if !matches!(&self.options().context, ContextMode::Llm { .. }) {
1056 return;
1057 }
1058
1059 let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1060 return;
1061 };
1062
1063 let now = Instant::now();
1064 let was_idle = zeta_project
1065 .refresh_context_timestamp
1066 .map_or(true, |timestamp| {
1067 now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION
1068 });
1069 zeta_project.refresh_context_timestamp = Some(now);
1070 zeta_project.refresh_context_debounce_task = Some(cx.spawn({
1071 let buffer = buffer.clone();
1072 let project = project.clone();
1073 async move |this, cx| {
1074 if was_idle {
1075 log::debug!("refetching edit prediction context after idle");
1076 } else {
1077 cx.background_executor()
1078 .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION)
1079 .await;
1080 log::debug!("refetching edit prediction context after pause");
1081 }
1082 this.update(cx, |this, cx| {
1083 let task = this.refresh_context(project.clone(), buffer, cursor_position, cx);
1084
1085 if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
1086 zeta_project.refresh_context_task = Some(task.log_err());
1087 };
1088 })
1089 .ok()
1090 }
1091 }));
1092 }
1093
1094 // Refresh the related excerpts asynchronously. Ensure the task runs to completion,
1095 // and avoid spawning more than one concurrent task.
1096 pub fn refresh_context(
1097 &mut self,
1098 project: Entity<Project>,
1099 buffer: Entity<language::Buffer>,
1100 cursor_position: language::Anchor,
1101 cx: &mut Context<Self>,
1102 ) -> Task<Result<()>> {
1103 cx.spawn(async move |this, cx| {
1104 let related_excerpts_result = this
1105 .update(cx, |this, cx| {
1106 let Some(zeta_project) = this.projects.get(&project.entity_id()) else {
1107 return Task::ready(anyhow::Ok(HashMap::default()));
1108 };
1109
1110 let ContextMode::Llm(options) = &this.options().context else {
1111 return Task::ready(anyhow::Ok(HashMap::default()));
1112 };
1113
1114 let mut edit_history_unified_diff = String::new();
1115
1116 for event in zeta_project.events.iter() {
1117 if let Some(event) = event.to_request_event(cx) {
1118 writeln!(&mut edit_history_unified_diff, "{event}").ok();
1119 }
1120 }
1121
1122 find_related_excerpts(
1123 buffer.clone(),
1124 cursor_position,
1125 &project,
1126 edit_history_unified_diff,
1127 options,
1128 this.debug_tx.clone(),
1129 cx,
1130 )
1131 })?
1132 .await;
1133
1134 this.update(cx, |this, _cx| {
1135 let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
1136 return Ok(());
1137 };
1138 zeta_project.refresh_context_task.take();
1139 if let Some(debug_tx) = &this.debug_tx {
1140 debug_tx
1141 .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished(
1142 ZetaContextRetrievalDebugInfo {
1143 project,
1144 timestamp: Instant::now(),
1145 },
1146 ))
1147 .ok();
1148 }
1149 match related_excerpts_result {
1150 Ok(excerpts) => {
1151 zeta_project.context = Some(excerpts);
1152 Ok(())
1153 }
1154 Err(error) => Err(error),
1155 }
1156 })?
1157 })
1158 }
1159
1160 fn gather_nearby_diagnostics(
1161 cursor_offset: usize,
1162 diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
1163 snapshot: &BufferSnapshot,
1164 max_diagnostics_bytes: usize,
1165 ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
1166 // TODO: Could make this more efficient
1167 let mut diagnostic_groups = Vec::new();
1168 for (language_server_id, diagnostics) in diagnostic_sets {
1169 let mut groups = Vec::new();
1170 diagnostics.groups(*language_server_id, &mut groups, &snapshot);
1171 diagnostic_groups.extend(
1172 groups
1173 .into_iter()
1174 .map(|(_, group)| group.resolve::<usize>(&snapshot)),
1175 );
1176 }
1177
1178 // sort by proximity to cursor
1179 diagnostic_groups.sort_by_key(|group| {
1180 let range = &group.entries[group.primary_ix].range;
1181 if range.start >= cursor_offset {
1182 range.start - cursor_offset
1183 } else if cursor_offset >= range.end {
1184 cursor_offset - range.end
1185 } else {
1186 (cursor_offset - range.start).min(range.end - cursor_offset)
1187 }
1188 });
1189
1190 let mut results = Vec::new();
1191 let mut diagnostic_groups_truncated = false;
1192 let mut diagnostics_byte_count = 0;
1193 for group in diagnostic_groups {
1194 let raw_value = serde_json::value::to_raw_value(&group).unwrap();
1195 diagnostics_byte_count += raw_value.get().len();
1196 if diagnostics_byte_count > max_diagnostics_bytes {
1197 diagnostic_groups_truncated = true;
1198 break;
1199 }
1200 results.push(predict_edits_v3::DiagnosticGroup(raw_value));
1201 }
1202
1203 (results, diagnostic_groups_truncated)
1204 }
1205
1206 // TODO: Dedupe with similar code in request_prediction?
1207 pub fn cloud_request_for_zeta_cli(
1208 &mut self,
1209 project: &Entity<Project>,
1210 buffer: &Entity<Buffer>,
1211 position: language::Anchor,
1212 cx: &mut Context<Self>,
1213 ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
1214 let project_state = self.projects.get(&project.entity_id());
1215
1216 let index_state = project_state.map(|state| {
1217 state
1218 .syntax_index
1219 .read_with(cx, |index, _cx| index.state().clone())
1220 });
1221 let options = self.options.clone();
1222 let snapshot = buffer.read(cx).snapshot();
1223 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
1224 return Task::ready(Err(anyhow!("No file path for excerpt")));
1225 };
1226 let worktree_snapshots = project
1227 .read(cx)
1228 .worktrees(cx)
1229 .map(|worktree| worktree.read(cx).snapshot())
1230 .collect::<Vec<_>>();
1231
1232 let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
1233 let mut path = f.worktree.read(cx).absolutize(&f.path);
1234 if path.pop() { Some(path) } else { None }
1235 });
1236
1237 cx.background_spawn(async move {
1238 let index_state = if let Some(index_state) = index_state {
1239 Some(index_state.lock_owned().await)
1240 } else {
1241 None
1242 };
1243
1244 let cursor_point = position.to_point(&snapshot);
1245
1246 let debug_info = true;
1247 EditPredictionContext::gather_context(
1248 cursor_point,
1249 &snapshot,
1250 parent_abs_path.as_deref(),
1251 match &options.context {
1252 ContextMode::Llm(_) => {
1253 // TODO
1254 panic!("Llm mode not supported in zeta cli yet");
1255 }
1256 ContextMode::Syntax(edit_prediction_context_options) => {
1257 edit_prediction_context_options
1258 }
1259 },
1260 index_state.as_deref(),
1261 )
1262 .context("Failed to select excerpt")
1263 .map(|context| {
1264 make_syntax_context_cloud_request(
1265 excerpt_path.into(),
1266 context,
1267 // TODO pass everything
1268 Vec::new(),
1269 false,
1270 Vec::new(),
1271 false,
1272 None,
1273 debug_info,
1274 &worktree_snapshots,
1275 index_state.as_deref(),
1276 Some(options.max_prompt_bytes),
1277 options.prompt_format,
1278 )
1279 })
1280 })
1281 }
1282
1283 pub fn wait_for_initial_indexing(
1284 &mut self,
1285 project: &Entity<Project>,
1286 cx: &mut App,
1287 ) -> Task<Result<()>> {
1288 let zeta_project = self.get_or_init_zeta_project(project, cx);
1289 zeta_project
1290 .syntax_index
1291 .read(cx)
1292 .wait_for_initial_file_indexing(cx)
1293 }
1294}
1295
1296#[derive(Error, Debug)]
1297#[error(
1298 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1299)]
1300pub struct ZedUpdateRequiredError {
1301 minimum_version: SemanticVersion,
1302}
1303
1304fn make_syntax_context_cloud_request(
1305 excerpt_path: Arc<Path>,
1306 context: EditPredictionContext,
1307 events: Vec<predict_edits_v3::Event>,
1308 can_collect_data: bool,
1309 diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
1310 diagnostic_groups_truncated: bool,
1311 git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
1312 debug_info: bool,
1313 worktrees: &Vec<worktree::Snapshot>,
1314 index_state: Option<&SyntaxIndexState>,
1315 prompt_max_bytes: Option<usize>,
1316 prompt_format: PromptFormat,
1317) -> predict_edits_v3::PredictEditsRequest {
1318 let mut signatures = Vec::new();
1319 let mut declaration_to_signature_index = HashMap::default();
1320 let mut referenced_declarations = Vec::new();
1321
1322 for snippet in context.declarations {
1323 let project_entry_id = snippet.declaration.project_entry_id();
1324 let Some(path) = worktrees.iter().find_map(|worktree| {
1325 worktree.entry_for_id(project_entry_id).map(|entry| {
1326 let mut full_path = RelPathBuf::new();
1327 full_path.push(worktree.root_name());
1328 full_path.push(&entry.path);
1329 full_path
1330 })
1331 }) else {
1332 continue;
1333 };
1334
1335 let parent_index = index_state.and_then(|index_state| {
1336 snippet.declaration.parent().and_then(|parent| {
1337 add_signature(
1338 parent,
1339 &mut declaration_to_signature_index,
1340 &mut signatures,
1341 index_state,
1342 )
1343 })
1344 });
1345
1346 let (text, text_is_truncated) = snippet.declaration.item_text();
1347 referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
1348 path: path.as_std_path().into(),
1349 text: text.into(),
1350 range: snippet.declaration.item_line_range(),
1351 text_is_truncated,
1352 signature_range: snippet.declaration.signature_range_in_item_text(),
1353 parent_index,
1354 signature_score: snippet.score(DeclarationStyle::Signature),
1355 declaration_score: snippet.score(DeclarationStyle::Declaration),
1356 score_components: snippet.components,
1357 });
1358 }
1359
1360 let excerpt_parent = index_state.and_then(|index_state| {
1361 context
1362 .excerpt
1363 .parent_declarations
1364 .last()
1365 .and_then(|(parent, _)| {
1366 add_signature(
1367 *parent,
1368 &mut declaration_to_signature_index,
1369 &mut signatures,
1370 index_state,
1371 )
1372 })
1373 });
1374
1375 predict_edits_v3::PredictEditsRequest {
1376 excerpt_path,
1377 excerpt: context.excerpt_text.body,
1378 excerpt_line_range: context.excerpt.line_range,
1379 excerpt_range: context.excerpt.range,
1380 cursor_point: predict_edits_v3::Point {
1381 line: predict_edits_v3::Line(context.cursor_point.row),
1382 column: context.cursor_point.column,
1383 },
1384 referenced_declarations,
1385 included_files: vec![],
1386 signatures,
1387 excerpt_parent,
1388 events,
1389 can_collect_data,
1390 diagnostic_groups,
1391 diagnostic_groups_truncated,
1392 git_info,
1393 debug_info,
1394 prompt_max_bytes,
1395 prompt_format,
1396 }
1397}
1398
1399fn add_signature(
1400 declaration_id: DeclarationId,
1401 declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
1402 signatures: &mut Vec<Signature>,
1403 index: &SyntaxIndexState,
1404) -> Option<usize> {
1405 if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
1406 return Some(*signature_index);
1407 }
1408 let Some(parent_declaration) = index.declaration(declaration_id) else {
1409 log::error!("bug: missing parent declaration");
1410 return None;
1411 };
1412 let parent_index = parent_declaration.parent().and_then(|parent| {
1413 add_signature(parent, declaration_to_signature_index, signatures, index)
1414 });
1415 let (text, text_is_truncated) = parent_declaration.signature_text();
1416 let signature_index = signatures.len();
1417 signatures.push(Signature {
1418 text: text.into(),
1419 text_is_truncated,
1420 parent_index,
1421 range: parent_declaration.signature_line_range(),
1422 });
1423 declaration_to_signature_index.insert(declaration_id, signature_index);
1424 Some(signature_index)
1425}
1426
1427#[cfg(test)]
1428mod tests {
1429 use std::{
1430 path::{Path, PathBuf},
1431 sync::Arc,
1432 };
1433
1434 use client::UserStore;
1435 use clock::FakeSystemClock;
1436 use cloud_llm_client::predict_edits_v3::{self, Point};
1437 use edit_prediction_context::Line;
1438 use futures::{
1439 AsyncReadExt, StreamExt,
1440 channel::{mpsc, oneshot},
1441 };
1442 use gpui::{
1443 Entity, TestAppContext,
1444 http_client::{FakeHttpClient, Response},
1445 prelude::*,
1446 };
1447 use indoc::indoc;
1448 use language::{LanguageServerId, OffsetRangeExt as _};
1449 use pretty_assertions::{assert_eq, assert_matches};
1450 use project::{FakeFs, Project};
1451 use serde_json::json;
1452 use settings::SettingsStore;
1453 use util::path;
1454 use uuid::Uuid;
1455
1456 use crate::{BufferEditPrediction, Zeta};
1457
1458 #[gpui::test]
1459 async fn test_current_state(cx: &mut TestAppContext) {
1460 let (zeta, mut req_rx) = init_test(cx);
1461 let fs = FakeFs::new(cx.executor());
1462 fs.insert_tree(
1463 "/root",
1464 json!({
1465 "1.txt": "Hello!\nHow\nBye",
1466 "2.txt": "Hola!\nComo\nAdios"
1467 }),
1468 )
1469 .await;
1470 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1471
1472 zeta.update(cx, |zeta, cx| {
1473 zeta.register_project(&project, cx);
1474 });
1475
1476 let buffer1 = project
1477 .update(cx, |project, cx| {
1478 let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
1479 project.open_buffer(path, cx)
1480 })
1481 .await
1482 .unwrap();
1483 let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
1484 let position = snapshot1.anchor_before(language::Point::new(1, 3));
1485
1486 // Prediction for current file
1487
1488 let prediction_task = zeta.update(cx, |zeta, cx| {
1489 zeta.refresh_prediction(&project, &buffer1, position, cx)
1490 });
1491 let (_request, respond_tx) = req_rx.next().await.unwrap();
1492 respond_tx
1493 .send(predict_edits_v3::PredictEditsResponse {
1494 request_id: Uuid::new_v4(),
1495 edits: vec![predict_edits_v3::Edit {
1496 path: Path::new(path!("root/1.txt")).into(),
1497 range: Line(0)..Line(snapshot1.max_point().row + 1),
1498 content: "Hello!\nHow are you?\nBye".into(),
1499 }],
1500 debug_info: None,
1501 })
1502 .unwrap();
1503 prediction_task.await.unwrap();
1504
1505 zeta.read_with(cx, |zeta, cx| {
1506 let prediction = zeta
1507 .current_prediction_for_buffer(&buffer1, &project, cx)
1508 .unwrap();
1509 assert_matches!(prediction, BufferEditPrediction::Local { .. });
1510 });
1511
1512 // Prediction for another file
1513 let prediction_task = zeta.update(cx, |zeta, cx| {
1514 zeta.refresh_prediction(&project, &buffer1, position, cx)
1515 });
1516 let (_request, respond_tx) = req_rx.next().await.unwrap();
1517 respond_tx
1518 .send(predict_edits_v3::PredictEditsResponse {
1519 request_id: Uuid::new_v4(),
1520 edits: vec![predict_edits_v3::Edit {
1521 path: Path::new(path!("root/2.txt")).into(),
1522 range: Line(0)..Line(snapshot1.max_point().row + 1),
1523 content: "Hola!\nComo estas?\nAdios".into(),
1524 }],
1525 debug_info: None,
1526 })
1527 .unwrap();
1528 prediction_task.await.unwrap();
1529 zeta.read_with(cx, |zeta, cx| {
1530 let prediction = zeta
1531 .current_prediction_for_buffer(&buffer1, &project, cx)
1532 .unwrap();
1533 assert_matches!(
1534 prediction,
1535 BufferEditPrediction::Jump { prediction } if prediction.path.as_ref() == Path::new(path!("root/2.txt"))
1536 );
1537 });
1538
1539 let buffer2 = project
1540 .update(cx, |project, cx| {
1541 let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
1542 project.open_buffer(path, cx)
1543 })
1544 .await
1545 .unwrap();
1546
1547 zeta.read_with(cx, |zeta, cx| {
1548 let prediction = zeta
1549 .current_prediction_for_buffer(&buffer2, &project, cx)
1550 .unwrap();
1551 assert_matches!(prediction, BufferEditPrediction::Local { .. });
1552 });
1553 }
1554
1555 #[gpui::test]
1556 async fn test_simple_request(cx: &mut TestAppContext) {
1557 let (zeta, mut req_rx) = init_test(cx);
1558 let fs = FakeFs::new(cx.executor());
1559 fs.insert_tree(
1560 "/root",
1561 json!({
1562 "foo.md": "Hello!\nHow\nBye"
1563 }),
1564 )
1565 .await;
1566 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1567
1568 let buffer = project
1569 .update(cx, |project, cx| {
1570 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1571 project.open_buffer(path, cx)
1572 })
1573 .await
1574 .unwrap();
1575 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1576 let position = snapshot.anchor_before(language::Point::new(1, 3));
1577
1578 let prediction_task = zeta.update(cx, |zeta, cx| {
1579 zeta.request_prediction(&project, &buffer, position, cx)
1580 });
1581
1582 let (request, respond_tx) = req_rx.next().await.unwrap();
1583 assert_eq!(
1584 request.excerpt_path.as_ref(),
1585 Path::new(path!("root/foo.md"))
1586 );
1587 assert_eq!(
1588 request.cursor_point,
1589 Point {
1590 line: Line(1),
1591 column: 3
1592 }
1593 );
1594
1595 respond_tx
1596 .send(predict_edits_v3::PredictEditsResponse {
1597 request_id: Uuid::new_v4(),
1598 edits: vec![predict_edits_v3::Edit {
1599 path: Path::new(path!("root/foo.md")).into(),
1600 range: Line(0)..Line(snapshot.max_point().row + 1),
1601 content: "Hello!\nHow are you?\nBye".into(),
1602 }],
1603 debug_info: None,
1604 })
1605 .unwrap();
1606
1607 let prediction = prediction_task.await.unwrap().unwrap();
1608
1609 assert_eq!(prediction.edits.len(), 1);
1610 assert_eq!(
1611 prediction.edits[0].0.to_point(&snapshot).start,
1612 language::Point::new(1, 3)
1613 );
1614 assert_eq!(prediction.edits[0].1, " are you?");
1615 }
1616
1617 #[gpui::test]
1618 async fn test_request_events(cx: &mut TestAppContext) {
1619 let (zeta, mut req_rx) = init_test(cx);
1620 let fs = FakeFs::new(cx.executor());
1621 fs.insert_tree(
1622 "/root",
1623 json!({
1624 "foo.md": "Hello!\n\nBye"
1625 }),
1626 )
1627 .await;
1628 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1629
1630 let buffer = project
1631 .update(cx, |project, cx| {
1632 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1633 project.open_buffer(path, cx)
1634 })
1635 .await
1636 .unwrap();
1637
1638 zeta.update(cx, |zeta, cx| {
1639 zeta.register_buffer(&buffer, &project, cx);
1640 });
1641
1642 buffer.update(cx, |buffer, cx| {
1643 buffer.edit(vec![(7..7, "How")], None, cx);
1644 });
1645
1646 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1647 let position = snapshot.anchor_before(language::Point::new(1, 3));
1648
1649 let prediction_task = zeta.update(cx, |zeta, cx| {
1650 zeta.request_prediction(&project, &buffer, position, cx)
1651 });
1652
1653 let (request, respond_tx) = req_rx.next().await.unwrap();
1654
1655 assert_eq!(request.events.len(), 1);
1656 assert_eq!(
1657 request.events[0],
1658 predict_edits_v3::Event::BufferChange {
1659 path: Some(PathBuf::from(path!("root/foo.md"))),
1660 old_path: None,
1661 diff: indoc! {"
1662 @@ -1,3 +1,3 @@
1663 Hello!
1664 -
1665 +How
1666 Bye
1667 "}
1668 .to_string(),
1669 predicted: false
1670 }
1671 );
1672
1673 respond_tx
1674 .send(predict_edits_v3::PredictEditsResponse {
1675 request_id: Uuid::new_v4(),
1676 edits: vec![predict_edits_v3::Edit {
1677 path: Path::new(path!("root/foo.md")).into(),
1678 range: Line(0)..Line(snapshot.max_point().row + 1),
1679 content: "Hello!\nHow are you?\nBye".into(),
1680 }],
1681 debug_info: None,
1682 })
1683 .unwrap();
1684
1685 let prediction = prediction_task.await.unwrap().unwrap();
1686
1687 assert_eq!(prediction.edits.len(), 1);
1688 assert_eq!(
1689 prediction.edits[0].0.to_point(&snapshot).start,
1690 language::Point::new(1, 3)
1691 );
1692 assert_eq!(prediction.edits[0].1, " are you?");
1693 }
1694
1695 #[gpui::test]
1696 async fn test_request_diagnostics(cx: &mut TestAppContext) {
1697 let (zeta, mut req_rx) = init_test(cx);
1698 let fs = FakeFs::new(cx.executor());
1699 fs.insert_tree(
1700 "/root",
1701 json!({
1702 "foo.md": "Hello!\nBye"
1703 }),
1704 )
1705 .await;
1706 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1707
1708 let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1709 let diagnostic = lsp::Diagnostic {
1710 range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1711 severity: Some(lsp::DiagnosticSeverity::ERROR),
1712 message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1713 ..Default::default()
1714 };
1715
1716 project.update(cx, |project, cx| {
1717 project.lsp_store().update(cx, |lsp_store, cx| {
1718 // Create some diagnostics
1719 lsp_store
1720 .update_diagnostics(
1721 LanguageServerId(0),
1722 lsp::PublishDiagnosticsParams {
1723 uri: path_to_buffer_uri.clone(),
1724 diagnostics: vec![diagnostic],
1725 version: None,
1726 },
1727 None,
1728 language::DiagnosticSourceKind::Pushed,
1729 &[],
1730 cx,
1731 )
1732 .unwrap();
1733 });
1734 });
1735
1736 let buffer = project
1737 .update(cx, |project, cx| {
1738 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1739 project.open_buffer(path, cx)
1740 })
1741 .await
1742 .unwrap();
1743
1744 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1745 let position = snapshot.anchor_before(language::Point::new(0, 0));
1746
1747 let _prediction_task = zeta.update(cx, |zeta, cx| {
1748 zeta.request_prediction(&project, &buffer, position, cx)
1749 });
1750
1751 let (request, _respond_tx) = req_rx.next().await.unwrap();
1752
1753 assert_eq!(request.diagnostic_groups.len(), 1);
1754 let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1755 .unwrap();
1756 // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1757 assert_eq!(
1758 value,
1759 json!({
1760 "entries": [{
1761 "range": {
1762 "start": 8,
1763 "end": 10
1764 },
1765 "diagnostic": {
1766 "source": null,
1767 "code": null,
1768 "code_description": null,
1769 "severity": 1,
1770 "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1771 "markdown": null,
1772 "group_id": 0,
1773 "is_primary": true,
1774 "is_disk_based": false,
1775 "is_unnecessary": false,
1776 "source_kind": "Pushed",
1777 "data": null,
1778 "underline": true
1779 }
1780 }],
1781 "primary_ix": 0
1782 })
1783 );
1784 }
1785
1786 fn init_test(
1787 cx: &mut TestAppContext,
1788 ) -> (
1789 Entity<Zeta>,
1790 mpsc::UnboundedReceiver<(
1791 predict_edits_v3::PredictEditsRequest,
1792 oneshot::Sender<predict_edits_v3::PredictEditsResponse>,
1793 )>,
1794 ) {
1795 cx.update(move |cx| {
1796 let settings_store = SettingsStore::test(cx);
1797 cx.set_global(settings_store);
1798 language::init(cx);
1799 Project::init_settings(cx);
1800
1801 let (req_tx, req_rx) = mpsc::unbounded();
1802
1803 let http_client = FakeHttpClient::create({
1804 move |req| {
1805 let uri = req.uri().path().to_string();
1806 let mut body = req.into_body();
1807 let req_tx = req_tx.clone();
1808 async move {
1809 let resp = match uri.as_str() {
1810 "/client/llm_tokens" => serde_json::to_string(&json!({
1811 "token": "test"
1812 }))
1813 .unwrap(),
1814 "/predict_edits/v3" => {
1815 let mut buf = Vec::new();
1816 body.read_to_end(&mut buf).await.ok();
1817 let req = serde_json::from_slice(&buf).unwrap();
1818
1819 let (res_tx, res_rx) = oneshot::channel();
1820 req_tx.unbounded_send((req, res_tx)).unwrap();
1821 serde_json::to_string(&res_rx.await?).unwrap()
1822 }
1823 _ => {
1824 panic!("Unexpected path: {}", uri)
1825 }
1826 };
1827
1828 Ok(Response::builder().body(resp.into()).unwrap())
1829 }
1830 }
1831 });
1832
1833 let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1834 client.cloud_client().set_credentials(1, "test".into());
1835
1836 language_model::init(client.clone(), cx);
1837
1838 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1839 let zeta = Zeta::global(&client, &user_store, cx);
1840
1841 (zeta, req_rx)
1842 })
1843 }
1844}