1use anyhow::{Context as _, Result, anyhow};
2use chrono::TimeDelta;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
5use cloud_llm_client::{
6 AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
7 ZED_VERSION_HEADER_NAME,
8};
9use cloud_zeta2_prompt::{DEFAULT_MAX_PROMPT_BYTES, build_prompt};
10use collections::HashMap;
11use edit_prediction_context::{
12 DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
13 EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionScoreOptions, Line,
14 SyntaxIndex, SyntaxIndexState,
15};
16use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
17use futures::AsyncReadExt as _;
18use futures::channel::{mpsc, oneshot};
19use gpui::http_client::{AsyncBody, Method};
20use gpui::{
21 App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
22 http_client, prelude::*,
23};
24use language::{Anchor, Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
25use language::{BufferSnapshot, OffsetRangeExt};
26use language_model::{LlmApiToken, RefreshLlmTokenListener};
27use project::Project;
28use release_channel::AppVersion;
29use serde::de::DeserializeOwned;
30use std::collections::{VecDeque, hash_map};
31use std::fmt::Write;
32use std::ops::Range;
33use std::path::Path;
34use std::str::FromStr as _;
35use std::sync::Arc;
36use std::time::{Duration, Instant};
37use thiserror::Error;
38use util::ResultExt as _;
39use util::rel_path::RelPathBuf;
40use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
41
42pub mod merge_excerpts;
43mod prediction;
44mod provider;
45pub mod related_excerpts;
46
47use crate::merge_excerpts::merge_excerpts;
48use crate::prediction::EditPrediction;
49use crate::related_excerpts::find_related_excerpts;
50pub use crate::related_excerpts::{LlmContextOptions, SearchToolQuery};
51pub use provider::ZetaEditPredictionProvider;
52
53/// Maximum number of events to track.
54const MAX_EVENT_COUNT: usize = 16;
55
56pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
57 max_bytes: 512,
58 min_bytes: 128,
59 target_before_cursor_over_total_bytes: 0.5,
60};
61
62pub const DEFAULT_CONTEXT_OPTIONS: ContextMode = ContextMode::Llm(DEFAULT_LLM_CONTEXT_OPTIONS);
63
64pub const DEFAULT_LLM_CONTEXT_OPTIONS: LlmContextOptions = LlmContextOptions {
65 excerpt: DEFAULT_EXCERPT_OPTIONS,
66};
67
68pub const DEFAULT_SYNTAX_CONTEXT_OPTIONS: EditPredictionContextOptions =
69 EditPredictionContextOptions {
70 use_imports: true,
71 max_retrieved_declarations: 0,
72 excerpt: DEFAULT_EXCERPT_OPTIONS,
73 score: EditPredictionScoreOptions {
74 omit_excerpt_overlaps: true,
75 },
76 };
77
78pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
79 context: DEFAULT_CONTEXT_OPTIONS,
80 max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
81 max_diagnostic_bytes: 2048,
82 prompt_format: PromptFormat::DEFAULT,
83 file_indexing_parallelism: 1,
84 buffer_change_grouping_interval: Duration::from_secs(1),
85};
86
87pub struct Zeta2FeatureFlag;
88
89impl FeatureFlag for Zeta2FeatureFlag {
90 const NAME: &'static str = "zeta2";
91
92 fn enabled_for_staff() -> bool {
93 false
94 }
95}
96
97#[derive(Clone)]
98struct ZetaGlobal(Entity<Zeta>);
99
100impl Global for ZetaGlobal {}
101
102pub struct Zeta {
103 client: Arc<Client>,
104 user_store: Entity<UserStore>,
105 llm_token: LlmApiToken,
106 _llm_token_subscription: Subscription,
107 projects: HashMap<EntityId, ZetaProject>,
108 options: ZetaOptions,
109 update_required: bool,
110 debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
111}
112
113#[derive(Debug, Clone, PartialEq)]
114pub struct ZetaOptions {
115 pub context: ContextMode,
116 pub max_prompt_bytes: usize,
117 pub max_diagnostic_bytes: usize,
118 pub prompt_format: predict_edits_v3::PromptFormat,
119 pub file_indexing_parallelism: usize,
120 pub buffer_change_grouping_interval: Duration,
121}
122
123#[derive(Debug, Clone, PartialEq)]
124pub enum ContextMode {
125 Llm(LlmContextOptions),
126 Syntax(EditPredictionContextOptions),
127}
128
129impl ContextMode {
130 pub fn excerpt(&self) -> &EditPredictionExcerptOptions {
131 match self {
132 ContextMode::Llm(options) => &options.excerpt,
133 ContextMode::Syntax(options) => &options.excerpt,
134 }
135 }
136}
137
138#[derive(Debug)]
139pub enum ZetaDebugInfo {
140 ContextRetrievalStarted(ZetaContextRetrievalStartedDebugInfo),
141 SearchQueriesGenerated(ZetaSearchQueryDebugInfo),
142 SearchQueriesExecuted(ZetaContextRetrievalDebugInfo),
143 SearchResultsFiltered(ZetaContextRetrievalDebugInfo),
144 ContextRetrievalFinished(ZetaContextRetrievalDebugInfo),
145 EditPredicted(ZetaEditPredictionDebugInfo),
146}
147
148#[derive(Debug)]
149pub struct ZetaContextRetrievalStartedDebugInfo {
150 pub project: Entity<Project>,
151 pub timestamp: Instant,
152 pub search_prompt: String,
153}
154
155#[derive(Debug)]
156pub struct ZetaContextRetrievalDebugInfo {
157 pub project: Entity<Project>,
158 pub timestamp: Instant,
159}
160
161#[derive(Debug)]
162pub struct ZetaEditPredictionDebugInfo {
163 pub request: predict_edits_v3::PredictEditsRequest,
164 pub retrieval_time: TimeDelta,
165 pub buffer: WeakEntity<Buffer>,
166 pub position: language::Anchor,
167 pub local_prompt: Result<String, String>,
168 pub response_rx: oneshot::Receiver<Result<predict_edits_v3::PredictEditsResponse, String>>,
169}
170
171#[derive(Debug)]
172pub struct ZetaSearchQueryDebugInfo {
173 pub project: Entity<Project>,
174 pub timestamp: Instant,
175 pub queries: Vec<SearchToolQuery>,
176}
177
178pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
179
180struct ZetaProject {
181 syntax_index: Entity<SyntaxIndex>,
182 events: VecDeque<Event>,
183 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
184 current_prediction: Option<CurrentEditPrediction>,
185 context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
186 refresh_context_task: Option<Task<Option<()>>>,
187 refresh_context_debounce_task: Option<Task<Option<()>>>,
188 refresh_context_timestamp: Option<Instant>,
189}
190
191#[derive(Debug, Clone)]
192struct CurrentEditPrediction {
193 pub requested_by_buffer_id: EntityId,
194 pub prediction: EditPrediction,
195}
196
197impl CurrentEditPrediction {
198 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
199 let Some(new_edits) = self
200 .prediction
201 .interpolate(&self.prediction.buffer.read(cx))
202 else {
203 return false;
204 };
205
206 if self.prediction.buffer != old_prediction.prediction.buffer {
207 return true;
208 }
209
210 let Some(old_edits) = old_prediction
211 .prediction
212 .interpolate(&old_prediction.prediction.buffer.read(cx))
213 else {
214 return true;
215 };
216
217 // This reduces the occurrence of UI thrash from replacing edits
218 //
219 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
220 if self.requested_by_buffer_id == self.prediction.buffer.entity_id()
221 && self.requested_by_buffer_id == old_prediction.prediction.buffer.entity_id()
222 && old_edits.len() == 1
223 && new_edits.len() == 1
224 {
225 let (old_range, old_text) = &old_edits[0];
226 let (new_range, new_text) = &new_edits[0];
227 new_range == old_range && new_text.starts_with(old_text)
228 } else {
229 true
230 }
231 }
232}
233
234/// A prediction from the perspective of a buffer.
235#[derive(Debug)]
236enum BufferEditPrediction<'a> {
237 Local { prediction: &'a EditPrediction },
238 Jump { prediction: &'a EditPrediction },
239}
240
241struct RegisteredBuffer {
242 snapshot: BufferSnapshot,
243 _subscriptions: [gpui::Subscription; 2],
244}
245
246#[derive(Clone)]
247pub enum Event {
248 BufferChange {
249 old_snapshot: BufferSnapshot,
250 new_snapshot: BufferSnapshot,
251 timestamp: Instant,
252 },
253}
254
255impl Event {
256 pub fn to_request_event(&self, cx: &App) -> Option<predict_edits_v3::Event> {
257 match self {
258 Event::BufferChange {
259 old_snapshot,
260 new_snapshot,
261 ..
262 } => {
263 let path = new_snapshot.file().map(|f| f.full_path(cx));
264
265 let old_path = old_snapshot.file().and_then(|f| {
266 let old_path = f.full_path(cx);
267 if Some(&old_path) != path.as_ref() {
268 Some(old_path)
269 } else {
270 None
271 }
272 });
273
274 // TODO [zeta2] move to bg?
275 let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
276
277 if path == old_path && diff.is_empty() {
278 None
279 } else {
280 Some(predict_edits_v3::Event::BufferChange {
281 old_path,
282 path,
283 diff,
284 //todo: Actually detect if this edit was predicted or not
285 predicted: false,
286 })
287 }
288 }
289 }
290 }
291}
292
293impl Zeta {
294 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
295 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
296 }
297
298 pub fn global(
299 client: &Arc<Client>,
300 user_store: &Entity<UserStore>,
301 cx: &mut App,
302 ) -> Entity<Self> {
303 cx.try_global::<ZetaGlobal>()
304 .map(|global| global.0.clone())
305 .unwrap_or_else(|| {
306 let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
307 cx.set_global(ZetaGlobal(zeta.clone()));
308 zeta
309 })
310 }
311
312 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
313 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
314
315 Self {
316 projects: HashMap::default(),
317 client,
318 user_store,
319 options: DEFAULT_OPTIONS,
320 llm_token: LlmApiToken::default(),
321 _llm_token_subscription: cx.subscribe(
322 &refresh_llm_token_listener,
323 |this, _listener, _event, cx| {
324 let client = this.client.clone();
325 let llm_token = this.llm_token.clone();
326 cx.spawn(async move |_this, _cx| {
327 llm_token.refresh(&client).await?;
328 anyhow::Ok(())
329 })
330 .detach_and_log_err(cx);
331 },
332 ),
333 update_required: false,
334 debug_tx: None,
335 }
336 }
337
338 pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<ZetaDebugInfo> {
339 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
340 self.debug_tx = Some(debug_watch_tx);
341 debug_watch_rx
342 }
343
344 pub fn options(&self) -> &ZetaOptions {
345 &self.options
346 }
347
348 pub fn set_options(&mut self, options: ZetaOptions) {
349 self.options = options;
350 }
351
352 pub fn clear_history(&mut self) {
353 for zeta_project in self.projects.values_mut() {
354 zeta_project.events.clear();
355 }
356 }
357
358 pub fn history_for_project(&self, project: &Entity<Project>) -> impl Iterator<Item = &Event> {
359 self.projects
360 .get(&project.entity_id())
361 .map(|project| project.events.iter())
362 .into_iter()
363 .flatten()
364 }
365
366 pub fn context_for_project(
367 &self,
368 project: &Entity<Project>,
369 ) -> impl Iterator<Item = (Entity<Buffer>, &[Range<Anchor>])> {
370 self.projects
371 .get(&project.entity_id())
372 .and_then(|project| {
373 Some(
374 project
375 .context
376 .as_ref()?
377 .iter()
378 .map(|(buffer, ranges)| (buffer.clone(), ranges.as_slice())),
379 )
380 })
381 .into_iter()
382 .flatten()
383 }
384
385 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
386 self.user_store.read(cx).edit_prediction_usage()
387 }
388
389 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
390 self.get_or_init_zeta_project(project, cx);
391 }
392
393 pub fn register_buffer(
394 &mut self,
395 buffer: &Entity<Buffer>,
396 project: &Entity<Project>,
397 cx: &mut Context<Self>,
398 ) {
399 let zeta_project = self.get_or_init_zeta_project(project, cx);
400 Self::register_buffer_impl(zeta_project, buffer, project, cx);
401 }
402
403 fn get_or_init_zeta_project(
404 &mut self,
405 project: &Entity<Project>,
406 cx: &mut App,
407 ) -> &mut ZetaProject {
408 self.projects
409 .entry(project.entity_id())
410 .or_insert_with(|| ZetaProject {
411 syntax_index: cx.new(|cx| {
412 SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
413 }),
414 events: VecDeque::new(),
415 registered_buffers: HashMap::default(),
416 current_prediction: None,
417 context: None,
418 refresh_context_task: None,
419 refresh_context_debounce_task: None,
420 refresh_context_timestamp: None,
421 })
422 }
423
424 fn register_buffer_impl<'a>(
425 zeta_project: &'a mut ZetaProject,
426 buffer: &Entity<Buffer>,
427 project: &Entity<Project>,
428 cx: &mut Context<Self>,
429 ) -> &'a mut RegisteredBuffer {
430 let buffer_id = buffer.entity_id();
431 match zeta_project.registered_buffers.entry(buffer_id) {
432 hash_map::Entry::Occupied(entry) => entry.into_mut(),
433 hash_map::Entry::Vacant(entry) => {
434 let snapshot = buffer.read(cx).snapshot();
435 let project_entity_id = project.entity_id();
436 entry.insert(RegisteredBuffer {
437 snapshot,
438 _subscriptions: [
439 cx.subscribe(buffer, {
440 let project = project.downgrade();
441 move |this, buffer, event, cx| {
442 if let language::BufferEvent::Edited = event
443 && let Some(project) = project.upgrade()
444 {
445 this.report_changes_for_buffer(&buffer, &project, cx);
446 }
447 }
448 }),
449 cx.observe_release(buffer, move |this, _buffer, _cx| {
450 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
451 else {
452 return;
453 };
454 zeta_project.registered_buffers.remove(&buffer_id);
455 }),
456 ],
457 })
458 }
459 }
460 }
461
462 fn report_changes_for_buffer(
463 &mut self,
464 buffer: &Entity<Buffer>,
465 project: &Entity<Project>,
466 cx: &mut Context<Self>,
467 ) -> BufferSnapshot {
468 let buffer_change_grouping_interval = self.options.buffer_change_grouping_interval;
469 let zeta_project = self.get_or_init_zeta_project(project, cx);
470 let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
471
472 let new_snapshot = buffer.read(cx).snapshot();
473 if new_snapshot.version != registered_buffer.snapshot.version {
474 let old_snapshot =
475 std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
476 Self::push_event(
477 zeta_project,
478 buffer_change_grouping_interval,
479 Event::BufferChange {
480 old_snapshot,
481 new_snapshot: new_snapshot.clone(),
482 timestamp: Instant::now(),
483 },
484 );
485 }
486
487 new_snapshot
488 }
489
490 fn push_event(
491 zeta_project: &mut ZetaProject,
492 buffer_change_grouping_interval: Duration,
493 event: Event,
494 ) {
495 let events = &mut zeta_project.events;
496
497 if buffer_change_grouping_interval > Duration::ZERO
498 && let Some(Event::BufferChange {
499 new_snapshot: last_new_snapshot,
500 timestamp: last_timestamp,
501 ..
502 }) = events.back_mut()
503 {
504 // Coalesce edits for the same buffer when they happen one after the other.
505 let Event::BufferChange {
506 old_snapshot,
507 new_snapshot,
508 timestamp,
509 } = &event;
510
511 if timestamp.duration_since(*last_timestamp) <= buffer_change_grouping_interval
512 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
513 && old_snapshot.version == last_new_snapshot.version
514 {
515 *last_new_snapshot = new_snapshot.clone();
516 *last_timestamp = *timestamp;
517 return;
518 }
519 }
520
521 if events.len() >= MAX_EVENT_COUNT {
522 // These are halved instead of popping to improve prompt caching.
523 events.drain(..MAX_EVENT_COUNT / 2);
524 }
525
526 events.push_back(event);
527 }
528
529 fn current_prediction_for_buffer(
530 &self,
531 buffer: &Entity<Buffer>,
532 project: &Entity<Project>,
533 cx: &App,
534 ) -> Option<BufferEditPrediction<'_>> {
535 let project_state = self.projects.get(&project.entity_id())?;
536
537 let CurrentEditPrediction {
538 requested_by_buffer_id,
539 prediction,
540 } = project_state.current_prediction.as_ref()?;
541
542 if prediction.targets_buffer(buffer.read(cx), cx) {
543 Some(BufferEditPrediction::Local { prediction })
544 } else if *requested_by_buffer_id == buffer.entity_id() {
545 Some(BufferEditPrediction::Jump { prediction })
546 } else {
547 None
548 }
549 }
550
551 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
552 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
553 return;
554 };
555
556 let Some(prediction) = project_state.current_prediction.take() else {
557 return;
558 };
559 let request_id = prediction.prediction.id.into();
560
561 let client = self.client.clone();
562 let llm_token = self.llm_token.clone();
563 let app_version = AppVersion::global(cx);
564 cx.spawn(async move |this, cx| {
565 let url = if let Ok(predict_edits_url) = std::env::var("ZED_ACCEPT_PREDICTION_URL") {
566 http_client::Url::parse(&predict_edits_url)?
567 } else {
568 client
569 .http_client()
570 .build_zed_llm_url("/predict_edits/accept", &[])?
571 };
572
573 let response = cx
574 .background_spawn(Self::send_api_request::<()>(
575 move |builder| {
576 let req = builder.uri(url.as_ref()).body(
577 serde_json::to_string(&AcceptEditPredictionBody { request_id })?.into(),
578 );
579 Ok(req?)
580 },
581 client,
582 llm_token,
583 app_version,
584 ))
585 .await;
586
587 Self::handle_api_response(&this, response, cx)?;
588 anyhow::Ok(())
589 })
590 .detach_and_log_err(cx);
591 }
592
593 fn discard_current_prediction(&mut self, project: &Entity<Project>) {
594 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
595 project_state.current_prediction.take();
596 };
597 }
598
599 pub fn refresh_prediction(
600 &mut self,
601 project: &Entity<Project>,
602 buffer: &Entity<Buffer>,
603 position: language::Anchor,
604 cx: &mut Context<Self>,
605 ) -> Task<Result<()>> {
606 let request_task = self.request_prediction(project, buffer, position, cx);
607 let buffer = buffer.clone();
608 let project = project.clone();
609
610 cx.spawn(async move |this, cx| {
611 if let Some(prediction) = request_task.await? {
612 this.update(cx, |this, cx| {
613 let project_state = this
614 .projects
615 .get_mut(&project.entity_id())
616 .context("Project not found")?;
617
618 let new_prediction = CurrentEditPrediction {
619 requested_by_buffer_id: buffer.entity_id(),
620 prediction: prediction,
621 };
622
623 if project_state
624 .current_prediction
625 .as_ref()
626 .is_none_or(|old_prediction| {
627 new_prediction.should_replace_prediction(&old_prediction, cx)
628 })
629 {
630 project_state.current_prediction = Some(new_prediction);
631 }
632 anyhow::Ok(())
633 })??;
634 }
635 Ok(())
636 })
637 }
638
639 pub fn request_prediction(
640 &mut self,
641 project: &Entity<Project>,
642 buffer: &Entity<Buffer>,
643 position: language::Anchor,
644 cx: &mut Context<Self>,
645 ) -> Task<Result<Option<EditPrediction>>> {
646 let project_state = self.projects.get(&project.entity_id());
647
648 let index_state = project_state.map(|state| {
649 state
650 .syntax_index
651 .read_with(cx, |index, _cx| index.state().clone())
652 });
653 let options = self.options.clone();
654 let snapshot = buffer.read(cx).snapshot();
655 let Some(excerpt_path) = snapshot
656 .file()
657 .map(|path| -> Arc<Path> { path.full_path(cx).into() })
658 else {
659 return Task::ready(Err(anyhow!("No file path for excerpt")));
660 };
661 let client = self.client.clone();
662 let llm_token = self.llm_token.clone();
663 let app_version = AppVersion::global(cx);
664 let worktree_snapshots = project
665 .read(cx)
666 .worktrees(cx)
667 .map(|worktree| worktree.read(cx).snapshot())
668 .collect::<Vec<_>>();
669 let debug_tx = self.debug_tx.clone();
670
671 let events = project_state
672 .map(|state| {
673 state
674 .events
675 .iter()
676 .filter_map(|event| event.to_request_event(cx))
677 .collect::<Vec<_>>()
678 })
679 .unwrap_or_default();
680
681 let diagnostics = snapshot.diagnostic_sets().clone();
682
683 let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
684 let mut path = f.worktree.read(cx).absolutize(&f.path);
685 if path.pop() { Some(path) } else { None }
686 });
687
688 // TODO data collection
689 let can_collect_data = cx.is_staff();
690
691 let mut included_files = project_state
692 .and_then(|project_state| project_state.context.as_ref())
693 .unwrap_or(&HashMap::default())
694 .iter()
695 .filter_map(|(buffer, ranges)| {
696 let buffer = buffer.read(cx);
697 Some((
698 buffer.snapshot(),
699 buffer.file()?.full_path(cx).into(),
700 ranges.clone(),
701 ))
702 })
703 .collect::<Vec<_>>();
704
705 let request_task = cx.background_spawn({
706 let snapshot = snapshot.clone();
707 let buffer = buffer.clone();
708 async move {
709 let index_state = if let Some(index_state) = index_state {
710 Some(index_state.lock_owned().await)
711 } else {
712 None
713 };
714
715 let cursor_offset = position.to_offset(&snapshot);
716 let cursor_point = cursor_offset.to_point(&snapshot);
717
718 let before_retrieval = chrono::Utc::now();
719
720 let (diagnostic_groups, diagnostic_groups_truncated) =
721 Self::gather_nearby_diagnostics(
722 cursor_offset,
723 &diagnostics,
724 &snapshot,
725 options.max_diagnostic_bytes,
726 );
727
728 let request = match options.context {
729 ContextMode::Llm(context_options) => {
730 let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
731 cursor_point,
732 &snapshot,
733 &context_options.excerpt,
734 index_state.as_deref(),
735 ) else {
736 return Ok((None, None));
737 };
738
739 let excerpt_anchor_range = snapshot.anchor_after(excerpt.range.start)
740 ..snapshot.anchor_before(excerpt.range.end);
741
742 if let Some(buffer_ix) = included_files
743 .iter()
744 .position(|(buffer, _, _)| buffer.remote_id() == snapshot.remote_id())
745 {
746 let (buffer, _, ranges) = &mut included_files[buffer_ix];
747 let range_ix = ranges
748 .binary_search_by(|probe| {
749 probe
750 .start
751 .cmp(&excerpt_anchor_range.start, buffer)
752 .then(excerpt_anchor_range.end.cmp(&probe.end, buffer))
753 })
754 .unwrap_or_else(|ix| ix);
755
756 ranges.insert(range_ix, excerpt_anchor_range);
757 let last_ix = included_files.len() - 1;
758 included_files.swap(buffer_ix, last_ix);
759 } else {
760 included_files.push((
761 snapshot,
762 excerpt_path.clone(),
763 vec![excerpt_anchor_range],
764 ));
765 }
766
767 let included_files = included_files
768 .into_iter()
769 .map(|(buffer, path, ranges)| {
770 let excerpts = merge_excerpts(
771 &buffer,
772 ranges.iter().map(|range| {
773 let point_range = range.to_point(&buffer);
774 Line(point_range.start.row)..Line(point_range.end.row)
775 }),
776 );
777 predict_edits_v3::IncludedFile {
778 path,
779 max_row: Line(buffer.max_point().row),
780 excerpts,
781 }
782 })
783 .collect::<Vec<_>>();
784
785 predict_edits_v3::PredictEditsRequest {
786 excerpt_path,
787 excerpt: String::new(),
788 excerpt_line_range: Line(0)..Line(0),
789 excerpt_range: 0..0,
790 cursor_point: predict_edits_v3::Point {
791 line: predict_edits_v3::Line(cursor_point.row),
792 column: cursor_point.column,
793 },
794 included_files,
795 referenced_declarations: vec![],
796 events,
797 can_collect_data,
798 diagnostic_groups,
799 diagnostic_groups_truncated,
800 debug_info: debug_tx.is_some(),
801 prompt_max_bytes: Some(options.max_prompt_bytes),
802 prompt_format: options.prompt_format,
803 // TODO [zeta2]
804 signatures: vec![],
805 excerpt_parent: None,
806 git_info: None,
807 }
808 }
809 ContextMode::Syntax(context_options) => {
810 let Some(context) = EditPredictionContext::gather_context(
811 cursor_point,
812 &snapshot,
813 parent_abs_path.as_deref(),
814 &context_options,
815 index_state.as_deref(),
816 ) else {
817 return Ok((None, None));
818 };
819
820 make_syntax_context_cloud_request(
821 excerpt_path,
822 context,
823 events,
824 can_collect_data,
825 diagnostic_groups,
826 diagnostic_groups_truncated,
827 None,
828 debug_tx.is_some(),
829 &worktree_snapshots,
830 index_state.as_deref(),
831 Some(options.max_prompt_bytes),
832 options.prompt_format,
833 )
834 }
835 };
836
837 let retrieval_time = chrono::Utc::now() - before_retrieval;
838
839 let debug_response_tx = if let Some(debug_tx) = &debug_tx {
840 let (response_tx, response_rx) = oneshot::channel();
841
842 let local_prompt = build_prompt(&request)
843 .map(|(prompt, _)| prompt)
844 .map_err(|err| err.to_string());
845
846 debug_tx
847 .unbounded_send(ZetaDebugInfo::EditPredicted(ZetaEditPredictionDebugInfo {
848 request: request.clone(),
849 retrieval_time,
850 buffer: buffer.downgrade(),
851 local_prompt,
852 position,
853 response_rx,
854 }))
855 .ok();
856 Some(response_tx)
857 } else {
858 None
859 };
860
861 if cfg!(debug_assertions) && std::env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
862 if let Some(debug_response_tx) = debug_response_tx {
863 debug_response_tx
864 .send(Err("Request skipped".to_string()))
865 .ok();
866 }
867 anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
868 }
869
870 let response =
871 Self::send_prediction_request(client, llm_token, app_version, request).await;
872
873 if let Some(debug_response_tx) = debug_response_tx {
874 debug_response_tx
875 .send(
876 response
877 .as_ref()
878 .map_err(|err| err.to_string())
879 .map(|response| response.0.clone()),
880 )
881 .ok();
882 }
883
884 response.map(|(res, usage)| (Some(res), usage))
885 }
886 });
887
888 let buffer = buffer.clone();
889
890 cx.spawn({
891 let project = project.clone();
892 async move |this, cx| {
893 let Some(response) = Self::handle_api_response(&this, request_task.await, cx)?
894 else {
895 return Ok(None);
896 };
897
898 // TODO telemetry: duration, etc
899 Ok(EditPrediction::from_response(response, &snapshot, &buffer, &project, cx).await)
900 }
901 })
902 }
903
904 async fn send_prediction_request(
905 client: Arc<Client>,
906 llm_token: LlmApiToken,
907 app_version: SemanticVersion,
908 request: predict_edits_v3::PredictEditsRequest,
909 ) -> Result<(
910 predict_edits_v3::PredictEditsResponse,
911 Option<EditPredictionUsage>,
912 )> {
913 let url = if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
914 http_client::Url::parse(&predict_edits_url)?
915 } else {
916 client
917 .http_client()
918 .build_zed_llm_url("/predict_edits/v3", &[])?
919 };
920
921 Self::send_api_request(
922 |builder| {
923 let req = builder
924 .uri(url.as_ref())
925 .body(serde_json::to_string(&request)?.into());
926 Ok(req?)
927 },
928 client,
929 llm_token,
930 app_version,
931 )
932 .await
933 }
934
935 fn handle_api_response<T>(
936 this: &WeakEntity<Self>,
937 response: Result<(T, Option<EditPredictionUsage>)>,
938 cx: &mut gpui::AsyncApp,
939 ) -> Result<T> {
940 match response {
941 Ok((data, usage)) => {
942 if let Some(usage) = usage {
943 this.update(cx, |this, cx| {
944 this.user_store.update(cx, |user_store, cx| {
945 user_store.update_edit_prediction_usage(usage, cx);
946 });
947 })
948 .ok();
949 }
950 Ok(data)
951 }
952 Err(err) => {
953 if err.is::<ZedUpdateRequiredError>() {
954 cx.update(|cx| {
955 this.update(cx, |this, _cx| {
956 this.update_required = true;
957 })
958 .ok();
959
960 let error_message: SharedString = err.to_string().into();
961 show_app_notification(
962 NotificationId::unique::<ZedUpdateRequiredError>(),
963 cx,
964 move |cx| {
965 cx.new(|cx| {
966 ErrorMessagePrompt::new(error_message.clone(), cx)
967 .with_link_button("Update Zed", "https://zed.dev/releases")
968 })
969 },
970 );
971 })
972 .ok();
973 }
974 Err(err)
975 }
976 }
977 }
978
979 async fn send_api_request<Res>(
980 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
981 client: Arc<Client>,
982 llm_token: LlmApiToken,
983 app_version: SemanticVersion,
984 ) -> Result<(Res, Option<EditPredictionUsage>)>
985 where
986 Res: DeserializeOwned,
987 {
988 let http_client = client.http_client();
989 let mut token = llm_token.acquire(&client).await?;
990 let mut did_retry = false;
991
992 loop {
993 let request_builder = http_client::Request::builder().method(Method::POST);
994
995 let request = build(
996 request_builder
997 .header("Content-Type", "application/json")
998 .header("Authorization", format!("Bearer {}", token))
999 .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
1000 )?;
1001
1002 let mut response = http_client.send(request).await?;
1003
1004 if let Some(minimum_required_version) = response
1005 .headers()
1006 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
1007 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
1008 {
1009 anyhow::ensure!(
1010 app_version >= minimum_required_version,
1011 ZedUpdateRequiredError {
1012 minimum_version: minimum_required_version
1013 }
1014 );
1015 }
1016
1017 if response.status().is_success() {
1018 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
1019
1020 let mut body = Vec::new();
1021 response.body_mut().read_to_end(&mut body).await?;
1022 return Ok((serde_json::from_slice(&body)?, usage));
1023 } else if !did_retry
1024 && response
1025 .headers()
1026 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1027 .is_some()
1028 {
1029 did_retry = true;
1030 token = llm_token.refresh(&client).await?;
1031 } else {
1032 let mut body = String::new();
1033 response.body_mut().read_to_string(&mut body).await?;
1034 anyhow::bail!(
1035 "Request failed with status: {:?}\nBody: {}",
1036 response.status(),
1037 body
1038 );
1039 }
1040 }
1041 }
1042
1043 pub const CONTEXT_RETRIEVAL_IDLE_DURATION: Duration = Duration::from_secs(10);
1044 pub const CONTEXT_RETRIEVAL_DEBOUNCE_DURATION: Duration = Duration::from_secs(3);
1045
1046 // Refresh the related excerpts when the user just beguns editing after
1047 // an idle period, and after they pause editing.
1048 fn refresh_context_if_needed(
1049 &mut self,
1050 project: &Entity<Project>,
1051 buffer: &Entity<language::Buffer>,
1052 cursor_position: language::Anchor,
1053 cx: &mut Context<Self>,
1054 ) {
1055 if !matches!(&self.options().context, ContextMode::Llm { .. }) {
1056 return;
1057 }
1058
1059 let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1060 return;
1061 };
1062
1063 let now = Instant::now();
1064 let was_idle = zeta_project
1065 .refresh_context_timestamp
1066 .map_or(true, |timestamp| {
1067 now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION
1068 });
1069 zeta_project.refresh_context_timestamp = Some(now);
1070 zeta_project.refresh_context_debounce_task = Some(cx.spawn({
1071 let buffer = buffer.clone();
1072 let project = project.clone();
1073 async move |this, cx| {
1074 if was_idle {
1075 log::debug!("refetching edit prediction context after idle");
1076 } else {
1077 cx.background_executor()
1078 .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION)
1079 .await;
1080 log::debug!("refetching edit prediction context after pause");
1081 }
1082 this.update(cx, |this, cx| {
1083 this.refresh_context(project, buffer, cursor_position, cx);
1084 })
1085 .ok()
1086 }
1087 }));
1088 }
1089
1090 // Refresh the related excerpts asynchronously. Ensure the task runs to completion,
1091 // and avoid spawning more than one concurrent task.
1092 fn refresh_context(
1093 &mut self,
1094 project: Entity<Project>,
1095 buffer: Entity<language::Buffer>,
1096 cursor_position: language::Anchor,
1097 cx: &mut Context<Self>,
1098 ) {
1099 let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1100 return;
1101 };
1102
1103 let debug_tx = self.debug_tx.clone();
1104
1105 zeta_project
1106 .refresh_context_task
1107 .get_or_insert(cx.spawn(async move |this, cx| {
1108 let related_excerpts = this
1109 .update(cx, |this, cx| {
1110 let Some(zeta_project) = this.projects.get(&project.entity_id()) else {
1111 return Task::ready(anyhow::Ok(HashMap::default()));
1112 };
1113
1114 let ContextMode::Llm(options) = &this.options().context else {
1115 return Task::ready(anyhow::Ok(HashMap::default()));
1116 };
1117
1118 let mut edit_history_unified_diff = String::new();
1119
1120 for event in zeta_project.events.iter() {
1121 if let Some(event) = event.to_request_event(cx) {
1122 writeln!(&mut edit_history_unified_diff, "{event}").ok();
1123 }
1124 }
1125
1126 find_related_excerpts(
1127 buffer.clone(),
1128 cursor_position,
1129 &project,
1130 edit_history_unified_diff,
1131 options,
1132 debug_tx,
1133 cx,
1134 )
1135 })
1136 .ok()?
1137 .await
1138 .log_err()
1139 .unwrap_or_default();
1140 this.update(cx, |this, _cx| {
1141 let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
1142 return;
1143 };
1144 zeta_project.context = Some(related_excerpts);
1145 zeta_project.refresh_context_task.take();
1146 if let Some(debug_tx) = &this.debug_tx {
1147 debug_tx
1148 .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished(
1149 ZetaContextRetrievalDebugInfo {
1150 project,
1151 timestamp: Instant::now(),
1152 },
1153 ))
1154 .ok();
1155 }
1156 })
1157 .ok()
1158 }));
1159 }
1160
1161 fn gather_nearby_diagnostics(
1162 cursor_offset: usize,
1163 diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
1164 snapshot: &BufferSnapshot,
1165 max_diagnostics_bytes: usize,
1166 ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
1167 // TODO: Could make this more efficient
1168 let mut diagnostic_groups = Vec::new();
1169 for (language_server_id, diagnostics) in diagnostic_sets {
1170 let mut groups = Vec::new();
1171 diagnostics.groups(*language_server_id, &mut groups, &snapshot);
1172 diagnostic_groups.extend(
1173 groups
1174 .into_iter()
1175 .map(|(_, group)| group.resolve::<usize>(&snapshot)),
1176 );
1177 }
1178
1179 // sort by proximity to cursor
1180 diagnostic_groups.sort_by_key(|group| {
1181 let range = &group.entries[group.primary_ix].range;
1182 if range.start >= cursor_offset {
1183 range.start - cursor_offset
1184 } else if cursor_offset >= range.end {
1185 cursor_offset - range.end
1186 } else {
1187 (cursor_offset - range.start).min(range.end - cursor_offset)
1188 }
1189 });
1190
1191 let mut results = Vec::new();
1192 let mut diagnostic_groups_truncated = false;
1193 let mut diagnostics_byte_count = 0;
1194 for group in diagnostic_groups {
1195 let raw_value = serde_json::value::to_raw_value(&group).unwrap();
1196 diagnostics_byte_count += raw_value.get().len();
1197 if diagnostics_byte_count > max_diagnostics_bytes {
1198 diagnostic_groups_truncated = true;
1199 break;
1200 }
1201 results.push(predict_edits_v3::DiagnosticGroup(raw_value));
1202 }
1203
1204 (results, diagnostic_groups_truncated)
1205 }
1206
1207 // TODO: Dedupe with similar code in request_prediction?
1208 pub fn cloud_request_for_zeta_cli(
1209 &mut self,
1210 project: &Entity<Project>,
1211 buffer: &Entity<Buffer>,
1212 position: language::Anchor,
1213 cx: &mut Context<Self>,
1214 ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
1215 let project_state = self.projects.get(&project.entity_id());
1216
1217 let index_state = project_state.map(|state| {
1218 state
1219 .syntax_index
1220 .read_with(cx, |index, _cx| index.state().clone())
1221 });
1222 let options = self.options.clone();
1223 let snapshot = buffer.read(cx).snapshot();
1224 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
1225 return Task::ready(Err(anyhow!("No file path for excerpt")));
1226 };
1227 let worktree_snapshots = project
1228 .read(cx)
1229 .worktrees(cx)
1230 .map(|worktree| worktree.read(cx).snapshot())
1231 .collect::<Vec<_>>();
1232
1233 let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
1234 let mut path = f.worktree.read(cx).absolutize(&f.path);
1235 if path.pop() { Some(path) } else { None }
1236 });
1237
1238 cx.background_spawn(async move {
1239 let index_state = if let Some(index_state) = index_state {
1240 Some(index_state.lock_owned().await)
1241 } else {
1242 None
1243 };
1244
1245 let cursor_point = position.to_point(&snapshot);
1246
1247 let debug_info = true;
1248 EditPredictionContext::gather_context(
1249 cursor_point,
1250 &snapshot,
1251 parent_abs_path.as_deref(),
1252 match &options.context {
1253 ContextMode::Llm(_) => {
1254 // TODO
1255 panic!("Llm mode not supported in zeta cli yet");
1256 }
1257 ContextMode::Syntax(edit_prediction_context_options) => {
1258 edit_prediction_context_options
1259 }
1260 },
1261 index_state.as_deref(),
1262 )
1263 .context("Failed to select excerpt")
1264 .map(|context| {
1265 make_syntax_context_cloud_request(
1266 excerpt_path.into(),
1267 context,
1268 // TODO pass everything
1269 Vec::new(),
1270 false,
1271 Vec::new(),
1272 false,
1273 None,
1274 debug_info,
1275 &worktree_snapshots,
1276 index_state.as_deref(),
1277 Some(options.max_prompt_bytes),
1278 options.prompt_format,
1279 )
1280 })
1281 })
1282 }
1283
1284 pub fn wait_for_initial_indexing(
1285 &mut self,
1286 project: &Entity<Project>,
1287 cx: &mut App,
1288 ) -> Task<Result<()>> {
1289 let zeta_project = self.get_or_init_zeta_project(project, cx);
1290 zeta_project
1291 .syntax_index
1292 .read(cx)
1293 .wait_for_initial_file_indexing(cx)
1294 }
1295}
1296
1297#[derive(Error, Debug)]
1298#[error(
1299 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1300)]
1301pub struct ZedUpdateRequiredError {
1302 minimum_version: SemanticVersion,
1303}
1304
1305fn make_syntax_context_cloud_request(
1306 excerpt_path: Arc<Path>,
1307 context: EditPredictionContext,
1308 events: Vec<predict_edits_v3::Event>,
1309 can_collect_data: bool,
1310 diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
1311 diagnostic_groups_truncated: bool,
1312 git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
1313 debug_info: bool,
1314 worktrees: &Vec<worktree::Snapshot>,
1315 index_state: Option<&SyntaxIndexState>,
1316 prompt_max_bytes: Option<usize>,
1317 prompt_format: PromptFormat,
1318) -> predict_edits_v3::PredictEditsRequest {
1319 let mut signatures = Vec::new();
1320 let mut declaration_to_signature_index = HashMap::default();
1321 let mut referenced_declarations = Vec::new();
1322
1323 for snippet in context.declarations {
1324 let project_entry_id = snippet.declaration.project_entry_id();
1325 let Some(path) = worktrees.iter().find_map(|worktree| {
1326 worktree.entry_for_id(project_entry_id).map(|entry| {
1327 let mut full_path = RelPathBuf::new();
1328 full_path.push(worktree.root_name());
1329 full_path.push(&entry.path);
1330 full_path
1331 })
1332 }) else {
1333 continue;
1334 };
1335
1336 let parent_index = index_state.and_then(|index_state| {
1337 snippet.declaration.parent().and_then(|parent| {
1338 add_signature(
1339 parent,
1340 &mut declaration_to_signature_index,
1341 &mut signatures,
1342 index_state,
1343 )
1344 })
1345 });
1346
1347 let (text, text_is_truncated) = snippet.declaration.item_text();
1348 referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
1349 path: path.as_std_path().into(),
1350 text: text.into(),
1351 range: snippet.declaration.item_line_range(),
1352 text_is_truncated,
1353 signature_range: snippet.declaration.signature_range_in_item_text(),
1354 parent_index,
1355 signature_score: snippet.score(DeclarationStyle::Signature),
1356 declaration_score: snippet.score(DeclarationStyle::Declaration),
1357 score_components: snippet.components,
1358 });
1359 }
1360
1361 let excerpt_parent = index_state.and_then(|index_state| {
1362 context
1363 .excerpt
1364 .parent_declarations
1365 .last()
1366 .and_then(|(parent, _)| {
1367 add_signature(
1368 *parent,
1369 &mut declaration_to_signature_index,
1370 &mut signatures,
1371 index_state,
1372 )
1373 })
1374 });
1375
1376 predict_edits_v3::PredictEditsRequest {
1377 excerpt_path,
1378 excerpt: context.excerpt_text.body,
1379 excerpt_line_range: context.excerpt.line_range,
1380 excerpt_range: context.excerpt.range,
1381 cursor_point: predict_edits_v3::Point {
1382 line: predict_edits_v3::Line(context.cursor_point.row),
1383 column: context.cursor_point.column,
1384 },
1385 referenced_declarations,
1386 included_files: vec![],
1387 signatures,
1388 excerpt_parent,
1389 events,
1390 can_collect_data,
1391 diagnostic_groups,
1392 diagnostic_groups_truncated,
1393 git_info,
1394 debug_info,
1395 prompt_max_bytes,
1396 prompt_format,
1397 }
1398}
1399
1400fn add_signature(
1401 declaration_id: DeclarationId,
1402 declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
1403 signatures: &mut Vec<Signature>,
1404 index: &SyntaxIndexState,
1405) -> Option<usize> {
1406 if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
1407 return Some(*signature_index);
1408 }
1409 let Some(parent_declaration) = index.declaration(declaration_id) else {
1410 log::error!("bug: missing parent declaration");
1411 return None;
1412 };
1413 let parent_index = parent_declaration.parent().and_then(|parent| {
1414 add_signature(parent, declaration_to_signature_index, signatures, index)
1415 });
1416 let (text, text_is_truncated) = parent_declaration.signature_text();
1417 let signature_index = signatures.len();
1418 signatures.push(Signature {
1419 text: text.into(),
1420 text_is_truncated,
1421 parent_index,
1422 range: parent_declaration.signature_line_range(),
1423 });
1424 declaration_to_signature_index.insert(declaration_id, signature_index);
1425 Some(signature_index)
1426}
1427
1428#[cfg(test)]
1429mod tests {
1430 use std::{
1431 path::{Path, PathBuf},
1432 sync::Arc,
1433 };
1434
1435 use client::UserStore;
1436 use clock::FakeSystemClock;
1437 use cloud_llm_client::predict_edits_v3::{self, Point};
1438 use edit_prediction_context::Line;
1439 use futures::{
1440 AsyncReadExt, StreamExt,
1441 channel::{mpsc, oneshot},
1442 };
1443 use gpui::{
1444 Entity, TestAppContext,
1445 http_client::{FakeHttpClient, Response},
1446 prelude::*,
1447 };
1448 use indoc::indoc;
1449 use language::{LanguageServerId, OffsetRangeExt as _};
1450 use pretty_assertions::{assert_eq, assert_matches};
1451 use project::{FakeFs, Project};
1452 use serde_json::json;
1453 use settings::SettingsStore;
1454 use util::path;
1455 use uuid::Uuid;
1456
1457 use crate::{BufferEditPrediction, Zeta};
1458
1459 #[gpui::test]
1460 async fn test_current_state(cx: &mut TestAppContext) {
1461 let (zeta, mut req_rx) = init_test(cx);
1462 let fs = FakeFs::new(cx.executor());
1463 fs.insert_tree(
1464 "/root",
1465 json!({
1466 "1.txt": "Hello!\nHow\nBye",
1467 "2.txt": "Hola!\nComo\nAdios"
1468 }),
1469 )
1470 .await;
1471 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1472
1473 zeta.update(cx, |zeta, cx| {
1474 zeta.register_project(&project, cx);
1475 });
1476
1477 let buffer1 = project
1478 .update(cx, |project, cx| {
1479 let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
1480 project.open_buffer(path, cx)
1481 })
1482 .await
1483 .unwrap();
1484 let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
1485 let position = snapshot1.anchor_before(language::Point::new(1, 3));
1486
1487 // Prediction for current file
1488
1489 let prediction_task = zeta.update(cx, |zeta, cx| {
1490 zeta.refresh_prediction(&project, &buffer1, position, cx)
1491 });
1492 let (_request, respond_tx) = req_rx.next().await.unwrap();
1493 respond_tx
1494 .send(predict_edits_v3::PredictEditsResponse {
1495 request_id: Uuid::new_v4(),
1496 edits: vec![predict_edits_v3::Edit {
1497 path: Path::new(path!("root/1.txt")).into(),
1498 range: Line(0)..Line(snapshot1.max_point().row + 1),
1499 content: "Hello!\nHow are you?\nBye".into(),
1500 }],
1501 debug_info: None,
1502 })
1503 .unwrap();
1504 prediction_task.await.unwrap();
1505
1506 zeta.read_with(cx, |zeta, cx| {
1507 let prediction = zeta
1508 .current_prediction_for_buffer(&buffer1, &project, cx)
1509 .unwrap();
1510 assert_matches!(prediction, BufferEditPrediction::Local { .. });
1511 });
1512
1513 // Prediction for another file
1514 let prediction_task = zeta.update(cx, |zeta, cx| {
1515 zeta.refresh_prediction(&project, &buffer1, position, cx)
1516 });
1517 let (_request, respond_tx) = req_rx.next().await.unwrap();
1518 respond_tx
1519 .send(predict_edits_v3::PredictEditsResponse {
1520 request_id: Uuid::new_v4(),
1521 edits: vec![predict_edits_v3::Edit {
1522 path: Path::new(path!("root/2.txt")).into(),
1523 range: Line(0)..Line(snapshot1.max_point().row + 1),
1524 content: "Hola!\nComo estas?\nAdios".into(),
1525 }],
1526 debug_info: None,
1527 })
1528 .unwrap();
1529 prediction_task.await.unwrap();
1530 zeta.read_with(cx, |zeta, cx| {
1531 let prediction = zeta
1532 .current_prediction_for_buffer(&buffer1, &project, cx)
1533 .unwrap();
1534 assert_matches!(
1535 prediction,
1536 BufferEditPrediction::Jump { prediction } if prediction.path.as_ref() == Path::new(path!("root/2.txt"))
1537 );
1538 });
1539
1540 let buffer2 = project
1541 .update(cx, |project, cx| {
1542 let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
1543 project.open_buffer(path, cx)
1544 })
1545 .await
1546 .unwrap();
1547
1548 zeta.read_with(cx, |zeta, cx| {
1549 let prediction = zeta
1550 .current_prediction_for_buffer(&buffer2, &project, cx)
1551 .unwrap();
1552 assert_matches!(prediction, BufferEditPrediction::Local { .. });
1553 });
1554 }
1555
1556 #[gpui::test]
1557 async fn test_simple_request(cx: &mut TestAppContext) {
1558 let (zeta, mut req_rx) = init_test(cx);
1559 let fs = FakeFs::new(cx.executor());
1560 fs.insert_tree(
1561 "/root",
1562 json!({
1563 "foo.md": "Hello!\nHow\nBye"
1564 }),
1565 )
1566 .await;
1567 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1568
1569 let buffer = project
1570 .update(cx, |project, cx| {
1571 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1572 project.open_buffer(path, cx)
1573 })
1574 .await
1575 .unwrap();
1576 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1577 let position = snapshot.anchor_before(language::Point::new(1, 3));
1578
1579 let prediction_task = zeta.update(cx, |zeta, cx| {
1580 zeta.request_prediction(&project, &buffer, position, cx)
1581 });
1582
1583 let (request, respond_tx) = req_rx.next().await.unwrap();
1584 assert_eq!(
1585 request.excerpt_path.as_ref(),
1586 Path::new(path!("root/foo.md"))
1587 );
1588 assert_eq!(
1589 request.cursor_point,
1590 Point {
1591 line: Line(1),
1592 column: 3
1593 }
1594 );
1595
1596 respond_tx
1597 .send(predict_edits_v3::PredictEditsResponse {
1598 request_id: Uuid::new_v4(),
1599 edits: vec![predict_edits_v3::Edit {
1600 path: Path::new(path!("root/foo.md")).into(),
1601 range: Line(0)..Line(snapshot.max_point().row + 1),
1602 content: "Hello!\nHow are you?\nBye".into(),
1603 }],
1604 debug_info: None,
1605 })
1606 .unwrap();
1607
1608 let prediction = prediction_task.await.unwrap().unwrap();
1609
1610 assert_eq!(prediction.edits.len(), 1);
1611 assert_eq!(
1612 prediction.edits[0].0.to_point(&snapshot).start,
1613 language::Point::new(1, 3)
1614 );
1615 assert_eq!(prediction.edits[0].1, " are you?");
1616 }
1617
1618 #[gpui::test]
1619 async fn test_request_events(cx: &mut TestAppContext) {
1620 let (zeta, mut req_rx) = init_test(cx);
1621 let fs = FakeFs::new(cx.executor());
1622 fs.insert_tree(
1623 "/root",
1624 json!({
1625 "foo.md": "Hello!\n\nBye"
1626 }),
1627 )
1628 .await;
1629 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1630
1631 let buffer = project
1632 .update(cx, |project, cx| {
1633 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1634 project.open_buffer(path, cx)
1635 })
1636 .await
1637 .unwrap();
1638
1639 zeta.update(cx, |zeta, cx| {
1640 zeta.register_buffer(&buffer, &project, cx);
1641 });
1642
1643 buffer.update(cx, |buffer, cx| {
1644 buffer.edit(vec![(7..7, "How")], None, cx);
1645 });
1646
1647 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1648 let position = snapshot.anchor_before(language::Point::new(1, 3));
1649
1650 let prediction_task = zeta.update(cx, |zeta, cx| {
1651 zeta.request_prediction(&project, &buffer, position, cx)
1652 });
1653
1654 let (request, respond_tx) = req_rx.next().await.unwrap();
1655
1656 assert_eq!(request.events.len(), 1);
1657 assert_eq!(
1658 request.events[0],
1659 predict_edits_v3::Event::BufferChange {
1660 path: Some(PathBuf::from(path!("root/foo.md"))),
1661 old_path: None,
1662 diff: indoc! {"
1663 @@ -1,3 +1,3 @@
1664 Hello!
1665 -
1666 +How
1667 Bye
1668 "}
1669 .to_string(),
1670 predicted: false
1671 }
1672 );
1673
1674 respond_tx
1675 .send(predict_edits_v3::PredictEditsResponse {
1676 request_id: Uuid::new_v4(),
1677 edits: vec![predict_edits_v3::Edit {
1678 path: Path::new(path!("root/foo.md")).into(),
1679 range: Line(0)..Line(snapshot.max_point().row + 1),
1680 content: "Hello!\nHow are you?\nBye".into(),
1681 }],
1682 debug_info: None,
1683 })
1684 .unwrap();
1685
1686 let prediction = prediction_task.await.unwrap().unwrap();
1687
1688 assert_eq!(prediction.edits.len(), 1);
1689 assert_eq!(
1690 prediction.edits[0].0.to_point(&snapshot).start,
1691 language::Point::new(1, 3)
1692 );
1693 assert_eq!(prediction.edits[0].1, " are you?");
1694 }
1695
1696 #[gpui::test]
1697 async fn test_request_diagnostics(cx: &mut TestAppContext) {
1698 let (zeta, mut req_rx) = init_test(cx);
1699 let fs = FakeFs::new(cx.executor());
1700 fs.insert_tree(
1701 "/root",
1702 json!({
1703 "foo.md": "Hello!\nBye"
1704 }),
1705 )
1706 .await;
1707 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1708
1709 let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1710 let diagnostic = lsp::Diagnostic {
1711 range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1712 severity: Some(lsp::DiagnosticSeverity::ERROR),
1713 message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1714 ..Default::default()
1715 };
1716
1717 project.update(cx, |project, cx| {
1718 project.lsp_store().update(cx, |lsp_store, cx| {
1719 // Create some diagnostics
1720 lsp_store
1721 .update_diagnostics(
1722 LanguageServerId(0),
1723 lsp::PublishDiagnosticsParams {
1724 uri: path_to_buffer_uri.clone(),
1725 diagnostics: vec![diagnostic],
1726 version: None,
1727 },
1728 None,
1729 language::DiagnosticSourceKind::Pushed,
1730 &[],
1731 cx,
1732 )
1733 .unwrap();
1734 });
1735 });
1736
1737 let buffer = project
1738 .update(cx, |project, cx| {
1739 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1740 project.open_buffer(path, cx)
1741 })
1742 .await
1743 .unwrap();
1744
1745 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1746 let position = snapshot.anchor_before(language::Point::new(0, 0));
1747
1748 let _prediction_task = zeta.update(cx, |zeta, cx| {
1749 zeta.request_prediction(&project, &buffer, position, cx)
1750 });
1751
1752 let (request, _respond_tx) = req_rx.next().await.unwrap();
1753
1754 assert_eq!(request.diagnostic_groups.len(), 1);
1755 let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1756 .unwrap();
1757 // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1758 assert_eq!(
1759 value,
1760 json!({
1761 "entries": [{
1762 "range": {
1763 "start": 8,
1764 "end": 10
1765 },
1766 "diagnostic": {
1767 "source": null,
1768 "code": null,
1769 "code_description": null,
1770 "severity": 1,
1771 "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1772 "markdown": null,
1773 "group_id": 0,
1774 "is_primary": true,
1775 "is_disk_based": false,
1776 "is_unnecessary": false,
1777 "source_kind": "Pushed",
1778 "data": null,
1779 "underline": true
1780 }
1781 }],
1782 "primary_ix": 0
1783 })
1784 );
1785 }
1786
1787 fn init_test(
1788 cx: &mut TestAppContext,
1789 ) -> (
1790 Entity<Zeta>,
1791 mpsc::UnboundedReceiver<(
1792 predict_edits_v3::PredictEditsRequest,
1793 oneshot::Sender<predict_edits_v3::PredictEditsResponse>,
1794 )>,
1795 ) {
1796 cx.update(move |cx| {
1797 let settings_store = SettingsStore::test(cx);
1798 cx.set_global(settings_store);
1799 language::init(cx);
1800 Project::init_settings(cx);
1801
1802 let (req_tx, req_rx) = mpsc::unbounded();
1803
1804 let http_client = FakeHttpClient::create({
1805 move |req| {
1806 let uri = req.uri().path().to_string();
1807 let mut body = req.into_body();
1808 let req_tx = req_tx.clone();
1809 async move {
1810 let resp = match uri.as_str() {
1811 "/client/llm_tokens" => serde_json::to_string(&json!({
1812 "token": "test"
1813 }))
1814 .unwrap(),
1815 "/predict_edits/v3" => {
1816 let mut buf = Vec::new();
1817 body.read_to_end(&mut buf).await.ok();
1818 let req = serde_json::from_slice(&buf).unwrap();
1819
1820 let (res_tx, res_rx) = oneshot::channel();
1821 req_tx.unbounded_send((req, res_tx)).unwrap();
1822 serde_json::to_string(&res_rx.await?).unwrap()
1823 }
1824 _ => {
1825 panic!("Unexpected path: {}", uri)
1826 }
1827 };
1828
1829 Ok(Response::builder().body(resp.into()).unwrap())
1830 }
1831 }
1832 });
1833
1834 let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1835 client.cloud_client().set_credentials(1, "test".into());
1836
1837 language_model::init(client.clone(), cx);
1838
1839 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1840 let zeta = Zeta::global(&client, &user_store, cx);
1841
1842 (zeta, req_rx)
1843 })
1844 }
1845}