1use anyhow::{Context as _, Result, anyhow};
2use chrono::TimeDelta;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
5use cloud_llm_client::{
6 AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
7 ZED_VERSION_HEADER_NAME,
8};
9use cloud_zeta2_prompt::{DEFAULT_MAX_PROMPT_BYTES, build_prompt};
10use collections::HashMap;
11use edit_prediction_context::{
12 DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
13 EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionScoreOptions, Line,
14 SyntaxIndex, SyntaxIndexState,
15};
16use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
17use futures::AsyncReadExt as _;
18use futures::channel::{mpsc, oneshot};
19use gpui::http_client::{AsyncBody, Method};
20use gpui::{
21 App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
22 http_client, prelude::*,
23};
24use language::{Anchor, Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
25use language::{BufferSnapshot, OffsetRangeExt};
26use language_model::{LlmApiToken, RefreshLlmTokenListener};
27use project::Project;
28use release_channel::AppVersion;
29use serde::de::DeserializeOwned;
30use std::collections::{VecDeque, hash_map};
31use std::ops::Range;
32use std::path::Path;
33use std::str::FromStr as _;
34use std::sync::Arc;
35use std::time::{Duration, Instant};
36use thiserror::Error;
37use util::ResultExt as _;
38use util::rel_path::RelPathBuf;
39use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
40
41mod merge_excerpts;
42mod prediction;
43mod provider;
44mod related_excerpts;
45
46use crate::merge_excerpts::merge_excerpts;
47use crate::prediction::EditPrediction;
48use crate::related_excerpts::find_related_excerpts;
49pub use crate::related_excerpts::{LlmContextOptions, SearchToolQuery};
50pub use provider::ZetaEditPredictionProvider;
51
52const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
53
54/// Maximum number of events to track.
55const MAX_EVENT_COUNT: usize = 16;
56
57pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
58 max_bytes: 512,
59 min_bytes: 128,
60 target_before_cursor_over_total_bytes: 0.5,
61};
62
63pub const DEFAULT_CONTEXT_OPTIONS: ContextMode = ContextMode::Llm(DEFAULT_LLM_CONTEXT_OPTIONS);
64
65pub const DEFAULT_LLM_CONTEXT_OPTIONS: LlmContextOptions = LlmContextOptions {
66 excerpt: DEFAULT_EXCERPT_OPTIONS,
67};
68
69pub const DEFAULT_SYNTAX_CONTEXT_OPTIONS: EditPredictionContextOptions =
70 EditPredictionContextOptions {
71 use_imports: true,
72 max_retrieved_declarations: 0,
73 excerpt: DEFAULT_EXCERPT_OPTIONS,
74 score: EditPredictionScoreOptions {
75 omit_excerpt_overlaps: true,
76 },
77 };
78
79pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
80 context: DEFAULT_CONTEXT_OPTIONS,
81 max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
82 max_diagnostic_bytes: 2048,
83 prompt_format: PromptFormat::DEFAULT,
84 file_indexing_parallelism: 1,
85};
86
87pub struct Zeta2FeatureFlag;
88
89impl FeatureFlag for Zeta2FeatureFlag {
90 const NAME: &'static str = "zeta2";
91
92 fn enabled_for_staff() -> bool {
93 false
94 }
95}
96
97#[derive(Clone)]
98struct ZetaGlobal(Entity<Zeta>);
99
100impl Global for ZetaGlobal {}
101
102pub struct Zeta {
103 client: Arc<Client>,
104 user_store: Entity<UserStore>,
105 llm_token: LlmApiToken,
106 _llm_token_subscription: Subscription,
107 projects: HashMap<EntityId, ZetaProject>,
108 options: ZetaOptions,
109 update_required: bool,
110 debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
111}
112
113#[derive(Debug, Clone, PartialEq)]
114pub struct ZetaOptions {
115 pub context: ContextMode,
116 pub max_prompt_bytes: usize,
117 pub max_diagnostic_bytes: usize,
118 pub prompt_format: predict_edits_v3::PromptFormat,
119 pub file_indexing_parallelism: usize,
120}
121
122#[derive(Debug, Clone, PartialEq)]
123pub enum ContextMode {
124 Llm(LlmContextOptions),
125 Syntax(EditPredictionContextOptions),
126}
127
128impl ContextMode {
129 pub fn excerpt(&self) -> &EditPredictionExcerptOptions {
130 match self {
131 ContextMode::Llm(options) => &options.excerpt,
132 ContextMode::Syntax(options) => &options.excerpt,
133 }
134 }
135}
136
137pub enum ZetaDebugInfo {
138 ContextRetrievalStarted(ZetaContextRetrievalDebugInfo),
139 SearchQueriesGenerated(ZetaSearchQueryDebugInfo),
140 SearchQueriesExecuted(ZetaContextRetrievalDebugInfo),
141 ContextRetrievalFinished(ZetaContextRetrievalDebugInfo),
142 EditPredicted(ZetaEditPredictionDebugInfo),
143}
144
145pub struct ZetaContextRetrievalDebugInfo {
146 pub project: Entity<Project>,
147 pub timestamp: Instant,
148}
149
150pub struct ZetaEditPredictionDebugInfo {
151 pub request: predict_edits_v3::PredictEditsRequest,
152 pub retrieval_time: TimeDelta,
153 pub buffer: WeakEntity<Buffer>,
154 pub position: language::Anchor,
155 pub local_prompt: Result<String, String>,
156 pub response_rx: oneshot::Receiver<Result<predict_edits_v3::PredictEditsResponse, String>>,
157}
158
159pub struct ZetaSearchQueryDebugInfo {
160 pub project: Entity<Project>,
161 pub timestamp: Instant,
162 pub queries: Vec<SearchToolQuery>,
163}
164
165pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
166
167struct ZetaProject {
168 syntax_index: Entity<SyntaxIndex>,
169 events: VecDeque<Event>,
170 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
171 current_prediction: Option<CurrentEditPrediction>,
172 context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
173 refresh_context_task: Option<Task<Option<()>>>,
174 refresh_context_debounce_task: Option<Task<Option<()>>>,
175 refresh_context_timestamp: Option<Instant>,
176}
177
178#[derive(Debug, Clone)]
179struct CurrentEditPrediction {
180 pub requested_by_buffer_id: EntityId,
181 pub prediction: EditPrediction,
182}
183
184impl CurrentEditPrediction {
185 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
186 let Some(new_edits) = self
187 .prediction
188 .interpolate(&self.prediction.buffer.read(cx))
189 else {
190 return false;
191 };
192
193 if self.prediction.buffer != old_prediction.prediction.buffer {
194 return true;
195 }
196
197 let Some(old_edits) = old_prediction
198 .prediction
199 .interpolate(&old_prediction.prediction.buffer.read(cx))
200 else {
201 return true;
202 };
203
204 // This reduces the occurrence of UI thrash from replacing edits
205 //
206 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
207 if self.requested_by_buffer_id == self.prediction.buffer.entity_id()
208 && self.requested_by_buffer_id == old_prediction.prediction.buffer.entity_id()
209 && old_edits.len() == 1
210 && new_edits.len() == 1
211 {
212 let (old_range, old_text) = &old_edits[0];
213 let (new_range, new_text) = &new_edits[0];
214 new_range == old_range && new_text.starts_with(old_text)
215 } else {
216 true
217 }
218 }
219}
220
221/// A prediction from the perspective of a buffer.
222#[derive(Debug)]
223enum BufferEditPrediction<'a> {
224 Local { prediction: &'a EditPrediction },
225 Jump { prediction: &'a EditPrediction },
226}
227
228struct RegisteredBuffer {
229 snapshot: BufferSnapshot,
230 _subscriptions: [gpui::Subscription; 2],
231}
232
233#[derive(Clone)]
234pub enum Event {
235 BufferChange {
236 old_snapshot: BufferSnapshot,
237 new_snapshot: BufferSnapshot,
238 timestamp: Instant,
239 },
240}
241
242impl Event {
243 pub fn to_request_event(&self, cx: &App) -> Option<predict_edits_v3::Event> {
244 match self {
245 Event::BufferChange {
246 old_snapshot,
247 new_snapshot,
248 ..
249 } => {
250 let path = new_snapshot.file().map(|f| f.full_path(cx));
251
252 let old_path = old_snapshot.file().and_then(|f| {
253 let old_path = f.full_path(cx);
254 if Some(&old_path) != path.as_ref() {
255 Some(old_path)
256 } else {
257 None
258 }
259 });
260
261 // TODO [zeta2] move to bg?
262 let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
263
264 if path == old_path && diff.is_empty() {
265 None
266 } else {
267 Some(predict_edits_v3::Event::BufferChange {
268 old_path,
269 path,
270 diff,
271 //todo: Actually detect if this edit was predicted or not
272 predicted: false,
273 })
274 }
275 }
276 }
277 }
278}
279
280impl Zeta {
281 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
282 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
283 }
284
285 pub fn global(
286 client: &Arc<Client>,
287 user_store: &Entity<UserStore>,
288 cx: &mut App,
289 ) -> Entity<Self> {
290 cx.try_global::<ZetaGlobal>()
291 .map(|global| global.0.clone())
292 .unwrap_or_else(|| {
293 let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
294 cx.set_global(ZetaGlobal(zeta.clone()));
295 zeta
296 })
297 }
298
299 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
300 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
301
302 Self {
303 projects: HashMap::default(),
304 client,
305 user_store,
306 options: DEFAULT_OPTIONS,
307 llm_token: LlmApiToken::default(),
308 _llm_token_subscription: cx.subscribe(
309 &refresh_llm_token_listener,
310 |this, _listener, _event, cx| {
311 let client = this.client.clone();
312 let llm_token = this.llm_token.clone();
313 cx.spawn(async move |_this, _cx| {
314 llm_token.refresh(&client).await?;
315 anyhow::Ok(())
316 })
317 .detach_and_log_err(cx);
318 },
319 ),
320 update_required: false,
321 debug_tx: None,
322 }
323 }
324
325 pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<ZetaDebugInfo> {
326 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
327 self.debug_tx = Some(debug_watch_tx);
328 debug_watch_rx
329 }
330
331 pub fn options(&self) -> &ZetaOptions {
332 &self.options
333 }
334
335 pub fn set_options(&mut self, options: ZetaOptions) {
336 self.options = options;
337 }
338
339 pub fn clear_history(&mut self) {
340 for zeta_project in self.projects.values_mut() {
341 zeta_project.events.clear();
342 }
343 }
344
345 pub fn history_for_project(&self, project: &Entity<Project>) -> impl Iterator<Item = &Event> {
346 self.projects
347 .get(&project.entity_id())
348 .map(|project| project.events.iter())
349 .into_iter()
350 .flatten()
351 }
352
353 pub fn context_for_project(
354 &self,
355 project: &Entity<Project>,
356 ) -> impl Iterator<Item = (Entity<Buffer>, &[Range<Anchor>])> {
357 self.projects
358 .get(&project.entity_id())
359 .and_then(|project| {
360 Some(
361 project
362 .context
363 .as_ref()?
364 .iter()
365 .map(|(buffer, ranges)| (buffer.clone(), ranges.as_slice())),
366 )
367 })
368 .into_iter()
369 .flatten()
370 }
371
372 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
373 self.user_store.read(cx).edit_prediction_usage()
374 }
375
376 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
377 self.get_or_init_zeta_project(project, cx);
378 }
379
380 pub fn register_buffer(
381 &mut self,
382 buffer: &Entity<Buffer>,
383 project: &Entity<Project>,
384 cx: &mut Context<Self>,
385 ) {
386 let zeta_project = self.get_or_init_zeta_project(project, cx);
387 Self::register_buffer_impl(zeta_project, buffer, project, cx);
388 }
389
390 fn get_or_init_zeta_project(
391 &mut self,
392 project: &Entity<Project>,
393 cx: &mut App,
394 ) -> &mut ZetaProject {
395 self.projects
396 .entry(project.entity_id())
397 .or_insert_with(|| ZetaProject {
398 syntax_index: cx.new(|cx| {
399 SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
400 }),
401 events: VecDeque::new(),
402 registered_buffers: HashMap::default(),
403 current_prediction: None,
404 context: None,
405 refresh_context_task: None,
406 refresh_context_debounce_task: None,
407 refresh_context_timestamp: None,
408 })
409 }
410
411 fn register_buffer_impl<'a>(
412 zeta_project: &'a mut ZetaProject,
413 buffer: &Entity<Buffer>,
414 project: &Entity<Project>,
415 cx: &mut Context<Self>,
416 ) -> &'a mut RegisteredBuffer {
417 let buffer_id = buffer.entity_id();
418 match zeta_project.registered_buffers.entry(buffer_id) {
419 hash_map::Entry::Occupied(entry) => entry.into_mut(),
420 hash_map::Entry::Vacant(entry) => {
421 let snapshot = buffer.read(cx).snapshot();
422 let project_entity_id = project.entity_id();
423 entry.insert(RegisteredBuffer {
424 snapshot,
425 _subscriptions: [
426 cx.subscribe(buffer, {
427 let project = project.downgrade();
428 move |this, buffer, event, cx| {
429 if let language::BufferEvent::Edited = event
430 && let Some(project) = project.upgrade()
431 {
432 this.report_changes_for_buffer(&buffer, &project, cx);
433 }
434 }
435 }),
436 cx.observe_release(buffer, move |this, _buffer, _cx| {
437 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
438 else {
439 return;
440 };
441 zeta_project.registered_buffers.remove(&buffer_id);
442 }),
443 ],
444 })
445 }
446 }
447 }
448
449 fn report_changes_for_buffer(
450 &mut self,
451 buffer: &Entity<Buffer>,
452 project: &Entity<Project>,
453 cx: &mut Context<Self>,
454 ) -> BufferSnapshot {
455 let zeta_project = self.get_or_init_zeta_project(project, cx);
456 let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
457
458 let new_snapshot = buffer.read(cx).snapshot();
459 if new_snapshot.version != registered_buffer.snapshot.version {
460 let old_snapshot =
461 std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
462 Self::push_event(
463 zeta_project,
464 Event::BufferChange {
465 old_snapshot,
466 new_snapshot: new_snapshot.clone(),
467 timestamp: Instant::now(),
468 },
469 );
470 }
471
472 new_snapshot
473 }
474
475 fn push_event(zeta_project: &mut ZetaProject, event: Event) {
476 let events = &mut zeta_project.events;
477
478 if let Some(Event::BufferChange {
479 new_snapshot: last_new_snapshot,
480 timestamp: last_timestamp,
481 ..
482 }) = events.back_mut()
483 {
484 // Coalesce edits for the same buffer when they happen one after the other.
485 let Event::BufferChange {
486 old_snapshot,
487 new_snapshot,
488 timestamp,
489 } = &event;
490
491 if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
492 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
493 && old_snapshot.version == last_new_snapshot.version
494 {
495 *last_new_snapshot = new_snapshot.clone();
496 *last_timestamp = *timestamp;
497 return;
498 }
499 }
500
501 if events.len() >= MAX_EVENT_COUNT {
502 // These are halved instead of popping to improve prompt caching.
503 events.drain(..MAX_EVENT_COUNT / 2);
504 }
505
506 events.push_back(event);
507 }
508
509 fn current_prediction_for_buffer(
510 &self,
511 buffer: &Entity<Buffer>,
512 project: &Entity<Project>,
513 cx: &App,
514 ) -> Option<BufferEditPrediction<'_>> {
515 let project_state = self.projects.get(&project.entity_id())?;
516
517 let CurrentEditPrediction {
518 requested_by_buffer_id,
519 prediction,
520 } = project_state.current_prediction.as_ref()?;
521
522 if prediction.targets_buffer(buffer.read(cx), cx) {
523 Some(BufferEditPrediction::Local { prediction })
524 } else if *requested_by_buffer_id == buffer.entity_id() {
525 Some(BufferEditPrediction::Jump { prediction })
526 } else {
527 None
528 }
529 }
530
531 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
532 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
533 return;
534 };
535
536 let Some(prediction) = project_state.current_prediction.take() else {
537 return;
538 };
539 let request_id = prediction.prediction.id.into();
540
541 let client = self.client.clone();
542 let llm_token = self.llm_token.clone();
543 let app_version = AppVersion::global(cx);
544 cx.spawn(async move |this, cx| {
545 let url = if let Ok(predict_edits_url) = std::env::var("ZED_ACCEPT_PREDICTION_URL") {
546 http_client::Url::parse(&predict_edits_url)?
547 } else {
548 client
549 .http_client()
550 .build_zed_llm_url("/predict_edits/accept", &[])?
551 };
552
553 let response = cx
554 .background_spawn(Self::send_api_request::<()>(
555 move |builder| {
556 let req = builder.uri(url.as_ref()).body(
557 serde_json::to_string(&AcceptEditPredictionBody { request_id })?.into(),
558 );
559 Ok(req?)
560 },
561 client,
562 llm_token,
563 app_version,
564 ))
565 .await;
566
567 Self::handle_api_response(&this, response, cx)?;
568 anyhow::Ok(())
569 })
570 .detach_and_log_err(cx);
571 }
572
573 fn discard_current_prediction(&mut self, project: &Entity<Project>) {
574 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
575 project_state.current_prediction.take();
576 };
577 }
578
579 pub fn refresh_prediction(
580 &mut self,
581 project: &Entity<Project>,
582 buffer: &Entity<Buffer>,
583 position: language::Anchor,
584 cx: &mut Context<Self>,
585 ) -> Task<Result<()>> {
586 let request_task = self.request_prediction(project, buffer, position, cx);
587 let buffer = buffer.clone();
588 let project = project.clone();
589
590 cx.spawn(async move |this, cx| {
591 if let Some(prediction) = request_task.await? {
592 this.update(cx, |this, cx| {
593 let project_state = this
594 .projects
595 .get_mut(&project.entity_id())
596 .context("Project not found")?;
597
598 let new_prediction = CurrentEditPrediction {
599 requested_by_buffer_id: buffer.entity_id(),
600 prediction: prediction,
601 };
602
603 if project_state
604 .current_prediction
605 .as_ref()
606 .is_none_or(|old_prediction| {
607 new_prediction.should_replace_prediction(&old_prediction, cx)
608 })
609 {
610 project_state.current_prediction = Some(new_prediction);
611 }
612 anyhow::Ok(())
613 })??;
614 }
615 Ok(())
616 })
617 }
618
619 fn request_prediction(
620 &mut self,
621 project: &Entity<Project>,
622 buffer: &Entity<Buffer>,
623 position: language::Anchor,
624 cx: &mut Context<Self>,
625 ) -> Task<Result<Option<EditPrediction>>> {
626 let project_state = self.projects.get(&project.entity_id());
627
628 let index_state = project_state.map(|state| {
629 state
630 .syntax_index
631 .read_with(cx, |index, _cx| index.state().clone())
632 });
633 let options = self.options.clone();
634 let snapshot = buffer.read(cx).snapshot();
635 let Some(excerpt_path) = snapshot
636 .file()
637 .map(|path| -> Arc<Path> { path.full_path(cx).into() })
638 else {
639 return Task::ready(Err(anyhow!("No file path for excerpt")));
640 };
641 let client = self.client.clone();
642 let llm_token = self.llm_token.clone();
643 let app_version = AppVersion::global(cx);
644 let worktree_snapshots = project
645 .read(cx)
646 .worktrees(cx)
647 .map(|worktree| worktree.read(cx).snapshot())
648 .collect::<Vec<_>>();
649 let debug_tx = self.debug_tx.clone();
650
651 let events = project_state
652 .map(|state| {
653 state
654 .events
655 .iter()
656 .filter_map(|event| event.to_request_event(cx))
657 .collect::<Vec<_>>()
658 })
659 .unwrap_or_default();
660
661 let diagnostics = snapshot.diagnostic_sets().clone();
662
663 let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
664 let mut path = f.worktree.read(cx).absolutize(&f.path);
665 if path.pop() { Some(path) } else { None }
666 });
667
668 // TODO data collection
669 let can_collect_data = cx.is_staff();
670
671 let mut included_files = project_state
672 .and_then(|project_state| project_state.context.as_ref())
673 .unwrap_or(&HashMap::default())
674 .iter()
675 .filter_map(|(buffer, ranges)| {
676 let buffer = buffer.read(cx);
677 Some((
678 buffer.snapshot(),
679 buffer.file()?.full_path(cx).into(),
680 ranges.clone(),
681 ))
682 })
683 .collect::<Vec<_>>();
684
685 let request_task = cx.background_spawn({
686 let snapshot = snapshot.clone();
687 let buffer = buffer.clone();
688 async move {
689 let index_state = if let Some(index_state) = index_state {
690 Some(index_state.lock_owned().await)
691 } else {
692 None
693 };
694
695 let cursor_offset = position.to_offset(&snapshot);
696 let cursor_point = cursor_offset.to_point(&snapshot);
697
698 let before_retrieval = chrono::Utc::now();
699
700 let (diagnostic_groups, diagnostic_groups_truncated) =
701 Self::gather_nearby_diagnostics(
702 cursor_offset,
703 &diagnostics,
704 &snapshot,
705 options.max_diagnostic_bytes,
706 );
707
708 let request = match options.context {
709 ContextMode::Llm(context_options) => {
710 let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
711 cursor_point,
712 &snapshot,
713 &context_options.excerpt,
714 index_state.as_deref(),
715 ) else {
716 return Ok((None, None));
717 };
718
719 let excerpt_anchor_range = snapshot.anchor_after(excerpt.range.start)
720 ..snapshot.anchor_before(excerpt.range.end);
721
722 if let Some(buffer_ix) = included_files
723 .iter()
724 .position(|(buffer, _, _)| buffer.remote_id() == snapshot.remote_id())
725 {
726 let (buffer, _, ranges) = &mut included_files[buffer_ix];
727 let range_ix = ranges
728 .binary_search_by(|probe| {
729 probe
730 .start
731 .cmp(&excerpt_anchor_range.start, buffer)
732 .then(excerpt_anchor_range.end.cmp(&probe.end, buffer))
733 })
734 .unwrap_or_else(|ix| ix);
735
736 ranges.insert(range_ix, excerpt_anchor_range);
737 let last_ix = included_files.len() - 1;
738 included_files.swap(buffer_ix, last_ix);
739 } else {
740 included_files.push((
741 snapshot,
742 excerpt_path.clone(),
743 vec![excerpt_anchor_range],
744 ));
745 }
746
747 let included_files = included_files
748 .into_iter()
749 .map(|(buffer, path, ranges)| {
750 let excerpts = merge_excerpts(
751 &buffer,
752 ranges.iter().map(|range| {
753 let point_range = range.to_point(&buffer);
754 Line(point_range.start.row)..Line(point_range.end.row)
755 }),
756 );
757 predict_edits_v3::IncludedFile {
758 path,
759 max_row: Line(buffer.max_point().row),
760 excerpts,
761 }
762 })
763 .collect::<Vec<_>>();
764
765 predict_edits_v3::PredictEditsRequest {
766 excerpt_path,
767 excerpt: String::new(),
768 excerpt_line_range: Line(0)..Line(0),
769 excerpt_range: 0..0,
770 cursor_point: predict_edits_v3::Point {
771 line: predict_edits_v3::Line(cursor_point.row),
772 column: cursor_point.column,
773 },
774 included_files,
775 referenced_declarations: vec![],
776 events,
777 can_collect_data,
778 diagnostic_groups,
779 diagnostic_groups_truncated,
780 debug_info: debug_tx.is_some(),
781 prompt_max_bytes: Some(options.max_prompt_bytes),
782 prompt_format: options.prompt_format,
783 // TODO [zeta2]
784 signatures: vec![],
785 excerpt_parent: None,
786 git_info: None,
787 }
788 }
789 ContextMode::Syntax(context_options) => {
790 let Some(context) = EditPredictionContext::gather_context(
791 cursor_point,
792 &snapshot,
793 parent_abs_path.as_deref(),
794 &context_options,
795 index_state.as_deref(),
796 ) else {
797 return Ok((None, None));
798 };
799
800 make_syntax_context_cloud_request(
801 excerpt_path,
802 context,
803 events,
804 can_collect_data,
805 diagnostic_groups,
806 diagnostic_groups_truncated,
807 None,
808 debug_tx.is_some(),
809 &worktree_snapshots,
810 index_state.as_deref(),
811 Some(options.max_prompt_bytes),
812 options.prompt_format,
813 )
814 }
815 };
816
817 let retrieval_time = chrono::Utc::now() - before_retrieval;
818
819 let debug_response_tx = if let Some(debug_tx) = &debug_tx {
820 let (response_tx, response_rx) = oneshot::channel();
821
822 let local_prompt = build_prompt(&request)
823 .map(|(prompt, _)| prompt)
824 .map_err(|err| err.to_string());
825
826 debug_tx
827 .unbounded_send(ZetaDebugInfo::EditPredicted(ZetaEditPredictionDebugInfo {
828 request: request.clone(),
829 retrieval_time,
830 buffer: buffer.downgrade(),
831 local_prompt,
832 position,
833 response_rx,
834 }))
835 .ok();
836 Some(response_tx)
837 } else {
838 None
839 };
840
841 if cfg!(debug_assertions) && std::env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
842 if let Some(debug_response_tx) = debug_response_tx {
843 debug_response_tx
844 .send(Err("Request skipped".to_string()))
845 .ok();
846 }
847 anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
848 }
849
850 let response =
851 Self::send_prediction_request(client, llm_token, app_version, request).await;
852
853 if let Some(debug_response_tx) = debug_response_tx {
854 debug_response_tx
855 .send(
856 response
857 .as_ref()
858 .map_err(|err| err.to_string())
859 .map(|response| response.0.clone()),
860 )
861 .ok();
862 }
863
864 response.map(|(res, usage)| (Some(res), usage))
865 }
866 });
867
868 let buffer = buffer.clone();
869
870 cx.spawn({
871 let project = project.clone();
872 async move |this, cx| {
873 let Some(response) = Self::handle_api_response(&this, request_task.await, cx)?
874 else {
875 return Ok(None);
876 };
877
878 // TODO telemetry: duration, etc
879 Ok(EditPrediction::from_response(response, &snapshot, &buffer, &project, cx).await)
880 }
881 })
882 }
883
884 async fn send_prediction_request(
885 client: Arc<Client>,
886 llm_token: LlmApiToken,
887 app_version: SemanticVersion,
888 request: predict_edits_v3::PredictEditsRequest,
889 ) -> Result<(
890 predict_edits_v3::PredictEditsResponse,
891 Option<EditPredictionUsage>,
892 )> {
893 let url = if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
894 http_client::Url::parse(&predict_edits_url)?
895 } else {
896 client
897 .http_client()
898 .build_zed_llm_url("/predict_edits/v3", &[])?
899 };
900
901 Self::send_api_request(
902 |builder| {
903 let req = builder
904 .uri(url.as_ref())
905 .body(serde_json::to_string(&request)?.into());
906 Ok(req?)
907 },
908 client,
909 llm_token,
910 app_version,
911 )
912 .await
913 }
914
915 fn handle_api_response<T>(
916 this: &WeakEntity<Self>,
917 response: Result<(T, Option<EditPredictionUsage>)>,
918 cx: &mut gpui::AsyncApp,
919 ) -> Result<T> {
920 match response {
921 Ok((data, usage)) => {
922 if let Some(usage) = usage {
923 this.update(cx, |this, cx| {
924 this.user_store.update(cx, |user_store, cx| {
925 user_store.update_edit_prediction_usage(usage, cx);
926 });
927 })
928 .ok();
929 }
930 Ok(data)
931 }
932 Err(err) => {
933 if err.is::<ZedUpdateRequiredError>() {
934 cx.update(|cx| {
935 this.update(cx, |this, _cx| {
936 this.update_required = true;
937 })
938 .ok();
939
940 let error_message: SharedString = err.to_string().into();
941 show_app_notification(
942 NotificationId::unique::<ZedUpdateRequiredError>(),
943 cx,
944 move |cx| {
945 cx.new(|cx| {
946 ErrorMessagePrompt::new(error_message.clone(), cx)
947 .with_link_button("Update Zed", "https://zed.dev/releases")
948 })
949 },
950 );
951 })
952 .ok();
953 }
954 Err(err)
955 }
956 }
957 }
958
959 async fn send_api_request<Res>(
960 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
961 client: Arc<Client>,
962 llm_token: LlmApiToken,
963 app_version: SemanticVersion,
964 ) -> Result<(Res, Option<EditPredictionUsage>)>
965 where
966 Res: DeserializeOwned,
967 {
968 let http_client = client.http_client();
969 let mut token = llm_token.acquire(&client).await?;
970 let mut did_retry = false;
971
972 loop {
973 let request_builder = http_client::Request::builder().method(Method::POST);
974
975 let request = build(
976 request_builder
977 .header("Content-Type", "application/json")
978 .header("Authorization", format!("Bearer {}", token))
979 .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
980 )?;
981
982 let mut response = http_client.send(request).await?;
983
984 if let Some(minimum_required_version) = response
985 .headers()
986 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
987 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
988 {
989 anyhow::ensure!(
990 app_version >= minimum_required_version,
991 ZedUpdateRequiredError {
992 minimum_version: minimum_required_version
993 }
994 );
995 }
996
997 if response.status().is_success() {
998 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
999
1000 let mut body = Vec::new();
1001 response.body_mut().read_to_end(&mut body).await?;
1002 return Ok((serde_json::from_slice(&body)?, usage));
1003 } else if !did_retry
1004 && response
1005 .headers()
1006 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1007 .is_some()
1008 {
1009 did_retry = true;
1010 token = llm_token.refresh(&client).await?;
1011 } else {
1012 let mut body = String::new();
1013 response.body_mut().read_to_string(&mut body).await?;
1014 anyhow::bail!(
1015 "Request failed with status: {:?}\nBody: {}",
1016 response.status(),
1017 body
1018 );
1019 }
1020 }
1021 }
1022
1023 pub const CONTEXT_RETRIEVAL_IDLE_DURATION: Duration = Duration::from_secs(10);
1024 pub const CONTEXT_RETRIEVAL_DEBOUNCE_DURATION: Duration = Duration::from_secs(3);
1025
1026 // Refresh the related excerpts when the user just beguns editing after
1027 // an idle period, and after they pause editing.
1028 fn refresh_context_if_needed(
1029 &mut self,
1030 project: &Entity<Project>,
1031 buffer: &Entity<language::Buffer>,
1032 cursor_position: language::Anchor,
1033 cx: &mut Context<Self>,
1034 ) {
1035 if !matches!(&self.options().context, ContextMode::Llm { .. }) {
1036 return;
1037 }
1038
1039 let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1040 return;
1041 };
1042
1043 let now = Instant::now();
1044 let was_idle = zeta_project
1045 .refresh_context_timestamp
1046 .map_or(true, |timestamp| {
1047 now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION
1048 });
1049 zeta_project.refresh_context_timestamp = Some(now);
1050 zeta_project.refresh_context_debounce_task = Some(cx.spawn({
1051 let buffer = buffer.clone();
1052 let project = project.clone();
1053 async move |this, cx| {
1054 if was_idle {
1055 log::debug!("refetching edit prediction context after idle");
1056 } else {
1057 cx.background_executor()
1058 .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION)
1059 .await;
1060 log::debug!("refetching edit prediction context after pause");
1061 }
1062 this.update(cx, |this, cx| {
1063 this.refresh_context(project, buffer, cursor_position, cx);
1064 })
1065 .ok()
1066 }
1067 }));
1068 }
1069
1070 // Refresh the related excerpts asynchronously. Ensure the task runs to completion,
1071 // and avoid spawning more than one concurrent task.
1072 fn refresh_context(
1073 &mut self,
1074 project: Entity<Project>,
1075 buffer: Entity<language::Buffer>,
1076 cursor_position: language::Anchor,
1077 cx: &mut Context<Self>,
1078 ) {
1079 let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1080 return;
1081 };
1082
1083 let debug_tx = self.debug_tx.clone();
1084
1085 zeta_project.refresh_context_task = Some(cx.spawn(async move |this, cx| {
1086 if let Some(debug_tx) = &debug_tx {
1087 debug_tx
1088 .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted(
1089 ZetaContextRetrievalDebugInfo {
1090 project: project.clone(),
1091 timestamp: Instant::now(),
1092 },
1093 ))
1094 .ok();
1095 }
1096
1097 let related_excerpts = this
1098 .update(cx, |this, cx| {
1099 let Some(zeta_project) = this.projects.get(&project.entity_id()) else {
1100 return Task::ready(anyhow::Ok(HashMap::default()));
1101 };
1102
1103 let ContextMode::Llm(options) = &this.options().context else {
1104 return Task::ready(anyhow::Ok(HashMap::default()));
1105 };
1106
1107 find_related_excerpts(
1108 buffer.clone(),
1109 cursor_position,
1110 &project,
1111 zeta_project.events.iter(),
1112 options,
1113 debug_tx,
1114 cx,
1115 )
1116 })
1117 .ok()?
1118 .await
1119 .log_err()
1120 .unwrap_or_default();
1121 this.update(cx, |this, _cx| {
1122 let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
1123 return;
1124 };
1125 zeta_project.context = Some(related_excerpts);
1126 zeta_project.refresh_context_task.take();
1127 if let Some(debug_tx) = &this.debug_tx {
1128 debug_tx
1129 .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished(
1130 ZetaContextRetrievalDebugInfo {
1131 project,
1132 timestamp: Instant::now(),
1133 },
1134 ))
1135 .ok();
1136 }
1137 })
1138 .ok()
1139 }));
1140 }
1141
1142 fn gather_nearby_diagnostics(
1143 cursor_offset: usize,
1144 diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
1145 snapshot: &BufferSnapshot,
1146 max_diagnostics_bytes: usize,
1147 ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
1148 // TODO: Could make this more efficient
1149 let mut diagnostic_groups = Vec::new();
1150 for (language_server_id, diagnostics) in diagnostic_sets {
1151 let mut groups = Vec::new();
1152 diagnostics.groups(*language_server_id, &mut groups, &snapshot);
1153 diagnostic_groups.extend(
1154 groups
1155 .into_iter()
1156 .map(|(_, group)| group.resolve::<usize>(&snapshot)),
1157 );
1158 }
1159
1160 // sort by proximity to cursor
1161 diagnostic_groups.sort_by_key(|group| {
1162 let range = &group.entries[group.primary_ix].range;
1163 if range.start >= cursor_offset {
1164 range.start - cursor_offset
1165 } else if cursor_offset >= range.end {
1166 cursor_offset - range.end
1167 } else {
1168 (cursor_offset - range.start).min(range.end - cursor_offset)
1169 }
1170 });
1171
1172 let mut results = Vec::new();
1173 let mut diagnostic_groups_truncated = false;
1174 let mut diagnostics_byte_count = 0;
1175 for group in diagnostic_groups {
1176 let raw_value = serde_json::value::to_raw_value(&group).unwrap();
1177 diagnostics_byte_count += raw_value.get().len();
1178 if diagnostics_byte_count > max_diagnostics_bytes {
1179 diagnostic_groups_truncated = true;
1180 break;
1181 }
1182 results.push(predict_edits_v3::DiagnosticGroup(raw_value));
1183 }
1184
1185 (results, diagnostic_groups_truncated)
1186 }
1187
1188 // TODO: Dedupe with similar code in request_prediction?
1189 pub fn cloud_request_for_zeta_cli(
1190 &mut self,
1191 project: &Entity<Project>,
1192 buffer: &Entity<Buffer>,
1193 position: language::Anchor,
1194 cx: &mut Context<Self>,
1195 ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
1196 let project_state = self.projects.get(&project.entity_id());
1197
1198 let index_state = project_state.map(|state| {
1199 state
1200 .syntax_index
1201 .read_with(cx, |index, _cx| index.state().clone())
1202 });
1203 let options = self.options.clone();
1204 let snapshot = buffer.read(cx).snapshot();
1205 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
1206 return Task::ready(Err(anyhow!("No file path for excerpt")));
1207 };
1208 let worktree_snapshots = project
1209 .read(cx)
1210 .worktrees(cx)
1211 .map(|worktree| worktree.read(cx).snapshot())
1212 .collect::<Vec<_>>();
1213
1214 let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
1215 let mut path = f.worktree.read(cx).absolutize(&f.path);
1216 if path.pop() { Some(path) } else { None }
1217 });
1218
1219 cx.background_spawn(async move {
1220 let index_state = if let Some(index_state) = index_state {
1221 Some(index_state.lock_owned().await)
1222 } else {
1223 None
1224 };
1225
1226 let cursor_point = position.to_point(&snapshot);
1227
1228 let debug_info = true;
1229 EditPredictionContext::gather_context(
1230 cursor_point,
1231 &snapshot,
1232 parent_abs_path.as_deref(),
1233 match &options.context {
1234 ContextMode::Llm(_) => {
1235 // TODO
1236 panic!("Llm mode not supported in zeta cli yet");
1237 }
1238 ContextMode::Syntax(edit_prediction_context_options) => {
1239 edit_prediction_context_options
1240 }
1241 },
1242 index_state.as_deref(),
1243 )
1244 .context("Failed to select excerpt")
1245 .map(|context| {
1246 make_syntax_context_cloud_request(
1247 excerpt_path.into(),
1248 context,
1249 // TODO pass everything
1250 Vec::new(),
1251 false,
1252 Vec::new(),
1253 false,
1254 None,
1255 debug_info,
1256 &worktree_snapshots,
1257 index_state.as_deref(),
1258 Some(options.max_prompt_bytes),
1259 options.prompt_format,
1260 )
1261 })
1262 })
1263 }
1264
1265 pub fn wait_for_initial_indexing(
1266 &mut self,
1267 project: &Entity<Project>,
1268 cx: &mut App,
1269 ) -> Task<Result<()>> {
1270 let zeta_project = self.get_or_init_zeta_project(project, cx);
1271 zeta_project
1272 .syntax_index
1273 .read(cx)
1274 .wait_for_initial_file_indexing(cx)
1275 }
1276}
1277
1278#[derive(Error, Debug)]
1279#[error(
1280 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1281)]
1282pub struct ZedUpdateRequiredError {
1283 minimum_version: SemanticVersion,
1284}
1285
1286fn make_syntax_context_cloud_request(
1287 excerpt_path: Arc<Path>,
1288 context: EditPredictionContext,
1289 events: Vec<predict_edits_v3::Event>,
1290 can_collect_data: bool,
1291 diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
1292 diagnostic_groups_truncated: bool,
1293 git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
1294 debug_info: bool,
1295 worktrees: &Vec<worktree::Snapshot>,
1296 index_state: Option<&SyntaxIndexState>,
1297 prompt_max_bytes: Option<usize>,
1298 prompt_format: PromptFormat,
1299) -> predict_edits_v3::PredictEditsRequest {
1300 let mut signatures = Vec::new();
1301 let mut declaration_to_signature_index = HashMap::default();
1302 let mut referenced_declarations = Vec::new();
1303
1304 for snippet in context.declarations {
1305 let project_entry_id = snippet.declaration.project_entry_id();
1306 let Some(path) = worktrees.iter().find_map(|worktree| {
1307 worktree.entry_for_id(project_entry_id).map(|entry| {
1308 let mut full_path = RelPathBuf::new();
1309 full_path.push(worktree.root_name());
1310 full_path.push(&entry.path);
1311 full_path
1312 })
1313 }) else {
1314 continue;
1315 };
1316
1317 let parent_index = index_state.and_then(|index_state| {
1318 snippet.declaration.parent().and_then(|parent| {
1319 add_signature(
1320 parent,
1321 &mut declaration_to_signature_index,
1322 &mut signatures,
1323 index_state,
1324 )
1325 })
1326 });
1327
1328 let (text, text_is_truncated) = snippet.declaration.item_text();
1329 referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
1330 path: path.as_std_path().into(),
1331 text: text.into(),
1332 range: snippet.declaration.item_line_range(),
1333 text_is_truncated,
1334 signature_range: snippet.declaration.signature_range_in_item_text(),
1335 parent_index,
1336 signature_score: snippet.score(DeclarationStyle::Signature),
1337 declaration_score: snippet.score(DeclarationStyle::Declaration),
1338 score_components: snippet.components,
1339 });
1340 }
1341
1342 let excerpt_parent = index_state.and_then(|index_state| {
1343 context
1344 .excerpt
1345 .parent_declarations
1346 .last()
1347 .and_then(|(parent, _)| {
1348 add_signature(
1349 *parent,
1350 &mut declaration_to_signature_index,
1351 &mut signatures,
1352 index_state,
1353 )
1354 })
1355 });
1356
1357 predict_edits_v3::PredictEditsRequest {
1358 excerpt_path,
1359 excerpt: context.excerpt_text.body,
1360 excerpt_line_range: context.excerpt.line_range,
1361 excerpt_range: context.excerpt.range,
1362 cursor_point: predict_edits_v3::Point {
1363 line: predict_edits_v3::Line(context.cursor_point.row),
1364 column: context.cursor_point.column,
1365 },
1366 referenced_declarations,
1367 included_files: vec![],
1368 signatures,
1369 excerpt_parent,
1370 events,
1371 can_collect_data,
1372 diagnostic_groups,
1373 diagnostic_groups_truncated,
1374 git_info,
1375 debug_info,
1376 prompt_max_bytes,
1377 prompt_format,
1378 }
1379}
1380
1381fn add_signature(
1382 declaration_id: DeclarationId,
1383 declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
1384 signatures: &mut Vec<Signature>,
1385 index: &SyntaxIndexState,
1386) -> Option<usize> {
1387 if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
1388 return Some(*signature_index);
1389 }
1390 let Some(parent_declaration) = index.declaration(declaration_id) else {
1391 log::error!("bug: missing parent declaration");
1392 return None;
1393 };
1394 let parent_index = parent_declaration.parent().and_then(|parent| {
1395 add_signature(parent, declaration_to_signature_index, signatures, index)
1396 });
1397 let (text, text_is_truncated) = parent_declaration.signature_text();
1398 let signature_index = signatures.len();
1399 signatures.push(Signature {
1400 text: text.into(),
1401 text_is_truncated,
1402 parent_index,
1403 range: parent_declaration.signature_line_range(),
1404 });
1405 declaration_to_signature_index.insert(declaration_id, signature_index);
1406 Some(signature_index)
1407}
1408
1409#[cfg(test)]
1410mod tests {
1411 use std::{
1412 path::{Path, PathBuf},
1413 sync::Arc,
1414 };
1415
1416 use client::UserStore;
1417 use clock::FakeSystemClock;
1418 use cloud_llm_client::predict_edits_v3::{self, Point};
1419 use edit_prediction_context::Line;
1420 use futures::{
1421 AsyncReadExt, StreamExt,
1422 channel::{mpsc, oneshot},
1423 };
1424 use gpui::{
1425 Entity, TestAppContext,
1426 http_client::{FakeHttpClient, Response},
1427 prelude::*,
1428 };
1429 use indoc::indoc;
1430 use language::{LanguageServerId, OffsetRangeExt as _};
1431 use pretty_assertions::{assert_eq, assert_matches};
1432 use project::{FakeFs, Project};
1433 use serde_json::json;
1434 use settings::SettingsStore;
1435 use util::path;
1436 use uuid::Uuid;
1437
1438 use crate::{BufferEditPrediction, Zeta};
1439
1440 #[gpui::test]
1441 async fn test_current_state(cx: &mut TestAppContext) {
1442 let (zeta, mut req_rx) = init_test(cx);
1443 let fs = FakeFs::new(cx.executor());
1444 fs.insert_tree(
1445 "/root",
1446 json!({
1447 "1.txt": "Hello!\nHow\nBye",
1448 "2.txt": "Hola!\nComo\nAdios"
1449 }),
1450 )
1451 .await;
1452 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1453
1454 zeta.update(cx, |zeta, cx| {
1455 zeta.register_project(&project, cx);
1456 });
1457
1458 let buffer1 = project
1459 .update(cx, |project, cx| {
1460 let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
1461 project.open_buffer(path, cx)
1462 })
1463 .await
1464 .unwrap();
1465 let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
1466 let position = snapshot1.anchor_before(language::Point::new(1, 3));
1467
1468 // Prediction for current file
1469
1470 let prediction_task = zeta.update(cx, |zeta, cx| {
1471 zeta.refresh_prediction(&project, &buffer1, position, cx)
1472 });
1473 let (_request, respond_tx) = req_rx.next().await.unwrap();
1474 respond_tx
1475 .send(predict_edits_v3::PredictEditsResponse {
1476 request_id: Uuid::new_v4(),
1477 edits: vec![predict_edits_v3::Edit {
1478 path: Path::new(path!("root/1.txt")).into(),
1479 range: Line(0)..Line(snapshot1.max_point().row + 1),
1480 content: "Hello!\nHow are you?\nBye".into(),
1481 }],
1482 debug_info: None,
1483 })
1484 .unwrap();
1485 prediction_task.await.unwrap();
1486
1487 zeta.read_with(cx, |zeta, cx| {
1488 let prediction = zeta
1489 .current_prediction_for_buffer(&buffer1, &project, cx)
1490 .unwrap();
1491 assert_matches!(prediction, BufferEditPrediction::Local { .. });
1492 });
1493
1494 // Prediction for another file
1495 let prediction_task = zeta.update(cx, |zeta, cx| {
1496 zeta.refresh_prediction(&project, &buffer1, position, cx)
1497 });
1498 let (_request, respond_tx) = req_rx.next().await.unwrap();
1499 respond_tx
1500 .send(predict_edits_v3::PredictEditsResponse {
1501 request_id: Uuid::new_v4(),
1502 edits: vec![predict_edits_v3::Edit {
1503 path: Path::new(path!("root/2.txt")).into(),
1504 range: Line(0)..Line(snapshot1.max_point().row + 1),
1505 content: "Hola!\nComo estas?\nAdios".into(),
1506 }],
1507 debug_info: None,
1508 })
1509 .unwrap();
1510 prediction_task.await.unwrap();
1511 zeta.read_with(cx, |zeta, cx| {
1512 let prediction = zeta
1513 .current_prediction_for_buffer(&buffer1, &project, cx)
1514 .unwrap();
1515 assert_matches!(
1516 prediction,
1517 BufferEditPrediction::Jump { prediction } if prediction.path.as_ref() == Path::new(path!("root/2.txt"))
1518 );
1519 });
1520
1521 let buffer2 = project
1522 .update(cx, |project, cx| {
1523 let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
1524 project.open_buffer(path, cx)
1525 })
1526 .await
1527 .unwrap();
1528
1529 zeta.read_with(cx, |zeta, cx| {
1530 let prediction = zeta
1531 .current_prediction_for_buffer(&buffer2, &project, cx)
1532 .unwrap();
1533 assert_matches!(prediction, BufferEditPrediction::Local { .. });
1534 });
1535 }
1536
1537 #[gpui::test]
1538 async fn test_simple_request(cx: &mut TestAppContext) {
1539 let (zeta, mut req_rx) = init_test(cx);
1540 let fs = FakeFs::new(cx.executor());
1541 fs.insert_tree(
1542 "/root",
1543 json!({
1544 "foo.md": "Hello!\nHow\nBye"
1545 }),
1546 )
1547 .await;
1548 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1549
1550 let buffer = project
1551 .update(cx, |project, cx| {
1552 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1553 project.open_buffer(path, cx)
1554 })
1555 .await
1556 .unwrap();
1557 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1558 let position = snapshot.anchor_before(language::Point::new(1, 3));
1559
1560 let prediction_task = zeta.update(cx, |zeta, cx| {
1561 zeta.request_prediction(&project, &buffer, position, cx)
1562 });
1563
1564 let (request, respond_tx) = req_rx.next().await.unwrap();
1565 assert_eq!(
1566 request.excerpt_path.as_ref(),
1567 Path::new(path!("root/foo.md"))
1568 );
1569 assert_eq!(
1570 request.cursor_point,
1571 Point {
1572 line: Line(1),
1573 column: 3
1574 }
1575 );
1576
1577 respond_tx
1578 .send(predict_edits_v3::PredictEditsResponse {
1579 request_id: Uuid::new_v4(),
1580 edits: vec![predict_edits_v3::Edit {
1581 path: Path::new(path!("root/foo.md")).into(),
1582 range: Line(0)..Line(snapshot.max_point().row + 1),
1583 content: "Hello!\nHow are you?\nBye".into(),
1584 }],
1585 debug_info: None,
1586 })
1587 .unwrap();
1588
1589 let prediction = prediction_task.await.unwrap().unwrap();
1590
1591 assert_eq!(prediction.edits.len(), 1);
1592 assert_eq!(
1593 prediction.edits[0].0.to_point(&snapshot).start,
1594 language::Point::new(1, 3)
1595 );
1596 assert_eq!(prediction.edits[0].1, " are you?");
1597 }
1598
1599 #[gpui::test]
1600 async fn test_request_events(cx: &mut TestAppContext) {
1601 let (zeta, mut req_rx) = init_test(cx);
1602 let fs = FakeFs::new(cx.executor());
1603 fs.insert_tree(
1604 "/root",
1605 json!({
1606 "foo.md": "Hello!\n\nBye"
1607 }),
1608 )
1609 .await;
1610 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1611
1612 let buffer = project
1613 .update(cx, |project, cx| {
1614 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1615 project.open_buffer(path, cx)
1616 })
1617 .await
1618 .unwrap();
1619
1620 zeta.update(cx, |zeta, cx| {
1621 zeta.register_buffer(&buffer, &project, cx);
1622 });
1623
1624 buffer.update(cx, |buffer, cx| {
1625 buffer.edit(vec![(7..7, "How")], None, cx);
1626 });
1627
1628 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1629 let position = snapshot.anchor_before(language::Point::new(1, 3));
1630
1631 let prediction_task = zeta.update(cx, |zeta, cx| {
1632 zeta.request_prediction(&project, &buffer, position, cx)
1633 });
1634
1635 let (request, respond_tx) = req_rx.next().await.unwrap();
1636
1637 assert_eq!(request.events.len(), 1);
1638 assert_eq!(
1639 request.events[0],
1640 predict_edits_v3::Event::BufferChange {
1641 path: Some(PathBuf::from(path!("root/foo.md"))),
1642 old_path: None,
1643 diff: indoc! {"
1644 @@ -1,3 +1,3 @@
1645 Hello!
1646 -
1647 +How
1648 Bye
1649 "}
1650 .to_string(),
1651 predicted: false
1652 }
1653 );
1654
1655 respond_tx
1656 .send(predict_edits_v3::PredictEditsResponse {
1657 request_id: Uuid::new_v4(),
1658 edits: vec![predict_edits_v3::Edit {
1659 path: Path::new(path!("root/foo.md")).into(),
1660 range: Line(0)..Line(snapshot.max_point().row + 1),
1661 content: "Hello!\nHow are you?\nBye".into(),
1662 }],
1663 debug_info: None,
1664 })
1665 .unwrap();
1666
1667 let prediction = prediction_task.await.unwrap().unwrap();
1668
1669 assert_eq!(prediction.edits.len(), 1);
1670 assert_eq!(
1671 prediction.edits[0].0.to_point(&snapshot).start,
1672 language::Point::new(1, 3)
1673 );
1674 assert_eq!(prediction.edits[0].1, " are you?");
1675 }
1676
1677 #[gpui::test]
1678 async fn test_request_diagnostics(cx: &mut TestAppContext) {
1679 let (zeta, mut req_rx) = init_test(cx);
1680 let fs = FakeFs::new(cx.executor());
1681 fs.insert_tree(
1682 "/root",
1683 json!({
1684 "foo.md": "Hello!\nBye"
1685 }),
1686 )
1687 .await;
1688 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1689
1690 let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1691 let diagnostic = lsp::Diagnostic {
1692 range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1693 severity: Some(lsp::DiagnosticSeverity::ERROR),
1694 message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1695 ..Default::default()
1696 };
1697
1698 project.update(cx, |project, cx| {
1699 project.lsp_store().update(cx, |lsp_store, cx| {
1700 // Create some diagnostics
1701 lsp_store
1702 .update_diagnostics(
1703 LanguageServerId(0),
1704 lsp::PublishDiagnosticsParams {
1705 uri: path_to_buffer_uri.clone(),
1706 diagnostics: vec![diagnostic],
1707 version: None,
1708 },
1709 None,
1710 language::DiagnosticSourceKind::Pushed,
1711 &[],
1712 cx,
1713 )
1714 .unwrap();
1715 });
1716 });
1717
1718 let buffer = project
1719 .update(cx, |project, cx| {
1720 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1721 project.open_buffer(path, cx)
1722 })
1723 .await
1724 .unwrap();
1725
1726 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1727 let position = snapshot.anchor_before(language::Point::new(0, 0));
1728
1729 let _prediction_task = zeta.update(cx, |zeta, cx| {
1730 zeta.request_prediction(&project, &buffer, position, cx)
1731 });
1732
1733 let (request, _respond_tx) = req_rx.next().await.unwrap();
1734
1735 assert_eq!(request.diagnostic_groups.len(), 1);
1736 let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1737 .unwrap();
1738 // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1739 assert_eq!(
1740 value,
1741 json!({
1742 "entries": [{
1743 "range": {
1744 "start": 8,
1745 "end": 10
1746 },
1747 "diagnostic": {
1748 "source": null,
1749 "code": null,
1750 "code_description": null,
1751 "severity": 1,
1752 "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1753 "markdown": null,
1754 "group_id": 0,
1755 "is_primary": true,
1756 "is_disk_based": false,
1757 "is_unnecessary": false,
1758 "source_kind": "Pushed",
1759 "data": null,
1760 "underline": true
1761 }
1762 }],
1763 "primary_ix": 0
1764 })
1765 );
1766 }
1767
1768 fn init_test(
1769 cx: &mut TestAppContext,
1770 ) -> (
1771 Entity<Zeta>,
1772 mpsc::UnboundedReceiver<(
1773 predict_edits_v3::PredictEditsRequest,
1774 oneshot::Sender<predict_edits_v3::PredictEditsResponse>,
1775 )>,
1776 ) {
1777 cx.update(move |cx| {
1778 let settings_store = SettingsStore::test(cx);
1779 cx.set_global(settings_store);
1780 language::init(cx);
1781 Project::init_settings(cx);
1782
1783 let (req_tx, req_rx) = mpsc::unbounded();
1784
1785 let http_client = FakeHttpClient::create({
1786 move |req| {
1787 let uri = req.uri().path().to_string();
1788 let mut body = req.into_body();
1789 let req_tx = req_tx.clone();
1790 async move {
1791 let resp = match uri.as_str() {
1792 "/client/llm_tokens" => serde_json::to_string(&json!({
1793 "token": "test"
1794 }))
1795 .unwrap(),
1796 "/predict_edits/v3" => {
1797 let mut buf = Vec::new();
1798 body.read_to_end(&mut buf).await.ok();
1799 let req = serde_json::from_slice(&buf).unwrap();
1800
1801 let (res_tx, res_rx) = oneshot::channel();
1802 req_tx.unbounded_send((req, res_tx)).unwrap();
1803 serde_json::to_string(&res_rx.await?).unwrap()
1804 }
1805 _ => {
1806 panic!("Unexpected path: {}", uri)
1807 }
1808 };
1809
1810 Ok(Response::builder().body(resp.into()).unwrap())
1811 }
1812 }
1813 });
1814
1815 let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1816 client.cloud_client().set_credentials(1, "test".into());
1817
1818 language_model::init(client.clone(), cx);
1819
1820 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1821 let zeta = Zeta::global(&client, &user_store, cx);
1822
1823 (zeta, req_rx)
1824 })
1825 }
1826}