1use anyhow::{Context as _, Result, anyhow};
2use chrono::TimeDelta;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
5use cloud_llm_client::{
6 AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
7 ZED_VERSION_HEADER_NAME,
8};
9use cloud_zeta2_prompt::{DEFAULT_MAX_PROMPT_BYTES, build_prompt};
10use collections::HashMap;
11use edit_prediction_context::{
12 DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
13 EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionScoreOptions, Line,
14 SyntaxIndex, SyntaxIndexState,
15};
16use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
17use futures::AsyncReadExt as _;
18use futures::channel::{mpsc, oneshot};
19use gpui::http_client::{AsyncBody, Method};
20use gpui::{
21 App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
22 http_client, prelude::*,
23};
24use language::{Anchor, Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
25use language::{BufferSnapshot, OffsetRangeExt};
26use language_model::{LlmApiToken, RefreshLlmTokenListener};
27use project::Project;
28use release_channel::AppVersion;
29use serde::de::DeserializeOwned;
30use std::collections::{VecDeque, hash_map};
31use std::ops::Range;
32use std::path::Path;
33use std::str::FromStr as _;
34use std::sync::Arc;
35use std::time::{Duration, Instant};
36use thiserror::Error;
37use util::ResultExt as _;
38use util::rel_path::RelPathBuf;
39use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
40
41mod merge_excerpts;
42mod prediction;
43mod provider;
44mod related_excerpts;
45
46use crate::merge_excerpts::merge_excerpts;
47use crate::prediction::EditPrediction;
48use crate::related_excerpts::find_related_excerpts;
49pub use crate::related_excerpts::{LlmContextOptions, SearchToolQuery};
50pub use provider::ZetaEditPredictionProvider;
51
52const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
53
54/// Maximum number of events to track.
55const MAX_EVENT_COUNT: usize = 16;
56
57pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
58 max_bytes: 512,
59 min_bytes: 128,
60 target_before_cursor_over_total_bytes: 0.5,
61};
62
63pub const DEFAULT_CONTEXT_OPTIONS: ContextMode = ContextMode::Llm(DEFAULT_LLM_CONTEXT_OPTIONS);
64
65pub const DEFAULT_LLM_CONTEXT_OPTIONS: LlmContextOptions = LlmContextOptions {
66 excerpt: DEFAULT_EXCERPT_OPTIONS,
67};
68
69pub const DEFAULT_SYNTAX_CONTEXT_OPTIONS: EditPredictionContextOptions =
70 EditPredictionContextOptions {
71 use_imports: true,
72 max_retrieved_declarations: 0,
73 excerpt: DEFAULT_EXCERPT_OPTIONS,
74 score: EditPredictionScoreOptions {
75 omit_excerpt_overlaps: true,
76 },
77 };
78
79pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
80 context: DEFAULT_CONTEXT_OPTIONS,
81 max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
82 max_diagnostic_bytes: 2048,
83 prompt_format: PromptFormat::DEFAULT,
84 file_indexing_parallelism: 1,
85};
86
87pub struct Zeta2FeatureFlag;
88
89impl FeatureFlag for Zeta2FeatureFlag {
90 const NAME: &'static str = "zeta2";
91
92 fn enabled_for_staff() -> bool {
93 false
94 }
95}
96
97#[derive(Clone)]
98struct ZetaGlobal(Entity<Zeta>);
99
100impl Global for ZetaGlobal {}
101
102pub struct Zeta {
103 client: Arc<Client>,
104 user_store: Entity<UserStore>,
105 llm_token: LlmApiToken,
106 _llm_token_subscription: Subscription,
107 projects: HashMap<EntityId, ZetaProject>,
108 options: ZetaOptions,
109 update_required: bool,
110 debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
111}
112
113#[derive(Debug, Clone, PartialEq)]
114pub struct ZetaOptions {
115 pub context: ContextMode,
116 pub max_prompt_bytes: usize,
117 pub max_diagnostic_bytes: usize,
118 pub prompt_format: predict_edits_v3::PromptFormat,
119 pub file_indexing_parallelism: usize,
120}
121
122#[derive(Debug, Clone, PartialEq)]
123pub enum ContextMode {
124 Llm(LlmContextOptions),
125 Syntax(EditPredictionContextOptions),
126}
127
128impl ContextMode {
129 pub fn excerpt(&self) -> &EditPredictionExcerptOptions {
130 match self {
131 ContextMode::Llm(options) => &options.excerpt,
132 ContextMode::Syntax(options) => &options.excerpt,
133 }
134 }
135}
136
137pub enum ZetaDebugInfo {
138 ContextRetrievalStarted(ZetaContextRetrievalDebugInfo),
139 SearchQueriesGenerated(ZetaSearchQueryDebugInfo),
140 SearchQueriesExecuted(ZetaContextRetrievalDebugInfo),
141 ContextRetrievalFinished(ZetaContextRetrievalDebugInfo),
142 EditPredicted(ZetaEditPredictionDebugInfo),
143}
144
145pub struct ZetaContextRetrievalDebugInfo {
146 pub project: Entity<Project>,
147 pub timestamp: Instant,
148}
149
150pub struct ZetaEditPredictionDebugInfo {
151 pub request: predict_edits_v3::PredictEditsRequest,
152 pub retrieval_time: TimeDelta,
153 pub buffer: WeakEntity<Buffer>,
154 pub position: language::Anchor,
155 pub local_prompt: Result<String, String>,
156 pub response_rx: oneshot::Receiver<Result<predict_edits_v3::PredictEditsResponse, String>>,
157}
158
159pub struct ZetaSearchQueryDebugInfo {
160 pub project: Entity<Project>,
161 pub timestamp: Instant,
162 pub queries: Vec<SearchToolQuery>,
163}
164
165pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
166
167struct ZetaProject {
168 syntax_index: Entity<SyntaxIndex>,
169 events: VecDeque<Event>,
170 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
171 current_prediction: Option<CurrentEditPrediction>,
172 context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
173 refresh_context_task: Option<Task<Option<()>>>,
174 refresh_context_debounce_task: Option<Task<Option<()>>>,
175 refresh_context_timestamp: Option<Instant>,
176}
177
178#[derive(Debug, Clone)]
179struct CurrentEditPrediction {
180 pub requested_by_buffer_id: EntityId,
181 pub prediction: EditPrediction,
182}
183
184impl CurrentEditPrediction {
185 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
186 let Some(new_edits) = self
187 .prediction
188 .interpolate(&self.prediction.buffer.read(cx))
189 else {
190 return false;
191 };
192
193 if self.prediction.buffer != old_prediction.prediction.buffer {
194 return true;
195 }
196
197 let Some(old_edits) = old_prediction
198 .prediction
199 .interpolate(&old_prediction.prediction.buffer.read(cx))
200 else {
201 return true;
202 };
203
204 // This reduces the occurrence of UI thrash from replacing edits
205 //
206 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
207 if self.requested_by_buffer_id == self.prediction.buffer.entity_id()
208 && self.requested_by_buffer_id == old_prediction.prediction.buffer.entity_id()
209 && old_edits.len() == 1
210 && new_edits.len() == 1
211 {
212 let (old_range, old_text) = &old_edits[0];
213 let (new_range, new_text) = &new_edits[0];
214 new_range == old_range && new_text.starts_with(old_text)
215 } else {
216 true
217 }
218 }
219}
220
221/// A prediction from the perspective of a buffer.
222#[derive(Debug)]
223enum BufferEditPrediction<'a> {
224 Local { prediction: &'a EditPrediction },
225 Jump { prediction: &'a EditPrediction },
226}
227
228struct RegisteredBuffer {
229 snapshot: BufferSnapshot,
230 _subscriptions: [gpui::Subscription; 2],
231}
232
233#[derive(Clone)]
234pub enum Event {
235 BufferChange {
236 old_snapshot: BufferSnapshot,
237 new_snapshot: BufferSnapshot,
238 timestamp: Instant,
239 },
240}
241
242impl Event {
243 pub fn to_request_event(&self, cx: &App) -> Option<predict_edits_v3::Event> {
244 match self {
245 Event::BufferChange {
246 old_snapshot,
247 new_snapshot,
248 ..
249 } => {
250 let path = new_snapshot.file().map(|f| f.full_path(cx));
251
252 let old_path = old_snapshot.file().and_then(|f| {
253 let old_path = f.full_path(cx);
254 if Some(&old_path) != path.as_ref() {
255 Some(old_path)
256 } else {
257 None
258 }
259 });
260
261 // TODO [zeta2] move to bg?
262 let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
263
264 if path == old_path && diff.is_empty() {
265 None
266 } else {
267 Some(predict_edits_v3::Event::BufferChange {
268 old_path,
269 path,
270 diff,
271 //todo: Actually detect if this edit was predicted or not
272 predicted: false,
273 })
274 }
275 }
276 }
277 }
278}
279
280impl Zeta {
281 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
282 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
283 }
284
285 pub fn global(
286 client: &Arc<Client>,
287 user_store: &Entity<UserStore>,
288 cx: &mut App,
289 ) -> Entity<Self> {
290 cx.try_global::<ZetaGlobal>()
291 .map(|global| global.0.clone())
292 .unwrap_or_else(|| {
293 let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
294 cx.set_global(ZetaGlobal(zeta.clone()));
295 zeta
296 })
297 }
298
299 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
300 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
301
302 Self {
303 projects: HashMap::default(),
304 client,
305 user_store,
306 options: DEFAULT_OPTIONS,
307 llm_token: LlmApiToken::default(),
308 _llm_token_subscription: cx.subscribe(
309 &refresh_llm_token_listener,
310 |this, _listener, _event, cx| {
311 let client = this.client.clone();
312 let llm_token = this.llm_token.clone();
313 cx.spawn(async move |_this, _cx| {
314 llm_token.refresh(&client).await?;
315 anyhow::Ok(())
316 })
317 .detach_and_log_err(cx);
318 },
319 ),
320 update_required: false,
321 debug_tx: None,
322 }
323 }
324
325 pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<ZetaDebugInfo> {
326 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
327 self.debug_tx = Some(debug_watch_tx);
328 debug_watch_rx
329 }
330
331 pub fn options(&self) -> &ZetaOptions {
332 &self.options
333 }
334
335 pub fn set_options(&mut self, options: ZetaOptions) {
336 self.options = options;
337 }
338
339 pub fn clear_history(&mut self) {
340 for zeta_project in self.projects.values_mut() {
341 zeta_project.events.clear();
342 }
343 }
344
345 pub fn history_for_project(&self, project: &Entity<Project>) -> impl Iterator<Item = &Event> {
346 self.projects
347 .get(&project.entity_id())
348 .map(|project| project.events.iter())
349 .into_iter()
350 .flatten()
351 }
352
353 pub fn context_for_project(
354 &self,
355 project: &Entity<Project>,
356 ) -> impl Iterator<Item = (Entity<Buffer>, &[Range<Anchor>])> {
357 self.projects
358 .get(&project.entity_id())
359 .and_then(|project| {
360 Some(
361 project
362 .context
363 .as_ref()?
364 .iter()
365 .map(|(buffer, ranges)| (buffer.clone(), ranges.as_slice())),
366 )
367 })
368 .into_iter()
369 .flatten()
370 }
371
372 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
373 self.user_store.read(cx).edit_prediction_usage()
374 }
375
376 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
377 self.get_or_init_zeta_project(project, cx);
378 }
379
380 pub fn register_buffer(
381 &mut self,
382 buffer: &Entity<Buffer>,
383 project: &Entity<Project>,
384 cx: &mut Context<Self>,
385 ) {
386 let zeta_project = self.get_or_init_zeta_project(project, cx);
387 Self::register_buffer_impl(zeta_project, buffer, project, cx);
388 }
389
390 fn get_or_init_zeta_project(
391 &mut self,
392 project: &Entity<Project>,
393 cx: &mut App,
394 ) -> &mut ZetaProject {
395 self.projects
396 .entry(project.entity_id())
397 .or_insert_with(|| ZetaProject {
398 syntax_index: cx.new(|cx| {
399 SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
400 }),
401 events: VecDeque::new(),
402 registered_buffers: HashMap::default(),
403 current_prediction: None,
404 context: None,
405 refresh_context_task: None,
406 refresh_context_debounce_task: None,
407 refresh_context_timestamp: None,
408 })
409 }
410
411 fn register_buffer_impl<'a>(
412 zeta_project: &'a mut ZetaProject,
413 buffer: &Entity<Buffer>,
414 project: &Entity<Project>,
415 cx: &mut Context<Self>,
416 ) -> &'a mut RegisteredBuffer {
417 let buffer_id = buffer.entity_id();
418 match zeta_project.registered_buffers.entry(buffer_id) {
419 hash_map::Entry::Occupied(entry) => entry.into_mut(),
420 hash_map::Entry::Vacant(entry) => {
421 let snapshot = buffer.read(cx).snapshot();
422 let project_entity_id = project.entity_id();
423 entry.insert(RegisteredBuffer {
424 snapshot,
425 _subscriptions: [
426 cx.subscribe(buffer, {
427 let project = project.downgrade();
428 move |this, buffer, event, cx| {
429 if let language::BufferEvent::Edited = event
430 && let Some(project) = project.upgrade()
431 {
432 this.report_changes_for_buffer(&buffer, &project, cx);
433 }
434 }
435 }),
436 cx.observe_release(buffer, move |this, _buffer, _cx| {
437 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
438 else {
439 return;
440 };
441 zeta_project.registered_buffers.remove(&buffer_id);
442 }),
443 ],
444 })
445 }
446 }
447 }
448
449 fn report_changes_for_buffer(
450 &mut self,
451 buffer: &Entity<Buffer>,
452 project: &Entity<Project>,
453 cx: &mut Context<Self>,
454 ) -> BufferSnapshot {
455 let zeta_project = self.get_or_init_zeta_project(project, cx);
456 let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
457
458 let new_snapshot = buffer.read(cx).snapshot();
459 if new_snapshot.version != registered_buffer.snapshot.version {
460 let old_snapshot =
461 std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
462 Self::push_event(
463 zeta_project,
464 Event::BufferChange {
465 old_snapshot,
466 new_snapshot: new_snapshot.clone(),
467 timestamp: Instant::now(),
468 },
469 );
470 }
471
472 new_snapshot
473 }
474
475 fn push_event(zeta_project: &mut ZetaProject, event: Event) {
476 let events = &mut zeta_project.events;
477
478 if let Some(Event::BufferChange {
479 new_snapshot: last_new_snapshot,
480 timestamp: last_timestamp,
481 ..
482 }) = events.back_mut()
483 {
484 // Coalesce edits for the same buffer when they happen one after the other.
485 let Event::BufferChange {
486 old_snapshot,
487 new_snapshot,
488 timestamp,
489 } = &event;
490
491 if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
492 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
493 && old_snapshot.version == last_new_snapshot.version
494 {
495 *last_new_snapshot = new_snapshot.clone();
496 *last_timestamp = *timestamp;
497 return;
498 }
499 }
500
501 if events.len() >= MAX_EVENT_COUNT {
502 // These are halved instead of popping to improve prompt caching.
503 events.drain(..MAX_EVENT_COUNT / 2);
504 }
505
506 events.push_back(event);
507 }
508
509 fn current_prediction_for_buffer(
510 &self,
511 buffer: &Entity<Buffer>,
512 project: &Entity<Project>,
513 cx: &App,
514 ) -> Option<BufferEditPrediction<'_>> {
515 let project_state = self.projects.get(&project.entity_id())?;
516
517 let CurrentEditPrediction {
518 requested_by_buffer_id,
519 prediction,
520 } = project_state.current_prediction.as_ref()?;
521
522 if prediction.targets_buffer(buffer.read(cx), cx) {
523 Some(BufferEditPrediction::Local { prediction })
524 } else if *requested_by_buffer_id == buffer.entity_id() {
525 Some(BufferEditPrediction::Jump { prediction })
526 } else {
527 None
528 }
529 }
530
531 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
532 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
533 return;
534 };
535
536 let Some(prediction) = project_state.current_prediction.take() else {
537 return;
538 };
539 let request_id = prediction.prediction.id.into();
540
541 let client = self.client.clone();
542 let llm_token = self.llm_token.clone();
543 let app_version = AppVersion::global(cx);
544 cx.spawn(async move |this, cx| {
545 let url = if let Ok(predict_edits_url) = std::env::var("ZED_ACCEPT_PREDICTION_URL") {
546 http_client::Url::parse(&predict_edits_url)?
547 } else {
548 client
549 .http_client()
550 .build_zed_llm_url("/predict_edits/accept", &[])?
551 };
552
553 let response = cx
554 .background_spawn(Self::send_api_request::<()>(
555 move |builder| {
556 let req = builder.uri(url.as_ref()).body(
557 serde_json::to_string(&AcceptEditPredictionBody { request_id })?.into(),
558 );
559 Ok(req?)
560 },
561 client,
562 llm_token,
563 app_version,
564 ))
565 .await;
566
567 Self::handle_api_response(&this, response, cx)?;
568 anyhow::Ok(())
569 })
570 .detach_and_log_err(cx);
571 }
572
573 fn discard_current_prediction(&mut self, project: &Entity<Project>) {
574 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
575 project_state.current_prediction.take();
576 };
577 }
578
579 pub fn refresh_prediction(
580 &mut self,
581 project: &Entity<Project>,
582 buffer: &Entity<Buffer>,
583 position: language::Anchor,
584 cx: &mut Context<Self>,
585 ) -> Task<Result<()>> {
586 let request_task = self.request_prediction(project, buffer, position, cx);
587 let buffer = buffer.clone();
588 let project = project.clone();
589
590 cx.spawn(async move |this, cx| {
591 if let Some(prediction) = request_task.await? {
592 this.update(cx, |this, cx| {
593 let project_state = this
594 .projects
595 .get_mut(&project.entity_id())
596 .context("Project not found")?;
597
598 let new_prediction = CurrentEditPrediction {
599 requested_by_buffer_id: buffer.entity_id(),
600 prediction: prediction,
601 };
602
603 if project_state
604 .current_prediction
605 .as_ref()
606 .is_none_or(|old_prediction| {
607 new_prediction.should_replace_prediction(&old_prediction, cx)
608 })
609 {
610 project_state.current_prediction = Some(new_prediction);
611 }
612 anyhow::Ok(())
613 })??;
614 }
615 Ok(())
616 })
617 }
618
619 fn request_prediction(
620 &mut self,
621 project: &Entity<Project>,
622 buffer: &Entity<Buffer>,
623 position: language::Anchor,
624 cx: &mut Context<Self>,
625 ) -> Task<Result<Option<EditPrediction>>> {
626 let project_state = self.projects.get(&project.entity_id());
627
628 let index_state = project_state.map(|state| {
629 state
630 .syntax_index
631 .read_with(cx, |index, _cx| index.state().clone())
632 });
633 let options = self.options.clone();
634 let snapshot = buffer.read(cx).snapshot();
635 let Some(excerpt_path) = snapshot
636 .file()
637 .map(|path| -> Arc<Path> { path.full_path(cx).into() })
638 else {
639 return Task::ready(Err(anyhow!("No file path for excerpt")));
640 };
641 let client = self.client.clone();
642 let llm_token = self.llm_token.clone();
643 let app_version = AppVersion::global(cx);
644 let worktree_snapshots = project
645 .read(cx)
646 .worktrees(cx)
647 .map(|worktree| worktree.read(cx).snapshot())
648 .collect::<Vec<_>>();
649 let debug_tx = self.debug_tx.clone();
650
651 let events = project_state
652 .map(|state| {
653 state
654 .events
655 .iter()
656 .filter_map(|event| event.to_request_event(cx))
657 .collect::<Vec<_>>()
658 })
659 .unwrap_or_default();
660
661 let diagnostics = snapshot.diagnostic_sets().clone();
662
663 let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
664 let mut path = f.worktree.read(cx).absolutize(&f.path);
665 if path.pop() { Some(path) } else { None }
666 });
667
668 // TODO data collection
669 let can_collect_data = cx.is_staff();
670
671 let mut included_files = project_state
672 .and_then(|project_state| project_state.context.as_ref())
673 .unwrap_or(&HashMap::default())
674 .iter()
675 .filter_map(|(buffer, ranges)| {
676 let buffer = buffer.read(cx);
677 Some((
678 buffer.snapshot(),
679 buffer.file()?.full_path(cx).into(),
680 ranges.clone(),
681 ))
682 })
683 .collect::<Vec<_>>();
684
685 let request_task = cx.background_spawn({
686 let snapshot = snapshot.clone();
687 let buffer = buffer.clone();
688 async move {
689 let index_state = if let Some(index_state) = index_state {
690 Some(index_state.lock_owned().await)
691 } else {
692 None
693 };
694
695 let cursor_offset = position.to_offset(&snapshot);
696 let cursor_point = cursor_offset.to_point(&snapshot);
697
698 let before_retrieval = chrono::Utc::now();
699
700 let (diagnostic_groups, diagnostic_groups_truncated) =
701 Self::gather_nearby_diagnostics(
702 cursor_offset,
703 &diagnostics,
704 &snapshot,
705 options.max_diagnostic_bytes,
706 );
707
708 let request = match options.context {
709 ContextMode::Llm(context_options) => {
710 let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
711 cursor_point,
712 &snapshot,
713 &context_options.excerpt,
714 index_state.as_deref(),
715 ) else {
716 return Ok((None, None));
717 };
718
719 let excerpt_anchor_range = snapshot.anchor_after(excerpt.range.start)
720 ..snapshot.anchor_before(excerpt.range.end);
721
722 if let Some(buffer_ix) = included_files
723 .iter()
724 .position(|(buffer, _, _)| buffer.remote_id() == snapshot.remote_id())
725 {
726 let (buffer, _, ranges) = &mut included_files[buffer_ix];
727 let range_ix = ranges
728 .binary_search_by(|probe| {
729 probe
730 .start
731 .cmp(&excerpt_anchor_range.start, buffer)
732 .then(excerpt_anchor_range.end.cmp(&probe.end, buffer))
733 })
734 .unwrap_or_else(|ix| ix);
735
736 ranges.insert(range_ix, excerpt_anchor_range);
737 let last_ix = included_files.len() - 1;
738 included_files.swap(buffer_ix, last_ix);
739 } else {
740 included_files.push((
741 snapshot,
742 excerpt_path.clone(),
743 vec![excerpt_anchor_range],
744 ));
745 }
746
747 let included_files = included_files
748 .into_iter()
749 .map(|(buffer, path, ranges)| {
750 let excerpts = merge_excerpts(
751 &buffer,
752 ranges.iter().map(|range| {
753 let point_range = range.to_point(&buffer);
754 Line(point_range.start.row)..Line(point_range.end.row)
755 }),
756 );
757 predict_edits_v3::IncludedFile {
758 path,
759 max_row: Line(buffer.max_point().row),
760 excerpts,
761 }
762 })
763 .collect::<Vec<_>>();
764
765 predict_edits_v3::PredictEditsRequest {
766 excerpt_path,
767 excerpt: String::new(),
768 excerpt_line_range: Line(0)..Line(0),
769 excerpt_range: 0..0,
770 cursor_point: predict_edits_v3::Point {
771 line: predict_edits_v3::Line(cursor_point.row),
772 column: cursor_point.column,
773 },
774 included_files,
775 referenced_declarations: vec![],
776 events,
777 can_collect_data,
778 diagnostic_groups,
779 diagnostic_groups_truncated,
780 debug_info: debug_tx.is_some(),
781 prompt_max_bytes: Some(options.max_prompt_bytes),
782 prompt_format: options.prompt_format,
783 // TODO [zeta2]
784 signatures: vec![],
785 excerpt_parent: None,
786 git_info: None,
787 }
788 }
789 ContextMode::Syntax(context_options) => {
790 let Some(context) = EditPredictionContext::gather_context(
791 cursor_point,
792 &snapshot,
793 parent_abs_path.as_deref(),
794 &context_options,
795 index_state.as_deref(),
796 ) else {
797 return Ok((None, None));
798 };
799
800 make_syntax_context_cloud_request(
801 excerpt_path,
802 context,
803 events,
804 can_collect_data,
805 diagnostic_groups,
806 diagnostic_groups_truncated,
807 None,
808 debug_tx.is_some(),
809 &worktree_snapshots,
810 index_state.as_deref(),
811 Some(options.max_prompt_bytes),
812 options.prompt_format,
813 )
814 }
815 };
816
817 let retrieval_time = chrono::Utc::now() - before_retrieval;
818
819 let debug_response_tx = if let Some(debug_tx) = &debug_tx {
820 let (response_tx, response_rx) = oneshot::channel();
821
822 let local_prompt = build_prompt(&request)
823 .map(|(prompt, _)| prompt)
824 .map_err(|err| err.to_string());
825
826 debug_tx
827 .unbounded_send(ZetaDebugInfo::EditPredicted(ZetaEditPredictionDebugInfo {
828 request: request.clone(),
829 retrieval_time,
830 buffer: buffer.downgrade(),
831 local_prompt,
832 position,
833 response_rx,
834 }))
835 .ok();
836 Some(response_tx)
837 } else {
838 None
839 };
840
841 if cfg!(debug_assertions) && std::env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
842 if let Some(debug_response_tx) = debug_response_tx {
843 debug_response_tx
844 .send(Err("Request skipped".to_string()))
845 .ok();
846 }
847 anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
848 }
849
850 let response =
851 Self::send_prediction_request(client, llm_token, app_version, request).await;
852
853 if let Some(debug_response_tx) = debug_response_tx {
854 debug_response_tx
855 .send(
856 response
857 .as_ref()
858 .map_err(|err| err.to_string())
859 .map(|response| response.0.clone()),
860 )
861 .ok();
862 }
863
864 response.map(|(res, usage)| (Some(res), usage))
865 }
866 });
867
868 let buffer = buffer.clone();
869
870 cx.spawn({
871 let project = project.clone();
872 async move |this, cx| {
873 let Some(response) = Self::handle_api_response(&this, request_task.await, cx)?
874 else {
875 return Ok(None);
876 };
877
878 // TODO telemetry: duration, etc
879 Ok(EditPrediction::from_response(response, &snapshot, &buffer, &project, cx).await)
880 }
881 })
882 }
883
884 async fn send_prediction_request(
885 client: Arc<Client>,
886 llm_token: LlmApiToken,
887 app_version: SemanticVersion,
888 request: predict_edits_v3::PredictEditsRequest,
889 ) -> Result<(
890 predict_edits_v3::PredictEditsResponse,
891 Option<EditPredictionUsage>,
892 )> {
893 let url = if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
894 http_client::Url::parse(&predict_edits_url)?
895 } else {
896 client
897 .http_client()
898 .build_zed_llm_url("/predict_edits/v3", &[])?
899 };
900
901 Self::send_api_request(
902 |builder| {
903 let req = builder
904 .uri(url.as_ref())
905 .body(serde_json::to_string(&request)?.into());
906 Ok(req?)
907 },
908 client,
909 llm_token,
910 app_version,
911 )
912 .await
913 }
914
915 fn handle_api_response<T>(
916 this: &WeakEntity<Self>,
917 response: Result<(T, Option<EditPredictionUsage>)>,
918 cx: &mut gpui::AsyncApp,
919 ) -> Result<T> {
920 match response {
921 Ok((data, usage)) => {
922 if let Some(usage) = usage {
923 this.update(cx, |this, cx| {
924 this.user_store.update(cx, |user_store, cx| {
925 user_store.update_edit_prediction_usage(usage, cx);
926 });
927 })
928 .ok();
929 }
930 Ok(data)
931 }
932 Err(err) => {
933 if err.is::<ZedUpdateRequiredError>() {
934 cx.update(|cx| {
935 this.update(cx, |this, _cx| {
936 this.update_required = true;
937 })
938 .ok();
939
940 let error_message: SharedString = err.to_string().into();
941 show_app_notification(
942 NotificationId::unique::<ZedUpdateRequiredError>(),
943 cx,
944 move |cx| {
945 cx.new(|cx| {
946 ErrorMessagePrompt::new(error_message.clone(), cx)
947 .with_link_button("Update Zed", "https://zed.dev/releases")
948 })
949 },
950 );
951 })
952 .ok();
953 }
954 Err(err)
955 }
956 }
957 }
958
959 async fn send_api_request<Res>(
960 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
961 client: Arc<Client>,
962 llm_token: LlmApiToken,
963 app_version: SemanticVersion,
964 ) -> Result<(Res, Option<EditPredictionUsage>)>
965 where
966 Res: DeserializeOwned,
967 {
968 let http_client = client.http_client();
969 let mut token = llm_token.acquire(&client).await?;
970 let mut did_retry = false;
971
972 loop {
973 let request_builder = http_client::Request::builder().method(Method::POST);
974
975 let request = build(
976 request_builder
977 .header("Content-Type", "application/json")
978 .header("Authorization", format!("Bearer {}", token))
979 .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
980 )?;
981
982 let mut response = http_client.send(request).await?;
983
984 if let Some(minimum_required_version) = response
985 .headers()
986 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
987 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
988 {
989 anyhow::ensure!(
990 app_version >= minimum_required_version,
991 ZedUpdateRequiredError {
992 minimum_version: minimum_required_version
993 }
994 );
995 }
996
997 if response.status().is_success() {
998 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
999
1000 let mut body = Vec::new();
1001 response.body_mut().read_to_end(&mut body).await?;
1002 return Ok((serde_json::from_slice(&body)?, usage));
1003 } else if !did_retry
1004 && response
1005 .headers()
1006 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1007 .is_some()
1008 {
1009 did_retry = true;
1010 token = llm_token.refresh(&client).await?;
1011 } else {
1012 let mut body = String::new();
1013 response.body_mut().read_to_string(&mut body).await?;
1014 anyhow::bail!(
1015 "Request failed with status: {:?}\nBody: {}",
1016 response.status(),
1017 body
1018 );
1019 }
1020 }
1021 }
1022
1023 pub const CONTEXT_RETRIEVAL_IDLE_DURATION: Duration = Duration::from_secs(10);
1024 pub const CONTEXT_RETRIEVAL_DEBOUNCE_DURATION: Duration = Duration::from_secs(3);
1025
1026 // Refresh the related excerpts when the user just beguns editing after
1027 // an idle period, and after they pause editing.
1028 fn refresh_context_if_needed(
1029 &mut self,
1030 project: &Entity<Project>,
1031 buffer: &Entity<language::Buffer>,
1032 cursor_position: language::Anchor,
1033 cx: &mut Context<Self>,
1034 ) {
1035 if !matches!(&self.options().context, ContextMode::Llm { .. }) {
1036 return;
1037 }
1038
1039 let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1040 return;
1041 };
1042
1043 let now = Instant::now();
1044 let was_idle = zeta_project
1045 .refresh_context_timestamp
1046 .map_or(true, |timestamp| {
1047 now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION
1048 });
1049 zeta_project.refresh_context_timestamp = Some(now);
1050 zeta_project.refresh_context_debounce_task = Some(cx.spawn({
1051 let buffer = buffer.clone();
1052 let project = project.clone();
1053 async move |this, cx| {
1054 if was_idle {
1055 log::debug!("refetching edit prediction context after idle");
1056 } else {
1057 cx.background_executor()
1058 .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION)
1059 .await;
1060 log::debug!("refetching edit prediction context after pause");
1061 }
1062 this.update(cx, |this, cx| {
1063 this.refresh_context(project, buffer, cursor_position, cx);
1064 })
1065 .ok()
1066 }
1067 }));
1068 }
1069
1070 // Refresh the related excerpts asynchronously. Ensure the task runs to completion,
1071 // and avoid spawning more than one concurrent task.
1072 fn refresh_context(
1073 &mut self,
1074 project: Entity<Project>,
1075 buffer: Entity<language::Buffer>,
1076 cursor_position: language::Anchor,
1077 cx: &mut Context<Self>,
1078 ) {
1079 let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1080 return;
1081 };
1082
1083 let debug_tx = self.debug_tx.clone();
1084
1085 zeta_project
1086 .refresh_context_task
1087 .get_or_insert(cx.spawn(async move |this, cx| {
1088 if let Some(debug_tx) = &debug_tx {
1089 debug_tx
1090 .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted(
1091 ZetaContextRetrievalDebugInfo {
1092 project: project.clone(),
1093 timestamp: Instant::now(),
1094 },
1095 ))
1096 .ok();
1097 }
1098
1099 let related_excerpts = this
1100 .update(cx, |this, cx| {
1101 let Some(zeta_project) = this.projects.get(&project.entity_id()) else {
1102 return Task::ready(anyhow::Ok(HashMap::default()));
1103 };
1104
1105 let ContextMode::Llm(options) = &this.options().context else {
1106 return Task::ready(anyhow::Ok(HashMap::default()));
1107 };
1108
1109 find_related_excerpts(
1110 buffer.clone(),
1111 cursor_position,
1112 &project,
1113 zeta_project.events.iter(),
1114 options,
1115 debug_tx,
1116 cx,
1117 )
1118 })
1119 .ok()?
1120 .await
1121 .log_err()
1122 .unwrap_or_default();
1123 this.update(cx, |this, _cx| {
1124 let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
1125 return;
1126 };
1127 zeta_project.context = Some(related_excerpts);
1128 zeta_project.refresh_context_task.take();
1129 if let Some(debug_tx) = &this.debug_tx {
1130 debug_tx
1131 .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished(
1132 ZetaContextRetrievalDebugInfo {
1133 project,
1134 timestamp: Instant::now(),
1135 },
1136 ))
1137 .ok();
1138 }
1139 })
1140 .ok()
1141 }));
1142 }
1143
1144 fn gather_nearby_diagnostics(
1145 cursor_offset: usize,
1146 diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
1147 snapshot: &BufferSnapshot,
1148 max_diagnostics_bytes: usize,
1149 ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
1150 // TODO: Could make this more efficient
1151 let mut diagnostic_groups = Vec::new();
1152 for (language_server_id, diagnostics) in diagnostic_sets {
1153 let mut groups = Vec::new();
1154 diagnostics.groups(*language_server_id, &mut groups, &snapshot);
1155 diagnostic_groups.extend(
1156 groups
1157 .into_iter()
1158 .map(|(_, group)| group.resolve::<usize>(&snapshot)),
1159 );
1160 }
1161
1162 // sort by proximity to cursor
1163 diagnostic_groups.sort_by_key(|group| {
1164 let range = &group.entries[group.primary_ix].range;
1165 if range.start >= cursor_offset {
1166 range.start - cursor_offset
1167 } else if cursor_offset >= range.end {
1168 cursor_offset - range.end
1169 } else {
1170 (cursor_offset - range.start).min(range.end - cursor_offset)
1171 }
1172 });
1173
1174 let mut results = Vec::new();
1175 let mut diagnostic_groups_truncated = false;
1176 let mut diagnostics_byte_count = 0;
1177 for group in diagnostic_groups {
1178 let raw_value = serde_json::value::to_raw_value(&group).unwrap();
1179 diagnostics_byte_count += raw_value.get().len();
1180 if diagnostics_byte_count > max_diagnostics_bytes {
1181 diagnostic_groups_truncated = true;
1182 break;
1183 }
1184 results.push(predict_edits_v3::DiagnosticGroup(raw_value));
1185 }
1186
1187 (results, diagnostic_groups_truncated)
1188 }
1189
1190 // TODO: Dedupe with similar code in request_prediction?
1191 pub fn cloud_request_for_zeta_cli(
1192 &mut self,
1193 project: &Entity<Project>,
1194 buffer: &Entity<Buffer>,
1195 position: language::Anchor,
1196 cx: &mut Context<Self>,
1197 ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
1198 let project_state = self.projects.get(&project.entity_id());
1199
1200 let index_state = project_state.map(|state| {
1201 state
1202 .syntax_index
1203 .read_with(cx, |index, _cx| index.state().clone())
1204 });
1205 let options = self.options.clone();
1206 let snapshot = buffer.read(cx).snapshot();
1207 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
1208 return Task::ready(Err(anyhow!("No file path for excerpt")));
1209 };
1210 let worktree_snapshots = project
1211 .read(cx)
1212 .worktrees(cx)
1213 .map(|worktree| worktree.read(cx).snapshot())
1214 .collect::<Vec<_>>();
1215
1216 let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
1217 let mut path = f.worktree.read(cx).absolutize(&f.path);
1218 if path.pop() { Some(path) } else { None }
1219 });
1220
1221 cx.background_spawn(async move {
1222 let index_state = if let Some(index_state) = index_state {
1223 Some(index_state.lock_owned().await)
1224 } else {
1225 None
1226 };
1227
1228 let cursor_point = position.to_point(&snapshot);
1229
1230 let debug_info = true;
1231 EditPredictionContext::gather_context(
1232 cursor_point,
1233 &snapshot,
1234 parent_abs_path.as_deref(),
1235 match &options.context {
1236 ContextMode::Llm(_) => {
1237 // TODO
1238 panic!("Llm mode not supported in zeta cli yet");
1239 }
1240 ContextMode::Syntax(edit_prediction_context_options) => {
1241 edit_prediction_context_options
1242 }
1243 },
1244 index_state.as_deref(),
1245 )
1246 .context("Failed to select excerpt")
1247 .map(|context| {
1248 make_syntax_context_cloud_request(
1249 excerpt_path.into(),
1250 context,
1251 // TODO pass everything
1252 Vec::new(),
1253 false,
1254 Vec::new(),
1255 false,
1256 None,
1257 debug_info,
1258 &worktree_snapshots,
1259 index_state.as_deref(),
1260 Some(options.max_prompt_bytes),
1261 options.prompt_format,
1262 )
1263 })
1264 })
1265 }
1266
1267 pub fn wait_for_initial_indexing(
1268 &mut self,
1269 project: &Entity<Project>,
1270 cx: &mut App,
1271 ) -> Task<Result<()>> {
1272 let zeta_project = self.get_or_init_zeta_project(project, cx);
1273 zeta_project
1274 .syntax_index
1275 .read(cx)
1276 .wait_for_initial_file_indexing(cx)
1277 }
1278}
1279
1280#[derive(Error, Debug)]
1281#[error(
1282 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1283)]
1284pub struct ZedUpdateRequiredError {
1285 minimum_version: SemanticVersion,
1286}
1287
1288fn make_syntax_context_cloud_request(
1289 excerpt_path: Arc<Path>,
1290 context: EditPredictionContext,
1291 events: Vec<predict_edits_v3::Event>,
1292 can_collect_data: bool,
1293 diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
1294 diagnostic_groups_truncated: bool,
1295 git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
1296 debug_info: bool,
1297 worktrees: &Vec<worktree::Snapshot>,
1298 index_state: Option<&SyntaxIndexState>,
1299 prompt_max_bytes: Option<usize>,
1300 prompt_format: PromptFormat,
1301) -> predict_edits_v3::PredictEditsRequest {
1302 let mut signatures = Vec::new();
1303 let mut declaration_to_signature_index = HashMap::default();
1304 let mut referenced_declarations = Vec::new();
1305
1306 for snippet in context.declarations {
1307 let project_entry_id = snippet.declaration.project_entry_id();
1308 let Some(path) = worktrees.iter().find_map(|worktree| {
1309 worktree.entry_for_id(project_entry_id).map(|entry| {
1310 let mut full_path = RelPathBuf::new();
1311 full_path.push(worktree.root_name());
1312 full_path.push(&entry.path);
1313 full_path
1314 })
1315 }) else {
1316 continue;
1317 };
1318
1319 let parent_index = index_state.and_then(|index_state| {
1320 snippet.declaration.parent().and_then(|parent| {
1321 add_signature(
1322 parent,
1323 &mut declaration_to_signature_index,
1324 &mut signatures,
1325 index_state,
1326 )
1327 })
1328 });
1329
1330 let (text, text_is_truncated) = snippet.declaration.item_text();
1331 referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
1332 path: path.as_std_path().into(),
1333 text: text.into(),
1334 range: snippet.declaration.item_line_range(),
1335 text_is_truncated,
1336 signature_range: snippet.declaration.signature_range_in_item_text(),
1337 parent_index,
1338 signature_score: snippet.score(DeclarationStyle::Signature),
1339 declaration_score: snippet.score(DeclarationStyle::Declaration),
1340 score_components: snippet.components,
1341 });
1342 }
1343
1344 let excerpt_parent = index_state.and_then(|index_state| {
1345 context
1346 .excerpt
1347 .parent_declarations
1348 .last()
1349 .and_then(|(parent, _)| {
1350 add_signature(
1351 *parent,
1352 &mut declaration_to_signature_index,
1353 &mut signatures,
1354 index_state,
1355 )
1356 })
1357 });
1358
1359 predict_edits_v3::PredictEditsRequest {
1360 excerpt_path,
1361 excerpt: context.excerpt_text.body,
1362 excerpt_line_range: context.excerpt.line_range,
1363 excerpt_range: context.excerpt.range,
1364 cursor_point: predict_edits_v3::Point {
1365 line: predict_edits_v3::Line(context.cursor_point.row),
1366 column: context.cursor_point.column,
1367 },
1368 referenced_declarations,
1369 included_files: vec![],
1370 signatures,
1371 excerpt_parent,
1372 events,
1373 can_collect_data,
1374 diagnostic_groups,
1375 diagnostic_groups_truncated,
1376 git_info,
1377 debug_info,
1378 prompt_max_bytes,
1379 prompt_format,
1380 }
1381}
1382
1383fn add_signature(
1384 declaration_id: DeclarationId,
1385 declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
1386 signatures: &mut Vec<Signature>,
1387 index: &SyntaxIndexState,
1388) -> Option<usize> {
1389 if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
1390 return Some(*signature_index);
1391 }
1392 let Some(parent_declaration) = index.declaration(declaration_id) else {
1393 log::error!("bug: missing parent declaration");
1394 return None;
1395 };
1396 let parent_index = parent_declaration.parent().and_then(|parent| {
1397 add_signature(parent, declaration_to_signature_index, signatures, index)
1398 });
1399 let (text, text_is_truncated) = parent_declaration.signature_text();
1400 let signature_index = signatures.len();
1401 signatures.push(Signature {
1402 text: text.into(),
1403 text_is_truncated,
1404 parent_index,
1405 range: parent_declaration.signature_line_range(),
1406 });
1407 declaration_to_signature_index.insert(declaration_id, signature_index);
1408 Some(signature_index)
1409}
1410
1411#[cfg(test)]
1412mod tests {
1413 use std::{
1414 path::{Path, PathBuf},
1415 sync::Arc,
1416 };
1417
1418 use client::UserStore;
1419 use clock::FakeSystemClock;
1420 use cloud_llm_client::predict_edits_v3::{self, Point};
1421 use edit_prediction_context::Line;
1422 use futures::{
1423 AsyncReadExt, StreamExt,
1424 channel::{mpsc, oneshot},
1425 };
1426 use gpui::{
1427 Entity, TestAppContext,
1428 http_client::{FakeHttpClient, Response},
1429 prelude::*,
1430 };
1431 use indoc::indoc;
1432 use language::{LanguageServerId, OffsetRangeExt as _};
1433 use pretty_assertions::{assert_eq, assert_matches};
1434 use project::{FakeFs, Project};
1435 use serde_json::json;
1436 use settings::SettingsStore;
1437 use util::path;
1438 use uuid::Uuid;
1439
1440 use crate::{BufferEditPrediction, Zeta};
1441
1442 #[gpui::test]
1443 async fn test_current_state(cx: &mut TestAppContext) {
1444 let (zeta, mut req_rx) = init_test(cx);
1445 let fs = FakeFs::new(cx.executor());
1446 fs.insert_tree(
1447 "/root",
1448 json!({
1449 "1.txt": "Hello!\nHow\nBye",
1450 "2.txt": "Hola!\nComo\nAdios"
1451 }),
1452 )
1453 .await;
1454 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1455
1456 zeta.update(cx, |zeta, cx| {
1457 zeta.register_project(&project, cx);
1458 });
1459
1460 let buffer1 = project
1461 .update(cx, |project, cx| {
1462 let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
1463 project.open_buffer(path, cx)
1464 })
1465 .await
1466 .unwrap();
1467 let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
1468 let position = snapshot1.anchor_before(language::Point::new(1, 3));
1469
1470 // Prediction for current file
1471
1472 let prediction_task = zeta.update(cx, |zeta, cx| {
1473 zeta.refresh_prediction(&project, &buffer1, position, cx)
1474 });
1475 let (_request, respond_tx) = req_rx.next().await.unwrap();
1476 respond_tx
1477 .send(predict_edits_v3::PredictEditsResponse {
1478 request_id: Uuid::new_v4(),
1479 edits: vec![predict_edits_v3::Edit {
1480 path: Path::new(path!("root/1.txt")).into(),
1481 range: Line(0)..Line(snapshot1.max_point().row + 1),
1482 content: "Hello!\nHow are you?\nBye".into(),
1483 }],
1484 debug_info: None,
1485 })
1486 .unwrap();
1487 prediction_task.await.unwrap();
1488
1489 zeta.read_with(cx, |zeta, cx| {
1490 let prediction = zeta
1491 .current_prediction_for_buffer(&buffer1, &project, cx)
1492 .unwrap();
1493 assert_matches!(prediction, BufferEditPrediction::Local { .. });
1494 });
1495
1496 // Prediction for another file
1497 let prediction_task = zeta.update(cx, |zeta, cx| {
1498 zeta.refresh_prediction(&project, &buffer1, position, cx)
1499 });
1500 let (_request, respond_tx) = req_rx.next().await.unwrap();
1501 respond_tx
1502 .send(predict_edits_v3::PredictEditsResponse {
1503 request_id: Uuid::new_v4(),
1504 edits: vec![predict_edits_v3::Edit {
1505 path: Path::new(path!("root/2.txt")).into(),
1506 range: Line(0)..Line(snapshot1.max_point().row + 1),
1507 content: "Hola!\nComo estas?\nAdios".into(),
1508 }],
1509 debug_info: None,
1510 })
1511 .unwrap();
1512 prediction_task.await.unwrap();
1513 zeta.read_with(cx, |zeta, cx| {
1514 let prediction = zeta
1515 .current_prediction_for_buffer(&buffer1, &project, cx)
1516 .unwrap();
1517 assert_matches!(
1518 prediction,
1519 BufferEditPrediction::Jump { prediction } if prediction.path.as_ref() == Path::new(path!("root/2.txt"))
1520 );
1521 });
1522
1523 let buffer2 = project
1524 .update(cx, |project, cx| {
1525 let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
1526 project.open_buffer(path, cx)
1527 })
1528 .await
1529 .unwrap();
1530
1531 zeta.read_with(cx, |zeta, cx| {
1532 let prediction = zeta
1533 .current_prediction_for_buffer(&buffer2, &project, cx)
1534 .unwrap();
1535 assert_matches!(prediction, BufferEditPrediction::Local { .. });
1536 });
1537 }
1538
1539 #[gpui::test]
1540 async fn test_simple_request(cx: &mut TestAppContext) {
1541 let (zeta, mut req_rx) = init_test(cx);
1542 let fs = FakeFs::new(cx.executor());
1543 fs.insert_tree(
1544 "/root",
1545 json!({
1546 "foo.md": "Hello!\nHow\nBye"
1547 }),
1548 )
1549 .await;
1550 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1551
1552 let buffer = project
1553 .update(cx, |project, cx| {
1554 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1555 project.open_buffer(path, cx)
1556 })
1557 .await
1558 .unwrap();
1559 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1560 let position = snapshot.anchor_before(language::Point::new(1, 3));
1561
1562 let prediction_task = zeta.update(cx, |zeta, cx| {
1563 zeta.request_prediction(&project, &buffer, position, cx)
1564 });
1565
1566 let (request, respond_tx) = req_rx.next().await.unwrap();
1567 assert_eq!(
1568 request.excerpt_path.as_ref(),
1569 Path::new(path!("root/foo.md"))
1570 );
1571 assert_eq!(
1572 request.cursor_point,
1573 Point {
1574 line: Line(1),
1575 column: 3
1576 }
1577 );
1578
1579 respond_tx
1580 .send(predict_edits_v3::PredictEditsResponse {
1581 request_id: Uuid::new_v4(),
1582 edits: vec![predict_edits_v3::Edit {
1583 path: Path::new(path!("root/foo.md")).into(),
1584 range: Line(0)..Line(snapshot.max_point().row + 1),
1585 content: "Hello!\nHow are you?\nBye".into(),
1586 }],
1587 debug_info: None,
1588 })
1589 .unwrap();
1590
1591 let prediction = prediction_task.await.unwrap().unwrap();
1592
1593 assert_eq!(prediction.edits.len(), 1);
1594 assert_eq!(
1595 prediction.edits[0].0.to_point(&snapshot).start,
1596 language::Point::new(1, 3)
1597 );
1598 assert_eq!(prediction.edits[0].1, " are you?");
1599 }
1600
1601 #[gpui::test]
1602 async fn test_request_events(cx: &mut TestAppContext) {
1603 let (zeta, mut req_rx) = init_test(cx);
1604 let fs = FakeFs::new(cx.executor());
1605 fs.insert_tree(
1606 "/root",
1607 json!({
1608 "foo.md": "Hello!\n\nBye"
1609 }),
1610 )
1611 .await;
1612 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1613
1614 let buffer = project
1615 .update(cx, |project, cx| {
1616 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1617 project.open_buffer(path, cx)
1618 })
1619 .await
1620 .unwrap();
1621
1622 zeta.update(cx, |zeta, cx| {
1623 zeta.register_buffer(&buffer, &project, cx);
1624 });
1625
1626 buffer.update(cx, |buffer, cx| {
1627 buffer.edit(vec![(7..7, "How")], None, cx);
1628 });
1629
1630 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1631 let position = snapshot.anchor_before(language::Point::new(1, 3));
1632
1633 let prediction_task = zeta.update(cx, |zeta, cx| {
1634 zeta.request_prediction(&project, &buffer, position, cx)
1635 });
1636
1637 let (request, respond_tx) = req_rx.next().await.unwrap();
1638
1639 assert_eq!(request.events.len(), 1);
1640 assert_eq!(
1641 request.events[0],
1642 predict_edits_v3::Event::BufferChange {
1643 path: Some(PathBuf::from(path!("root/foo.md"))),
1644 old_path: None,
1645 diff: indoc! {"
1646 @@ -1,3 +1,3 @@
1647 Hello!
1648 -
1649 +How
1650 Bye
1651 "}
1652 .to_string(),
1653 predicted: false
1654 }
1655 );
1656
1657 respond_tx
1658 .send(predict_edits_v3::PredictEditsResponse {
1659 request_id: Uuid::new_v4(),
1660 edits: vec![predict_edits_v3::Edit {
1661 path: Path::new(path!("root/foo.md")).into(),
1662 range: Line(0)..Line(snapshot.max_point().row + 1),
1663 content: "Hello!\nHow are you?\nBye".into(),
1664 }],
1665 debug_info: None,
1666 })
1667 .unwrap();
1668
1669 let prediction = prediction_task.await.unwrap().unwrap();
1670
1671 assert_eq!(prediction.edits.len(), 1);
1672 assert_eq!(
1673 prediction.edits[0].0.to_point(&snapshot).start,
1674 language::Point::new(1, 3)
1675 );
1676 assert_eq!(prediction.edits[0].1, " are you?");
1677 }
1678
1679 #[gpui::test]
1680 async fn test_request_diagnostics(cx: &mut TestAppContext) {
1681 let (zeta, mut req_rx) = init_test(cx);
1682 let fs = FakeFs::new(cx.executor());
1683 fs.insert_tree(
1684 "/root",
1685 json!({
1686 "foo.md": "Hello!\nBye"
1687 }),
1688 )
1689 .await;
1690 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1691
1692 let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1693 let diagnostic = lsp::Diagnostic {
1694 range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1695 severity: Some(lsp::DiagnosticSeverity::ERROR),
1696 message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1697 ..Default::default()
1698 };
1699
1700 project.update(cx, |project, cx| {
1701 project.lsp_store().update(cx, |lsp_store, cx| {
1702 // Create some diagnostics
1703 lsp_store
1704 .update_diagnostics(
1705 LanguageServerId(0),
1706 lsp::PublishDiagnosticsParams {
1707 uri: path_to_buffer_uri.clone(),
1708 diagnostics: vec![diagnostic],
1709 version: None,
1710 },
1711 None,
1712 language::DiagnosticSourceKind::Pushed,
1713 &[],
1714 cx,
1715 )
1716 .unwrap();
1717 });
1718 });
1719
1720 let buffer = project
1721 .update(cx, |project, cx| {
1722 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1723 project.open_buffer(path, cx)
1724 })
1725 .await
1726 .unwrap();
1727
1728 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1729 let position = snapshot.anchor_before(language::Point::new(0, 0));
1730
1731 let _prediction_task = zeta.update(cx, |zeta, cx| {
1732 zeta.request_prediction(&project, &buffer, position, cx)
1733 });
1734
1735 let (request, _respond_tx) = req_rx.next().await.unwrap();
1736
1737 assert_eq!(request.diagnostic_groups.len(), 1);
1738 let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1739 .unwrap();
1740 // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1741 assert_eq!(
1742 value,
1743 json!({
1744 "entries": [{
1745 "range": {
1746 "start": 8,
1747 "end": 10
1748 },
1749 "diagnostic": {
1750 "source": null,
1751 "code": null,
1752 "code_description": null,
1753 "severity": 1,
1754 "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1755 "markdown": null,
1756 "group_id": 0,
1757 "is_primary": true,
1758 "is_disk_based": false,
1759 "is_unnecessary": false,
1760 "source_kind": "Pushed",
1761 "data": null,
1762 "underline": true
1763 }
1764 }],
1765 "primary_ix": 0
1766 })
1767 );
1768 }
1769
1770 fn init_test(
1771 cx: &mut TestAppContext,
1772 ) -> (
1773 Entity<Zeta>,
1774 mpsc::UnboundedReceiver<(
1775 predict_edits_v3::PredictEditsRequest,
1776 oneshot::Sender<predict_edits_v3::PredictEditsResponse>,
1777 )>,
1778 ) {
1779 cx.update(move |cx| {
1780 let settings_store = SettingsStore::test(cx);
1781 cx.set_global(settings_store);
1782 language::init(cx);
1783 Project::init_settings(cx);
1784
1785 let (req_tx, req_rx) = mpsc::unbounded();
1786
1787 let http_client = FakeHttpClient::create({
1788 move |req| {
1789 let uri = req.uri().path().to_string();
1790 let mut body = req.into_body();
1791 let req_tx = req_tx.clone();
1792 async move {
1793 let resp = match uri.as_str() {
1794 "/client/llm_tokens" => serde_json::to_string(&json!({
1795 "token": "test"
1796 }))
1797 .unwrap(),
1798 "/predict_edits/v3" => {
1799 let mut buf = Vec::new();
1800 body.read_to_end(&mut buf).await.ok();
1801 let req = serde_json::from_slice(&buf).unwrap();
1802
1803 let (res_tx, res_rx) = oneshot::channel();
1804 req_tx.unbounded_send((req, res_tx)).unwrap();
1805 serde_json::to_string(&res_rx.await?).unwrap()
1806 }
1807 _ => {
1808 panic!("Unexpected path: {}", uri)
1809 }
1810 };
1811
1812 Ok(Response::builder().body(resp.into()).unwrap())
1813 }
1814 }
1815 });
1816
1817 let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1818 client.cloud_client().set_credentials(1, "test".into());
1819
1820 language_model::init(client.clone(), cx);
1821
1822 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1823 let zeta = Zeta::global(&client, &user_store, cx);
1824
1825 (zeta, req_rx)
1826 })
1827 }
1828}