1use anyhow::{Context as _, Result, anyhow};
2use chrono::TimeDelta;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
5use cloud_llm_client::{
6 AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
7 ZED_VERSION_HEADER_NAME,
8};
9use cloud_zeta2_prompt::{DEFAULT_MAX_PROMPT_BYTES, build_prompt};
10use collections::HashMap;
11use edit_prediction_context::{
12 DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
13 EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionScoreOptions, Line,
14 SyntaxIndex, SyntaxIndexState,
15};
16use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
17use futures::AsyncReadExt as _;
18use futures::channel::{mpsc, oneshot};
19use gpui::http_client::{AsyncBody, Method};
20use gpui::{
21 App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
22 http_client, prelude::*,
23};
24use language::{Anchor, Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
25use language::{BufferSnapshot, OffsetRangeExt};
26use language_model::{LlmApiToken, RefreshLlmTokenListener};
27use project::Project;
28use release_channel::AppVersion;
29use serde::de::DeserializeOwned;
30use std::collections::{VecDeque, hash_map};
31use std::ops::Range;
32use std::path::Path;
33use std::str::FromStr as _;
34use std::sync::Arc;
35use std::time::{Duration, Instant};
36use thiserror::Error;
37use util::ResultExt as _;
38use util::rel_path::RelPathBuf;
39use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
40
41mod merge_excerpts;
42mod prediction;
43mod provider;
44mod related_excerpts;
45
46use crate::merge_excerpts::merge_excerpts;
47use crate::prediction::EditPrediction;
48pub use crate::related_excerpts::LlmContextOptions;
49use crate::related_excerpts::find_related_excerpts;
50pub use provider::ZetaEditPredictionProvider;
51
52const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
53
54/// Maximum number of events to track.
55const MAX_EVENT_COUNT: usize = 16;
56
57pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
58 max_bytes: 512,
59 min_bytes: 128,
60 target_before_cursor_over_total_bytes: 0.5,
61};
62
63pub const DEFAULT_CONTEXT_OPTIONS: ContextMode = ContextMode::Llm(DEFAULT_LLM_CONTEXT_OPTIONS);
64
65pub const DEFAULT_LLM_CONTEXT_OPTIONS: LlmContextOptions = LlmContextOptions {
66 excerpt: DEFAULT_EXCERPT_OPTIONS,
67};
68
69pub const DEFAULT_SYNTAX_CONTEXT_OPTIONS: EditPredictionContextOptions =
70 EditPredictionContextOptions {
71 use_imports: true,
72 max_retrieved_declarations: 0,
73 excerpt: DEFAULT_EXCERPT_OPTIONS,
74 score: EditPredictionScoreOptions {
75 omit_excerpt_overlaps: true,
76 },
77 };
78
79pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
80 context: DEFAULT_CONTEXT_OPTIONS,
81 max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
82 max_diagnostic_bytes: 2048,
83 prompt_format: PromptFormat::DEFAULT,
84 file_indexing_parallelism: 1,
85};
86
87pub struct Zeta2FeatureFlag;
88
89impl FeatureFlag for Zeta2FeatureFlag {
90 const NAME: &'static str = "zeta2";
91
92 fn enabled_for_staff() -> bool {
93 false
94 }
95}
96
97#[derive(Clone)]
98struct ZetaGlobal(Entity<Zeta>);
99
100impl Global for ZetaGlobal {}
101
102pub struct Zeta {
103 client: Arc<Client>,
104 user_store: Entity<UserStore>,
105 llm_token: LlmApiToken,
106 _llm_token_subscription: Subscription,
107 projects: HashMap<EntityId, ZetaProject>,
108 options: ZetaOptions,
109 update_required: bool,
110 debug_tx: Option<mpsc::UnboundedSender<PredictionDebugInfo>>,
111}
112
113#[derive(Debug, Clone, PartialEq)]
114pub struct ZetaOptions {
115 pub context: ContextMode,
116 pub max_prompt_bytes: usize,
117 pub max_diagnostic_bytes: usize,
118 pub prompt_format: predict_edits_v3::PromptFormat,
119 pub file_indexing_parallelism: usize,
120}
121
122#[derive(Debug, Clone, PartialEq)]
123pub enum ContextMode {
124 Llm(LlmContextOptions),
125 Syntax(EditPredictionContextOptions),
126}
127
128impl ContextMode {
129 pub fn excerpt(&self) -> &EditPredictionExcerptOptions {
130 match self {
131 ContextMode::Llm(options) => &options.excerpt,
132 ContextMode::Syntax(options) => &options.excerpt,
133 }
134 }
135}
136
137pub struct PredictionDebugInfo {
138 pub request: predict_edits_v3::PredictEditsRequest,
139 pub retrieval_time: TimeDelta,
140 pub buffer: WeakEntity<Buffer>,
141 pub position: language::Anchor,
142 pub local_prompt: Result<String, String>,
143 pub response_rx: oneshot::Receiver<Result<predict_edits_v3::PredictEditsResponse, String>>,
144}
145
146pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
147
148struct ZetaProject {
149 syntax_index: Entity<SyntaxIndex>,
150 events: VecDeque<Event>,
151 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
152 current_prediction: Option<CurrentEditPrediction>,
153 context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
154 refresh_context_task: Option<Task<Option<()>>>,
155 refresh_context_debounce_task: Option<Task<Option<()>>>,
156 refresh_context_timestamp: Option<Instant>,
157}
158
159#[derive(Debug, Clone)]
160struct CurrentEditPrediction {
161 pub requested_by_buffer_id: EntityId,
162 pub prediction: EditPrediction,
163}
164
165impl CurrentEditPrediction {
166 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
167 let Some(new_edits) = self
168 .prediction
169 .interpolate(&self.prediction.buffer.read(cx))
170 else {
171 return false;
172 };
173
174 if self.prediction.buffer != old_prediction.prediction.buffer {
175 return true;
176 }
177
178 let Some(old_edits) = old_prediction
179 .prediction
180 .interpolate(&old_prediction.prediction.buffer.read(cx))
181 else {
182 return true;
183 };
184
185 // This reduces the occurrence of UI thrash from replacing edits
186 //
187 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
188 if self.requested_by_buffer_id == self.prediction.buffer.entity_id()
189 && self.requested_by_buffer_id == old_prediction.prediction.buffer.entity_id()
190 && old_edits.len() == 1
191 && new_edits.len() == 1
192 {
193 let (old_range, old_text) = &old_edits[0];
194 let (new_range, new_text) = &new_edits[0];
195 new_range == old_range && new_text.starts_with(old_text)
196 } else {
197 true
198 }
199 }
200}
201
202/// A prediction from the perspective of a buffer.
203#[derive(Debug)]
204enum BufferEditPrediction<'a> {
205 Local { prediction: &'a EditPrediction },
206 Jump { prediction: &'a EditPrediction },
207}
208
209struct RegisteredBuffer {
210 snapshot: BufferSnapshot,
211 _subscriptions: [gpui::Subscription; 2],
212}
213
214#[derive(Clone)]
215pub enum Event {
216 BufferChange {
217 old_snapshot: BufferSnapshot,
218 new_snapshot: BufferSnapshot,
219 timestamp: Instant,
220 },
221}
222
223impl Event {
224 pub fn to_request_event(&self, cx: &App) -> Option<predict_edits_v3::Event> {
225 match self {
226 Event::BufferChange {
227 old_snapshot,
228 new_snapshot,
229 ..
230 } => {
231 let path = new_snapshot.file().map(|f| f.full_path(cx));
232
233 let old_path = old_snapshot.file().and_then(|f| {
234 let old_path = f.full_path(cx);
235 if Some(&old_path) != path.as_ref() {
236 Some(old_path)
237 } else {
238 None
239 }
240 });
241
242 // TODO [zeta2] move to bg?
243 let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
244
245 if path == old_path && diff.is_empty() {
246 None
247 } else {
248 Some(predict_edits_v3::Event::BufferChange {
249 old_path,
250 path,
251 diff,
252 //todo: Actually detect if this edit was predicted or not
253 predicted: false,
254 })
255 }
256 }
257 }
258 }
259}
260
261impl Zeta {
262 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
263 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
264 }
265
266 pub fn global(
267 client: &Arc<Client>,
268 user_store: &Entity<UserStore>,
269 cx: &mut App,
270 ) -> Entity<Self> {
271 cx.try_global::<ZetaGlobal>()
272 .map(|global| global.0.clone())
273 .unwrap_or_else(|| {
274 let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
275 cx.set_global(ZetaGlobal(zeta.clone()));
276 zeta
277 })
278 }
279
280 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
281 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
282
283 Self {
284 projects: HashMap::default(),
285 client,
286 user_store,
287 options: DEFAULT_OPTIONS,
288 llm_token: LlmApiToken::default(),
289 _llm_token_subscription: cx.subscribe(
290 &refresh_llm_token_listener,
291 |this, _listener, _event, cx| {
292 let client = this.client.clone();
293 let llm_token = this.llm_token.clone();
294 cx.spawn(async move |_this, _cx| {
295 llm_token.refresh(&client).await?;
296 anyhow::Ok(())
297 })
298 .detach_and_log_err(cx);
299 },
300 ),
301 update_required: false,
302 debug_tx: None,
303 }
304 }
305
306 pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<PredictionDebugInfo> {
307 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
308 self.debug_tx = Some(debug_watch_tx);
309 debug_watch_rx
310 }
311
312 pub fn options(&self) -> &ZetaOptions {
313 &self.options
314 }
315
316 pub fn set_options(&mut self, options: ZetaOptions) {
317 self.options = options;
318 }
319
320 pub fn clear_history(&mut self) {
321 for zeta_project in self.projects.values_mut() {
322 zeta_project.events.clear();
323 }
324 }
325
326 pub fn history_for_project(&self, project: &Entity<Project>) -> impl Iterator<Item = &Event> {
327 static EMPTY_EVENTS: VecDeque<Event> = VecDeque::new();
328 self.projects
329 .get(&project.entity_id())
330 .map_or(&EMPTY_EVENTS, |project| &project.events)
331 .iter()
332 }
333
334 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
335 self.user_store.read(cx).edit_prediction_usage()
336 }
337
338 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
339 self.get_or_init_zeta_project(project, cx);
340 }
341
342 pub fn register_buffer(
343 &mut self,
344 buffer: &Entity<Buffer>,
345 project: &Entity<Project>,
346 cx: &mut Context<Self>,
347 ) {
348 let zeta_project = self.get_or_init_zeta_project(project, cx);
349 Self::register_buffer_impl(zeta_project, buffer, project, cx);
350 }
351
352 fn get_or_init_zeta_project(
353 &mut self,
354 project: &Entity<Project>,
355 cx: &mut App,
356 ) -> &mut ZetaProject {
357 self.projects
358 .entry(project.entity_id())
359 .or_insert_with(|| ZetaProject {
360 syntax_index: cx.new(|cx| {
361 SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
362 }),
363 events: VecDeque::new(),
364 registered_buffers: HashMap::default(),
365 current_prediction: None,
366 context: None,
367 refresh_context_task: None,
368 refresh_context_debounce_task: None,
369 refresh_context_timestamp: None,
370 })
371 }
372
373 fn register_buffer_impl<'a>(
374 zeta_project: &'a mut ZetaProject,
375 buffer: &Entity<Buffer>,
376 project: &Entity<Project>,
377 cx: &mut Context<Self>,
378 ) -> &'a mut RegisteredBuffer {
379 let buffer_id = buffer.entity_id();
380 match zeta_project.registered_buffers.entry(buffer_id) {
381 hash_map::Entry::Occupied(entry) => entry.into_mut(),
382 hash_map::Entry::Vacant(entry) => {
383 let snapshot = buffer.read(cx).snapshot();
384 let project_entity_id = project.entity_id();
385 entry.insert(RegisteredBuffer {
386 snapshot,
387 _subscriptions: [
388 cx.subscribe(buffer, {
389 let project = project.downgrade();
390 move |this, buffer, event, cx| {
391 if let language::BufferEvent::Edited = event
392 && let Some(project) = project.upgrade()
393 {
394 this.report_changes_for_buffer(&buffer, &project, cx);
395 }
396 }
397 }),
398 cx.observe_release(buffer, move |this, _buffer, _cx| {
399 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
400 else {
401 return;
402 };
403 zeta_project.registered_buffers.remove(&buffer_id);
404 }),
405 ],
406 })
407 }
408 }
409 }
410
411 fn report_changes_for_buffer(
412 &mut self,
413 buffer: &Entity<Buffer>,
414 project: &Entity<Project>,
415 cx: &mut Context<Self>,
416 ) -> BufferSnapshot {
417 let zeta_project = self.get_or_init_zeta_project(project, cx);
418 let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
419
420 let new_snapshot = buffer.read(cx).snapshot();
421 if new_snapshot.version != registered_buffer.snapshot.version {
422 let old_snapshot =
423 std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
424 Self::push_event(
425 zeta_project,
426 Event::BufferChange {
427 old_snapshot,
428 new_snapshot: new_snapshot.clone(),
429 timestamp: Instant::now(),
430 },
431 );
432 }
433
434 new_snapshot
435 }
436
437 fn push_event(zeta_project: &mut ZetaProject, event: Event) {
438 let events = &mut zeta_project.events;
439
440 if let Some(Event::BufferChange {
441 new_snapshot: last_new_snapshot,
442 timestamp: last_timestamp,
443 ..
444 }) = events.back_mut()
445 {
446 // Coalesce edits for the same buffer when they happen one after the other.
447 let Event::BufferChange {
448 old_snapshot,
449 new_snapshot,
450 timestamp,
451 } = &event;
452
453 if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
454 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
455 && old_snapshot.version == last_new_snapshot.version
456 {
457 *last_new_snapshot = new_snapshot.clone();
458 *last_timestamp = *timestamp;
459 return;
460 }
461 }
462
463 if events.len() >= MAX_EVENT_COUNT {
464 // These are halved instead of popping to improve prompt caching.
465 events.drain(..MAX_EVENT_COUNT / 2);
466 }
467
468 events.push_back(event);
469 }
470
471 fn current_prediction_for_buffer(
472 &self,
473 buffer: &Entity<Buffer>,
474 project: &Entity<Project>,
475 cx: &App,
476 ) -> Option<BufferEditPrediction<'_>> {
477 let project_state = self.projects.get(&project.entity_id())?;
478
479 let CurrentEditPrediction {
480 requested_by_buffer_id,
481 prediction,
482 } = project_state.current_prediction.as_ref()?;
483
484 if prediction.targets_buffer(buffer.read(cx), cx) {
485 Some(BufferEditPrediction::Local { prediction })
486 } else if *requested_by_buffer_id == buffer.entity_id() {
487 Some(BufferEditPrediction::Jump { prediction })
488 } else {
489 None
490 }
491 }
492
493 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
494 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
495 return;
496 };
497
498 let Some(prediction) = project_state.current_prediction.take() else {
499 return;
500 };
501 let request_id = prediction.prediction.id.into();
502
503 let client = self.client.clone();
504 let llm_token = self.llm_token.clone();
505 let app_version = AppVersion::global(cx);
506 cx.spawn(async move |this, cx| {
507 let url = if let Ok(predict_edits_url) = std::env::var("ZED_ACCEPT_PREDICTION_URL") {
508 http_client::Url::parse(&predict_edits_url)?
509 } else {
510 client
511 .http_client()
512 .build_zed_llm_url("/predict_edits/accept", &[])?
513 };
514
515 let response = cx
516 .background_spawn(Self::send_api_request::<()>(
517 move |builder| {
518 let req = builder.uri(url.as_ref()).body(
519 serde_json::to_string(&AcceptEditPredictionBody { request_id })?.into(),
520 );
521 Ok(req?)
522 },
523 client,
524 llm_token,
525 app_version,
526 ))
527 .await;
528
529 Self::handle_api_response(&this, response, cx)?;
530 anyhow::Ok(())
531 })
532 .detach_and_log_err(cx);
533 }
534
535 fn discard_current_prediction(&mut self, project: &Entity<Project>) {
536 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
537 project_state.current_prediction.take();
538 };
539 }
540
541 pub fn refresh_prediction(
542 &mut self,
543 project: &Entity<Project>,
544 buffer: &Entity<Buffer>,
545 position: language::Anchor,
546 cx: &mut Context<Self>,
547 ) -> Task<Result<()>> {
548 let request_task = self.request_prediction(project, buffer, position, cx);
549 let buffer = buffer.clone();
550 let project = project.clone();
551
552 cx.spawn(async move |this, cx| {
553 if let Some(prediction) = request_task.await? {
554 this.update(cx, |this, cx| {
555 let project_state = this
556 .projects
557 .get_mut(&project.entity_id())
558 .context("Project not found")?;
559
560 let new_prediction = CurrentEditPrediction {
561 requested_by_buffer_id: buffer.entity_id(),
562 prediction: prediction,
563 };
564
565 if project_state
566 .current_prediction
567 .as_ref()
568 .is_none_or(|old_prediction| {
569 new_prediction.should_replace_prediction(&old_prediction, cx)
570 })
571 {
572 project_state.current_prediction = Some(new_prediction);
573 }
574 anyhow::Ok(())
575 })??;
576 }
577 Ok(())
578 })
579 }
580
581 fn request_prediction(
582 &mut self,
583 project: &Entity<Project>,
584 buffer: &Entity<Buffer>,
585 position: language::Anchor,
586 cx: &mut Context<Self>,
587 ) -> Task<Result<Option<EditPrediction>>> {
588 let project_state = self.projects.get(&project.entity_id());
589
590 let index_state = project_state.map(|state| {
591 state
592 .syntax_index
593 .read_with(cx, |index, _cx| index.state().clone())
594 });
595 let options = self.options.clone();
596 let snapshot = buffer.read(cx).snapshot();
597 let Some(excerpt_path) = snapshot
598 .file()
599 .map(|path| -> Arc<Path> { path.full_path(cx).into() })
600 else {
601 return Task::ready(Err(anyhow!("No file path for excerpt")));
602 };
603 let client = self.client.clone();
604 let llm_token = self.llm_token.clone();
605 let app_version = AppVersion::global(cx);
606 let worktree_snapshots = project
607 .read(cx)
608 .worktrees(cx)
609 .map(|worktree| worktree.read(cx).snapshot())
610 .collect::<Vec<_>>();
611 let debug_tx = self.debug_tx.clone();
612
613 let events = project_state
614 .map(|state| {
615 state
616 .events
617 .iter()
618 .filter_map(|event| event.to_request_event(cx))
619 .collect::<Vec<_>>()
620 })
621 .unwrap_or_default();
622
623 let diagnostics = snapshot.diagnostic_sets().clone();
624
625 let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
626 let mut path = f.worktree.read(cx).absolutize(&f.path);
627 if path.pop() { Some(path) } else { None }
628 });
629
630 // TODO data collection
631 let can_collect_data = cx.is_staff();
632
633 let mut included_files = project_state
634 .and_then(|project_state| project_state.context.as_ref())
635 .unwrap_or(&HashMap::default())
636 .iter()
637 .filter_map(|(buffer, ranges)| {
638 let buffer = buffer.read(cx);
639 Some((
640 buffer.snapshot(),
641 buffer.file()?.full_path(cx).into(),
642 ranges.clone(),
643 ))
644 })
645 .collect::<Vec<_>>();
646
647 let request_task = cx.background_spawn({
648 let snapshot = snapshot.clone();
649 let buffer = buffer.clone();
650 async move {
651 let index_state = if let Some(index_state) = index_state {
652 Some(index_state.lock_owned().await)
653 } else {
654 None
655 };
656
657 let cursor_offset = position.to_offset(&snapshot);
658 let cursor_point = cursor_offset.to_point(&snapshot);
659
660 let before_retrieval = chrono::Utc::now();
661
662 let (diagnostic_groups, diagnostic_groups_truncated) =
663 Self::gather_nearby_diagnostics(
664 cursor_offset,
665 &diagnostics,
666 &snapshot,
667 options.max_diagnostic_bytes,
668 );
669
670 let request = match options.context {
671 ContextMode::Llm(context_options) => {
672 let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
673 cursor_point,
674 &snapshot,
675 &context_options.excerpt,
676 index_state.as_deref(),
677 ) else {
678 return Ok((None, None));
679 };
680
681 let excerpt_anchor_range = snapshot.anchor_after(excerpt.range.start)
682 ..snapshot.anchor_before(excerpt.range.end);
683
684 if let Some(buffer_ix) = included_files
685 .iter()
686 .position(|(buffer, _, _)| buffer.remote_id() == snapshot.remote_id())
687 {
688 let (buffer, _, ranges) = &mut included_files[buffer_ix];
689 let range_ix = ranges
690 .binary_search_by(|probe| {
691 probe
692 .start
693 .cmp(&excerpt_anchor_range.start, buffer)
694 .then(excerpt_anchor_range.end.cmp(&probe.end, buffer))
695 })
696 .unwrap_or_else(|ix| ix);
697
698 ranges.insert(range_ix, excerpt_anchor_range);
699 let last_ix = included_files.len() - 1;
700 included_files.swap(buffer_ix, last_ix);
701 } else {
702 included_files.push((
703 snapshot,
704 excerpt_path.clone(),
705 vec![excerpt_anchor_range],
706 ));
707 }
708
709 let included_files = included_files
710 .into_iter()
711 .map(|(buffer, path, ranges)| {
712 let excerpts = merge_excerpts(
713 &buffer,
714 ranges.iter().map(|range| {
715 let point_range = range.to_point(&buffer);
716 Line(point_range.start.row)..Line(point_range.end.row)
717 }),
718 );
719 predict_edits_v3::IncludedFile {
720 path,
721 max_row: Line(buffer.max_point().row),
722 excerpts,
723 }
724 })
725 .collect::<Vec<_>>();
726
727 predict_edits_v3::PredictEditsRequest {
728 excerpt_path,
729 excerpt: String::new(),
730 excerpt_line_range: Line(0)..Line(0),
731 excerpt_range: 0..0,
732 cursor_point: predict_edits_v3::Point {
733 line: predict_edits_v3::Line(cursor_point.row),
734 column: cursor_point.column,
735 },
736 included_files,
737 referenced_declarations: vec![],
738 events,
739 can_collect_data,
740 diagnostic_groups,
741 diagnostic_groups_truncated,
742 debug_info: debug_tx.is_some(),
743 prompt_max_bytes: Some(options.max_prompt_bytes),
744 prompt_format: options.prompt_format,
745 // TODO [zeta2]
746 signatures: vec![],
747 excerpt_parent: None,
748 git_info: None,
749 }
750 }
751 ContextMode::Syntax(context_options) => {
752 let Some(context) = EditPredictionContext::gather_context(
753 cursor_point,
754 &snapshot,
755 parent_abs_path.as_deref(),
756 &context_options,
757 index_state.as_deref(),
758 ) else {
759 return Ok((None, None));
760 };
761
762 make_syntax_context_cloud_request(
763 excerpt_path,
764 context,
765 events,
766 can_collect_data,
767 diagnostic_groups,
768 diagnostic_groups_truncated,
769 None,
770 debug_tx.is_some(),
771 &worktree_snapshots,
772 index_state.as_deref(),
773 Some(options.max_prompt_bytes),
774 options.prompt_format,
775 )
776 }
777 };
778
779 let retrieval_time = chrono::Utc::now() - before_retrieval;
780
781 let debug_response_tx = if let Some(debug_tx) = &debug_tx {
782 let (response_tx, response_rx) = oneshot::channel();
783
784 if !request.referenced_declarations.is_empty() || !request.signatures.is_empty()
785 {
786 } else {
787 };
788
789 let local_prompt = build_prompt(&request)
790 .map(|(prompt, _)| prompt)
791 .map_err(|err| err.to_string());
792
793 debug_tx
794 .unbounded_send(PredictionDebugInfo {
795 request: request.clone(),
796 retrieval_time,
797 buffer: buffer.downgrade(),
798 local_prompt,
799 position,
800 response_rx,
801 })
802 .ok();
803 Some(response_tx)
804 } else {
805 None
806 };
807
808 if cfg!(debug_assertions) && std::env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
809 if let Some(debug_response_tx) = debug_response_tx {
810 debug_response_tx
811 .send(Err("Request skipped".to_string()))
812 .ok();
813 }
814 anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
815 }
816
817 let response =
818 Self::send_prediction_request(client, llm_token, app_version, request).await;
819
820 if let Some(debug_response_tx) = debug_response_tx {
821 debug_response_tx
822 .send(
823 response
824 .as_ref()
825 .map_err(|err| err.to_string())
826 .map(|response| response.0.clone()),
827 )
828 .ok();
829 }
830
831 response.map(|(res, usage)| (Some(res), usage))
832 }
833 });
834
835 let buffer = buffer.clone();
836
837 cx.spawn({
838 let project = project.clone();
839 async move |this, cx| {
840 let Some(response) = Self::handle_api_response(&this, request_task.await, cx)?
841 else {
842 return Ok(None);
843 };
844
845 // TODO telemetry: duration, etc
846 Ok(EditPrediction::from_response(response, &snapshot, &buffer, &project, cx).await)
847 }
848 })
849 }
850
851 async fn send_prediction_request(
852 client: Arc<Client>,
853 llm_token: LlmApiToken,
854 app_version: SemanticVersion,
855 request: predict_edits_v3::PredictEditsRequest,
856 ) -> Result<(
857 predict_edits_v3::PredictEditsResponse,
858 Option<EditPredictionUsage>,
859 )> {
860 let url = if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
861 http_client::Url::parse(&predict_edits_url)?
862 } else {
863 client
864 .http_client()
865 .build_zed_llm_url("/predict_edits/v3", &[])?
866 };
867
868 Self::send_api_request(
869 |builder| {
870 let req = builder
871 .uri(url.as_ref())
872 .body(serde_json::to_string(&request)?.into());
873 Ok(req?)
874 },
875 client,
876 llm_token,
877 app_version,
878 )
879 .await
880 }
881
882 fn handle_api_response<T>(
883 this: &WeakEntity<Self>,
884 response: Result<(T, Option<EditPredictionUsage>)>,
885 cx: &mut gpui::AsyncApp,
886 ) -> Result<T> {
887 match response {
888 Ok((data, usage)) => {
889 if let Some(usage) = usage {
890 this.update(cx, |this, cx| {
891 this.user_store.update(cx, |user_store, cx| {
892 user_store.update_edit_prediction_usage(usage, cx);
893 });
894 })
895 .ok();
896 }
897 Ok(data)
898 }
899 Err(err) => {
900 if err.is::<ZedUpdateRequiredError>() {
901 cx.update(|cx| {
902 this.update(cx, |this, _cx| {
903 this.update_required = true;
904 })
905 .ok();
906
907 let error_message: SharedString = err.to_string().into();
908 show_app_notification(
909 NotificationId::unique::<ZedUpdateRequiredError>(),
910 cx,
911 move |cx| {
912 cx.new(|cx| {
913 ErrorMessagePrompt::new(error_message.clone(), cx)
914 .with_link_button("Update Zed", "https://zed.dev/releases")
915 })
916 },
917 );
918 })
919 .ok();
920 }
921 Err(err)
922 }
923 }
924 }
925
926 async fn send_api_request<Res>(
927 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
928 client: Arc<Client>,
929 llm_token: LlmApiToken,
930 app_version: SemanticVersion,
931 ) -> Result<(Res, Option<EditPredictionUsage>)>
932 where
933 Res: DeserializeOwned,
934 {
935 let http_client = client.http_client();
936 let mut token = llm_token.acquire(&client).await?;
937 let mut did_retry = false;
938
939 loop {
940 let request_builder = http_client::Request::builder().method(Method::POST);
941
942 let request = build(
943 request_builder
944 .header("Content-Type", "application/json")
945 .header("Authorization", format!("Bearer {}", token))
946 .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
947 )?;
948
949 let mut response = http_client.send(request).await?;
950
951 if let Some(minimum_required_version) = response
952 .headers()
953 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
954 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
955 {
956 anyhow::ensure!(
957 app_version >= minimum_required_version,
958 ZedUpdateRequiredError {
959 minimum_version: minimum_required_version
960 }
961 );
962 }
963
964 if response.status().is_success() {
965 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
966
967 let mut body = Vec::new();
968 response.body_mut().read_to_end(&mut body).await?;
969 return Ok((serde_json::from_slice(&body)?, usage));
970 } else if !did_retry
971 && response
972 .headers()
973 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
974 .is_some()
975 {
976 did_retry = true;
977 token = llm_token.refresh(&client).await?;
978 } else {
979 let mut body = String::new();
980 response.body_mut().read_to_string(&mut body).await?;
981 anyhow::bail!(
982 "Request failed with status: {:?}\nBody: {}",
983 response.status(),
984 body
985 );
986 }
987 }
988 }
989
990 pub const CONTEXT_RETRIEVAL_IDLE_DURATION: Duration = Duration::from_secs(10);
991 pub const CONTEXT_RETRIEVAL_DEBOUNCE_DURATION: Duration = Duration::from_secs(3);
992
993 // Refresh the related excerpts when the user just beguns editing after
994 // an idle period, and after they pause editing.
995 fn refresh_context_if_needed(
996 &mut self,
997 project: &Entity<Project>,
998 buffer: &Entity<language::Buffer>,
999 cursor_position: language::Anchor,
1000 cx: &mut Context<Self>,
1001 ) {
1002 if !matches!(&self.options().context, ContextMode::Llm { .. }) {
1003 return;
1004 }
1005
1006 let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1007 return;
1008 };
1009
1010 let now = Instant::now();
1011 let was_idle = zeta_project
1012 .refresh_context_timestamp
1013 .map_or(true, |timestamp| {
1014 now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION
1015 });
1016 zeta_project.refresh_context_timestamp = Some(now);
1017 zeta_project.refresh_context_debounce_task = Some(cx.spawn({
1018 let buffer = buffer.clone();
1019 let project = project.clone();
1020 async move |this, cx| {
1021 if was_idle {
1022 log::debug!("refetching edit prediction context after idle");
1023 } else {
1024 cx.background_executor()
1025 .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION)
1026 .await;
1027 log::debug!("refetching edit prediction context after pause");
1028 }
1029 this.update(cx, |this, cx| {
1030 this.refresh_context(project, buffer, cursor_position, cx);
1031 })
1032 .ok()
1033 }
1034 }));
1035 }
1036
1037 // Refresh the related excerpts asynchronously. Ensure the task runs to completion,
1038 // and avoid spawning more than one concurrent task.
1039 fn refresh_context(
1040 &mut self,
1041 project: Entity<Project>,
1042 buffer: Entity<language::Buffer>,
1043 cursor_position: language::Anchor,
1044 cx: &mut Context<Self>,
1045 ) {
1046 let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1047 return;
1048 };
1049
1050 zeta_project
1051 .refresh_context_task
1052 .get_or_insert(cx.spawn(async move |this, cx| {
1053 let related_excerpts = this
1054 .update(cx, |this, cx| {
1055 let Some(zeta_project) = this.projects.get(&project.entity_id()) else {
1056 return Task::ready(anyhow::Ok(HashMap::default()));
1057 };
1058
1059 let ContextMode::Llm(options) = &this.options().context else {
1060 return Task::ready(anyhow::Ok(HashMap::default()));
1061 };
1062
1063 find_related_excerpts(
1064 buffer.clone(),
1065 cursor_position,
1066 &project,
1067 zeta_project.events.iter(),
1068 options,
1069 cx,
1070 )
1071 })
1072 .ok()?
1073 .await
1074 .log_err()
1075 .unwrap_or_default();
1076 this.update(cx, |this, _cx| {
1077 let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
1078 return;
1079 };
1080 zeta_project.context = Some(related_excerpts);
1081 zeta_project.refresh_context_task.take();
1082 })
1083 .ok()
1084 }));
1085 }
1086
1087 fn gather_nearby_diagnostics(
1088 cursor_offset: usize,
1089 diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
1090 snapshot: &BufferSnapshot,
1091 max_diagnostics_bytes: usize,
1092 ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
1093 // TODO: Could make this more efficient
1094 let mut diagnostic_groups = Vec::new();
1095 for (language_server_id, diagnostics) in diagnostic_sets {
1096 let mut groups = Vec::new();
1097 diagnostics.groups(*language_server_id, &mut groups, &snapshot);
1098 diagnostic_groups.extend(
1099 groups
1100 .into_iter()
1101 .map(|(_, group)| group.resolve::<usize>(&snapshot)),
1102 );
1103 }
1104
1105 // sort by proximity to cursor
1106 diagnostic_groups.sort_by_key(|group| {
1107 let range = &group.entries[group.primary_ix].range;
1108 if range.start >= cursor_offset {
1109 range.start - cursor_offset
1110 } else if cursor_offset >= range.end {
1111 cursor_offset - range.end
1112 } else {
1113 (cursor_offset - range.start).min(range.end - cursor_offset)
1114 }
1115 });
1116
1117 let mut results = Vec::new();
1118 let mut diagnostic_groups_truncated = false;
1119 let mut diagnostics_byte_count = 0;
1120 for group in diagnostic_groups {
1121 let raw_value = serde_json::value::to_raw_value(&group).unwrap();
1122 diagnostics_byte_count += raw_value.get().len();
1123 if diagnostics_byte_count > max_diagnostics_bytes {
1124 diagnostic_groups_truncated = true;
1125 break;
1126 }
1127 results.push(predict_edits_v3::DiagnosticGroup(raw_value));
1128 }
1129
1130 (results, diagnostic_groups_truncated)
1131 }
1132
1133 // TODO: Dedupe with similar code in request_prediction?
1134 pub fn cloud_request_for_zeta_cli(
1135 &mut self,
1136 project: &Entity<Project>,
1137 buffer: &Entity<Buffer>,
1138 position: language::Anchor,
1139 cx: &mut Context<Self>,
1140 ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
1141 let project_state = self.projects.get(&project.entity_id());
1142
1143 let index_state = project_state.map(|state| {
1144 state
1145 .syntax_index
1146 .read_with(cx, |index, _cx| index.state().clone())
1147 });
1148 let options = self.options.clone();
1149 let snapshot = buffer.read(cx).snapshot();
1150 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
1151 return Task::ready(Err(anyhow!("No file path for excerpt")));
1152 };
1153 let worktree_snapshots = project
1154 .read(cx)
1155 .worktrees(cx)
1156 .map(|worktree| worktree.read(cx).snapshot())
1157 .collect::<Vec<_>>();
1158
1159 let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
1160 let mut path = f.worktree.read(cx).absolutize(&f.path);
1161 if path.pop() { Some(path) } else { None }
1162 });
1163
1164 cx.background_spawn(async move {
1165 let index_state = if let Some(index_state) = index_state {
1166 Some(index_state.lock_owned().await)
1167 } else {
1168 None
1169 };
1170
1171 let cursor_point = position.to_point(&snapshot);
1172
1173 let debug_info = true;
1174 EditPredictionContext::gather_context(
1175 cursor_point,
1176 &snapshot,
1177 parent_abs_path.as_deref(),
1178 match &options.context {
1179 ContextMode::Llm(_) => {
1180 // TODO
1181 panic!("Llm mode not supported in zeta cli yet");
1182 }
1183 ContextMode::Syntax(edit_prediction_context_options) => {
1184 edit_prediction_context_options
1185 }
1186 },
1187 index_state.as_deref(),
1188 )
1189 .context("Failed to select excerpt")
1190 .map(|context| {
1191 make_syntax_context_cloud_request(
1192 excerpt_path.into(),
1193 context,
1194 // TODO pass everything
1195 Vec::new(),
1196 false,
1197 Vec::new(),
1198 false,
1199 None,
1200 debug_info,
1201 &worktree_snapshots,
1202 index_state.as_deref(),
1203 Some(options.max_prompt_bytes),
1204 options.prompt_format,
1205 )
1206 })
1207 })
1208 }
1209
1210 pub fn wait_for_initial_indexing(
1211 &mut self,
1212 project: &Entity<Project>,
1213 cx: &mut App,
1214 ) -> Task<Result<()>> {
1215 let zeta_project = self.get_or_init_zeta_project(project, cx);
1216 zeta_project
1217 .syntax_index
1218 .read(cx)
1219 .wait_for_initial_file_indexing(cx)
1220 }
1221}
1222
1223#[derive(Error, Debug)]
1224#[error(
1225 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1226)]
1227pub struct ZedUpdateRequiredError {
1228 minimum_version: SemanticVersion,
1229}
1230
1231fn make_syntax_context_cloud_request(
1232 excerpt_path: Arc<Path>,
1233 context: EditPredictionContext,
1234 events: Vec<predict_edits_v3::Event>,
1235 can_collect_data: bool,
1236 diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
1237 diagnostic_groups_truncated: bool,
1238 git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
1239 debug_info: bool,
1240 worktrees: &Vec<worktree::Snapshot>,
1241 index_state: Option<&SyntaxIndexState>,
1242 prompt_max_bytes: Option<usize>,
1243 prompt_format: PromptFormat,
1244) -> predict_edits_v3::PredictEditsRequest {
1245 let mut signatures = Vec::new();
1246 let mut declaration_to_signature_index = HashMap::default();
1247 let mut referenced_declarations = Vec::new();
1248
1249 for snippet in context.declarations {
1250 let project_entry_id = snippet.declaration.project_entry_id();
1251 let Some(path) = worktrees.iter().find_map(|worktree| {
1252 worktree.entry_for_id(project_entry_id).map(|entry| {
1253 let mut full_path = RelPathBuf::new();
1254 full_path.push(worktree.root_name());
1255 full_path.push(&entry.path);
1256 full_path
1257 })
1258 }) else {
1259 continue;
1260 };
1261
1262 let parent_index = index_state.and_then(|index_state| {
1263 snippet.declaration.parent().and_then(|parent| {
1264 add_signature(
1265 parent,
1266 &mut declaration_to_signature_index,
1267 &mut signatures,
1268 index_state,
1269 )
1270 })
1271 });
1272
1273 let (text, text_is_truncated) = snippet.declaration.item_text();
1274 referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
1275 path: path.as_std_path().into(),
1276 text: text.into(),
1277 range: snippet.declaration.item_line_range(),
1278 text_is_truncated,
1279 signature_range: snippet.declaration.signature_range_in_item_text(),
1280 parent_index,
1281 signature_score: snippet.score(DeclarationStyle::Signature),
1282 declaration_score: snippet.score(DeclarationStyle::Declaration),
1283 score_components: snippet.components,
1284 });
1285 }
1286
1287 let excerpt_parent = index_state.and_then(|index_state| {
1288 context
1289 .excerpt
1290 .parent_declarations
1291 .last()
1292 .and_then(|(parent, _)| {
1293 add_signature(
1294 *parent,
1295 &mut declaration_to_signature_index,
1296 &mut signatures,
1297 index_state,
1298 )
1299 })
1300 });
1301
1302 predict_edits_v3::PredictEditsRequest {
1303 excerpt_path,
1304 excerpt: context.excerpt_text.body,
1305 excerpt_line_range: context.excerpt.line_range,
1306 excerpt_range: context.excerpt.range,
1307 cursor_point: predict_edits_v3::Point {
1308 line: predict_edits_v3::Line(context.cursor_point.row),
1309 column: context.cursor_point.column,
1310 },
1311 referenced_declarations,
1312 included_files: vec![],
1313 signatures,
1314 excerpt_parent,
1315 events,
1316 can_collect_data,
1317 diagnostic_groups,
1318 diagnostic_groups_truncated,
1319 git_info,
1320 debug_info,
1321 prompt_max_bytes,
1322 prompt_format,
1323 }
1324}
1325
1326fn add_signature(
1327 declaration_id: DeclarationId,
1328 declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
1329 signatures: &mut Vec<Signature>,
1330 index: &SyntaxIndexState,
1331) -> Option<usize> {
1332 if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
1333 return Some(*signature_index);
1334 }
1335 let Some(parent_declaration) = index.declaration(declaration_id) else {
1336 log::error!("bug: missing parent declaration");
1337 return None;
1338 };
1339 let parent_index = parent_declaration.parent().and_then(|parent| {
1340 add_signature(parent, declaration_to_signature_index, signatures, index)
1341 });
1342 let (text, text_is_truncated) = parent_declaration.signature_text();
1343 let signature_index = signatures.len();
1344 signatures.push(Signature {
1345 text: text.into(),
1346 text_is_truncated,
1347 parent_index,
1348 range: parent_declaration.signature_line_range(),
1349 });
1350 declaration_to_signature_index.insert(declaration_id, signature_index);
1351 Some(signature_index)
1352}
1353
1354#[cfg(test)]
1355mod tests {
1356 use std::{
1357 path::{Path, PathBuf},
1358 sync::Arc,
1359 };
1360
1361 use client::UserStore;
1362 use clock::FakeSystemClock;
1363 use cloud_llm_client::predict_edits_v3::{self, Point};
1364 use edit_prediction_context::Line;
1365 use futures::{
1366 AsyncReadExt, StreamExt,
1367 channel::{mpsc, oneshot},
1368 };
1369 use gpui::{
1370 Entity, TestAppContext,
1371 http_client::{FakeHttpClient, Response},
1372 prelude::*,
1373 };
1374 use indoc::indoc;
1375 use language::{LanguageServerId, OffsetRangeExt as _};
1376 use pretty_assertions::{assert_eq, assert_matches};
1377 use project::{FakeFs, Project};
1378 use serde_json::json;
1379 use settings::SettingsStore;
1380 use util::path;
1381 use uuid::Uuid;
1382
1383 use crate::{BufferEditPrediction, Zeta};
1384
1385 #[gpui::test]
1386 async fn test_current_state(cx: &mut TestAppContext) {
1387 let (zeta, mut req_rx) = init_test(cx);
1388 let fs = FakeFs::new(cx.executor());
1389 fs.insert_tree(
1390 "/root",
1391 json!({
1392 "1.txt": "Hello!\nHow\nBye",
1393 "2.txt": "Hola!\nComo\nAdios"
1394 }),
1395 )
1396 .await;
1397 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1398
1399 zeta.update(cx, |zeta, cx| {
1400 zeta.register_project(&project, cx);
1401 });
1402
1403 let buffer1 = project
1404 .update(cx, |project, cx| {
1405 let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
1406 project.open_buffer(path, cx)
1407 })
1408 .await
1409 .unwrap();
1410 let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
1411 let position = snapshot1.anchor_before(language::Point::new(1, 3));
1412
1413 // Prediction for current file
1414
1415 let prediction_task = zeta.update(cx, |zeta, cx| {
1416 zeta.refresh_prediction(&project, &buffer1, position, cx)
1417 });
1418 let (_request, respond_tx) = req_rx.next().await.unwrap();
1419 respond_tx
1420 .send(predict_edits_v3::PredictEditsResponse {
1421 request_id: Uuid::new_v4(),
1422 edits: vec![predict_edits_v3::Edit {
1423 path: Path::new(path!("root/1.txt")).into(),
1424 range: Line(0)..Line(snapshot1.max_point().row + 1),
1425 content: "Hello!\nHow are you?\nBye".into(),
1426 }],
1427 debug_info: None,
1428 })
1429 .unwrap();
1430 prediction_task.await.unwrap();
1431
1432 zeta.read_with(cx, |zeta, cx| {
1433 let prediction = zeta
1434 .current_prediction_for_buffer(&buffer1, &project, cx)
1435 .unwrap();
1436 assert_matches!(prediction, BufferEditPrediction::Local { .. });
1437 });
1438
1439 // Prediction for another file
1440 let prediction_task = zeta.update(cx, |zeta, cx| {
1441 zeta.refresh_prediction(&project, &buffer1, position, cx)
1442 });
1443 let (_request, respond_tx) = req_rx.next().await.unwrap();
1444 respond_tx
1445 .send(predict_edits_v3::PredictEditsResponse {
1446 request_id: Uuid::new_v4(),
1447 edits: vec![predict_edits_v3::Edit {
1448 path: Path::new(path!("root/2.txt")).into(),
1449 range: Line(0)..Line(snapshot1.max_point().row + 1),
1450 content: "Hola!\nComo estas?\nAdios".into(),
1451 }],
1452 debug_info: None,
1453 })
1454 .unwrap();
1455 prediction_task.await.unwrap();
1456 zeta.read_with(cx, |zeta, cx| {
1457 let prediction = zeta
1458 .current_prediction_for_buffer(&buffer1, &project, cx)
1459 .unwrap();
1460 assert_matches!(
1461 prediction,
1462 BufferEditPrediction::Jump { prediction } if prediction.path.as_ref() == Path::new(path!("root/2.txt"))
1463 );
1464 });
1465
1466 let buffer2 = project
1467 .update(cx, |project, cx| {
1468 let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
1469 project.open_buffer(path, cx)
1470 })
1471 .await
1472 .unwrap();
1473
1474 zeta.read_with(cx, |zeta, cx| {
1475 let prediction = zeta
1476 .current_prediction_for_buffer(&buffer2, &project, cx)
1477 .unwrap();
1478 assert_matches!(prediction, BufferEditPrediction::Local { .. });
1479 });
1480 }
1481
1482 #[gpui::test]
1483 async fn test_simple_request(cx: &mut TestAppContext) {
1484 let (zeta, mut req_rx) = init_test(cx);
1485 let fs = FakeFs::new(cx.executor());
1486 fs.insert_tree(
1487 "/root",
1488 json!({
1489 "foo.md": "Hello!\nHow\nBye"
1490 }),
1491 )
1492 .await;
1493 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1494
1495 let buffer = project
1496 .update(cx, |project, cx| {
1497 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1498 project.open_buffer(path, cx)
1499 })
1500 .await
1501 .unwrap();
1502 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1503 let position = snapshot.anchor_before(language::Point::new(1, 3));
1504
1505 let prediction_task = zeta.update(cx, |zeta, cx| {
1506 zeta.request_prediction(&project, &buffer, position, cx)
1507 });
1508
1509 let (request, respond_tx) = req_rx.next().await.unwrap();
1510 assert_eq!(
1511 request.excerpt_path.as_ref(),
1512 Path::new(path!("root/foo.md"))
1513 );
1514 assert_eq!(
1515 request.cursor_point,
1516 Point {
1517 line: Line(1),
1518 column: 3
1519 }
1520 );
1521
1522 respond_tx
1523 .send(predict_edits_v3::PredictEditsResponse {
1524 request_id: Uuid::new_v4(),
1525 edits: vec![predict_edits_v3::Edit {
1526 path: Path::new(path!("root/foo.md")).into(),
1527 range: Line(0)..Line(snapshot.max_point().row + 1),
1528 content: "Hello!\nHow are you?\nBye".into(),
1529 }],
1530 debug_info: None,
1531 })
1532 .unwrap();
1533
1534 let prediction = prediction_task.await.unwrap().unwrap();
1535
1536 assert_eq!(prediction.edits.len(), 1);
1537 assert_eq!(
1538 prediction.edits[0].0.to_point(&snapshot).start,
1539 language::Point::new(1, 3)
1540 );
1541 assert_eq!(prediction.edits[0].1, " are you?");
1542 }
1543
1544 #[gpui::test]
1545 async fn test_request_events(cx: &mut TestAppContext) {
1546 let (zeta, mut req_rx) = init_test(cx);
1547 let fs = FakeFs::new(cx.executor());
1548 fs.insert_tree(
1549 "/root",
1550 json!({
1551 "foo.md": "Hello!\n\nBye"
1552 }),
1553 )
1554 .await;
1555 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1556
1557 let buffer = project
1558 .update(cx, |project, cx| {
1559 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1560 project.open_buffer(path, cx)
1561 })
1562 .await
1563 .unwrap();
1564
1565 zeta.update(cx, |zeta, cx| {
1566 zeta.register_buffer(&buffer, &project, cx);
1567 });
1568
1569 buffer.update(cx, |buffer, cx| {
1570 buffer.edit(vec![(7..7, "How")], None, cx);
1571 });
1572
1573 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1574 let position = snapshot.anchor_before(language::Point::new(1, 3));
1575
1576 let prediction_task = zeta.update(cx, |zeta, cx| {
1577 zeta.request_prediction(&project, &buffer, position, cx)
1578 });
1579
1580 let (request, respond_tx) = req_rx.next().await.unwrap();
1581
1582 assert_eq!(request.events.len(), 1);
1583 assert_eq!(
1584 request.events[0],
1585 predict_edits_v3::Event::BufferChange {
1586 path: Some(PathBuf::from(path!("root/foo.md"))),
1587 old_path: None,
1588 diff: indoc! {"
1589 @@ -1,3 +1,3 @@
1590 Hello!
1591 -
1592 +How
1593 Bye
1594 "}
1595 .to_string(),
1596 predicted: false
1597 }
1598 );
1599
1600 respond_tx
1601 .send(predict_edits_v3::PredictEditsResponse {
1602 request_id: Uuid::new_v4(),
1603 edits: vec![predict_edits_v3::Edit {
1604 path: Path::new(path!("root/foo.md")).into(),
1605 range: Line(0)..Line(snapshot.max_point().row + 1),
1606 content: "Hello!\nHow are you?\nBye".into(),
1607 }],
1608 debug_info: None,
1609 })
1610 .unwrap();
1611
1612 let prediction = prediction_task.await.unwrap().unwrap();
1613
1614 assert_eq!(prediction.edits.len(), 1);
1615 assert_eq!(
1616 prediction.edits[0].0.to_point(&snapshot).start,
1617 language::Point::new(1, 3)
1618 );
1619 assert_eq!(prediction.edits[0].1, " are you?");
1620 }
1621
1622 #[gpui::test]
1623 async fn test_request_diagnostics(cx: &mut TestAppContext) {
1624 let (zeta, mut req_rx) = init_test(cx);
1625 let fs = FakeFs::new(cx.executor());
1626 fs.insert_tree(
1627 "/root",
1628 json!({
1629 "foo.md": "Hello!\nBye"
1630 }),
1631 )
1632 .await;
1633 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1634
1635 let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1636 let diagnostic = lsp::Diagnostic {
1637 range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1638 severity: Some(lsp::DiagnosticSeverity::ERROR),
1639 message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1640 ..Default::default()
1641 };
1642
1643 project.update(cx, |project, cx| {
1644 project.lsp_store().update(cx, |lsp_store, cx| {
1645 // Create some diagnostics
1646 lsp_store
1647 .update_diagnostics(
1648 LanguageServerId(0),
1649 lsp::PublishDiagnosticsParams {
1650 uri: path_to_buffer_uri.clone(),
1651 diagnostics: vec![diagnostic],
1652 version: None,
1653 },
1654 None,
1655 language::DiagnosticSourceKind::Pushed,
1656 &[],
1657 cx,
1658 )
1659 .unwrap();
1660 });
1661 });
1662
1663 let buffer = project
1664 .update(cx, |project, cx| {
1665 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1666 project.open_buffer(path, cx)
1667 })
1668 .await
1669 .unwrap();
1670
1671 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1672 let position = snapshot.anchor_before(language::Point::new(0, 0));
1673
1674 let _prediction_task = zeta.update(cx, |zeta, cx| {
1675 zeta.request_prediction(&project, &buffer, position, cx)
1676 });
1677
1678 let (request, _respond_tx) = req_rx.next().await.unwrap();
1679
1680 assert_eq!(request.diagnostic_groups.len(), 1);
1681 let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1682 .unwrap();
1683 // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1684 assert_eq!(
1685 value,
1686 json!({
1687 "entries": [{
1688 "range": {
1689 "start": 8,
1690 "end": 10
1691 },
1692 "diagnostic": {
1693 "source": null,
1694 "code": null,
1695 "code_description": null,
1696 "severity": 1,
1697 "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1698 "markdown": null,
1699 "group_id": 0,
1700 "is_primary": true,
1701 "is_disk_based": false,
1702 "is_unnecessary": false,
1703 "source_kind": "Pushed",
1704 "data": null,
1705 "underline": true
1706 }
1707 }],
1708 "primary_ix": 0
1709 })
1710 );
1711 }
1712
1713 fn init_test(
1714 cx: &mut TestAppContext,
1715 ) -> (
1716 Entity<Zeta>,
1717 mpsc::UnboundedReceiver<(
1718 predict_edits_v3::PredictEditsRequest,
1719 oneshot::Sender<predict_edits_v3::PredictEditsResponse>,
1720 )>,
1721 ) {
1722 cx.update(move |cx| {
1723 let settings_store = SettingsStore::test(cx);
1724 cx.set_global(settings_store);
1725 language::init(cx);
1726 Project::init_settings(cx);
1727
1728 let (req_tx, req_rx) = mpsc::unbounded();
1729
1730 let http_client = FakeHttpClient::create({
1731 move |req| {
1732 let uri = req.uri().path().to_string();
1733 let mut body = req.into_body();
1734 let req_tx = req_tx.clone();
1735 async move {
1736 let resp = match uri.as_str() {
1737 "/client/llm_tokens" => serde_json::to_string(&json!({
1738 "token": "test"
1739 }))
1740 .unwrap(),
1741 "/predict_edits/v3" => {
1742 let mut buf = Vec::new();
1743 body.read_to_end(&mut buf).await.ok();
1744 let req = serde_json::from_slice(&buf).unwrap();
1745
1746 let (res_tx, res_rx) = oneshot::channel();
1747 req_tx.unbounded_send((req, res_tx)).unwrap();
1748 serde_json::to_string(&res_rx.await?).unwrap()
1749 }
1750 _ => {
1751 panic!("Unexpected path: {}", uri)
1752 }
1753 };
1754
1755 Ok(Response::builder().body(resp.into()).unwrap())
1756 }
1757 }
1758 });
1759
1760 let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1761 client.cloud_client().set_credentials(1, "test".into());
1762
1763 language_model::init(client.clone(), cx);
1764
1765 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1766 let zeta = Zeta::global(&client, &user_store, cx);
1767
1768 (zeta, req_rx)
1769 })
1770 }
1771}