1use anyhow::{Context as _, Result, anyhow};
2use chrono::TimeDelta;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
5use cloud_llm_client::{
6 AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
7 ZED_VERSION_HEADER_NAME,
8};
9use cloud_zeta2_prompt::{DEFAULT_MAX_PROMPT_BYTES, build_prompt};
10use collections::HashMap;
11use edit_prediction_context::{
12 DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
13 EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionScoreOptions, Line,
14 SyntaxIndex, SyntaxIndexState,
15};
16use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
17use futures::AsyncReadExt as _;
18use futures::channel::{mpsc, oneshot};
19use gpui::http_client::{AsyncBody, Method};
20use gpui::{
21 App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
22 http_client, prelude::*,
23};
24use language::{Anchor, Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
25use language::{BufferSnapshot, OffsetRangeExt};
26use language_model::{LlmApiToken, RefreshLlmTokenListener};
27use project::Project;
28use release_channel::AppVersion;
29use serde::de::DeserializeOwned;
30use std::collections::{VecDeque, hash_map};
31use std::ops::Range;
32use std::path::Path;
33use std::str::FromStr as _;
34use std::sync::Arc;
35use std::time::{Duration, Instant};
36use thiserror::Error;
37use util::ResultExt as _;
38use util::rel_path::RelPathBuf;
39use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
40
41mod merge_excerpts;
42mod prediction;
43mod provider;
44mod related_excerpts;
45
46use crate::merge_excerpts::merge_excerpts;
47use crate::prediction::EditPrediction;
48use crate::related_excerpts::find_related_excerpts;
49pub use crate::related_excerpts::{LlmContextOptions, SearchToolQuery};
50pub use provider::ZetaEditPredictionProvider;
51
52const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
53
54/// Maximum number of events to track.
55const MAX_EVENT_COUNT: usize = 16;
56
57pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
58 max_bytes: 512,
59 min_bytes: 128,
60 target_before_cursor_over_total_bytes: 0.5,
61};
62
63pub const DEFAULT_CONTEXT_OPTIONS: ContextMode = ContextMode::Llm(DEFAULT_LLM_CONTEXT_OPTIONS);
64
65pub const DEFAULT_LLM_CONTEXT_OPTIONS: LlmContextOptions = LlmContextOptions {
66 excerpt: DEFAULT_EXCERPT_OPTIONS,
67};
68
69pub const DEFAULT_SYNTAX_CONTEXT_OPTIONS: EditPredictionContextOptions =
70 EditPredictionContextOptions {
71 use_imports: true,
72 max_retrieved_declarations: 0,
73 excerpt: DEFAULT_EXCERPT_OPTIONS,
74 score: EditPredictionScoreOptions {
75 omit_excerpt_overlaps: true,
76 },
77 };
78
79pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
80 context: DEFAULT_CONTEXT_OPTIONS,
81 max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
82 max_diagnostic_bytes: 2048,
83 prompt_format: PromptFormat::DEFAULT,
84 file_indexing_parallelism: 1,
85};
86
87pub struct Zeta2FeatureFlag;
88
89impl FeatureFlag for Zeta2FeatureFlag {
90 const NAME: &'static str = "zeta2";
91
92 fn enabled_for_staff() -> bool {
93 false
94 }
95}
96
97#[derive(Clone)]
98struct ZetaGlobal(Entity<Zeta>);
99
100impl Global for ZetaGlobal {}
101
102pub struct Zeta {
103 client: Arc<Client>,
104 user_store: Entity<UserStore>,
105 llm_token: LlmApiToken,
106 _llm_token_subscription: Subscription,
107 projects: HashMap<EntityId, ZetaProject>,
108 options: ZetaOptions,
109 update_required: bool,
110 debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
111}
112
113#[derive(Debug, Clone, PartialEq)]
114pub struct ZetaOptions {
115 pub context: ContextMode,
116 pub max_prompt_bytes: usize,
117 pub max_diagnostic_bytes: usize,
118 pub prompt_format: predict_edits_v3::PromptFormat,
119 pub file_indexing_parallelism: usize,
120}
121
122#[derive(Debug, Clone, PartialEq)]
123pub enum ContextMode {
124 Llm(LlmContextOptions),
125 Syntax(EditPredictionContextOptions),
126}
127
128impl ContextMode {
129 pub fn excerpt(&self) -> &EditPredictionExcerptOptions {
130 match self {
131 ContextMode::Llm(options) => &options.excerpt,
132 ContextMode::Syntax(options) => &options.excerpt,
133 }
134 }
135}
136
137pub enum ZetaDebugInfo {
138 ContextRetrievalStarted(ZetaContextRetrievalDebugInfo),
139 SearchQueriesGenerated(ZetaSearchQueryDebugInfo),
140 SearchQueriesExecuted(ZetaContextRetrievalDebugInfo),
141 SearchResultsFiltered(ZetaContextRetrievalDebugInfo),
142 ContextRetrievalFinished(ZetaContextRetrievalDebugInfo),
143 EditPredicted(ZetaEditPredictionDebugInfo),
144}
145
146pub struct ZetaContextRetrievalDebugInfo {
147 pub project: Entity<Project>,
148 pub timestamp: Instant,
149}
150
151pub struct ZetaEditPredictionDebugInfo {
152 pub request: predict_edits_v3::PredictEditsRequest,
153 pub retrieval_time: TimeDelta,
154 pub buffer: WeakEntity<Buffer>,
155 pub position: language::Anchor,
156 pub local_prompt: Result<String, String>,
157 pub response_rx: oneshot::Receiver<Result<predict_edits_v3::PredictEditsResponse, String>>,
158}
159
160pub struct ZetaSearchQueryDebugInfo {
161 pub project: Entity<Project>,
162 pub timestamp: Instant,
163 pub queries: Vec<SearchToolQuery>,
164}
165
166pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
167
168struct ZetaProject {
169 syntax_index: Entity<SyntaxIndex>,
170 events: VecDeque<Event>,
171 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
172 current_prediction: Option<CurrentEditPrediction>,
173 context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
174 refresh_context_task: Option<Task<Option<()>>>,
175 refresh_context_debounce_task: Option<Task<Option<()>>>,
176 refresh_context_timestamp: Option<Instant>,
177}
178
179#[derive(Debug, Clone)]
180struct CurrentEditPrediction {
181 pub requested_by_buffer_id: EntityId,
182 pub prediction: EditPrediction,
183}
184
185impl CurrentEditPrediction {
186 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
187 let Some(new_edits) = self
188 .prediction
189 .interpolate(&self.prediction.buffer.read(cx))
190 else {
191 return false;
192 };
193
194 if self.prediction.buffer != old_prediction.prediction.buffer {
195 return true;
196 }
197
198 let Some(old_edits) = old_prediction
199 .prediction
200 .interpolate(&old_prediction.prediction.buffer.read(cx))
201 else {
202 return true;
203 };
204
205 // This reduces the occurrence of UI thrash from replacing edits
206 //
207 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
208 if self.requested_by_buffer_id == self.prediction.buffer.entity_id()
209 && self.requested_by_buffer_id == old_prediction.prediction.buffer.entity_id()
210 && old_edits.len() == 1
211 && new_edits.len() == 1
212 {
213 let (old_range, old_text) = &old_edits[0];
214 let (new_range, new_text) = &new_edits[0];
215 new_range == old_range && new_text.starts_with(old_text)
216 } else {
217 true
218 }
219 }
220}
221
222/// A prediction from the perspective of a buffer.
223#[derive(Debug)]
224enum BufferEditPrediction<'a> {
225 Local { prediction: &'a EditPrediction },
226 Jump { prediction: &'a EditPrediction },
227}
228
229struct RegisteredBuffer {
230 snapshot: BufferSnapshot,
231 _subscriptions: [gpui::Subscription; 2],
232}
233
234#[derive(Clone)]
235pub enum Event {
236 BufferChange {
237 old_snapshot: BufferSnapshot,
238 new_snapshot: BufferSnapshot,
239 timestamp: Instant,
240 },
241}
242
243impl Event {
244 pub fn to_request_event(&self, cx: &App) -> Option<predict_edits_v3::Event> {
245 match self {
246 Event::BufferChange {
247 old_snapshot,
248 new_snapshot,
249 ..
250 } => {
251 let path = new_snapshot.file().map(|f| f.full_path(cx));
252
253 let old_path = old_snapshot.file().and_then(|f| {
254 let old_path = f.full_path(cx);
255 if Some(&old_path) != path.as_ref() {
256 Some(old_path)
257 } else {
258 None
259 }
260 });
261
262 // TODO [zeta2] move to bg?
263 let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
264
265 if path == old_path && diff.is_empty() {
266 None
267 } else {
268 Some(predict_edits_v3::Event::BufferChange {
269 old_path,
270 path,
271 diff,
272 //todo: Actually detect if this edit was predicted or not
273 predicted: false,
274 })
275 }
276 }
277 }
278 }
279}
280
281impl Zeta {
282 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
283 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
284 }
285
286 pub fn global(
287 client: &Arc<Client>,
288 user_store: &Entity<UserStore>,
289 cx: &mut App,
290 ) -> Entity<Self> {
291 cx.try_global::<ZetaGlobal>()
292 .map(|global| global.0.clone())
293 .unwrap_or_else(|| {
294 let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
295 cx.set_global(ZetaGlobal(zeta.clone()));
296 zeta
297 })
298 }
299
300 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
301 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
302
303 Self {
304 projects: HashMap::default(),
305 client,
306 user_store,
307 options: DEFAULT_OPTIONS,
308 llm_token: LlmApiToken::default(),
309 _llm_token_subscription: cx.subscribe(
310 &refresh_llm_token_listener,
311 |this, _listener, _event, cx| {
312 let client = this.client.clone();
313 let llm_token = this.llm_token.clone();
314 cx.spawn(async move |_this, _cx| {
315 llm_token.refresh(&client).await?;
316 anyhow::Ok(())
317 })
318 .detach_and_log_err(cx);
319 },
320 ),
321 update_required: false,
322 debug_tx: None,
323 }
324 }
325
326 pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<ZetaDebugInfo> {
327 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
328 self.debug_tx = Some(debug_watch_tx);
329 debug_watch_rx
330 }
331
332 pub fn options(&self) -> &ZetaOptions {
333 &self.options
334 }
335
336 pub fn set_options(&mut self, options: ZetaOptions) {
337 self.options = options;
338 }
339
340 pub fn clear_history(&mut self) {
341 for zeta_project in self.projects.values_mut() {
342 zeta_project.events.clear();
343 }
344 }
345
346 pub fn history_for_project(&self, project: &Entity<Project>) -> impl Iterator<Item = &Event> {
347 self.projects
348 .get(&project.entity_id())
349 .map(|project| project.events.iter())
350 .into_iter()
351 .flatten()
352 }
353
354 pub fn context_for_project(
355 &self,
356 project: &Entity<Project>,
357 ) -> impl Iterator<Item = (Entity<Buffer>, &[Range<Anchor>])> {
358 self.projects
359 .get(&project.entity_id())
360 .and_then(|project| {
361 Some(
362 project
363 .context
364 .as_ref()?
365 .iter()
366 .map(|(buffer, ranges)| (buffer.clone(), ranges.as_slice())),
367 )
368 })
369 .into_iter()
370 .flatten()
371 }
372
373 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
374 self.user_store.read(cx).edit_prediction_usage()
375 }
376
377 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
378 self.get_or_init_zeta_project(project, cx);
379 }
380
381 pub fn register_buffer(
382 &mut self,
383 buffer: &Entity<Buffer>,
384 project: &Entity<Project>,
385 cx: &mut Context<Self>,
386 ) {
387 let zeta_project = self.get_or_init_zeta_project(project, cx);
388 Self::register_buffer_impl(zeta_project, buffer, project, cx);
389 }
390
391 fn get_or_init_zeta_project(
392 &mut self,
393 project: &Entity<Project>,
394 cx: &mut App,
395 ) -> &mut ZetaProject {
396 self.projects
397 .entry(project.entity_id())
398 .or_insert_with(|| ZetaProject {
399 syntax_index: cx.new(|cx| {
400 SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
401 }),
402 events: VecDeque::new(),
403 registered_buffers: HashMap::default(),
404 current_prediction: None,
405 context: None,
406 refresh_context_task: None,
407 refresh_context_debounce_task: None,
408 refresh_context_timestamp: None,
409 })
410 }
411
412 fn register_buffer_impl<'a>(
413 zeta_project: &'a mut ZetaProject,
414 buffer: &Entity<Buffer>,
415 project: &Entity<Project>,
416 cx: &mut Context<Self>,
417 ) -> &'a mut RegisteredBuffer {
418 let buffer_id = buffer.entity_id();
419 match zeta_project.registered_buffers.entry(buffer_id) {
420 hash_map::Entry::Occupied(entry) => entry.into_mut(),
421 hash_map::Entry::Vacant(entry) => {
422 let snapshot = buffer.read(cx).snapshot();
423 let project_entity_id = project.entity_id();
424 entry.insert(RegisteredBuffer {
425 snapshot,
426 _subscriptions: [
427 cx.subscribe(buffer, {
428 let project = project.downgrade();
429 move |this, buffer, event, cx| {
430 if let language::BufferEvent::Edited = event
431 && let Some(project) = project.upgrade()
432 {
433 this.report_changes_for_buffer(&buffer, &project, cx);
434 }
435 }
436 }),
437 cx.observe_release(buffer, move |this, _buffer, _cx| {
438 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
439 else {
440 return;
441 };
442 zeta_project.registered_buffers.remove(&buffer_id);
443 }),
444 ],
445 })
446 }
447 }
448 }
449
450 fn report_changes_for_buffer(
451 &mut self,
452 buffer: &Entity<Buffer>,
453 project: &Entity<Project>,
454 cx: &mut Context<Self>,
455 ) -> BufferSnapshot {
456 let zeta_project = self.get_or_init_zeta_project(project, cx);
457 let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
458
459 let new_snapshot = buffer.read(cx).snapshot();
460 if new_snapshot.version != registered_buffer.snapshot.version {
461 let old_snapshot =
462 std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
463 Self::push_event(
464 zeta_project,
465 Event::BufferChange {
466 old_snapshot,
467 new_snapshot: new_snapshot.clone(),
468 timestamp: Instant::now(),
469 },
470 );
471 }
472
473 new_snapshot
474 }
475
476 fn push_event(zeta_project: &mut ZetaProject, event: Event) {
477 let events = &mut zeta_project.events;
478
479 if let Some(Event::BufferChange {
480 new_snapshot: last_new_snapshot,
481 timestamp: last_timestamp,
482 ..
483 }) = events.back_mut()
484 {
485 // Coalesce edits for the same buffer when they happen one after the other.
486 let Event::BufferChange {
487 old_snapshot,
488 new_snapshot,
489 timestamp,
490 } = &event;
491
492 if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
493 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
494 && old_snapshot.version == last_new_snapshot.version
495 {
496 *last_new_snapshot = new_snapshot.clone();
497 *last_timestamp = *timestamp;
498 return;
499 }
500 }
501
502 if events.len() >= MAX_EVENT_COUNT {
503 // These are halved instead of popping to improve prompt caching.
504 events.drain(..MAX_EVENT_COUNT / 2);
505 }
506
507 events.push_back(event);
508 }
509
510 fn current_prediction_for_buffer(
511 &self,
512 buffer: &Entity<Buffer>,
513 project: &Entity<Project>,
514 cx: &App,
515 ) -> Option<BufferEditPrediction<'_>> {
516 let project_state = self.projects.get(&project.entity_id())?;
517
518 let CurrentEditPrediction {
519 requested_by_buffer_id,
520 prediction,
521 } = project_state.current_prediction.as_ref()?;
522
523 if prediction.targets_buffer(buffer.read(cx), cx) {
524 Some(BufferEditPrediction::Local { prediction })
525 } else if *requested_by_buffer_id == buffer.entity_id() {
526 Some(BufferEditPrediction::Jump { prediction })
527 } else {
528 None
529 }
530 }
531
532 fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
533 let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
534 return;
535 };
536
537 let Some(prediction) = project_state.current_prediction.take() else {
538 return;
539 };
540 let request_id = prediction.prediction.id.into();
541
542 let client = self.client.clone();
543 let llm_token = self.llm_token.clone();
544 let app_version = AppVersion::global(cx);
545 cx.spawn(async move |this, cx| {
546 let url = if let Ok(predict_edits_url) = std::env::var("ZED_ACCEPT_PREDICTION_URL") {
547 http_client::Url::parse(&predict_edits_url)?
548 } else {
549 client
550 .http_client()
551 .build_zed_llm_url("/predict_edits/accept", &[])?
552 };
553
554 let response = cx
555 .background_spawn(Self::send_api_request::<()>(
556 move |builder| {
557 let req = builder.uri(url.as_ref()).body(
558 serde_json::to_string(&AcceptEditPredictionBody { request_id })?.into(),
559 );
560 Ok(req?)
561 },
562 client,
563 llm_token,
564 app_version,
565 ))
566 .await;
567
568 Self::handle_api_response(&this, response, cx)?;
569 anyhow::Ok(())
570 })
571 .detach_and_log_err(cx);
572 }
573
574 fn discard_current_prediction(&mut self, project: &Entity<Project>) {
575 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
576 project_state.current_prediction.take();
577 };
578 }
579
580 pub fn refresh_prediction(
581 &mut self,
582 project: &Entity<Project>,
583 buffer: &Entity<Buffer>,
584 position: language::Anchor,
585 cx: &mut Context<Self>,
586 ) -> Task<Result<()>> {
587 let request_task = self.request_prediction(project, buffer, position, cx);
588 let buffer = buffer.clone();
589 let project = project.clone();
590
591 cx.spawn(async move |this, cx| {
592 if let Some(prediction) = request_task.await? {
593 this.update(cx, |this, cx| {
594 let project_state = this
595 .projects
596 .get_mut(&project.entity_id())
597 .context("Project not found")?;
598
599 let new_prediction = CurrentEditPrediction {
600 requested_by_buffer_id: buffer.entity_id(),
601 prediction: prediction,
602 };
603
604 if project_state
605 .current_prediction
606 .as_ref()
607 .is_none_or(|old_prediction| {
608 new_prediction.should_replace_prediction(&old_prediction, cx)
609 })
610 {
611 project_state.current_prediction = Some(new_prediction);
612 }
613 anyhow::Ok(())
614 })??;
615 }
616 Ok(())
617 })
618 }
619
620 fn request_prediction(
621 &mut self,
622 project: &Entity<Project>,
623 buffer: &Entity<Buffer>,
624 position: language::Anchor,
625 cx: &mut Context<Self>,
626 ) -> Task<Result<Option<EditPrediction>>> {
627 let project_state = self.projects.get(&project.entity_id());
628
629 let index_state = project_state.map(|state| {
630 state
631 .syntax_index
632 .read_with(cx, |index, _cx| index.state().clone())
633 });
634 let options = self.options.clone();
635 let snapshot = buffer.read(cx).snapshot();
636 let Some(excerpt_path) = snapshot
637 .file()
638 .map(|path| -> Arc<Path> { path.full_path(cx).into() })
639 else {
640 return Task::ready(Err(anyhow!("No file path for excerpt")));
641 };
642 let client = self.client.clone();
643 let llm_token = self.llm_token.clone();
644 let app_version = AppVersion::global(cx);
645 let worktree_snapshots = project
646 .read(cx)
647 .worktrees(cx)
648 .map(|worktree| worktree.read(cx).snapshot())
649 .collect::<Vec<_>>();
650 let debug_tx = self.debug_tx.clone();
651
652 let events = project_state
653 .map(|state| {
654 state
655 .events
656 .iter()
657 .filter_map(|event| event.to_request_event(cx))
658 .collect::<Vec<_>>()
659 })
660 .unwrap_or_default();
661
662 let diagnostics = snapshot.diagnostic_sets().clone();
663
664 let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
665 let mut path = f.worktree.read(cx).absolutize(&f.path);
666 if path.pop() { Some(path) } else { None }
667 });
668
669 // TODO data collection
670 let can_collect_data = cx.is_staff();
671
672 let mut included_files = project_state
673 .and_then(|project_state| project_state.context.as_ref())
674 .unwrap_or(&HashMap::default())
675 .iter()
676 .filter_map(|(buffer, ranges)| {
677 let buffer = buffer.read(cx);
678 Some((
679 buffer.snapshot(),
680 buffer.file()?.full_path(cx).into(),
681 ranges.clone(),
682 ))
683 })
684 .collect::<Vec<_>>();
685
686 let request_task = cx.background_spawn({
687 let snapshot = snapshot.clone();
688 let buffer = buffer.clone();
689 async move {
690 let index_state = if let Some(index_state) = index_state {
691 Some(index_state.lock_owned().await)
692 } else {
693 None
694 };
695
696 let cursor_offset = position.to_offset(&snapshot);
697 let cursor_point = cursor_offset.to_point(&snapshot);
698
699 let before_retrieval = chrono::Utc::now();
700
701 let (diagnostic_groups, diagnostic_groups_truncated) =
702 Self::gather_nearby_diagnostics(
703 cursor_offset,
704 &diagnostics,
705 &snapshot,
706 options.max_diagnostic_bytes,
707 );
708
709 let request = match options.context {
710 ContextMode::Llm(context_options) => {
711 let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
712 cursor_point,
713 &snapshot,
714 &context_options.excerpt,
715 index_state.as_deref(),
716 ) else {
717 return Ok((None, None));
718 };
719
720 let excerpt_anchor_range = snapshot.anchor_after(excerpt.range.start)
721 ..snapshot.anchor_before(excerpt.range.end);
722
723 if let Some(buffer_ix) = included_files
724 .iter()
725 .position(|(buffer, _, _)| buffer.remote_id() == snapshot.remote_id())
726 {
727 let (buffer, _, ranges) = &mut included_files[buffer_ix];
728 let range_ix = ranges
729 .binary_search_by(|probe| {
730 probe
731 .start
732 .cmp(&excerpt_anchor_range.start, buffer)
733 .then(excerpt_anchor_range.end.cmp(&probe.end, buffer))
734 })
735 .unwrap_or_else(|ix| ix);
736
737 ranges.insert(range_ix, excerpt_anchor_range);
738 let last_ix = included_files.len() - 1;
739 included_files.swap(buffer_ix, last_ix);
740 } else {
741 included_files.push((
742 snapshot,
743 excerpt_path.clone(),
744 vec![excerpt_anchor_range],
745 ));
746 }
747
748 let included_files = included_files
749 .into_iter()
750 .map(|(buffer, path, ranges)| {
751 let excerpts = merge_excerpts(
752 &buffer,
753 ranges.iter().map(|range| {
754 let point_range = range.to_point(&buffer);
755 Line(point_range.start.row)..Line(point_range.end.row)
756 }),
757 );
758 predict_edits_v3::IncludedFile {
759 path,
760 max_row: Line(buffer.max_point().row),
761 excerpts,
762 }
763 })
764 .collect::<Vec<_>>();
765
766 predict_edits_v3::PredictEditsRequest {
767 excerpt_path,
768 excerpt: String::new(),
769 excerpt_line_range: Line(0)..Line(0),
770 excerpt_range: 0..0,
771 cursor_point: predict_edits_v3::Point {
772 line: predict_edits_v3::Line(cursor_point.row),
773 column: cursor_point.column,
774 },
775 included_files,
776 referenced_declarations: vec![],
777 events,
778 can_collect_data,
779 diagnostic_groups,
780 diagnostic_groups_truncated,
781 debug_info: debug_tx.is_some(),
782 prompt_max_bytes: Some(options.max_prompt_bytes),
783 prompt_format: options.prompt_format,
784 // TODO [zeta2]
785 signatures: vec![],
786 excerpt_parent: None,
787 git_info: None,
788 }
789 }
790 ContextMode::Syntax(context_options) => {
791 let Some(context) = EditPredictionContext::gather_context(
792 cursor_point,
793 &snapshot,
794 parent_abs_path.as_deref(),
795 &context_options,
796 index_state.as_deref(),
797 ) else {
798 return Ok((None, None));
799 };
800
801 make_syntax_context_cloud_request(
802 excerpt_path,
803 context,
804 events,
805 can_collect_data,
806 diagnostic_groups,
807 diagnostic_groups_truncated,
808 None,
809 debug_tx.is_some(),
810 &worktree_snapshots,
811 index_state.as_deref(),
812 Some(options.max_prompt_bytes),
813 options.prompt_format,
814 )
815 }
816 };
817
818 let retrieval_time = chrono::Utc::now() - before_retrieval;
819
820 let debug_response_tx = if let Some(debug_tx) = &debug_tx {
821 let (response_tx, response_rx) = oneshot::channel();
822
823 let local_prompt = build_prompt(&request)
824 .map(|(prompt, _)| prompt)
825 .map_err(|err| err.to_string());
826
827 debug_tx
828 .unbounded_send(ZetaDebugInfo::EditPredicted(ZetaEditPredictionDebugInfo {
829 request: request.clone(),
830 retrieval_time,
831 buffer: buffer.downgrade(),
832 local_prompt,
833 position,
834 response_rx,
835 }))
836 .ok();
837 Some(response_tx)
838 } else {
839 None
840 };
841
842 if cfg!(debug_assertions) && std::env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
843 if let Some(debug_response_tx) = debug_response_tx {
844 debug_response_tx
845 .send(Err("Request skipped".to_string()))
846 .ok();
847 }
848 anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
849 }
850
851 let response =
852 Self::send_prediction_request(client, llm_token, app_version, request).await;
853
854 if let Some(debug_response_tx) = debug_response_tx {
855 debug_response_tx
856 .send(
857 response
858 .as_ref()
859 .map_err(|err| err.to_string())
860 .map(|response| response.0.clone()),
861 )
862 .ok();
863 }
864
865 response.map(|(res, usage)| (Some(res), usage))
866 }
867 });
868
869 let buffer = buffer.clone();
870
871 cx.spawn({
872 let project = project.clone();
873 async move |this, cx| {
874 let Some(response) = Self::handle_api_response(&this, request_task.await, cx)?
875 else {
876 return Ok(None);
877 };
878
879 // TODO telemetry: duration, etc
880 Ok(EditPrediction::from_response(response, &snapshot, &buffer, &project, cx).await)
881 }
882 })
883 }
884
885 async fn send_prediction_request(
886 client: Arc<Client>,
887 llm_token: LlmApiToken,
888 app_version: SemanticVersion,
889 request: predict_edits_v3::PredictEditsRequest,
890 ) -> Result<(
891 predict_edits_v3::PredictEditsResponse,
892 Option<EditPredictionUsage>,
893 )> {
894 let url = if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
895 http_client::Url::parse(&predict_edits_url)?
896 } else {
897 client
898 .http_client()
899 .build_zed_llm_url("/predict_edits/v3", &[])?
900 };
901
902 Self::send_api_request(
903 |builder| {
904 let req = builder
905 .uri(url.as_ref())
906 .body(serde_json::to_string(&request)?.into());
907 Ok(req?)
908 },
909 client,
910 llm_token,
911 app_version,
912 )
913 .await
914 }
915
916 fn handle_api_response<T>(
917 this: &WeakEntity<Self>,
918 response: Result<(T, Option<EditPredictionUsage>)>,
919 cx: &mut gpui::AsyncApp,
920 ) -> Result<T> {
921 match response {
922 Ok((data, usage)) => {
923 if let Some(usage) = usage {
924 this.update(cx, |this, cx| {
925 this.user_store.update(cx, |user_store, cx| {
926 user_store.update_edit_prediction_usage(usage, cx);
927 });
928 })
929 .ok();
930 }
931 Ok(data)
932 }
933 Err(err) => {
934 if err.is::<ZedUpdateRequiredError>() {
935 cx.update(|cx| {
936 this.update(cx, |this, _cx| {
937 this.update_required = true;
938 })
939 .ok();
940
941 let error_message: SharedString = err.to_string().into();
942 show_app_notification(
943 NotificationId::unique::<ZedUpdateRequiredError>(),
944 cx,
945 move |cx| {
946 cx.new(|cx| {
947 ErrorMessagePrompt::new(error_message.clone(), cx)
948 .with_link_button("Update Zed", "https://zed.dev/releases")
949 })
950 },
951 );
952 })
953 .ok();
954 }
955 Err(err)
956 }
957 }
958 }
959
960 async fn send_api_request<Res>(
961 build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
962 client: Arc<Client>,
963 llm_token: LlmApiToken,
964 app_version: SemanticVersion,
965 ) -> Result<(Res, Option<EditPredictionUsage>)>
966 where
967 Res: DeserializeOwned,
968 {
969 let http_client = client.http_client();
970 let mut token = llm_token.acquire(&client).await?;
971 let mut did_retry = false;
972
973 loop {
974 let request_builder = http_client::Request::builder().method(Method::POST);
975
976 let request = build(
977 request_builder
978 .header("Content-Type", "application/json")
979 .header("Authorization", format!("Bearer {}", token))
980 .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
981 )?;
982
983 let mut response = http_client.send(request).await?;
984
985 if let Some(minimum_required_version) = response
986 .headers()
987 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
988 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
989 {
990 anyhow::ensure!(
991 app_version >= minimum_required_version,
992 ZedUpdateRequiredError {
993 minimum_version: minimum_required_version
994 }
995 );
996 }
997
998 if response.status().is_success() {
999 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
1000
1001 let mut body = Vec::new();
1002 response.body_mut().read_to_end(&mut body).await?;
1003 return Ok((serde_json::from_slice(&body)?, usage));
1004 } else if !did_retry
1005 && response
1006 .headers()
1007 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1008 .is_some()
1009 {
1010 did_retry = true;
1011 token = llm_token.refresh(&client).await?;
1012 } else {
1013 let mut body = String::new();
1014 response.body_mut().read_to_string(&mut body).await?;
1015 anyhow::bail!(
1016 "Request failed with status: {:?}\nBody: {}",
1017 response.status(),
1018 body
1019 );
1020 }
1021 }
1022 }
1023
1024 pub const CONTEXT_RETRIEVAL_IDLE_DURATION: Duration = Duration::from_secs(10);
1025 pub const CONTEXT_RETRIEVAL_DEBOUNCE_DURATION: Duration = Duration::from_secs(3);
1026
1027 // Refresh the related excerpts when the user just beguns editing after
1028 // an idle period, and after they pause editing.
1029 fn refresh_context_if_needed(
1030 &mut self,
1031 project: &Entity<Project>,
1032 buffer: &Entity<language::Buffer>,
1033 cursor_position: language::Anchor,
1034 cx: &mut Context<Self>,
1035 ) {
1036 if !matches!(&self.options().context, ContextMode::Llm { .. }) {
1037 return;
1038 }
1039
1040 let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1041 return;
1042 };
1043
1044 let now = Instant::now();
1045 let was_idle = zeta_project
1046 .refresh_context_timestamp
1047 .map_or(true, |timestamp| {
1048 now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION
1049 });
1050 zeta_project.refresh_context_timestamp = Some(now);
1051 zeta_project.refresh_context_debounce_task = Some(cx.spawn({
1052 let buffer = buffer.clone();
1053 let project = project.clone();
1054 async move |this, cx| {
1055 if was_idle {
1056 log::debug!("refetching edit prediction context after idle");
1057 } else {
1058 cx.background_executor()
1059 .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION)
1060 .await;
1061 log::debug!("refetching edit prediction context after pause");
1062 }
1063 this.update(cx, |this, cx| {
1064 this.refresh_context(project, buffer, cursor_position, cx);
1065 })
1066 .ok()
1067 }
1068 }));
1069 }
1070
1071 // Refresh the related excerpts asynchronously. Ensure the task runs to completion,
1072 // and avoid spawning more than one concurrent task.
1073 fn refresh_context(
1074 &mut self,
1075 project: Entity<Project>,
1076 buffer: Entity<language::Buffer>,
1077 cursor_position: language::Anchor,
1078 cx: &mut Context<Self>,
1079 ) {
1080 let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1081 return;
1082 };
1083
1084 let debug_tx = self.debug_tx.clone();
1085
1086 zeta_project
1087 .refresh_context_task
1088 .get_or_insert(cx.spawn(async move |this, cx| {
1089 if let Some(debug_tx) = &debug_tx {
1090 debug_tx
1091 .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted(
1092 ZetaContextRetrievalDebugInfo {
1093 project: project.clone(),
1094 timestamp: Instant::now(),
1095 },
1096 ))
1097 .ok();
1098 }
1099
1100 let related_excerpts = this
1101 .update(cx, |this, cx| {
1102 let Some(zeta_project) = this.projects.get(&project.entity_id()) else {
1103 return Task::ready(anyhow::Ok(HashMap::default()));
1104 };
1105
1106 let ContextMode::Llm(options) = &this.options().context else {
1107 return Task::ready(anyhow::Ok(HashMap::default()));
1108 };
1109
1110 find_related_excerpts(
1111 buffer.clone(),
1112 cursor_position,
1113 &project,
1114 zeta_project.events.iter(),
1115 options,
1116 debug_tx,
1117 cx,
1118 )
1119 })
1120 .ok()?
1121 .await
1122 .log_err()
1123 .unwrap_or_default();
1124 this.update(cx, |this, _cx| {
1125 let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
1126 return;
1127 };
1128 zeta_project.context = Some(related_excerpts);
1129 zeta_project.refresh_context_task.take();
1130 if let Some(debug_tx) = &this.debug_tx {
1131 debug_tx
1132 .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished(
1133 ZetaContextRetrievalDebugInfo {
1134 project,
1135 timestamp: Instant::now(),
1136 },
1137 ))
1138 .ok();
1139 }
1140 })
1141 .ok()
1142 }));
1143 }
1144
1145 fn gather_nearby_diagnostics(
1146 cursor_offset: usize,
1147 diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
1148 snapshot: &BufferSnapshot,
1149 max_diagnostics_bytes: usize,
1150 ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
1151 // TODO: Could make this more efficient
1152 let mut diagnostic_groups = Vec::new();
1153 for (language_server_id, diagnostics) in diagnostic_sets {
1154 let mut groups = Vec::new();
1155 diagnostics.groups(*language_server_id, &mut groups, &snapshot);
1156 diagnostic_groups.extend(
1157 groups
1158 .into_iter()
1159 .map(|(_, group)| group.resolve::<usize>(&snapshot)),
1160 );
1161 }
1162
1163 // sort by proximity to cursor
1164 diagnostic_groups.sort_by_key(|group| {
1165 let range = &group.entries[group.primary_ix].range;
1166 if range.start >= cursor_offset {
1167 range.start - cursor_offset
1168 } else if cursor_offset >= range.end {
1169 cursor_offset - range.end
1170 } else {
1171 (cursor_offset - range.start).min(range.end - cursor_offset)
1172 }
1173 });
1174
1175 let mut results = Vec::new();
1176 let mut diagnostic_groups_truncated = false;
1177 let mut diagnostics_byte_count = 0;
1178 for group in diagnostic_groups {
1179 let raw_value = serde_json::value::to_raw_value(&group).unwrap();
1180 diagnostics_byte_count += raw_value.get().len();
1181 if diagnostics_byte_count > max_diagnostics_bytes {
1182 diagnostic_groups_truncated = true;
1183 break;
1184 }
1185 results.push(predict_edits_v3::DiagnosticGroup(raw_value));
1186 }
1187
1188 (results, diagnostic_groups_truncated)
1189 }
1190
1191 // TODO: Dedupe with similar code in request_prediction?
1192 pub fn cloud_request_for_zeta_cli(
1193 &mut self,
1194 project: &Entity<Project>,
1195 buffer: &Entity<Buffer>,
1196 position: language::Anchor,
1197 cx: &mut Context<Self>,
1198 ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
1199 let project_state = self.projects.get(&project.entity_id());
1200
1201 let index_state = project_state.map(|state| {
1202 state
1203 .syntax_index
1204 .read_with(cx, |index, _cx| index.state().clone())
1205 });
1206 let options = self.options.clone();
1207 let snapshot = buffer.read(cx).snapshot();
1208 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
1209 return Task::ready(Err(anyhow!("No file path for excerpt")));
1210 };
1211 let worktree_snapshots = project
1212 .read(cx)
1213 .worktrees(cx)
1214 .map(|worktree| worktree.read(cx).snapshot())
1215 .collect::<Vec<_>>();
1216
1217 let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
1218 let mut path = f.worktree.read(cx).absolutize(&f.path);
1219 if path.pop() { Some(path) } else { None }
1220 });
1221
1222 cx.background_spawn(async move {
1223 let index_state = if let Some(index_state) = index_state {
1224 Some(index_state.lock_owned().await)
1225 } else {
1226 None
1227 };
1228
1229 let cursor_point = position.to_point(&snapshot);
1230
1231 let debug_info = true;
1232 EditPredictionContext::gather_context(
1233 cursor_point,
1234 &snapshot,
1235 parent_abs_path.as_deref(),
1236 match &options.context {
1237 ContextMode::Llm(_) => {
1238 // TODO
1239 panic!("Llm mode not supported in zeta cli yet");
1240 }
1241 ContextMode::Syntax(edit_prediction_context_options) => {
1242 edit_prediction_context_options
1243 }
1244 },
1245 index_state.as_deref(),
1246 )
1247 .context("Failed to select excerpt")
1248 .map(|context| {
1249 make_syntax_context_cloud_request(
1250 excerpt_path.into(),
1251 context,
1252 // TODO pass everything
1253 Vec::new(),
1254 false,
1255 Vec::new(),
1256 false,
1257 None,
1258 debug_info,
1259 &worktree_snapshots,
1260 index_state.as_deref(),
1261 Some(options.max_prompt_bytes),
1262 options.prompt_format,
1263 )
1264 })
1265 })
1266 }
1267
1268 pub fn wait_for_initial_indexing(
1269 &mut self,
1270 project: &Entity<Project>,
1271 cx: &mut App,
1272 ) -> Task<Result<()>> {
1273 let zeta_project = self.get_or_init_zeta_project(project, cx);
1274 zeta_project
1275 .syntax_index
1276 .read(cx)
1277 .wait_for_initial_file_indexing(cx)
1278 }
1279}
1280
1281#[derive(Error, Debug)]
1282#[error(
1283 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1284)]
1285pub struct ZedUpdateRequiredError {
1286 minimum_version: SemanticVersion,
1287}
1288
1289fn make_syntax_context_cloud_request(
1290 excerpt_path: Arc<Path>,
1291 context: EditPredictionContext,
1292 events: Vec<predict_edits_v3::Event>,
1293 can_collect_data: bool,
1294 diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
1295 diagnostic_groups_truncated: bool,
1296 git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
1297 debug_info: bool,
1298 worktrees: &Vec<worktree::Snapshot>,
1299 index_state: Option<&SyntaxIndexState>,
1300 prompt_max_bytes: Option<usize>,
1301 prompt_format: PromptFormat,
1302) -> predict_edits_v3::PredictEditsRequest {
1303 let mut signatures = Vec::new();
1304 let mut declaration_to_signature_index = HashMap::default();
1305 let mut referenced_declarations = Vec::new();
1306
1307 for snippet in context.declarations {
1308 let project_entry_id = snippet.declaration.project_entry_id();
1309 let Some(path) = worktrees.iter().find_map(|worktree| {
1310 worktree.entry_for_id(project_entry_id).map(|entry| {
1311 let mut full_path = RelPathBuf::new();
1312 full_path.push(worktree.root_name());
1313 full_path.push(&entry.path);
1314 full_path
1315 })
1316 }) else {
1317 continue;
1318 };
1319
1320 let parent_index = index_state.and_then(|index_state| {
1321 snippet.declaration.parent().and_then(|parent| {
1322 add_signature(
1323 parent,
1324 &mut declaration_to_signature_index,
1325 &mut signatures,
1326 index_state,
1327 )
1328 })
1329 });
1330
1331 let (text, text_is_truncated) = snippet.declaration.item_text();
1332 referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
1333 path: path.as_std_path().into(),
1334 text: text.into(),
1335 range: snippet.declaration.item_line_range(),
1336 text_is_truncated,
1337 signature_range: snippet.declaration.signature_range_in_item_text(),
1338 parent_index,
1339 signature_score: snippet.score(DeclarationStyle::Signature),
1340 declaration_score: snippet.score(DeclarationStyle::Declaration),
1341 score_components: snippet.components,
1342 });
1343 }
1344
1345 let excerpt_parent = index_state.and_then(|index_state| {
1346 context
1347 .excerpt
1348 .parent_declarations
1349 .last()
1350 .and_then(|(parent, _)| {
1351 add_signature(
1352 *parent,
1353 &mut declaration_to_signature_index,
1354 &mut signatures,
1355 index_state,
1356 )
1357 })
1358 });
1359
1360 predict_edits_v3::PredictEditsRequest {
1361 excerpt_path,
1362 excerpt: context.excerpt_text.body,
1363 excerpt_line_range: context.excerpt.line_range,
1364 excerpt_range: context.excerpt.range,
1365 cursor_point: predict_edits_v3::Point {
1366 line: predict_edits_v3::Line(context.cursor_point.row),
1367 column: context.cursor_point.column,
1368 },
1369 referenced_declarations,
1370 included_files: vec![],
1371 signatures,
1372 excerpt_parent,
1373 events,
1374 can_collect_data,
1375 diagnostic_groups,
1376 diagnostic_groups_truncated,
1377 git_info,
1378 debug_info,
1379 prompt_max_bytes,
1380 prompt_format,
1381 }
1382}
1383
1384fn add_signature(
1385 declaration_id: DeclarationId,
1386 declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
1387 signatures: &mut Vec<Signature>,
1388 index: &SyntaxIndexState,
1389) -> Option<usize> {
1390 if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
1391 return Some(*signature_index);
1392 }
1393 let Some(parent_declaration) = index.declaration(declaration_id) else {
1394 log::error!("bug: missing parent declaration");
1395 return None;
1396 };
1397 let parent_index = parent_declaration.parent().and_then(|parent| {
1398 add_signature(parent, declaration_to_signature_index, signatures, index)
1399 });
1400 let (text, text_is_truncated) = parent_declaration.signature_text();
1401 let signature_index = signatures.len();
1402 signatures.push(Signature {
1403 text: text.into(),
1404 text_is_truncated,
1405 parent_index,
1406 range: parent_declaration.signature_line_range(),
1407 });
1408 declaration_to_signature_index.insert(declaration_id, signature_index);
1409 Some(signature_index)
1410}
1411
1412#[cfg(test)]
1413mod tests {
1414 use std::{
1415 path::{Path, PathBuf},
1416 sync::Arc,
1417 };
1418
1419 use client::UserStore;
1420 use clock::FakeSystemClock;
1421 use cloud_llm_client::predict_edits_v3::{self, Point};
1422 use edit_prediction_context::Line;
1423 use futures::{
1424 AsyncReadExt, StreamExt,
1425 channel::{mpsc, oneshot},
1426 };
1427 use gpui::{
1428 Entity, TestAppContext,
1429 http_client::{FakeHttpClient, Response},
1430 prelude::*,
1431 };
1432 use indoc::indoc;
1433 use language::{LanguageServerId, OffsetRangeExt as _};
1434 use pretty_assertions::{assert_eq, assert_matches};
1435 use project::{FakeFs, Project};
1436 use serde_json::json;
1437 use settings::SettingsStore;
1438 use util::path;
1439 use uuid::Uuid;
1440
1441 use crate::{BufferEditPrediction, Zeta};
1442
1443 #[gpui::test]
1444 async fn test_current_state(cx: &mut TestAppContext) {
1445 let (zeta, mut req_rx) = init_test(cx);
1446 let fs = FakeFs::new(cx.executor());
1447 fs.insert_tree(
1448 "/root",
1449 json!({
1450 "1.txt": "Hello!\nHow\nBye",
1451 "2.txt": "Hola!\nComo\nAdios"
1452 }),
1453 )
1454 .await;
1455 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1456
1457 zeta.update(cx, |zeta, cx| {
1458 zeta.register_project(&project, cx);
1459 });
1460
1461 let buffer1 = project
1462 .update(cx, |project, cx| {
1463 let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
1464 project.open_buffer(path, cx)
1465 })
1466 .await
1467 .unwrap();
1468 let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
1469 let position = snapshot1.anchor_before(language::Point::new(1, 3));
1470
1471 // Prediction for current file
1472
1473 let prediction_task = zeta.update(cx, |zeta, cx| {
1474 zeta.refresh_prediction(&project, &buffer1, position, cx)
1475 });
1476 let (_request, respond_tx) = req_rx.next().await.unwrap();
1477 respond_tx
1478 .send(predict_edits_v3::PredictEditsResponse {
1479 request_id: Uuid::new_v4(),
1480 edits: vec![predict_edits_v3::Edit {
1481 path: Path::new(path!("root/1.txt")).into(),
1482 range: Line(0)..Line(snapshot1.max_point().row + 1),
1483 content: "Hello!\nHow are you?\nBye".into(),
1484 }],
1485 debug_info: None,
1486 })
1487 .unwrap();
1488 prediction_task.await.unwrap();
1489
1490 zeta.read_with(cx, |zeta, cx| {
1491 let prediction = zeta
1492 .current_prediction_for_buffer(&buffer1, &project, cx)
1493 .unwrap();
1494 assert_matches!(prediction, BufferEditPrediction::Local { .. });
1495 });
1496
1497 // Prediction for another file
1498 let prediction_task = zeta.update(cx, |zeta, cx| {
1499 zeta.refresh_prediction(&project, &buffer1, position, cx)
1500 });
1501 let (_request, respond_tx) = req_rx.next().await.unwrap();
1502 respond_tx
1503 .send(predict_edits_v3::PredictEditsResponse {
1504 request_id: Uuid::new_v4(),
1505 edits: vec![predict_edits_v3::Edit {
1506 path: Path::new(path!("root/2.txt")).into(),
1507 range: Line(0)..Line(snapshot1.max_point().row + 1),
1508 content: "Hola!\nComo estas?\nAdios".into(),
1509 }],
1510 debug_info: None,
1511 })
1512 .unwrap();
1513 prediction_task.await.unwrap();
1514 zeta.read_with(cx, |zeta, cx| {
1515 let prediction = zeta
1516 .current_prediction_for_buffer(&buffer1, &project, cx)
1517 .unwrap();
1518 assert_matches!(
1519 prediction,
1520 BufferEditPrediction::Jump { prediction } if prediction.path.as_ref() == Path::new(path!("root/2.txt"))
1521 );
1522 });
1523
1524 let buffer2 = project
1525 .update(cx, |project, cx| {
1526 let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
1527 project.open_buffer(path, cx)
1528 })
1529 .await
1530 .unwrap();
1531
1532 zeta.read_with(cx, |zeta, cx| {
1533 let prediction = zeta
1534 .current_prediction_for_buffer(&buffer2, &project, cx)
1535 .unwrap();
1536 assert_matches!(prediction, BufferEditPrediction::Local { .. });
1537 });
1538 }
1539
1540 #[gpui::test]
1541 async fn test_simple_request(cx: &mut TestAppContext) {
1542 let (zeta, mut req_rx) = init_test(cx);
1543 let fs = FakeFs::new(cx.executor());
1544 fs.insert_tree(
1545 "/root",
1546 json!({
1547 "foo.md": "Hello!\nHow\nBye"
1548 }),
1549 )
1550 .await;
1551 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1552
1553 let buffer = project
1554 .update(cx, |project, cx| {
1555 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1556 project.open_buffer(path, cx)
1557 })
1558 .await
1559 .unwrap();
1560 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1561 let position = snapshot.anchor_before(language::Point::new(1, 3));
1562
1563 let prediction_task = zeta.update(cx, |zeta, cx| {
1564 zeta.request_prediction(&project, &buffer, position, cx)
1565 });
1566
1567 let (request, respond_tx) = req_rx.next().await.unwrap();
1568 assert_eq!(
1569 request.excerpt_path.as_ref(),
1570 Path::new(path!("root/foo.md"))
1571 );
1572 assert_eq!(
1573 request.cursor_point,
1574 Point {
1575 line: Line(1),
1576 column: 3
1577 }
1578 );
1579
1580 respond_tx
1581 .send(predict_edits_v3::PredictEditsResponse {
1582 request_id: Uuid::new_v4(),
1583 edits: vec![predict_edits_v3::Edit {
1584 path: Path::new(path!("root/foo.md")).into(),
1585 range: Line(0)..Line(snapshot.max_point().row + 1),
1586 content: "Hello!\nHow are you?\nBye".into(),
1587 }],
1588 debug_info: None,
1589 })
1590 .unwrap();
1591
1592 let prediction = prediction_task.await.unwrap().unwrap();
1593
1594 assert_eq!(prediction.edits.len(), 1);
1595 assert_eq!(
1596 prediction.edits[0].0.to_point(&snapshot).start,
1597 language::Point::new(1, 3)
1598 );
1599 assert_eq!(prediction.edits[0].1, " are you?");
1600 }
1601
1602 #[gpui::test]
1603 async fn test_request_events(cx: &mut TestAppContext) {
1604 let (zeta, mut req_rx) = init_test(cx);
1605 let fs = FakeFs::new(cx.executor());
1606 fs.insert_tree(
1607 "/root",
1608 json!({
1609 "foo.md": "Hello!\n\nBye"
1610 }),
1611 )
1612 .await;
1613 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1614
1615 let buffer = project
1616 .update(cx, |project, cx| {
1617 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1618 project.open_buffer(path, cx)
1619 })
1620 .await
1621 .unwrap();
1622
1623 zeta.update(cx, |zeta, cx| {
1624 zeta.register_buffer(&buffer, &project, cx);
1625 });
1626
1627 buffer.update(cx, |buffer, cx| {
1628 buffer.edit(vec![(7..7, "How")], None, cx);
1629 });
1630
1631 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1632 let position = snapshot.anchor_before(language::Point::new(1, 3));
1633
1634 let prediction_task = zeta.update(cx, |zeta, cx| {
1635 zeta.request_prediction(&project, &buffer, position, cx)
1636 });
1637
1638 let (request, respond_tx) = req_rx.next().await.unwrap();
1639
1640 assert_eq!(request.events.len(), 1);
1641 assert_eq!(
1642 request.events[0],
1643 predict_edits_v3::Event::BufferChange {
1644 path: Some(PathBuf::from(path!("root/foo.md"))),
1645 old_path: None,
1646 diff: indoc! {"
1647 @@ -1,3 +1,3 @@
1648 Hello!
1649 -
1650 +How
1651 Bye
1652 "}
1653 .to_string(),
1654 predicted: false
1655 }
1656 );
1657
1658 respond_tx
1659 .send(predict_edits_v3::PredictEditsResponse {
1660 request_id: Uuid::new_v4(),
1661 edits: vec![predict_edits_v3::Edit {
1662 path: Path::new(path!("root/foo.md")).into(),
1663 range: Line(0)..Line(snapshot.max_point().row + 1),
1664 content: "Hello!\nHow are you?\nBye".into(),
1665 }],
1666 debug_info: None,
1667 })
1668 .unwrap();
1669
1670 let prediction = prediction_task.await.unwrap().unwrap();
1671
1672 assert_eq!(prediction.edits.len(), 1);
1673 assert_eq!(
1674 prediction.edits[0].0.to_point(&snapshot).start,
1675 language::Point::new(1, 3)
1676 );
1677 assert_eq!(prediction.edits[0].1, " are you?");
1678 }
1679
1680 #[gpui::test]
1681 async fn test_request_diagnostics(cx: &mut TestAppContext) {
1682 let (zeta, mut req_rx) = init_test(cx);
1683 let fs = FakeFs::new(cx.executor());
1684 fs.insert_tree(
1685 "/root",
1686 json!({
1687 "foo.md": "Hello!\nBye"
1688 }),
1689 )
1690 .await;
1691 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1692
1693 let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1694 let diagnostic = lsp::Diagnostic {
1695 range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1696 severity: Some(lsp::DiagnosticSeverity::ERROR),
1697 message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1698 ..Default::default()
1699 };
1700
1701 project.update(cx, |project, cx| {
1702 project.lsp_store().update(cx, |lsp_store, cx| {
1703 // Create some diagnostics
1704 lsp_store
1705 .update_diagnostics(
1706 LanguageServerId(0),
1707 lsp::PublishDiagnosticsParams {
1708 uri: path_to_buffer_uri.clone(),
1709 diagnostics: vec![diagnostic],
1710 version: None,
1711 },
1712 None,
1713 language::DiagnosticSourceKind::Pushed,
1714 &[],
1715 cx,
1716 )
1717 .unwrap();
1718 });
1719 });
1720
1721 let buffer = project
1722 .update(cx, |project, cx| {
1723 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1724 project.open_buffer(path, cx)
1725 })
1726 .await
1727 .unwrap();
1728
1729 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1730 let position = snapshot.anchor_before(language::Point::new(0, 0));
1731
1732 let _prediction_task = zeta.update(cx, |zeta, cx| {
1733 zeta.request_prediction(&project, &buffer, position, cx)
1734 });
1735
1736 let (request, _respond_tx) = req_rx.next().await.unwrap();
1737
1738 assert_eq!(request.diagnostic_groups.len(), 1);
1739 let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1740 .unwrap();
1741 // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1742 assert_eq!(
1743 value,
1744 json!({
1745 "entries": [{
1746 "range": {
1747 "start": 8,
1748 "end": 10
1749 },
1750 "diagnostic": {
1751 "source": null,
1752 "code": null,
1753 "code_description": null,
1754 "severity": 1,
1755 "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1756 "markdown": null,
1757 "group_id": 0,
1758 "is_primary": true,
1759 "is_disk_based": false,
1760 "is_unnecessary": false,
1761 "source_kind": "Pushed",
1762 "data": null,
1763 "underline": true
1764 }
1765 }],
1766 "primary_ix": 0
1767 })
1768 );
1769 }
1770
1771 fn init_test(
1772 cx: &mut TestAppContext,
1773 ) -> (
1774 Entity<Zeta>,
1775 mpsc::UnboundedReceiver<(
1776 predict_edits_v3::PredictEditsRequest,
1777 oneshot::Sender<predict_edits_v3::PredictEditsResponse>,
1778 )>,
1779 ) {
1780 cx.update(move |cx| {
1781 let settings_store = SettingsStore::test(cx);
1782 cx.set_global(settings_store);
1783 language::init(cx);
1784 Project::init_settings(cx);
1785
1786 let (req_tx, req_rx) = mpsc::unbounded();
1787
1788 let http_client = FakeHttpClient::create({
1789 move |req| {
1790 let uri = req.uri().path().to_string();
1791 let mut body = req.into_body();
1792 let req_tx = req_tx.clone();
1793 async move {
1794 let resp = match uri.as_str() {
1795 "/client/llm_tokens" => serde_json::to_string(&json!({
1796 "token": "test"
1797 }))
1798 .unwrap(),
1799 "/predict_edits/v3" => {
1800 let mut buf = Vec::new();
1801 body.read_to_end(&mut buf).await.ok();
1802 let req = serde_json::from_slice(&buf).unwrap();
1803
1804 let (res_tx, res_rx) = oneshot::channel();
1805 req_tx.unbounded_send((req, res_tx)).unwrap();
1806 serde_json::to_string(&res_rx.await?).unwrap()
1807 }
1808 _ => {
1809 panic!("Unexpected path: {}", uri)
1810 }
1811 };
1812
1813 Ok(Response::builder().body(resp.into()).unwrap())
1814 }
1815 }
1816 });
1817
1818 let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1819 client.cloud_client().set_credentials(1, "test".into());
1820
1821 language_model::init(client.clone(), cx);
1822
1823 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1824 let zeta = Zeta::global(&client, &user_store, cx);
1825
1826 (zeta, req_rx)
1827 })
1828 }
1829}