1use anyhow::{Context as _, Result, anyhow};
2use chrono::TimeDelta;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
5use cloud_llm_client::{
6 AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
7 ZED_VERSION_HEADER_NAME,
8};
9use cloud_zeta2_prompt::{DEFAULT_MAX_PROMPT_BYTES, build_prompt};
10use collections::HashMap;
11use edit_prediction_context::{
12 DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
13 EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionScoreOptions, Line,
14 SyntaxIndex, SyntaxIndexState,
15};
16use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
17use futures::AsyncReadExt as _;
18use futures::channel::{mpsc, oneshot};
19use gpui::http_client::{AsyncBody, Method};
20use gpui::{
21 App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
22 http_client, prelude::*,
23};
24use language::{Anchor, Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
25use language::{BufferSnapshot, OffsetRangeExt};
26use language_model::{LlmApiToken, RefreshLlmTokenListener};
27use project::Project;
28use release_channel::AppVersion;
29use serde::de::DeserializeOwned;
30use std::collections::{VecDeque, hash_map};
31use std::fmt::Write;
32use std::ops::Range;
33use std::path::Path;
34use std::str::FromStr as _;
35use std::sync::Arc;
36use std::time::{Duration, Instant};
37use thiserror::Error;
38use util::ResultExt as _;
39use util::rel_path::RelPathBuf;
40use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
41
42pub mod merge_excerpts;
43mod prediction;
44mod provider;
45pub mod related_excerpts;
46
47use crate::merge_excerpts::merge_excerpts;
48use crate::prediction::EditPrediction;
49use crate::related_excerpts::find_related_excerpts;
50pub use crate::related_excerpts::{LlmContextOptions, SearchToolQuery};
51pub use provider::ZetaEditPredictionProvider;
52
53const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
54
55/// Maximum number of events to track.
56const MAX_EVENT_COUNT: usize = 16;
57
58pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
59 max_bytes: 512,
60 min_bytes: 128,
61 target_before_cursor_over_total_bytes: 0.5,
62};
63
64pub const DEFAULT_CONTEXT_OPTIONS: ContextMode = ContextMode::Llm(DEFAULT_LLM_CONTEXT_OPTIONS);
65
66pub const DEFAULT_LLM_CONTEXT_OPTIONS: LlmContextOptions = LlmContextOptions {
67 excerpt: DEFAULT_EXCERPT_OPTIONS,
68};
69
70pub const DEFAULT_SYNTAX_CONTEXT_OPTIONS: EditPredictionContextOptions =
71 EditPredictionContextOptions {
72 use_imports: true,
73 max_retrieved_declarations: 0,
74 excerpt: DEFAULT_EXCERPT_OPTIONS,
75 score: EditPredictionScoreOptions {
76 omit_excerpt_overlaps: true,
77 },
78 };
79
80pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
81 context: DEFAULT_CONTEXT_OPTIONS,
82 max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
83 max_diagnostic_bytes: 2048,
84 prompt_format: PromptFormat::DEFAULT,
85 file_indexing_parallelism: 1,
86};
87
88pub struct Zeta2FeatureFlag;
89
90impl FeatureFlag for Zeta2FeatureFlag {
91 const NAME: &'static str = "zeta2";
92
93 fn enabled_for_staff() -> bool {
94 false
95 }
96}
97
98#[derive(Clone)]
99struct ZetaGlobal(Entity<Zeta>);
100
101impl Global for ZetaGlobal {}
102
103pub struct Zeta {
104 client: Arc<Client>,
105 user_store: Entity<UserStore>,
106 llm_token: LlmApiToken,
107 _llm_token_subscription: Subscription,
108 projects: HashMap<EntityId, ZetaProject>,
109 options: ZetaOptions,
110 update_required: bool,
111 debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
112}
113
114#[derive(Debug, Clone, PartialEq)]
115pub struct ZetaOptions {
116 pub context: ContextMode,
117 pub max_prompt_bytes: usize,
118 pub max_diagnostic_bytes: usize,
119 pub prompt_format: predict_edits_v3::PromptFormat,
120 pub file_indexing_parallelism: usize,
121}
122
123#[derive(Debug, Clone, PartialEq)]
124pub enum ContextMode {
125 Llm(LlmContextOptions),
126 Syntax(EditPredictionContextOptions),
127}
128
129impl ContextMode {
130 pub fn excerpt(&self) -> &EditPredictionExcerptOptions {
131 match self {
132 ContextMode::Llm(options) => &options.excerpt,
133 ContextMode::Syntax(options) => &options.excerpt,
134 }
135 }
136}
137
138pub enum ZetaDebugInfo {
139 ContextRetrievalStarted(ZetaContextRetrievalStartedDebugInfo),
140 SearchQueriesGenerated(ZetaSearchQueryDebugInfo),
141 SearchQueriesExecuted(ZetaContextRetrievalDebugInfo),
142 SearchResultsFiltered(ZetaContextRetrievalDebugInfo),
143 ContextRetrievalFinished(ZetaContextRetrievalDebugInfo),
144 EditPredicted(ZetaEditPredictionDebugInfo),
145}
146
147pub struct ZetaContextRetrievalStartedDebugInfo {
148 pub project: Entity<Project>,
149 pub timestamp: Instant,
150 pub search_prompt: String,
151}
152
153pub struct ZetaContextRetrievalDebugInfo {
154 pub project: Entity<Project>,
155 pub timestamp: Instant,
156}
157
158pub struct ZetaEditPredictionDebugInfo {
159 pub request: predict_edits_v3::PredictEditsRequest,
160 pub retrieval_time: TimeDelta,
161 pub buffer: WeakEntity<Buffer>,
162 pub position: language::Anchor,
163 pub local_prompt: Result<String, String>,
164 pub response_rx: oneshot::Receiver<Result<predict_edits_v3::PredictEditsResponse, String>>,
165}
166
167pub struct ZetaSearchQueryDebugInfo {
168 pub project: Entity<Project>,
169 pub timestamp: Instant,
170 pub queries: Vec<SearchToolQuery>,
171}
172
173pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
174
175struct ZetaProject {
176 syntax_index: Entity<SyntaxIndex>,
177 events: VecDeque<Event>,
178 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
179 current_prediction: Option<CurrentEditPrediction>,
180 context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
181 refresh_context_task: Option<Task<Option<()>>>,
182 refresh_context_debounce_task: Option<Task<Option<()>>>,
183 refresh_context_timestamp: Option<Instant>,
184}
185
186#[derive(Debug, Clone)]
187struct CurrentEditPrediction {
188 pub requested_by_buffer_id: EntityId,
189 pub prediction: EditPrediction,
190}
191
192impl CurrentEditPrediction {
193 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
194 let Some(new_edits) = self
195 .prediction
196 .interpolate(&self.prediction.buffer.read(cx))
197 else {
198 return false;
199 };
200
201 if self.prediction.buffer != old_prediction.prediction.buffer {
202 return true;
203 }
204
205 let Some(old_edits) = old_prediction
206 .prediction
207 .interpolate(&old_prediction.prediction.buffer.read(cx))
208 else {
209 return true;
210 };
211
212 // This reduces the occurrence of UI thrash from replacing edits
213 //
214 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
215 if self.requested_by_buffer_id == self.prediction.buffer.entity_id()
216 && self.requested_by_buffer_id == old_prediction.prediction.buffer.entity_id()
217 && old_edits.len() == 1
218 && new_edits.len() == 1
219 {
220 let (old_range, old_text) = &old_edits[0];
221 let (new_range, new_text) = &new_edits[0];
222 new_range == old_range && new_text.starts_with(old_text)
223 } else {
224 true
225 }
226 }
227}
228
229/// A prediction from the perspective of a buffer.
230#[derive(Debug)]
231enum BufferEditPrediction<'a> {
232 Local { prediction: &'a EditPrediction },
233 Jump { prediction: &'a EditPrediction },
234}
235
236struct RegisteredBuffer {
237 snapshot: BufferSnapshot,
238 _subscriptions: [gpui::Subscription; 2],
239}
240
241#[derive(Clone)]
242pub enum Event {
243 BufferChange {
244 old_snapshot: BufferSnapshot,
245 new_snapshot: BufferSnapshot,
246 timestamp: Instant,
247 },
248}
249
250impl Event {
251 pub fn to_request_event(&self, cx: &App) -> Option<predict_edits_v3::Event> {
252 match self {
253 Event::BufferChange {
254 old_snapshot,
255 new_snapshot,
256 ..
257 } => {
258 let path = new_snapshot.file().map(|f| f.full_path(cx));
259
260 let old_path = old_snapshot.file().and_then(|f| {
261 let old_path = f.full_path(cx);
262 if Some(&old_path) != path.as_ref() {
263 Some(old_path)
264 } else {
265 None
266 }
267 });
268
269 // TODO [zeta2] move to bg?
270 let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
271
272 if path == old_path && diff.is_empty() {
273 None
274 } else {
275 Some(predict_edits_v3::Event::BufferChange {
276 old_path,
277 path,
278 diff,
279 //todo: Actually detect if this edit was predicted or not
280 predicted: false,
281 })
282 }
283 }
284 }
285 }
286}
287
288impl Zeta {
289 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
290 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
291 }
292
293 pub fn global(
294 client: &Arc<Client>,
295 user_store: &Entity<UserStore>,
296 cx: &mut App,
297 ) -> Entity<Self> {
298 cx.try_global::<ZetaGlobal>()
299 .map(|global| global.0.clone())
300 .unwrap_or_else(|| {
301 let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
302 cx.set_global(ZetaGlobal(zeta.clone()));
303 zeta
304 })
305 }
306
307 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
308 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
309
310 Self {
311 projects: HashMap::default(),
312 client,
313 user_store,
314 options: DEFAULT_OPTIONS,
315 llm_token: LlmApiToken::default(),
316 _llm_token_subscription: cx.subscribe(
317 &refresh_llm_token_listener,
318 |this, _listener, _event, cx| {
319 let client = this.client.clone();
320 let llm_token = this.llm_token.clone();
321 cx.spawn(async move |_this, _cx| {
322 llm_token.refresh(&client).await?;
323 anyhow::Ok(())
324 })
325 .detach_and_log_err(cx);
326 },
327 ),
328 update_required: false,
329 debug_tx: None,
330 }
331 }
332
333 pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<ZetaDebugInfo> {
334 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
335 self.debug_tx = Some(debug_watch_tx);
336 debug_watch_rx
337 }
338
339 pub fn options(&self) -> &ZetaOptions {
340 &self.options
341 }
342
343 pub fn set_options(&mut self, options: ZetaOptions) {
344 self.options = options;
345 }
346
347 pub fn clear_history(&mut self) {
348 for zeta_project in self.projects.values_mut() {
349 zeta_project.events.clear();
350 }
351 }
352
353 pub fn history_for_project(&self, project: &Entity<Project>) -> impl Iterator<Item = &Event> {
354 self.projects
355 .get(&project.entity_id())
356 .map(|project| project.events.iter())
357 .into_iter()
358 .flatten()
359 }
360
361 pub fn context_for_project(
362 &self,
363 project: &Entity<Project>,
364 ) -> impl Iterator<Item = (Entity<Buffer>, &[Range<Anchor>])> {
365 self.projects
366 .get(&project.entity_id())
367 .and_then(|project| {
368 Some(
369 project
370 .context
371 .as_ref()?
372 .iter()
373 .map(|(buffer, ranges)| (buffer.clone(), ranges.as_slice())),
374 )
375 })
376 .into_iter()
377 .flatten()
378 }
379
380 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
381 self.user_store.read(cx).edit_prediction_usage()
382 }
383
384 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
385 self.get_or_init_zeta_project(project, cx);
386 }
387
388 pub fn register_buffer(
389 &mut self,
390 buffer: &Entity<Buffer>,
391 project: &Entity<Project>,
392 cx: &mut Context<Self>,
393 ) {
394 let zeta_project = self.get_or_init_zeta_project(project, cx);
395 Self::register_buffer_impl(zeta_project, buffer, project, cx);
396 }
397
398 fn get_or_init_zeta_project(
399 &mut self,
400 project: &Entity<Project>,
401 cx: &mut App,
402 ) -> &mut ZetaProject {
403 self.projects
404 .entry(project.entity_id())
405 .or_insert_with(|| ZetaProject {
406 syntax_index: cx.new(|cx| {
407 SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
408 }),
409 events: VecDeque::new(),
410 registered_buffers: HashMap::default(),
411 current_prediction: None,
412 context: None,
413 refresh_context_task: None,
414 refresh_context_debounce_task: None,
415 refresh_context_timestamp: None,
416 })
417 }
418
419 fn register_buffer_impl<'a>(
420 zeta_project: &'a mut ZetaProject,
421 buffer: &Entity<Buffer>,
422 project: &Entity<Project>,
423 cx: &mut Context<Self>,
424 ) -> &'a mut RegisteredBuffer {
425 let buffer_id = buffer.entity_id();
426 match zeta_project.registered_buffers.entry(buffer_id) {
427 hash_map::Entry::Occupied(entry) => entry.into_mut(),
428 hash_map::Entry::Vacant(entry) => {
429 let snapshot = buffer.read(cx).snapshot();
430 let project_entity_id = project.entity_id();
431 entry.insert(RegisteredBuffer {
432 snapshot,
433 _subscriptions: [
434 cx.subscribe(buffer, {
435 let project = project.downgrade();
436 move |this, buffer, event, cx| {
437 if let language::BufferEvent::Edited = event
438 && let Some(project) = project.upgrade()
439 {
440 this.report_changes_for_buffer(&buffer, &project, cx);
441 }
442 }
443 }),
444 cx.observe_release(buffer, move |this, _buffer, _cx| {
445 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
446 else {
447 return;
448 };
449 zeta_project.registered_buffers.remove(&buffer_id);
450 }),
451 ],
452 })
453 }
454 }
455 }
456
457 fn report_changes_for_buffer(
458 &mut self,
459 buffer: &Entity<Buffer>,
460 project: &Entity<Project>,
461 cx: &mut Context<Self>,
462 ) -> BufferSnapshot {
463 let zeta_project = self.get_or_init_zeta_project(project, cx);
464 let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
465
466 let new_snapshot = buffer.read(cx).snapshot();
467 if new_snapshot.version != registered_buffer.snapshot.version {
468 let old_snapshot =
469 std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
470 Self::push_event(
471 zeta_project,
472 Event::BufferChange {
473 old_snapshot,
474 new_snapshot: new_snapshot.clone(),
475 timestamp: Instant::now(),
476 },
477 );
478 }
479
480 new_snapshot
481 }
482
483 fn push_event(zeta_project: &mut ZetaProject, event: Event) {
484 let events = &mut zeta_project.events;
485
486 if let Some(Event::BufferChange {
487 new_snapshot: last_new_snapshot,
488 timestamp: last_timestamp,
489 ..
490 }) = events.back_mut()
491 {
492 // Coalesce edits for the same buffer when they happen one after the other.
493 let Event::BufferChange {
494 old_snapshot,
495 new_snapshot,
496 timestamp,
497 } = &event;
498
499 if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
500 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
501 && old_snapshot.version == last_new_snapshot.version
502 {
503 *last_new_snapshot = new_snapshot.clone();
504 *last_timestamp = *timestamp;
505 return;
506 }
507 }
508
509 if events.len() >= MAX_EVENT_COUNT {
510 // These are halved instead of popping to improve prompt caching.
511 events.drain(..MAX_EVENT_COUNT / 2);
512 }
513
514 events.push_back(event);
515 }
516
517 fn current_prediction_for_buffer(
518 &self,
519 buffer: &Entity<Buffer>,
520 project: &Entity<Project>,
521 cx: &App,
522 ) -> Option<BufferEditPrediction<'_>> {
523 let project_state = self.projects.get(&project.entity_id())?;
524
525 let CurrentEditPrediction {
526 requested_by_buffer_id,
527 prediction,
528 } = project_state.current_prediction.as_ref()?;
529
530 if prediction.targets_buffer(buffer.read(cx), cx) {
531 Some(BufferEditPrediction::Local { prediction })
532 } else if *requested_by_buffer_id == buffer.entity_id() {
533 Some(BufferEditPrediction::Jump { prediction })
534 } else {
535 None
536 }
537 }
538
539 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
540 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
541 return;
542 };
543
544 let Some(prediction) = project_state.current_prediction.take() else {
545 return;
546 };
547 let request_id = prediction.prediction.id.into();
548
549 let client = self.client.clone();
550 let llm_token = self.llm_token.clone();
551 let app_version = AppVersion::global(cx);
552 cx.spawn(async move |this, cx| {
553 let url = if let Ok(predict_edits_url) = std::env::var("ZED_ACCEPT_PREDICTION_URL") {
554 http_client::Url::parse(&predict_edits_url)?
555 } else {
556 client
557 .http_client()
558 .build_zed_llm_url("/predict_edits/accept", &[])?
559 };
560
561 let response = cx
562 .background_spawn(Self::send_api_request::<()>(
563 move |builder| {
564 let req = builder.uri(url.as_ref()).body(
565 serde_json::to_string(&AcceptEditPredictionBody { request_id })?.into(),
566 );
567 Ok(req?)
568 },
569 client,
570 llm_token,
571 app_version,
572 ))
573 .await;
574
575 Self::handle_api_response(&this, response, cx)?;
576 anyhow::Ok(())
577 })
578 .detach_and_log_err(cx);
579 }
580
581 fn discard_current_prediction(&mut self, project: &Entity<Project>) {
582 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
583 project_state.current_prediction.take();
584 };
585 }
586
587 pub fn refresh_prediction(
588 &mut self,
589 project: &Entity<Project>,
590 buffer: &Entity<Buffer>,
591 position: language::Anchor,
592 cx: &mut Context<Self>,
593 ) -> Task<Result<()>> {
594 let request_task = self.request_prediction(project, buffer, position, cx);
595 let buffer = buffer.clone();
596 let project = project.clone();
597
598 cx.spawn(async move |this, cx| {
599 if let Some(prediction) = request_task.await? {
600 this.update(cx, |this, cx| {
601 let project_state = this
602 .projects
603 .get_mut(&project.entity_id())
604 .context("Project not found")?;
605
606 let new_prediction = CurrentEditPrediction {
607 requested_by_buffer_id: buffer.entity_id(),
608 prediction: prediction,
609 };
610
611 if project_state
612 .current_prediction
613 .as_ref()
614 .is_none_or(|old_prediction| {
615 new_prediction.should_replace_prediction(&old_prediction, cx)
616 })
617 {
618 project_state.current_prediction = Some(new_prediction);
619 }
620 anyhow::Ok(())
621 })??;
622 }
623 Ok(())
624 })
625 }
626
627 fn request_prediction(
628 &mut self,
629 project: &Entity<Project>,
630 buffer: &Entity<Buffer>,
631 position: language::Anchor,
632 cx: &mut Context<Self>,
633 ) -> Task<Result<Option<EditPrediction>>> {
634 let project_state = self.projects.get(&project.entity_id());
635
636 let index_state = project_state.map(|state| {
637 state
638 .syntax_index
639 .read_with(cx, |index, _cx| index.state().clone())
640 });
641 let options = self.options.clone();
642 let snapshot = buffer.read(cx).snapshot();
643 let Some(excerpt_path) = snapshot
644 .file()
645 .map(|path| -> Arc<Path> { path.full_path(cx).into() })
646 else {
647 return Task::ready(Err(anyhow!("No file path for excerpt")));
648 };
649 let client = self.client.clone();
650 let llm_token = self.llm_token.clone();
651 let app_version = AppVersion::global(cx);
652 let worktree_snapshots = project
653 .read(cx)
654 .worktrees(cx)
655 .map(|worktree| worktree.read(cx).snapshot())
656 .collect::<Vec<_>>();
657 let debug_tx = self.debug_tx.clone();
658
659 let events = project_state
660 .map(|state| {
661 state
662 .events
663 .iter()
664 .filter_map(|event| event.to_request_event(cx))
665 .collect::<Vec<_>>()
666 })
667 .unwrap_or_default();
668
669 let diagnostics = snapshot.diagnostic_sets().clone();
670
671 let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
672 let mut path = f.worktree.read(cx).absolutize(&f.path);
673 if path.pop() { Some(path) } else { None }
674 });
675
676 // TODO data collection
677 let can_collect_data = cx.is_staff();
678
679 let mut included_files = project_state
680 .and_then(|project_state| project_state.context.as_ref())
681 .unwrap_or(&HashMap::default())
682 .iter()
683 .filter_map(|(buffer, ranges)| {
684 let buffer = buffer.read(cx);
685 Some((
686 buffer.snapshot(),
687 buffer.file()?.full_path(cx).into(),
688 ranges.clone(),
689 ))
690 })
691 .collect::<Vec<_>>();
692
693 let request_task = cx.background_spawn({
694 let snapshot = snapshot.clone();
695 let buffer = buffer.clone();
696 async move {
697 let index_state = if let Some(index_state) = index_state {
698 Some(index_state.lock_owned().await)
699 } else {
700 None
701 };
702
703 let cursor_offset = position.to_offset(&snapshot);
704 let cursor_point = cursor_offset.to_point(&snapshot);
705
706 let before_retrieval = chrono::Utc::now();
707
708 let (diagnostic_groups, diagnostic_groups_truncated) =
709 Self::gather_nearby_diagnostics(
710 cursor_offset,
711 &diagnostics,
712 &snapshot,
713 options.max_diagnostic_bytes,
714 );
715
716 let request = match options.context {
717 ContextMode::Llm(context_options) => {
718 let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
719 cursor_point,
720 &snapshot,
721 &context_options.excerpt,
722 index_state.as_deref(),
723 ) else {
724 return Ok((None, None));
725 };
726
727 let excerpt_anchor_range = snapshot.anchor_after(excerpt.range.start)
728 ..snapshot.anchor_before(excerpt.range.end);
729
730 if let Some(buffer_ix) = included_files
731 .iter()
732 .position(|(buffer, _, _)| buffer.remote_id() == snapshot.remote_id())
733 {
734 let (buffer, _, ranges) = &mut included_files[buffer_ix];
735 let range_ix = ranges
736 .binary_search_by(|probe| {
737 probe
738 .start
739 .cmp(&excerpt_anchor_range.start, buffer)
740 .then(excerpt_anchor_range.end.cmp(&probe.end, buffer))
741 })
742 .unwrap_or_else(|ix| ix);
743
744 ranges.insert(range_ix, excerpt_anchor_range);
745 let last_ix = included_files.len() - 1;
746 included_files.swap(buffer_ix, last_ix);
747 } else {
748 included_files.push((
749 snapshot,
750 excerpt_path.clone(),
751 vec![excerpt_anchor_range],
752 ));
753 }
754
755 let included_files = included_files
756 .into_iter()
757 .map(|(buffer, path, ranges)| {
758 let excerpts = merge_excerpts(
759 &buffer,
760 ranges.iter().map(|range| {
761 let point_range = range.to_point(&buffer);
762 Line(point_range.start.row)..Line(point_range.end.row)
763 }),
764 );
765 predict_edits_v3::IncludedFile {
766 path,
767 max_row: Line(buffer.max_point().row),
768 excerpts,
769 }
770 })
771 .collect::<Vec<_>>();
772
773 predict_edits_v3::PredictEditsRequest {
774 excerpt_path,
775 excerpt: String::new(),
776 excerpt_line_range: Line(0)..Line(0),
777 excerpt_range: 0..0,
778 cursor_point: predict_edits_v3::Point {
779 line: predict_edits_v3::Line(cursor_point.row),
780 column: cursor_point.column,
781 },
782 included_files,
783 referenced_declarations: vec![],
784 events,
785 can_collect_data,
786 diagnostic_groups,
787 diagnostic_groups_truncated,
788 debug_info: debug_tx.is_some(),
789 prompt_max_bytes: Some(options.max_prompt_bytes),
790 prompt_format: options.prompt_format,
791 // TODO [zeta2]
792 signatures: vec![],
793 excerpt_parent: None,
794 git_info: None,
795 }
796 }
797 ContextMode::Syntax(context_options) => {
798 let Some(context) = EditPredictionContext::gather_context(
799 cursor_point,
800 &snapshot,
801 parent_abs_path.as_deref(),
802 &context_options,
803 index_state.as_deref(),
804 ) else {
805 return Ok((None, None));
806 };
807
808 make_syntax_context_cloud_request(
809 excerpt_path,
810 context,
811 events,
812 can_collect_data,
813 diagnostic_groups,
814 diagnostic_groups_truncated,
815 None,
816 debug_tx.is_some(),
817 &worktree_snapshots,
818 index_state.as_deref(),
819 Some(options.max_prompt_bytes),
820 options.prompt_format,
821 )
822 }
823 };
824
825 let retrieval_time = chrono::Utc::now() - before_retrieval;
826
827 let debug_response_tx = if let Some(debug_tx) = &debug_tx {
828 let (response_tx, response_rx) = oneshot::channel();
829
830 let local_prompt = build_prompt(&request)
831 .map(|(prompt, _)| prompt)
832 .map_err(|err| err.to_string());
833
834 debug_tx
835 .unbounded_send(ZetaDebugInfo::EditPredicted(ZetaEditPredictionDebugInfo {
836 request: request.clone(),
837 retrieval_time,
838 buffer: buffer.downgrade(),
839 local_prompt,
840 position,
841 response_rx,
842 }))
843 .ok();
844 Some(response_tx)
845 } else {
846 None
847 };
848
849 if cfg!(debug_assertions) && std::env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
850 if let Some(debug_response_tx) = debug_response_tx {
851 debug_response_tx
852 .send(Err("Request skipped".to_string()))
853 .ok();
854 }
855 anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
856 }
857
858 let response =
859 Self::send_prediction_request(client, llm_token, app_version, request).await;
860
861 if let Some(debug_response_tx) = debug_response_tx {
862 debug_response_tx
863 .send(
864 response
865 .as_ref()
866 .map_err(|err| err.to_string())
867 .map(|response| response.0.clone()),
868 )
869 .ok();
870 }
871
872 response.map(|(res, usage)| (Some(res), usage))
873 }
874 });
875
876 let buffer = buffer.clone();
877
878 cx.spawn({
879 let project = project.clone();
880 async move |this, cx| {
881 let Some(response) = Self::handle_api_response(&this, request_task.await, cx)?
882 else {
883 return Ok(None);
884 };
885
886 // TODO telemetry: duration, etc
887 Ok(EditPrediction::from_response(response, &snapshot, &buffer, &project, cx).await)
888 }
889 })
890 }
891
892 async fn send_prediction_request(
893 client: Arc<Client>,
894 llm_token: LlmApiToken,
895 app_version: SemanticVersion,
896 request: predict_edits_v3::PredictEditsRequest,
897 ) -> Result<(
898 predict_edits_v3::PredictEditsResponse,
899 Option<EditPredictionUsage>,
900 )> {
901 let url = if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
902 http_client::Url::parse(&predict_edits_url)?
903 } else {
904 client
905 .http_client()
906 .build_zed_llm_url("/predict_edits/v3", &[])?
907 };
908
909 Self::send_api_request(
910 |builder| {
911 let req = builder
912 .uri(url.as_ref())
913 .body(serde_json::to_string(&request)?.into());
914 Ok(req?)
915 },
916 client,
917 llm_token,
918 app_version,
919 )
920 .await
921 }
922
923 fn handle_api_response<T>(
924 this: &WeakEntity<Self>,
925 response: Result<(T, Option<EditPredictionUsage>)>,
926 cx: &mut gpui::AsyncApp,
927 ) -> Result<T> {
928 match response {
929 Ok((data, usage)) => {
930 if let Some(usage) = usage {
931 this.update(cx, |this, cx| {
932 this.user_store.update(cx, |user_store, cx| {
933 user_store.update_edit_prediction_usage(usage, cx);
934 });
935 })
936 .ok();
937 }
938 Ok(data)
939 }
940 Err(err) => {
941 if err.is::<ZedUpdateRequiredError>() {
942 cx.update(|cx| {
943 this.update(cx, |this, _cx| {
944 this.update_required = true;
945 })
946 .ok();
947
948 let error_message: SharedString = err.to_string().into();
949 show_app_notification(
950 NotificationId::unique::<ZedUpdateRequiredError>(),
951 cx,
952 move |cx| {
953 cx.new(|cx| {
954 ErrorMessagePrompt::new(error_message.clone(), cx)
955 .with_link_button("Update Zed", "https://zed.dev/releases")
956 })
957 },
958 );
959 })
960 .ok();
961 }
962 Err(err)
963 }
964 }
965 }
966
967 async fn send_api_request<Res>(
968 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
969 client: Arc<Client>,
970 llm_token: LlmApiToken,
971 app_version: SemanticVersion,
972 ) -> Result<(Res, Option<EditPredictionUsage>)>
973 where
974 Res: DeserializeOwned,
975 {
976 let http_client = client.http_client();
977 let mut token = llm_token.acquire(&client).await?;
978 let mut did_retry = false;
979
980 loop {
981 let request_builder = http_client::Request::builder().method(Method::POST);
982
983 let request = build(
984 request_builder
985 .header("Content-Type", "application/json")
986 .header("Authorization", format!("Bearer {}", token))
987 .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
988 )?;
989
990 let mut response = http_client.send(request).await?;
991
992 if let Some(minimum_required_version) = response
993 .headers()
994 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
995 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
996 {
997 anyhow::ensure!(
998 app_version >= minimum_required_version,
999 ZedUpdateRequiredError {
1000 minimum_version: minimum_required_version
1001 }
1002 );
1003 }
1004
1005 if response.status().is_success() {
1006 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
1007
1008 let mut body = Vec::new();
1009 response.body_mut().read_to_end(&mut body).await?;
1010 return Ok((serde_json::from_slice(&body)?, usage));
1011 } else if !did_retry
1012 && response
1013 .headers()
1014 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1015 .is_some()
1016 {
1017 did_retry = true;
1018 token = llm_token.refresh(&client).await?;
1019 } else {
1020 let mut body = String::new();
1021 response.body_mut().read_to_string(&mut body).await?;
1022 anyhow::bail!(
1023 "Request failed with status: {:?}\nBody: {}",
1024 response.status(),
1025 body
1026 );
1027 }
1028 }
1029 }
1030
1031 pub const CONTEXT_RETRIEVAL_IDLE_DURATION: Duration = Duration::from_secs(10);
1032 pub const CONTEXT_RETRIEVAL_DEBOUNCE_DURATION: Duration = Duration::from_secs(3);
1033
1034 // Refresh the related excerpts when the user just beguns editing after
1035 // an idle period, and after they pause editing.
1036 fn refresh_context_if_needed(
1037 &mut self,
1038 project: &Entity<Project>,
1039 buffer: &Entity<language::Buffer>,
1040 cursor_position: language::Anchor,
1041 cx: &mut Context<Self>,
1042 ) {
1043 if !matches!(&self.options().context, ContextMode::Llm { .. }) {
1044 return;
1045 }
1046
1047 let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1048 return;
1049 };
1050
1051 let now = Instant::now();
1052 let was_idle = zeta_project
1053 .refresh_context_timestamp
1054 .map_or(true, |timestamp| {
1055 now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION
1056 });
1057 zeta_project.refresh_context_timestamp = Some(now);
1058 zeta_project.refresh_context_debounce_task = Some(cx.spawn({
1059 let buffer = buffer.clone();
1060 let project = project.clone();
1061 async move |this, cx| {
1062 if was_idle {
1063 log::debug!("refetching edit prediction context after idle");
1064 } else {
1065 cx.background_executor()
1066 .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION)
1067 .await;
1068 log::debug!("refetching edit prediction context after pause");
1069 }
1070 this.update(cx, |this, cx| {
1071 this.refresh_context(project, buffer, cursor_position, cx);
1072 })
1073 .ok()
1074 }
1075 }));
1076 }
1077
1078 // Refresh the related excerpts asynchronously. Ensure the task runs to completion,
1079 // and avoid spawning more than one concurrent task.
1080 fn refresh_context(
1081 &mut self,
1082 project: Entity<Project>,
1083 buffer: Entity<language::Buffer>,
1084 cursor_position: language::Anchor,
1085 cx: &mut Context<Self>,
1086 ) {
1087 let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1088 return;
1089 };
1090
1091 let debug_tx = self.debug_tx.clone();
1092
1093 zeta_project
1094 .refresh_context_task
1095 .get_or_insert(cx.spawn(async move |this, cx| {
1096 let related_excerpts = this
1097 .update(cx, |this, cx| {
1098 let Some(zeta_project) = this.projects.get(&project.entity_id()) else {
1099 return Task::ready(anyhow::Ok(HashMap::default()));
1100 };
1101
1102 let ContextMode::Llm(options) = &this.options().context else {
1103 return Task::ready(anyhow::Ok(HashMap::default()));
1104 };
1105
1106 let mut edit_history_unified_diff = String::new();
1107
1108 for event in zeta_project.events.iter() {
1109 if let Some(event) = event.to_request_event(cx) {
1110 writeln!(&mut edit_history_unified_diff, "{event}").ok();
1111 }
1112 }
1113
1114 find_related_excerpts(
1115 buffer.clone(),
1116 cursor_position,
1117 &project,
1118 edit_history_unified_diff,
1119 options,
1120 debug_tx,
1121 cx,
1122 )
1123 })
1124 .ok()?
1125 .await
1126 .log_err()
1127 .unwrap_or_default();
1128 this.update(cx, |this, _cx| {
1129 let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
1130 return;
1131 };
1132 zeta_project.context = Some(related_excerpts);
1133 zeta_project.refresh_context_task.take();
1134 if let Some(debug_tx) = &this.debug_tx {
1135 debug_tx
1136 .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished(
1137 ZetaContextRetrievalDebugInfo {
1138 project,
1139 timestamp: Instant::now(),
1140 },
1141 ))
1142 .ok();
1143 }
1144 })
1145 .ok()
1146 }));
1147 }
1148
1149 fn gather_nearby_diagnostics(
1150 cursor_offset: usize,
1151 diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
1152 snapshot: &BufferSnapshot,
1153 max_diagnostics_bytes: usize,
1154 ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
1155 // TODO: Could make this more efficient
1156 let mut diagnostic_groups = Vec::new();
1157 for (language_server_id, diagnostics) in diagnostic_sets {
1158 let mut groups = Vec::new();
1159 diagnostics.groups(*language_server_id, &mut groups, &snapshot);
1160 diagnostic_groups.extend(
1161 groups
1162 .into_iter()
1163 .map(|(_, group)| group.resolve::<usize>(&snapshot)),
1164 );
1165 }
1166
1167 // sort by proximity to cursor
1168 diagnostic_groups.sort_by_key(|group| {
1169 let range = &group.entries[group.primary_ix].range;
1170 if range.start >= cursor_offset {
1171 range.start - cursor_offset
1172 } else if cursor_offset >= range.end {
1173 cursor_offset - range.end
1174 } else {
1175 (cursor_offset - range.start).min(range.end - cursor_offset)
1176 }
1177 });
1178
1179 let mut results = Vec::new();
1180 let mut diagnostic_groups_truncated = false;
1181 let mut diagnostics_byte_count = 0;
1182 for group in diagnostic_groups {
1183 let raw_value = serde_json::value::to_raw_value(&group).unwrap();
1184 diagnostics_byte_count += raw_value.get().len();
1185 if diagnostics_byte_count > max_diagnostics_bytes {
1186 diagnostic_groups_truncated = true;
1187 break;
1188 }
1189 results.push(predict_edits_v3::DiagnosticGroup(raw_value));
1190 }
1191
1192 (results, diagnostic_groups_truncated)
1193 }
1194
1195 // TODO: Dedupe with similar code in request_prediction?
1196 pub fn cloud_request_for_zeta_cli(
1197 &mut self,
1198 project: &Entity<Project>,
1199 buffer: &Entity<Buffer>,
1200 position: language::Anchor,
1201 cx: &mut Context<Self>,
1202 ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
1203 let project_state = self.projects.get(&project.entity_id());
1204
1205 let index_state = project_state.map(|state| {
1206 state
1207 .syntax_index
1208 .read_with(cx, |index, _cx| index.state().clone())
1209 });
1210 let options = self.options.clone();
1211 let snapshot = buffer.read(cx).snapshot();
1212 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
1213 return Task::ready(Err(anyhow!("No file path for excerpt")));
1214 };
1215 let worktree_snapshots = project
1216 .read(cx)
1217 .worktrees(cx)
1218 .map(|worktree| worktree.read(cx).snapshot())
1219 .collect::<Vec<_>>();
1220
1221 let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
1222 let mut path = f.worktree.read(cx).absolutize(&f.path);
1223 if path.pop() { Some(path) } else { None }
1224 });
1225
1226 cx.background_spawn(async move {
1227 let index_state = if let Some(index_state) = index_state {
1228 Some(index_state.lock_owned().await)
1229 } else {
1230 None
1231 };
1232
1233 let cursor_point = position.to_point(&snapshot);
1234
1235 let debug_info = true;
1236 EditPredictionContext::gather_context(
1237 cursor_point,
1238 &snapshot,
1239 parent_abs_path.as_deref(),
1240 match &options.context {
1241 ContextMode::Llm(_) => {
1242 // TODO
1243 panic!("Llm mode not supported in zeta cli yet");
1244 }
1245 ContextMode::Syntax(edit_prediction_context_options) => {
1246 edit_prediction_context_options
1247 }
1248 },
1249 index_state.as_deref(),
1250 )
1251 .context("Failed to select excerpt")
1252 .map(|context| {
1253 make_syntax_context_cloud_request(
1254 excerpt_path.into(),
1255 context,
1256 // TODO pass everything
1257 Vec::new(),
1258 false,
1259 Vec::new(),
1260 false,
1261 None,
1262 debug_info,
1263 &worktree_snapshots,
1264 index_state.as_deref(),
1265 Some(options.max_prompt_bytes),
1266 options.prompt_format,
1267 )
1268 })
1269 })
1270 }
1271
1272 pub fn wait_for_initial_indexing(
1273 &mut self,
1274 project: &Entity<Project>,
1275 cx: &mut App,
1276 ) -> Task<Result<()>> {
1277 let zeta_project = self.get_or_init_zeta_project(project, cx);
1278 zeta_project
1279 .syntax_index
1280 .read(cx)
1281 .wait_for_initial_file_indexing(cx)
1282 }
1283}
1284
1285#[derive(Error, Debug)]
1286#[error(
1287 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1288)]
1289pub struct ZedUpdateRequiredError {
1290 minimum_version: SemanticVersion,
1291}
1292
1293fn make_syntax_context_cloud_request(
1294 excerpt_path: Arc<Path>,
1295 context: EditPredictionContext,
1296 events: Vec<predict_edits_v3::Event>,
1297 can_collect_data: bool,
1298 diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
1299 diagnostic_groups_truncated: bool,
1300 git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
1301 debug_info: bool,
1302 worktrees: &Vec<worktree::Snapshot>,
1303 index_state: Option<&SyntaxIndexState>,
1304 prompt_max_bytes: Option<usize>,
1305 prompt_format: PromptFormat,
1306) -> predict_edits_v3::PredictEditsRequest {
1307 let mut signatures = Vec::new();
1308 let mut declaration_to_signature_index = HashMap::default();
1309 let mut referenced_declarations = Vec::new();
1310
1311 for snippet in context.declarations {
1312 let project_entry_id = snippet.declaration.project_entry_id();
1313 let Some(path) = worktrees.iter().find_map(|worktree| {
1314 worktree.entry_for_id(project_entry_id).map(|entry| {
1315 let mut full_path = RelPathBuf::new();
1316 full_path.push(worktree.root_name());
1317 full_path.push(&entry.path);
1318 full_path
1319 })
1320 }) else {
1321 continue;
1322 };
1323
1324 let parent_index = index_state.and_then(|index_state| {
1325 snippet.declaration.parent().and_then(|parent| {
1326 add_signature(
1327 parent,
1328 &mut declaration_to_signature_index,
1329 &mut signatures,
1330 index_state,
1331 )
1332 })
1333 });
1334
1335 let (text, text_is_truncated) = snippet.declaration.item_text();
1336 referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
1337 path: path.as_std_path().into(),
1338 text: text.into(),
1339 range: snippet.declaration.item_line_range(),
1340 text_is_truncated,
1341 signature_range: snippet.declaration.signature_range_in_item_text(),
1342 parent_index,
1343 signature_score: snippet.score(DeclarationStyle::Signature),
1344 declaration_score: snippet.score(DeclarationStyle::Declaration),
1345 score_components: snippet.components,
1346 });
1347 }
1348
1349 let excerpt_parent = index_state.and_then(|index_state| {
1350 context
1351 .excerpt
1352 .parent_declarations
1353 .last()
1354 .and_then(|(parent, _)| {
1355 add_signature(
1356 *parent,
1357 &mut declaration_to_signature_index,
1358 &mut signatures,
1359 index_state,
1360 )
1361 })
1362 });
1363
1364 predict_edits_v3::PredictEditsRequest {
1365 excerpt_path,
1366 excerpt: context.excerpt_text.body,
1367 excerpt_line_range: context.excerpt.line_range,
1368 excerpt_range: context.excerpt.range,
1369 cursor_point: predict_edits_v3::Point {
1370 line: predict_edits_v3::Line(context.cursor_point.row),
1371 column: context.cursor_point.column,
1372 },
1373 referenced_declarations,
1374 included_files: vec![],
1375 signatures,
1376 excerpt_parent,
1377 events,
1378 can_collect_data,
1379 diagnostic_groups,
1380 diagnostic_groups_truncated,
1381 git_info,
1382 debug_info,
1383 prompt_max_bytes,
1384 prompt_format,
1385 }
1386}
1387
1388fn add_signature(
1389 declaration_id: DeclarationId,
1390 declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
1391 signatures: &mut Vec<Signature>,
1392 index: &SyntaxIndexState,
1393) -> Option<usize> {
1394 if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
1395 return Some(*signature_index);
1396 }
1397 let Some(parent_declaration) = index.declaration(declaration_id) else {
1398 log::error!("bug: missing parent declaration");
1399 return None;
1400 };
1401 let parent_index = parent_declaration.parent().and_then(|parent| {
1402 add_signature(parent, declaration_to_signature_index, signatures, index)
1403 });
1404 let (text, text_is_truncated) = parent_declaration.signature_text();
1405 let signature_index = signatures.len();
1406 signatures.push(Signature {
1407 text: text.into(),
1408 text_is_truncated,
1409 parent_index,
1410 range: parent_declaration.signature_line_range(),
1411 });
1412 declaration_to_signature_index.insert(declaration_id, signature_index);
1413 Some(signature_index)
1414}
1415
1416#[cfg(test)]
1417mod tests {
1418 use std::{
1419 path::{Path, PathBuf},
1420 sync::Arc,
1421 };
1422
1423 use client::UserStore;
1424 use clock::FakeSystemClock;
1425 use cloud_llm_client::predict_edits_v3::{self, Point};
1426 use edit_prediction_context::Line;
1427 use futures::{
1428 AsyncReadExt, StreamExt,
1429 channel::{mpsc, oneshot},
1430 };
1431 use gpui::{
1432 Entity, TestAppContext,
1433 http_client::{FakeHttpClient, Response},
1434 prelude::*,
1435 };
1436 use indoc::indoc;
1437 use language::{LanguageServerId, OffsetRangeExt as _};
1438 use pretty_assertions::{assert_eq, assert_matches};
1439 use project::{FakeFs, Project};
1440 use serde_json::json;
1441 use settings::SettingsStore;
1442 use util::path;
1443 use uuid::Uuid;
1444
1445 use crate::{BufferEditPrediction, Zeta};
1446
1447 #[gpui::test]
1448 async fn test_current_state(cx: &mut TestAppContext) {
1449 let (zeta, mut req_rx) = init_test(cx);
1450 let fs = FakeFs::new(cx.executor());
1451 fs.insert_tree(
1452 "/root",
1453 json!({
1454 "1.txt": "Hello!\nHow\nBye",
1455 "2.txt": "Hola!\nComo\nAdios"
1456 }),
1457 )
1458 .await;
1459 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1460
1461 zeta.update(cx, |zeta, cx| {
1462 zeta.register_project(&project, cx);
1463 });
1464
1465 let buffer1 = project
1466 .update(cx, |project, cx| {
1467 let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
1468 project.open_buffer(path, cx)
1469 })
1470 .await
1471 .unwrap();
1472 let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
1473 let position = snapshot1.anchor_before(language::Point::new(1, 3));
1474
1475 // Prediction for current file
1476
1477 let prediction_task = zeta.update(cx, |zeta, cx| {
1478 zeta.refresh_prediction(&project, &buffer1, position, cx)
1479 });
1480 let (_request, respond_tx) = req_rx.next().await.unwrap();
1481 respond_tx
1482 .send(predict_edits_v3::PredictEditsResponse {
1483 request_id: Uuid::new_v4(),
1484 edits: vec![predict_edits_v3::Edit {
1485 path: Path::new(path!("root/1.txt")).into(),
1486 range: Line(0)..Line(snapshot1.max_point().row + 1),
1487 content: "Hello!\nHow are you?\nBye".into(),
1488 }],
1489 debug_info: None,
1490 })
1491 .unwrap();
1492 prediction_task.await.unwrap();
1493
1494 zeta.read_with(cx, |zeta, cx| {
1495 let prediction = zeta
1496 .current_prediction_for_buffer(&buffer1, &project, cx)
1497 .unwrap();
1498 assert_matches!(prediction, BufferEditPrediction::Local { .. });
1499 });
1500
1501 // Prediction for another file
1502 let prediction_task = zeta.update(cx, |zeta, cx| {
1503 zeta.refresh_prediction(&project, &buffer1, position, cx)
1504 });
1505 let (_request, respond_tx) = req_rx.next().await.unwrap();
1506 respond_tx
1507 .send(predict_edits_v3::PredictEditsResponse {
1508 request_id: Uuid::new_v4(),
1509 edits: vec![predict_edits_v3::Edit {
1510 path: Path::new(path!("root/2.txt")).into(),
1511 range: Line(0)..Line(snapshot1.max_point().row + 1),
1512 content: "Hola!\nComo estas?\nAdios".into(),
1513 }],
1514 debug_info: None,
1515 })
1516 .unwrap();
1517 prediction_task.await.unwrap();
1518 zeta.read_with(cx, |zeta, cx| {
1519 let prediction = zeta
1520 .current_prediction_for_buffer(&buffer1, &project, cx)
1521 .unwrap();
1522 assert_matches!(
1523 prediction,
1524 BufferEditPrediction::Jump { prediction } if prediction.path.as_ref() == Path::new(path!("root/2.txt"))
1525 );
1526 });
1527
1528 let buffer2 = project
1529 .update(cx, |project, cx| {
1530 let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
1531 project.open_buffer(path, cx)
1532 })
1533 .await
1534 .unwrap();
1535
1536 zeta.read_with(cx, |zeta, cx| {
1537 let prediction = zeta
1538 .current_prediction_for_buffer(&buffer2, &project, cx)
1539 .unwrap();
1540 assert_matches!(prediction, BufferEditPrediction::Local { .. });
1541 });
1542 }
1543
1544 #[gpui::test]
1545 async fn test_simple_request(cx: &mut TestAppContext) {
1546 let (zeta, mut req_rx) = init_test(cx);
1547 let fs = FakeFs::new(cx.executor());
1548 fs.insert_tree(
1549 "/root",
1550 json!({
1551 "foo.md": "Hello!\nHow\nBye"
1552 }),
1553 )
1554 .await;
1555 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1556
1557 let buffer = project
1558 .update(cx, |project, cx| {
1559 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1560 project.open_buffer(path, cx)
1561 })
1562 .await
1563 .unwrap();
1564 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1565 let position = snapshot.anchor_before(language::Point::new(1, 3));
1566
1567 let prediction_task = zeta.update(cx, |zeta, cx| {
1568 zeta.request_prediction(&project, &buffer, position, cx)
1569 });
1570
1571 let (request, respond_tx) = req_rx.next().await.unwrap();
1572 assert_eq!(
1573 request.excerpt_path.as_ref(),
1574 Path::new(path!("root/foo.md"))
1575 );
1576 assert_eq!(
1577 request.cursor_point,
1578 Point {
1579 line: Line(1),
1580 column: 3
1581 }
1582 );
1583
1584 respond_tx
1585 .send(predict_edits_v3::PredictEditsResponse {
1586 request_id: Uuid::new_v4(),
1587 edits: vec![predict_edits_v3::Edit {
1588 path: Path::new(path!("root/foo.md")).into(),
1589 range: Line(0)..Line(snapshot.max_point().row + 1),
1590 content: "Hello!\nHow are you?\nBye".into(),
1591 }],
1592 debug_info: None,
1593 })
1594 .unwrap();
1595
1596 let prediction = prediction_task.await.unwrap().unwrap();
1597
1598 assert_eq!(prediction.edits.len(), 1);
1599 assert_eq!(
1600 prediction.edits[0].0.to_point(&snapshot).start,
1601 language::Point::new(1, 3)
1602 );
1603 assert_eq!(prediction.edits[0].1, " are you?");
1604 }
1605
1606 #[gpui::test]
1607 async fn test_request_events(cx: &mut TestAppContext) {
1608 let (zeta, mut req_rx) = init_test(cx);
1609 let fs = FakeFs::new(cx.executor());
1610 fs.insert_tree(
1611 "/root",
1612 json!({
1613 "foo.md": "Hello!\n\nBye"
1614 }),
1615 )
1616 .await;
1617 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1618
1619 let buffer = project
1620 .update(cx, |project, cx| {
1621 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1622 project.open_buffer(path, cx)
1623 })
1624 .await
1625 .unwrap();
1626
1627 zeta.update(cx, |zeta, cx| {
1628 zeta.register_buffer(&buffer, &project, cx);
1629 });
1630
1631 buffer.update(cx, |buffer, cx| {
1632 buffer.edit(vec![(7..7, "How")], None, cx);
1633 });
1634
1635 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1636 let position = snapshot.anchor_before(language::Point::new(1, 3));
1637
1638 let prediction_task = zeta.update(cx, |zeta, cx| {
1639 zeta.request_prediction(&project, &buffer, position, cx)
1640 });
1641
1642 let (request, respond_tx) = req_rx.next().await.unwrap();
1643
1644 assert_eq!(request.events.len(), 1);
1645 assert_eq!(
1646 request.events[0],
1647 predict_edits_v3::Event::BufferChange {
1648 path: Some(PathBuf::from(path!("root/foo.md"))),
1649 old_path: None,
1650 diff: indoc! {"
1651 @@ -1,3 +1,3 @@
1652 Hello!
1653 -
1654 +How
1655 Bye
1656 "}
1657 .to_string(),
1658 predicted: false
1659 }
1660 );
1661
1662 respond_tx
1663 .send(predict_edits_v3::PredictEditsResponse {
1664 request_id: Uuid::new_v4(),
1665 edits: vec![predict_edits_v3::Edit {
1666 path: Path::new(path!("root/foo.md")).into(),
1667 range: Line(0)..Line(snapshot.max_point().row + 1),
1668 content: "Hello!\nHow are you?\nBye".into(),
1669 }],
1670 debug_info: None,
1671 })
1672 .unwrap();
1673
1674 let prediction = prediction_task.await.unwrap().unwrap();
1675
1676 assert_eq!(prediction.edits.len(), 1);
1677 assert_eq!(
1678 prediction.edits[0].0.to_point(&snapshot).start,
1679 language::Point::new(1, 3)
1680 );
1681 assert_eq!(prediction.edits[0].1, " are you?");
1682 }
1683
1684 #[gpui::test]
1685 async fn test_request_diagnostics(cx: &mut TestAppContext) {
1686 let (zeta, mut req_rx) = init_test(cx);
1687 let fs = FakeFs::new(cx.executor());
1688 fs.insert_tree(
1689 "/root",
1690 json!({
1691 "foo.md": "Hello!\nBye"
1692 }),
1693 )
1694 .await;
1695 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1696
1697 let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1698 let diagnostic = lsp::Diagnostic {
1699 range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1700 severity: Some(lsp::DiagnosticSeverity::ERROR),
1701 message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1702 ..Default::default()
1703 };
1704
1705 project.update(cx, |project, cx| {
1706 project.lsp_store().update(cx, |lsp_store, cx| {
1707 // Create some diagnostics
1708 lsp_store
1709 .update_diagnostics(
1710 LanguageServerId(0),
1711 lsp::PublishDiagnosticsParams {
1712 uri: path_to_buffer_uri.clone(),
1713 diagnostics: vec![diagnostic],
1714 version: None,
1715 },
1716 None,
1717 language::DiagnosticSourceKind::Pushed,
1718 &[],
1719 cx,
1720 )
1721 .unwrap();
1722 });
1723 });
1724
1725 let buffer = project
1726 .update(cx, |project, cx| {
1727 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1728 project.open_buffer(path, cx)
1729 })
1730 .await
1731 .unwrap();
1732
1733 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1734 let position = snapshot.anchor_before(language::Point::new(0, 0));
1735
1736 let _prediction_task = zeta.update(cx, |zeta, cx| {
1737 zeta.request_prediction(&project, &buffer, position, cx)
1738 });
1739
1740 let (request, _respond_tx) = req_rx.next().await.unwrap();
1741
1742 assert_eq!(request.diagnostic_groups.len(), 1);
1743 let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1744 .unwrap();
1745 // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1746 assert_eq!(
1747 value,
1748 json!({
1749 "entries": [{
1750 "range": {
1751 "start": 8,
1752 "end": 10
1753 },
1754 "diagnostic": {
1755 "source": null,
1756 "code": null,
1757 "code_description": null,
1758 "severity": 1,
1759 "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1760 "markdown": null,
1761 "group_id": 0,
1762 "is_primary": true,
1763 "is_disk_based": false,
1764 "is_unnecessary": false,
1765 "source_kind": "Pushed",
1766 "data": null,
1767 "underline": true
1768 }
1769 }],
1770 "primary_ix": 0
1771 })
1772 );
1773 }
1774
1775 fn init_test(
1776 cx: &mut TestAppContext,
1777 ) -> (
1778 Entity<Zeta>,
1779 mpsc::UnboundedReceiver<(
1780 predict_edits_v3::PredictEditsRequest,
1781 oneshot::Sender<predict_edits_v3::PredictEditsResponse>,
1782 )>,
1783 ) {
1784 cx.update(move |cx| {
1785 let settings_store = SettingsStore::test(cx);
1786 cx.set_global(settings_store);
1787 language::init(cx);
1788 Project::init_settings(cx);
1789
1790 let (req_tx, req_rx) = mpsc::unbounded();
1791
1792 let http_client = FakeHttpClient::create({
1793 move |req| {
1794 let uri = req.uri().path().to_string();
1795 let mut body = req.into_body();
1796 let req_tx = req_tx.clone();
1797 async move {
1798 let resp = match uri.as_str() {
1799 "/client/llm_tokens" => serde_json::to_string(&json!({
1800 "token": "test"
1801 }))
1802 .unwrap(),
1803 "/predict_edits/v3" => {
1804 let mut buf = Vec::new();
1805 body.read_to_end(&mut buf).await.ok();
1806 let req = serde_json::from_slice(&buf).unwrap();
1807
1808 let (res_tx, res_rx) = oneshot::channel();
1809 req_tx.unbounded_send((req, res_tx)).unwrap();
1810 serde_json::to_string(&res_rx.await?).unwrap()
1811 }
1812 _ => {
1813 panic!("Unexpected path: {}", uri)
1814 }
1815 };
1816
1817 Ok(Response::builder().body(resp.into()).unwrap())
1818 }
1819 }
1820 });
1821
1822 let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1823 client.cloud_client().set_credentials(1, "test".into());
1824
1825 language_model::init(client.clone(), cx);
1826
1827 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1828 let zeta = Zeta::global(&client, &user_store, cx);
1829
1830 (zeta, req_rx)
1831 })
1832 }
1833}