1use anyhow::{Context as _, Result, anyhow};
2use chrono::TimeDelta;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
5use cloud_llm_client::{
6 EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, ZED_VERSION_HEADER_NAME,
7};
8use cloud_zeta2_prompt::{DEFAULT_MAX_PROMPT_BYTES, PlannedPrompt};
9use edit_prediction_context::{
10 DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
11 EditPredictionExcerptOptions, EditPredictionScoreOptions, SyntaxIndex, SyntaxIndexState,
12};
13use futures::AsyncReadExt as _;
14use futures::channel::{mpsc, oneshot};
15use gpui::http_client::Method;
16use gpui::{
17 App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
18 http_client, prelude::*,
19};
20use language::BufferSnapshot;
21use language::{Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
22use language_model::{LlmApiToken, RefreshLlmTokenListener};
23use project::Project;
24use release_channel::AppVersion;
25use std::collections::{HashMap, VecDeque, hash_map};
26use std::path::Path;
27use std::str::FromStr as _;
28use std::sync::Arc;
29use std::time::{Duration, Instant};
30use thiserror::Error;
31use util::rel_path::RelPathBuf;
32use util::some_or_debug_panic;
33use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
34
35mod prediction;
36mod provider;
37
38use crate::prediction::EditPrediction;
39pub use provider::ZetaEditPredictionProvider;
40
41const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
42
43/// Maximum number of events to track.
44const MAX_EVENT_COUNT: usize = 16;
45
46pub const DEFAULT_CONTEXT_OPTIONS: EditPredictionContextOptions = EditPredictionContextOptions {
47 use_imports: true,
48 use_references: false,
49 excerpt: EditPredictionExcerptOptions {
50 max_bytes: 512,
51 min_bytes: 128,
52 target_before_cursor_over_total_bytes: 0.5,
53 },
54 score: EditPredictionScoreOptions {
55 omit_excerpt_overlaps: true,
56 },
57};
58
59pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
60 context: DEFAULT_CONTEXT_OPTIONS,
61 max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
62 max_diagnostic_bytes: 2048,
63 prompt_format: PromptFormat::DEFAULT,
64 file_indexing_parallelism: 1,
65};
66
67#[derive(Clone)]
68struct ZetaGlobal(Entity<Zeta>);
69
70impl Global for ZetaGlobal {}
71
72pub struct Zeta {
73 client: Arc<Client>,
74 user_store: Entity<UserStore>,
75 llm_token: LlmApiToken,
76 _llm_token_subscription: Subscription,
77 projects: HashMap<EntityId, ZetaProject>,
78 options: ZetaOptions,
79 update_required: bool,
80 debug_tx: Option<mpsc::UnboundedSender<PredictionDebugInfo>>,
81}
82
83#[derive(Debug, Clone, PartialEq)]
84pub struct ZetaOptions {
85 pub context: EditPredictionContextOptions,
86 pub max_prompt_bytes: usize,
87 pub max_diagnostic_bytes: usize,
88 pub prompt_format: predict_edits_v3::PromptFormat,
89 pub file_indexing_parallelism: usize,
90}
91
92pub struct PredictionDebugInfo {
93 pub context: EditPredictionContext,
94 pub retrieval_time: TimeDelta,
95 pub buffer: WeakEntity<Buffer>,
96 pub position: language::Anchor,
97 pub local_prompt: Result<String, String>,
98 pub response_rx: oneshot::Receiver<Result<RequestDebugInfo, String>>,
99}
100
101pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
102
103struct ZetaProject {
104 syntax_index: Entity<SyntaxIndex>,
105 events: VecDeque<Event>,
106 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
107 current_prediction: Option<CurrentEditPrediction>,
108}
109
110#[derive(Debug, Clone)]
111struct CurrentEditPrediction {
112 pub requested_by_buffer_id: EntityId,
113 pub prediction: EditPrediction,
114}
115
116impl CurrentEditPrediction {
117 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
118 let Some(new_edits) = self
119 .prediction
120 .interpolate(&self.prediction.buffer.read(cx))
121 else {
122 return false;
123 };
124
125 if self.prediction.buffer != old_prediction.prediction.buffer {
126 return true;
127 }
128
129 let Some(old_edits) = old_prediction
130 .prediction
131 .interpolate(&old_prediction.prediction.buffer.read(cx))
132 else {
133 return true;
134 };
135
136 // This reduces the occurrence of UI thrash from replacing edits
137 //
138 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
139 if self.requested_by_buffer_id == self.prediction.buffer.entity_id()
140 && self.requested_by_buffer_id == old_prediction.prediction.buffer.entity_id()
141 && old_edits.len() == 1
142 && new_edits.len() == 1
143 {
144 let (old_range, old_text) = &old_edits[0];
145 let (new_range, new_text) = &new_edits[0];
146 new_range == old_range && new_text.starts_with(old_text)
147 } else {
148 true
149 }
150 }
151}
152
153/// A prediction from the perspective of a buffer.
154#[derive(Debug)]
155enum BufferEditPrediction<'a> {
156 Local { prediction: &'a EditPrediction },
157 Jump { prediction: &'a EditPrediction },
158}
159
160struct RegisteredBuffer {
161 snapshot: BufferSnapshot,
162 _subscriptions: [gpui::Subscription; 2],
163}
164
165#[derive(Clone)]
166pub enum Event {
167 BufferChange {
168 old_snapshot: BufferSnapshot,
169 new_snapshot: BufferSnapshot,
170 timestamp: Instant,
171 },
172}
173
174impl Zeta {
175 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
176 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
177 }
178
179 pub fn global(
180 client: &Arc<Client>,
181 user_store: &Entity<UserStore>,
182 cx: &mut App,
183 ) -> Entity<Self> {
184 cx.try_global::<ZetaGlobal>()
185 .map(|global| global.0.clone())
186 .unwrap_or_else(|| {
187 let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
188 cx.set_global(ZetaGlobal(zeta.clone()));
189 zeta
190 })
191 }
192
193 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
194 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
195
196 Self {
197 projects: HashMap::new(),
198 client,
199 user_store,
200 options: DEFAULT_OPTIONS,
201 llm_token: LlmApiToken::default(),
202 _llm_token_subscription: cx.subscribe(
203 &refresh_llm_token_listener,
204 |this, _listener, _event, cx| {
205 let client = this.client.clone();
206 let llm_token = this.llm_token.clone();
207 cx.spawn(async move |_this, _cx| {
208 llm_token.refresh(&client).await?;
209 anyhow::Ok(())
210 })
211 .detach_and_log_err(cx);
212 },
213 ),
214 update_required: false,
215 debug_tx: None,
216 }
217 }
218
219 pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<PredictionDebugInfo> {
220 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
221 self.debug_tx = Some(debug_watch_tx);
222 debug_watch_rx
223 }
224
225 pub fn options(&self) -> &ZetaOptions {
226 &self.options
227 }
228
229 pub fn set_options(&mut self, options: ZetaOptions) {
230 self.options = options;
231 }
232
233 pub fn clear_history(&mut self) {
234 for zeta_project in self.projects.values_mut() {
235 zeta_project.events.clear();
236 }
237 }
238
239 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
240 self.user_store.read(cx).edit_prediction_usage()
241 }
242
243 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
244 self.get_or_init_zeta_project(project, cx);
245 }
246
247 pub fn register_buffer(
248 &mut self,
249 buffer: &Entity<Buffer>,
250 project: &Entity<Project>,
251 cx: &mut Context<Self>,
252 ) {
253 let zeta_project = self.get_or_init_zeta_project(project, cx);
254 Self::register_buffer_impl(zeta_project, buffer, project, cx);
255 }
256
257 fn get_or_init_zeta_project(
258 &mut self,
259 project: &Entity<Project>,
260 cx: &mut App,
261 ) -> &mut ZetaProject {
262 self.projects
263 .entry(project.entity_id())
264 .or_insert_with(|| ZetaProject {
265 syntax_index: cx.new(|cx| {
266 SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
267 }),
268 events: VecDeque::new(),
269 registered_buffers: HashMap::new(),
270 current_prediction: None,
271 })
272 }
273
274 fn register_buffer_impl<'a>(
275 zeta_project: &'a mut ZetaProject,
276 buffer: &Entity<Buffer>,
277 project: &Entity<Project>,
278 cx: &mut Context<Self>,
279 ) -> &'a mut RegisteredBuffer {
280 let buffer_id = buffer.entity_id();
281 match zeta_project.registered_buffers.entry(buffer_id) {
282 hash_map::Entry::Occupied(entry) => entry.into_mut(),
283 hash_map::Entry::Vacant(entry) => {
284 let snapshot = buffer.read(cx).snapshot();
285 let project_entity_id = project.entity_id();
286 entry.insert(RegisteredBuffer {
287 snapshot,
288 _subscriptions: [
289 cx.subscribe(buffer, {
290 let project = project.downgrade();
291 move |this, buffer, event, cx| {
292 if let language::BufferEvent::Edited = event
293 && let Some(project) = project.upgrade()
294 {
295 this.report_changes_for_buffer(&buffer, &project, cx);
296 }
297 }
298 }),
299 cx.observe_release(buffer, move |this, _buffer, _cx| {
300 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
301 else {
302 return;
303 };
304 zeta_project.registered_buffers.remove(&buffer_id);
305 }),
306 ],
307 })
308 }
309 }
310 }
311
312 fn report_changes_for_buffer(
313 &mut self,
314 buffer: &Entity<Buffer>,
315 project: &Entity<Project>,
316 cx: &mut Context<Self>,
317 ) -> BufferSnapshot {
318 let zeta_project = self.get_or_init_zeta_project(project, cx);
319 let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
320
321 let new_snapshot = buffer.read(cx).snapshot();
322 if new_snapshot.version != registered_buffer.snapshot.version {
323 let old_snapshot =
324 std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
325 Self::push_event(
326 zeta_project,
327 Event::BufferChange {
328 old_snapshot,
329 new_snapshot: new_snapshot.clone(),
330 timestamp: Instant::now(),
331 },
332 );
333 }
334
335 new_snapshot
336 }
337
338 fn push_event(zeta_project: &mut ZetaProject, event: Event) {
339 let events = &mut zeta_project.events;
340
341 if let Some(Event::BufferChange {
342 new_snapshot: last_new_snapshot,
343 timestamp: last_timestamp,
344 ..
345 }) = events.back_mut()
346 {
347 // Coalesce edits for the same buffer when they happen one after the other.
348 let Event::BufferChange {
349 old_snapshot,
350 new_snapshot,
351 timestamp,
352 } = &event;
353
354 if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
355 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
356 && old_snapshot.version == last_new_snapshot.version
357 {
358 *last_new_snapshot = new_snapshot.clone();
359 *last_timestamp = *timestamp;
360 return;
361 }
362 }
363
364 if events.len() >= MAX_EVENT_COUNT {
365 // These are halved instead of popping to improve prompt caching.
366 events.drain(..MAX_EVENT_COUNT / 2);
367 }
368
369 events.push_back(event);
370 }
371
372 fn current_prediction_for_buffer(
373 &self,
374 buffer: &Entity<Buffer>,
375 project: &Entity<Project>,
376 cx: &App,
377 ) -> Option<BufferEditPrediction<'_>> {
378 let project_state = self.projects.get(&project.entity_id())?;
379
380 let CurrentEditPrediction {
381 requested_by_buffer_id,
382 prediction,
383 } = project_state.current_prediction.as_ref()?;
384
385 if prediction.targets_buffer(buffer.read(cx), cx) {
386 Some(BufferEditPrediction::Local { prediction })
387 } else if *requested_by_buffer_id == buffer.entity_id() {
388 Some(BufferEditPrediction::Jump { prediction })
389 } else {
390 None
391 }
392 }
393
394 fn accept_current_prediction(&mut self, project: &Entity<Project>) {
395 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
396 project_state.current_prediction.take();
397 };
398 // TODO report accepted
399 }
400
401 fn discard_current_prediction(&mut self, project: &Entity<Project>) {
402 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
403 project_state.current_prediction.take();
404 };
405 }
406
407 pub fn refresh_prediction(
408 &mut self,
409 project: &Entity<Project>,
410 buffer: &Entity<Buffer>,
411 position: language::Anchor,
412 cx: &mut Context<Self>,
413 ) -> Task<Result<()>> {
414 let request_task = self.request_prediction(project, buffer, position, cx);
415 let buffer = buffer.clone();
416 let project = project.clone();
417
418 cx.spawn(async move |this, cx| {
419 if let Some(prediction) = request_task.await? {
420 this.update(cx, |this, cx| {
421 let project_state = this
422 .projects
423 .get_mut(&project.entity_id())
424 .context("Project not found")?;
425
426 let new_prediction = CurrentEditPrediction {
427 requested_by_buffer_id: buffer.entity_id(),
428 prediction: prediction,
429 };
430
431 if project_state
432 .current_prediction
433 .as_ref()
434 .is_none_or(|old_prediction| {
435 new_prediction.should_replace_prediction(&old_prediction, cx)
436 })
437 {
438 project_state.current_prediction = Some(new_prediction);
439 }
440 anyhow::Ok(())
441 })??;
442 }
443 Ok(())
444 })
445 }
446
447 fn request_prediction(
448 &mut self,
449 project: &Entity<Project>,
450 buffer: &Entity<Buffer>,
451 position: language::Anchor,
452 cx: &mut Context<Self>,
453 ) -> Task<Result<Option<EditPrediction>>> {
454 let project_state = self.projects.get(&project.entity_id());
455
456 let index_state = project_state.map(|state| {
457 state
458 .syntax_index
459 .read_with(cx, |index, _cx| index.state().clone())
460 });
461 let options = self.options.clone();
462 let snapshot = buffer.read(cx).snapshot();
463 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx).into()) else {
464 return Task::ready(Err(anyhow!("No file path for excerpt")));
465 };
466 let client = self.client.clone();
467 let llm_token = self.llm_token.clone();
468 let app_version = AppVersion::global(cx);
469 let worktree_snapshots = project
470 .read(cx)
471 .worktrees(cx)
472 .map(|worktree| worktree.read(cx).snapshot())
473 .collect::<Vec<_>>();
474 let debug_tx = self.debug_tx.clone();
475
476 let events = project_state
477 .map(|state| {
478 state
479 .events
480 .iter()
481 .filter_map(|event| match event {
482 Event::BufferChange {
483 old_snapshot,
484 new_snapshot,
485 ..
486 } => {
487 let path = new_snapshot.file().map(|f| f.full_path(cx));
488
489 let old_path = old_snapshot.file().and_then(|f| {
490 let old_path = f.full_path(cx);
491 if Some(&old_path) != path.as_ref() {
492 Some(old_path)
493 } else {
494 None
495 }
496 });
497
498 // TODO [zeta2] move to bg?
499 let diff =
500 language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
501
502 if path == old_path && diff.is_empty() {
503 None
504 } else {
505 Some(predict_edits_v3::Event::BufferChange {
506 old_path,
507 path,
508 diff,
509 //todo: Actually detect if this edit was predicted or not
510 predicted: false,
511 })
512 }
513 }
514 })
515 .collect::<Vec<_>>()
516 })
517 .unwrap_or_default();
518
519 let diagnostics = snapshot.diagnostic_sets().clone();
520
521 let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
522 let mut path = f.worktree.read(cx).absolutize(&f.path);
523 if path.pop() { Some(path) } else { None }
524 });
525
526 let request_task = cx.background_spawn({
527 let snapshot = snapshot.clone();
528 let buffer = buffer.clone();
529 async move {
530 let index_state = if let Some(index_state) = index_state {
531 Some(index_state.lock_owned().await)
532 } else {
533 None
534 };
535
536 let cursor_offset = position.to_offset(&snapshot);
537 let cursor_point = cursor_offset.to_point(&snapshot);
538
539 let before_retrieval = chrono::Utc::now();
540
541 let Some(context) = EditPredictionContext::gather_context(
542 cursor_point,
543 &snapshot,
544 parent_abs_path.as_deref(),
545 &options.context,
546 index_state.as_deref(),
547 ) else {
548 return Ok(None);
549 };
550
551 let retrieval_time = chrono::Utc::now() - before_retrieval;
552
553 let (diagnostic_groups, diagnostic_groups_truncated) =
554 Self::gather_nearby_diagnostics(
555 cursor_offset,
556 &diagnostics,
557 &snapshot,
558 options.max_diagnostic_bytes,
559 );
560
561 let debug_context = debug_tx.map(|tx| (tx, context.clone()));
562
563 let request = make_cloud_request(
564 excerpt_path,
565 context,
566 events,
567 // TODO data collection
568 false,
569 diagnostic_groups,
570 diagnostic_groups_truncated,
571 None,
572 debug_context.is_some(),
573 &worktree_snapshots,
574 index_state.as_deref(),
575 Some(options.max_prompt_bytes),
576 options.prompt_format,
577 );
578
579 let debug_response_tx = if let Some((debug_tx, context)) = debug_context {
580 let (response_tx, response_rx) = oneshot::channel();
581
582 let local_prompt = PlannedPrompt::populate(&request)
583 .and_then(|p| p.to_prompt_string().map(|p| p.0))
584 .map_err(|err| err.to_string());
585
586 debug_tx
587 .unbounded_send(PredictionDebugInfo {
588 context,
589 retrieval_time,
590 buffer: buffer.downgrade(),
591 local_prompt,
592 position,
593 response_rx,
594 })
595 .ok();
596 Some(response_tx)
597 } else {
598 None
599 };
600
601 if cfg!(debug_assertions) && std::env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
602 if let Some(debug_response_tx) = debug_response_tx {
603 debug_response_tx
604 .send(Err("Request skipped".to_string()))
605 .ok();
606 }
607 anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
608 }
609
610 let response = Self::perform_request(client, llm_token, app_version, request).await;
611
612 if let Some(debug_response_tx) = debug_response_tx {
613 debug_response_tx
614 .send(response.as_ref().map_err(|err| err.to_string()).and_then(
615 |response| match some_or_debug_panic(response.0.debug_info.clone()) {
616 Some(debug_info) => Ok(debug_info),
617 None => Err("Missing debug info".to_string()),
618 },
619 ))
620 .ok();
621 }
622
623 anyhow::Ok(Some(response?))
624 }
625 });
626
627 let buffer = buffer.clone();
628
629 cx.spawn({
630 let project = project.clone();
631 async move |this, cx| {
632 match request_task.await {
633 Ok(Some((response, usage))) => {
634 if let Some(usage) = usage {
635 this.update(cx, |this, cx| {
636 this.user_store.update(cx, |user_store, cx| {
637 user_store.update_edit_prediction_usage(usage, cx);
638 });
639 })
640 .ok();
641 }
642
643 let prediction = EditPrediction::from_response(
644 response, &snapshot, &buffer, &project, cx,
645 )
646 .await;
647
648 // TODO telemetry: duration, etc
649 Ok(prediction)
650 }
651 Ok(None) => Ok(None),
652 Err(err) => {
653 if err.is::<ZedUpdateRequiredError>() {
654 cx.update(|cx| {
655 this.update(cx, |this, _cx| {
656 this.update_required = true;
657 })
658 .ok();
659
660 let error_message: SharedString = err.to_string().into();
661 show_app_notification(
662 NotificationId::unique::<ZedUpdateRequiredError>(),
663 cx,
664 move |cx| {
665 cx.new(|cx| {
666 ErrorMessagePrompt::new(error_message.clone(), cx)
667 .with_link_button(
668 "Update Zed",
669 "https://zed.dev/releases",
670 )
671 })
672 },
673 );
674 })
675 .ok();
676 }
677
678 Err(err)
679 }
680 }
681 }
682 })
683 }
684
685 async fn perform_request(
686 client: Arc<Client>,
687 llm_token: LlmApiToken,
688 app_version: SemanticVersion,
689 request: predict_edits_v3::PredictEditsRequest,
690 ) -> Result<(
691 predict_edits_v3::PredictEditsResponse,
692 Option<EditPredictionUsage>,
693 )> {
694 let http_client = client.http_client();
695 let mut token = llm_token.acquire(&client).await?;
696 let mut did_retry = false;
697
698 loop {
699 let request_builder = http_client::Request::builder().method(Method::POST);
700 let request_builder =
701 if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
702 request_builder.uri(predict_edits_url)
703 } else {
704 request_builder.uri(
705 http_client
706 .build_zed_llm_url("/predict_edits/v3", &[])?
707 .as_ref(),
708 )
709 };
710 let request = request_builder
711 .header("Content-Type", "application/json")
712 .header("Authorization", format!("Bearer {}", token))
713 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
714 .body(serde_json::to_string(&request)?.into())?;
715
716 let mut response = http_client.send(request).await?;
717
718 if let Some(minimum_required_version) = response
719 .headers()
720 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
721 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
722 {
723 anyhow::ensure!(
724 app_version >= minimum_required_version,
725 ZedUpdateRequiredError {
726 minimum_version: minimum_required_version
727 }
728 );
729 }
730
731 if response.status().is_success() {
732 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
733
734 let mut body = Vec::new();
735 response.body_mut().read_to_end(&mut body).await?;
736 return Ok((serde_json::from_slice(&body)?, usage));
737 } else if !did_retry
738 && response
739 .headers()
740 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
741 .is_some()
742 {
743 did_retry = true;
744 token = llm_token.refresh(&client).await?;
745 } else {
746 let mut body = String::new();
747 response.body_mut().read_to_string(&mut body).await?;
748 anyhow::bail!(
749 "error predicting edits.\nStatus: {:?}\nBody: {}",
750 response.status(),
751 body
752 );
753 }
754 }
755 }
756
757 fn gather_nearby_diagnostics(
758 cursor_offset: usize,
759 diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
760 snapshot: &BufferSnapshot,
761 max_diagnostics_bytes: usize,
762 ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
763 // TODO: Could make this more efficient
764 let mut diagnostic_groups = Vec::new();
765 for (language_server_id, diagnostics) in diagnostic_sets {
766 let mut groups = Vec::new();
767 diagnostics.groups(*language_server_id, &mut groups, &snapshot);
768 diagnostic_groups.extend(
769 groups
770 .into_iter()
771 .map(|(_, group)| group.resolve::<usize>(&snapshot)),
772 );
773 }
774
775 // sort by proximity to cursor
776 diagnostic_groups.sort_by_key(|group| {
777 let range = &group.entries[group.primary_ix].range;
778 if range.start >= cursor_offset {
779 range.start - cursor_offset
780 } else if cursor_offset >= range.end {
781 cursor_offset - range.end
782 } else {
783 (cursor_offset - range.start).min(range.end - cursor_offset)
784 }
785 });
786
787 let mut results = Vec::new();
788 let mut diagnostic_groups_truncated = false;
789 let mut diagnostics_byte_count = 0;
790 for group in diagnostic_groups {
791 let raw_value = serde_json::value::to_raw_value(&group).unwrap();
792 diagnostics_byte_count += raw_value.get().len();
793 if diagnostics_byte_count > max_diagnostics_bytes {
794 diagnostic_groups_truncated = true;
795 break;
796 }
797 results.push(predict_edits_v3::DiagnosticGroup(raw_value));
798 }
799
800 (results, diagnostic_groups_truncated)
801 }
802
803 // TODO: Dedupe with similar code in request_prediction?
804 pub fn cloud_request_for_zeta_cli(
805 &mut self,
806 project: &Entity<Project>,
807 buffer: &Entity<Buffer>,
808 position: language::Anchor,
809 cx: &mut Context<Self>,
810 ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
811 let project_state = self.projects.get(&project.entity_id());
812
813 let index_state = project_state.map(|state| {
814 state
815 .syntax_index
816 .read_with(cx, |index, _cx| index.state().clone())
817 });
818 let options = self.options.clone();
819 let snapshot = buffer.read(cx).snapshot();
820 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
821 return Task::ready(Err(anyhow!("No file path for excerpt")));
822 };
823 let worktree_snapshots = project
824 .read(cx)
825 .worktrees(cx)
826 .map(|worktree| worktree.read(cx).snapshot())
827 .collect::<Vec<_>>();
828
829 let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
830 let mut path = f.worktree.read(cx).absolutize(&f.path);
831 if path.pop() { Some(path) } else { None }
832 });
833
834 cx.background_spawn(async move {
835 let index_state = if let Some(index_state) = index_state {
836 Some(index_state.lock_owned().await)
837 } else {
838 None
839 };
840
841 let cursor_point = position.to_point(&snapshot);
842
843 let debug_info = true;
844 EditPredictionContext::gather_context(
845 cursor_point,
846 &snapshot,
847 parent_abs_path.as_deref(),
848 &options.context,
849 index_state.as_deref(),
850 )
851 .context("Failed to select excerpt")
852 .map(|context| {
853 make_cloud_request(
854 excerpt_path.into(),
855 context,
856 // TODO pass everything
857 Vec::new(),
858 false,
859 Vec::new(),
860 false,
861 None,
862 debug_info,
863 &worktree_snapshots,
864 index_state.as_deref(),
865 Some(options.max_prompt_bytes),
866 options.prompt_format,
867 )
868 })
869 })
870 }
871
872 pub fn wait_for_initial_indexing(
873 &mut self,
874 project: &Entity<Project>,
875 cx: &mut App,
876 ) -> Task<Result<()>> {
877 let zeta_project = self.get_or_init_zeta_project(project, cx);
878 zeta_project
879 .syntax_index
880 .read(cx)
881 .wait_for_initial_file_indexing(cx)
882 }
883}
884
885#[derive(Error, Debug)]
886#[error(
887 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
888)]
889pub struct ZedUpdateRequiredError {
890 minimum_version: SemanticVersion,
891}
892
893fn make_cloud_request(
894 excerpt_path: Arc<Path>,
895 context: EditPredictionContext,
896 events: Vec<predict_edits_v3::Event>,
897 can_collect_data: bool,
898 diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
899 diagnostic_groups_truncated: bool,
900 git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
901 debug_info: bool,
902 worktrees: &Vec<worktree::Snapshot>,
903 index_state: Option<&SyntaxIndexState>,
904 prompt_max_bytes: Option<usize>,
905 prompt_format: PromptFormat,
906) -> predict_edits_v3::PredictEditsRequest {
907 let mut signatures = Vec::new();
908 let mut declaration_to_signature_index = HashMap::default();
909 let mut referenced_declarations = Vec::new();
910
911 for snippet in context.declarations {
912 let project_entry_id = snippet.declaration.project_entry_id();
913 let Some(path) = worktrees.iter().find_map(|worktree| {
914 worktree.entry_for_id(project_entry_id).map(|entry| {
915 let mut full_path = RelPathBuf::new();
916 full_path.push(worktree.root_name());
917 full_path.push(&entry.path);
918 full_path
919 })
920 }) else {
921 continue;
922 };
923
924 let parent_index = index_state.and_then(|index_state| {
925 snippet.declaration.parent().and_then(|parent| {
926 add_signature(
927 parent,
928 &mut declaration_to_signature_index,
929 &mut signatures,
930 index_state,
931 )
932 })
933 });
934
935 let (text, text_is_truncated) = snippet.declaration.item_text();
936 referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
937 path: path.as_std_path().into(),
938 text: text.into(),
939 range: snippet.declaration.item_line_range(),
940 text_is_truncated,
941 signature_range: snippet.declaration.signature_range_in_item_text(),
942 parent_index,
943 signature_score: snippet.score(DeclarationStyle::Signature),
944 declaration_score: snippet.score(DeclarationStyle::Declaration),
945 score_components: snippet.components,
946 });
947 }
948
949 let excerpt_parent = index_state.and_then(|index_state| {
950 context
951 .excerpt
952 .parent_declarations
953 .last()
954 .and_then(|(parent, _)| {
955 add_signature(
956 *parent,
957 &mut declaration_to_signature_index,
958 &mut signatures,
959 index_state,
960 )
961 })
962 });
963
964 predict_edits_v3::PredictEditsRequest {
965 excerpt_path,
966 excerpt: context.excerpt_text.body,
967 excerpt_line_range: context.excerpt.line_range,
968 excerpt_range: context.excerpt.range,
969 cursor_point: predict_edits_v3::Point {
970 line: predict_edits_v3::Line(context.cursor_point.row),
971 column: context.cursor_point.column,
972 },
973 referenced_declarations,
974 signatures,
975 excerpt_parent,
976 events,
977 can_collect_data,
978 diagnostic_groups,
979 diagnostic_groups_truncated,
980 git_info,
981 debug_info,
982 prompt_max_bytes,
983 prompt_format,
984 }
985}
986
987fn add_signature(
988 declaration_id: DeclarationId,
989 declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
990 signatures: &mut Vec<Signature>,
991 index: &SyntaxIndexState,
992) -> Option<usize> {
993 if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
994 return Some(*signature_index);
995 }
996 let Some(parent_declaration) = index.declaration(declaration_id) else {
997 log::error!("bug: missing parent declaration");
998 return None;
999 };
1000 let parent_index = parent_declaration.parent().and_then(|parent| {
1001 add_signature(parent, declaration_to_signature_index, signatures, index)
1002 });
1003 let (text, text_is_truncated) = parent_declaration.signature_text();
1004 let signature_index = signatures.len();
1005 signatures.push(Signature {
1006 text: text.into(),
1007 text_is_truncated,
1008 parent_index,
1009 range: parent_declaration.signature_line_range(),
1010 });
1011 declaration_to_signature_index.insert(declaration_id, signature_index);
1012 Some(signature_index)
1013}
1014
1015#[cfg(test)]
1016mod tests {
1017 use std::{
1018 path::{Path, PathBuf},
1019 sync::Arc,
1020 };
1021
1022 use client::UserStore;
1023 use clock::FakeSystemClock;
1024 use cloud_llm_client::predict_edits_v3::{self, Point};
1025 use edit_prediction_context::Line;
1026 use futures::{
1027 AsyncReadExt, StreamExt,
1028 channel::{mpsc, oneshot},
1029 };
1030 use gpui::{
1031 Entity, TestAppContext,
1032 http_client::{FakeHttpClient, Response},
1033 prelude::*,
1034 };
1035 use indoc::indoc;
1036 use language::{LanguageServerId, OffsetRangeExt as _};
1037 use pretty_assertions::{assert_eq, assert_matches};
1038 use project::{FakeFs, Project};
1039 use serde_json::json;
1040 use settings::SettingsStore;
1041 use util::path;
1042 use uuid::Uuid;
1043
1044 use crate::{BufferEditPrediction, Zeta};
1045
1046 #[gpui::test]
1047 async fn test_current_state(cx: &mut TestAppContext) {
1048 let (zeta, mut req_rx) = init_test(cx);
1049 let fs = FakeFs::new(cx.executor());
1050 fs.insert_tree(
1051 "/root",
1052 json!({
1053 "1.txt": "Hello!\nHow\nBye",
1054 "2.txt": "Hola!\nComo\nAdios"
1055 }),
1056 )
1057 .await;
1058 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1059
1060 zeta.update(cx, |zeta, cx| {
1061 zeta.register_project(&project, cx);
1062 });
1063
1064 let buffer1 = project
1065 .update(cx, |project, cx| {
1066 let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
1067 project.open_buffer(path, cx)
1068 })
1069 .await
1070 .unwrap();
1071 let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
1072 let position = snapshot1.anchor_before(language::Point::new(1, 3));
1073
1074 // Prediction for current file
1075
1076 let prediction_task = zeta.update(cx, |zeta, cx| {
1077 zeta.refresh_prediction(&project, &buffer1, position, cx)
1078 });
1079 let (_request, respond_tx) = req_rx.next().await.unwrap();
1080 respond_tx
1081 .send(predict_edits_v3::PredictEditsResponse {
1082 request_id: Uuid::new_v4(),
1083 edits: vec![predict_edits_v3::Edit {
1084 path: Path::new(path!("root/1.txt")).into(),
1085 range: Line(0)..Line(snapshot1.max_point().row + 1),
1086 content: "Hello!\nHow are you?\nBye".into(),
1087 }],
1088 debug_info: None,
1089 })
1090 .unwrap();
1091 prediction_task.await.unwrap();
1092
1093 zeta.read_with(cx, |zeta, cx| {
1094 let prediction = zeta
1095 .current_prediction_for_buffer(&buffer1, &project, cx)
1096 .unwrap();
1097 assert_matches!(prediction, BufferEditPrediction::Local { .. });
1098 });
1099
1100 // Prediction for another file
1101 let prediction_task = zeta.update(cx, |zeta, cx| {
1102 zeta.refresh_prediction(&project, &buffer1, position, cx)
1103 });
1104 let (_request, respond_tx) = req_rx.next().await.unwrap();
1105 respond_tx
1106 .send(predict_edits_v3::PredictEditsResponse {
1107 request_id: Uuid::new_v4(),
1108 edits: vec![predict_edits_v3::Edit {
1109 path: Path::new(path!("root/2.txt")).into(),
1110 range: Line(0)..Line(snapshot1.max_point().row + 1),
1111 content: "Hola!\nComo estas?\nAdios".into(),
1112 }],
1113 debug_info: None,
1114 })
1115 .unwrap();
1116 prediction_task.await.unwrap();
1117 zeta.read_with(cx, |zeta, cx| {
1118 let prediction = zeta
1119 .current_prediction_for_buffer(&buffer1, &project, cx)
1120 .unwrap();
1121 assert_matches!(
1122 prediction,
1123 BufferEditPrediction::Jump { prediction } if prediction.path.as_ref() == Path::new(path!("root/2.txt"))
1124 );
1125 });
1126
1127 let buffer2 = project
1128 .update(cx, |project, cx| {
1129 let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
1130 project.open_buffer(path, cx)
1131 })
1132 .await
1133 .unwrap();
1134
1135 zeta.read_with(cx, |zeta, cx| {
1136 let prediction = zeta
1137 .current_prediction_for_buffer(&buffer2, &project, cx)
1138 .unwrap();
1139 assert_matches!(prediction, BufferEditPrediction::Local { .. });
1140 });
1141 }
1142
1143 #[gpui::test]
1144 async fn test_simple_request(cx: &mut TestAppContext) {
1145 let (zeta, mut req_rx) = init_test(cx);
1146 let fs = FakeFs::new(cx.executor());
1147 fs.insert_tree(
1148 "/root",
1149 json!({
1150 "foo.md": "Hello!\nHow\nBye"
1151 }),
1152 )
1153 .await;
1154 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1155
1156 let buffer = project
1157 .update(cx, |project, cx| {
1158 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1159 project.open_buffer(path, cx)
1160 })
1161 .await
1162 .unwrap();
1163 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1164 let position = snapshot.anchor_before(language::Point::new(1, 3));
1165
1166 let prediction_task = zeta.update(cx, |zeta, cx| {
1167 zeta.request_prediction(&project, &buffer, position, cx)
1168 });
1169
1170 let (request, respond_tx) = req_rx.next().await.unwrap();
1171 assert_eq!(
1172 request.excerpt_path.as_ref(),
1173 Path::new(path!("root/foo.md"))
1174 );
1175 assert_eq!(
1176 request.cursor_point,
1177 Point {
1178 line: Line(1),
1179 column: 3
1180 }
1181 );
1182
1183 respond_tx
1184 .send(predict_edits_v3::PredictEditsResponse {
1185 request_id: Uuid::new_v4(),
1186 edits: vec![predict_edits_v3::Edit {
1187 path: Path::new(path!("root/foo.md")).into(),
1188 range: Line(0)..Line(snapshot.max_point().row + 1),
1189 content: "Hello!\nHow are you?\nBye".into(),
1190 }],
1191 debug_info: None,
1192 })
1193 .unwrap();
1194
1195 let prediction = prediction_task.await.unwrap().unwrap();
1196
1197 assert_eq!(prediction.edits.len(), 1);
1198 assert_eq!(
1199 prediction.edits[0].0.to_point(&snapshot).start,
1200 language::Point::new(1, 3)
1201 );
1202 assert_eq!(prediction.edits[0].1, " are you?");
1203 }
1204
1205 #[gpui::test]
1206 async fn test_request_events(cx: &mut TestAppContext) {
1207 let (zeta, mut req_rx) = init_test(cx);
1208 let fs = FakeFs::new(cx.executor());
1209 fs.insert_tree(
1210 "/root",
1211 json!({
1212 "foo.md": "Hello!\n\nBye"
1213 }),
1214 )
1215 .await;
1216 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1217
1218 let buffer = project
1219 .update(cx, |project, cx| {
1220 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1221 project.open_buffer(path, cx)
1222 })
1223 .await
1224 .unwrap();
1225
1226 zeta.update(cx, |zeta, cx| {
1227 zeta.register_buffer(&buffer, &project, cx);
1228 });
1229
1230 buffer.update(cx, |buffer, cx| {
1231 buffer.edit(vec![(7..7, "How")], None, cx);
1232 });
1233
1234 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1235 let position = snapshot.anchor_before(language::Point::new(1, 3));
1236
1237 let prediction_task = zeta.update(cx, |zeta, cx| {
1238 zeta.request_prediction(&project, &buffer, position, cx)
1239 });
1240
1241 let (request, respond_tx) = req_rx.next().await.unwrap();
1242
1243 assert_eq!(request.events.len(), 1);
1244 assert_eq!(
1245 request.events[0],
1246 predict_edits_v3::Event::BufferChange {
1247 path: Some(PathBuf::from(path!("root/foo.md"))),
1248 old_path: None,
1249 diff: indoc! {"
1250 @@ -1,3 +1,3 @@
1251 Hello!
1252 -
1253 +How
1254 Bye
1255 "}
1256 .to_string(),
1257 predicted: false
1258 }
1259 );
1260
1261 respond_tx
1262 .send(predict_edits_v3::PredictEditsResponse {
1263 request_id: Uuid::new_v4(),
1264 edits: vec![predict_edits_v3::Edit {
1265 path: Path::new(path!("root/foo.md")).into(),
1266 range: Line(0)..Line(snapshot.max_point().row + 1),
1267 content: "Hello!\nHow are you?\nBye".into(),
1268 }],
1269 debug_info: None,
1270 })
1271 .unwrap();
1272
1273 let prediction = prediction_task.await.unwrap().unwrap();
1274
1275 assert_eq!(prediction.edits.len(), 1);
1276 assert_eq!(
1277 prediction.edits[0].0.to_point(&snapshot).start,
1278 language::Point::new(1, 3)
1279 );
1280 assert_eq!(prediction.edits[0].1, " are you?");
1281 }
1282
1283 #[gpui::test]
1284 async fn test_request_diagnostics(cx: &mut TestAppContext) {
1285 let (zeta, mut req_rx) = init_test(cx);
1286 let fs = FakeFs::new(cx.executor());
1287 fs.insert_tree(
1288 "/root",
1289 json!({
1290 "foo.md": "Hello!\nBye"
1291 }),
1292 )
1293 .await;
1294 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1295
1296 let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1297 let diagnostic = lsp::Diagnostic {
1298 range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1299 severity: Some(lsp::DiagnosticSeverity::ERROR),
1300 message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1301 ..Default::default()
1302 };
1303
1304 project.update(cx, |project, cx| {
1305 project.lsp_store().update(cx, |lsp_store, cx| {
1306 // Create some diagnostics
1307 lsp_store
1308 .update_diagnostics(
1309 LanguageServerId(0),
1310 lsp::PublishDiagnosticsParams {
1311 uri: path_to_buffer_uri.clone(),
1312 diagnostics: vec![diagnostic],
1313 version: None,
1314 },
1315 None,
1316 language::DiagnosticSourceKind::Pushed,
1317 &[],
1318 cx,
1319 )
1320 .unwrap();
1321 });
1322 });
1323
1324 let buffer = project
1325 .update(cx, |project, cx| {
1326 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1327 project.open_buffer(path, cx)
1328 })
1329 .await
1330 .unwrap();
1331
1332 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1333 let position = snapshot.anchor_before(language::Point::new(0, 0));
1334
1335 let _prediction_task = zeta.update(cx, |zeta, cx| {
1336 zeta.request_prediction(&project, &buffer, position, cx)
1337 });
1338
1339 let (request, _respond_tx) = req_rx.next().await.unwrap();
1340
1341 assert_eq!(request.diagnostic_groups.len(), 1);
1342 let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1343 .unwrap();
1344 // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1345 assert_eq!(
1346 value,
1347 json!({
1348 "entries": [{
1349 "range": {
1350 "start": 8,
1351 "end": 10
1352 },
1353 "diagnostic": {
1354 "source": null,
1355 "code": null,
1356 "code_description": null,
1357 "severity": 1,
1358 "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1359 "markdown": null,
1360 "group_id": 0,
1361 "is_primary": true,
1362 "is_disk_based": false,
1363 "is_unnecessary": false,
1364 "source_kind": "Pushed",
1365 "data": null,
1366 "underline": true
1367 }
1368 }],
1369 "primary_ix": 0
1370 })
1371 );
1372 }
1373
1374 fn init_test(
1375 cx: &mut TestAppContext,
1376 ) -> (
1377 Entity<Zeta>,
1378 mpsc::UnboundedReceiver<(
1379 predict_edits_v3::PredictEditsRequest,
1380 oneshot::Sender<predict_edits_v3::PredictEditsResponse>,
1381 )>,
1382 ) {
1383 cx.update(move |cx| {
1384 let settings_store = SettingsStore::test(cx);
1385 cx.set_global(settings_store);
1386 language::init(cx);
1387 Project::init_settings(cx);
1388
1389 let (req_tx, req_rx) = mpsc::unbounded();
1390
1391 let http_client = FakeHttpClient::create({
1392 move |req| {
1393 let uri = req.uri().path().to_string();
1394 let mut body = req.into_body();
1395 let req_tx = req_tx.clone();
1396 async move {
1397 let resp = match uri.as_str() {
1398 "/client/llm_tokens" => serde_json::to_string(&json!({
1399 "token": "test"
1400 }))
1401 .unwrap(),
1402 "/predict_edits/v3" => {
1403 let mut buf = Vec::new();
1404 body.read_to_end(&mut buf).await.ok();
1405 let req = serde_json::from_slice(&buf).unwrap();
1406
1407 let (res_tx, res_rx) = oneshot::channel();
1408 req_tx.unbounded_send((req, res_tx)).unwrap();
1409 serde_json::to_string(&res_rx.await?).unwrap()
1410 }
1411 _ => {
1412 panic!("Unexpected path: {}", uri)
1413 }
1414 };
1415
1416 Ok(Response::builder().body(resp.into()).unwrap())
1417 }
1418 }
1419 });
1420
1421 let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1422 client.cloud_client().set_credentials(1, "test".into());
1423
1424 language_model::init(client.clone(), cx);
1425
1426 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1427 let zeta = Zeta::global(&client, &user_store, cx);
1428
1429 (zeta, req_rx)
1430 })
1431 }
1432}