1use anyhow::{Context as _, Result, anyhow};
2use chrono::TimeDelta;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
5use cloud_llm_client::{
6 EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, ZED_VERSION_HEADER_NAME,
7};
8use cloud_zeta2_prompt::DEFAULT_MAX_PROMPT_BYTES;
9use edit_prediction_context::{
10 DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
11 EditPredictionExcerptOptions, EditPredictionScoreOptions, SyntaxIndex, SyntaxIndexState,
12};
13use futures::AsyncReadExt as _;
14use futures::channel::{mpsc, oneshot};
15use gpui::http_client::Method;
16use gpui::{
17 App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
18 http_client, prelude::*,
19};
20use language::{Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
21use language::{BufferSnapshot, TextBufferSnapshot};
22use language_model::{LlmApiToken, RefreshLlmTokenListener};
23use project::Project;
24use release_channel::AppVersion;
25use std::collections::{HashMap, VecDeque, hash_map};
26use std::path::Path;
27use std::str::FromStr as _;
28use std::sync::Arc;
29use std::time::{Duration, Instant};
30use thiserror::Error;
31use util::rel_path::RelPathBuf;
32use util::some_or_debug_panic;
33use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
34
35mod prediction;
36mod provider;
37
38use crate::prediction::EditPrediction;
39pub use provider::ZetaEditPredictionProvider;
40
41const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
42
43/// Maximum number of events to track.
44const MAX_EVENT_COUNT: usize = 16;
45
46pub const DEFAULT_CONTEXT_OPTIONS: EditPredictionContextOptions = EditPredictionContextOptions {
47 use_imports: true,
48 excerpt: EditPredictionExcerptOptions {
49 max_bytes: 512,
50 min_bytes: 128,
51 target_before_cursor_over_total_bytes: 0.5,
52 },
53 score: EditPredictionScoreOptions {
54 omit_excerpt_overlaps: true,
55 },
56};
57
58pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
59 context: DEFAULT_CONTEXT_OPTIONS,
60 max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
61 max_diagnostic_bytes: 2048,
62 prompt_format: PromptFormat::DEFAULT,
63 file_indexing_parallelism: 1,
64};
65
66#[derive(Clone)]
67struct ZetaGlobal(Entity<Zeta>);
68
69impl Global for ZetaGlobal {}
70
71pub struct Zeta {
72 client: Arc<Client>,
73 user_store: Entity<UserStore>,
74 llm_token: LlmApiToken,
75 _llm_token_subscription: Subscription,
76 projects: HashMap<EntityId, ZetaProject>,
77 options: ZetaOptions,
78 update_required: bool,
79 debug_tx: Option<mpsc::UnboundedSender<PredictionDebugInfo>>,
80}
81
82#[derive(Debug, Clone, PartialEq)]
83pub struct ZetaOptions {
84 pub context: EditPredictionContextOptions,
85 pub max_prompt_bytes: usize,
86 pub max_diagnostic_bytes: usize,
87 pub prompt_format: predict_edits_v3::PromptFormat,
88 pub file_indexing_parallelism: usize,
89}
90
91pub struct PredictionDebugInfo {
92 pub context: EditPredictionContext,
93 pub retrieval_time: TimeDelta,
94 pub buffer: WeakEntity<Buffer>,
95 pub position: language::Anchor,
96 pub response_rx: oneshot::Receiver<Result<RequestDebugInfo, String>>,
97}
98
99pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
100
101struct ZetaProject {
102 syntax_index: Entity<SyntaxIndex>,
103 events: VecDeque<Event>,
104 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
105 current_prediction: Option<CurrentEditPrediction>,
106}
107
108#[derive(Clone)]
109struct CurrentEditPrediction {
110 pub requested_by_buffer_id: EntityId,
111 pub prediction: EditPrediction,
112}
113
114impl CurrentEditPrediction {
115 fn should_replace_prediction(
116 &self,
117 old_prediction: &Self,
118 snapshot: &TextBufferSnapshot,
119 ) -> bool {
120 if self.requested_by_buffer_id != old_prediction.requested_by_buffer_id {
121 return true;
122 }
123
124 let Some(old_edits) = old_prediction.prediction.interpolate(snapshot) else {
125 return true;
126 };
127
128 let Some(new_edits) = self.prediction.interpolate(snapshot) else {
129 return false;
130 };
131 if old_edits.len() == 1 && new_edits.len() == 1 {
132 let (old_range, old_text) = &old_edits[0];
133 let (new_range, new_text) = &new_edits[0];
134 new_range == old_range && new_text.starts_with(old_text)
135 } else {
136 true
137 }
138 }
139}
140
141/// A prediction from the perspective of a buffer.
142#[derive(Debug)]
143enum BufferEditPrediction<'a> {
144 Local { prediction: &'a EditPrediction },
145 Jump { prediction: &'a EditPrediction },
146}
147
148struct RegisteredBuffer {
149 snapshot: BufferSnapshot,
150 _subscriptions: [gpui::Subscription; 2],
151}
152
153#[derive(Clone)]
154pub enum Event {
155 BufferChange {
156 old_snapshot: BufferSnapshot,
157 new_snapshot: BufferSnapshot,
158 timestamp: Instant,
159 },
160}
161
162impl Zeta {
163 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
164 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
165 }
166
167 pub fn global(
168 client: &Arc<Client>,
169 user_store: &Entity<UserStore>,
170 cx: &mut App,
171 ) -> Entity<Self> {
172 cx.try_global::<ZetaGlobal>()
173 .map(|global| global.0.clone())
174 .unwrap_or_else(|| {
175 let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
176 cx.set_global(ZetaGlobal(zeta.clone()));
177 zeta
178 })
179 }
180
181 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
182 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
183
184 Self {
185 projects: HashMap::new(),
186 client,
187 user_store,
188 options: DEFAULT_OPTIONS,
189 llm_token: LlmApiToken::default(),
190 _llm_token_subscription: cx.subscribe(
191 &refresh_llm_token_listener,
192 |this, _listener, _event, cx| {
193 let client = this.client.clone();
194 let llm_token = this.llm_token.clone();
195 cx.spawn(async move |_this, _cx| {
196 llm_token.refresh(&client).await?;
197 anyhow::Ok(())
198 })
199 .detach_and_log_err(cx);
200 },
201 ),
202 update_required: false,
203 debug_tx: None,
204 }
205 }
206
207 pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<PredictionDebugInfo> {
208 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
209 self.debug_tx = Some(debug_watch_tx);
210 debug_watch_rx
211 }
212
213 pub fn options(&self) -> &ZetaOptions {
214 &self.options
215 }
216
217 pub fn set_options(&mut self, options: ZetaOptions) {
218 self.options = options;
219 }
220
221 pub fn clear_history(&mut self) {
222 for zeta_project in self.projects.values_mut() {
223 zeta_project.events.clear();
224 }
225 }
226
227 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
228 self.user_store.read(cx).edit_prediction_usage()
229 }
230
231 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
232 self.get_or_init_zeta_project(project, cx);
233 }
234
235 pub fn register_buffer(
236 &mut self,
237 buffer: &Entity<Buffer>,
238 project: &Entity<Project>,
239 cx: &mut Context<Self>,
240 ) {
241 let zeta_project = self.get_or_init_zeta_project(project, cx);
242 Self::register_buffer_impl(zeta_project, buffer, project, cx);
243 }
244
245 fn get_or_init_zeta_project(
246 &mut self,
247 project: &Entity<Project>,
248 cx: &mut App,
249 ) -> &mut ZetaProject {
250 self.projects
251 .entry(project.entity_id())
252 .or_insert_with(|| ZetaProject {
253 syntax_index: cx.new(|cx| {
254 SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
255 }),
256 events: VecDeque::new(),
257 registered_buffers: HashMap::new(),
258 current_prediction: None,
259 })
260 }
261
262 fn register_buffer_impl<'a>(
263 zeta_project: &'a mut ZetaProject,
264 buffer: &Entity<Buffer>,
265 project: &Entity<Project>,
266 cx: &mut Context<Self>,
267 ) -> &'a mut RegisteredBuffer {
268 let buffer_id = buffer.entity_id();
269 match zeta_project.registered_buffers.entry(buffer_id) {
270 hash_map::Entry::Occupied(entry) => entry.into_mut(),
271 hash_map::Entry::Vacant(entry) => {
272 let snapshot = buffer.read(cx).snapshot();
273 let project_entity_id = project.entity_id();
274 entry.insert(RegisteredBuffer {
275 snapshot,
276 _subscriptions: [
277 cx.subscribe(buffer, {
278 let project = project.downgrade();
279 move |this, buffer, event, cx| {
280 if let language::BufferEvent::Edited = event
281 && let Some(project) = project.upgrade()
282 {
283 this.report_changes_for_buffer(&buffer, &project, cx);
284 }
285 }
286 }),
287 cx.observe_release(buffer, move |this, _buffer, _cx| {
288 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
289 else {
290 return;
291 };
292 zeta_project.registered_buffers.remove(&buffer_id);
293 }),
294 ],
295 })
296 }
297 }
298 }
299
300 fn report_changes_for_buffer(
301 &mut self,
302 buffer: &Entity<Buffer>,
303 project: &Entity<Project>,
304 cx: &mut Context<Self>,
305 ) -> BufferSnapshot {
306 let zeta_project = self.get_or_init_zeta_project(project, cx);
307 let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
308
309 let new_snapshot = buffer.read(cx).snapshot();
310 if new_snapshot.version != registered_buffer.snapshot.version {
311 let old_snapshot =
312 std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
313 Self::push_event(
314 zeta_project,
315 Event::BufferChange {
316 old_snapshot,
317 new_snapshot: new_snapshot.clone(),
318 timestamp: Instant::now(),
319 },
320 );
321 }
322
323 new_snapshot
324 }
325
326 fn push_event(zeta_project: &mut ZetaProject, event: Event) {
327 let events = &mut zeta_project.events;
328
329 if let Some(Event::BufferChange {
330 new_snapshot: last_new_snapshot,
331 timestamp: last_timestamp,
332 ..
333 }) = events.back_mut()
334 {
335 // Coalesce edits for the same buffer when they happen one after the other.
336 let Event::BufferChange {
337 old_snapshot,
338 new_snapshot,
339 timestamp,
340 } = &event;
341
342 if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
343 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
344 && old_snapshot.version == last_new_snapshot.version
345 {
346 *last_new_snapshot = new_snapshot.clone();
347 *last_timestamp = *timestamp;
348 return;
349 }
350 }
351
352 if events.len() >= MAX_EVENT_COUNT {
353 // These are halved instead of popping to improve prompt caching.
354 events.drain(..MAX_EVENT_COUNT / 2);
355 }
356
357 events.push_back(event);
358 }
359
360 fn current_prediction_for_buffer(
361 &self,
362 buffer: &Entity<Buffer>,
363 project: &Entity<Project>,
364 cx: &App,
365 ) -> Option<BufferEditPrediction<'_>> {
366 let project_state = self.projects.get(&project.entity_id())?;
367
368 let CurrentEditPrediction {
369 requested_by_buffer_id,
370 prediction,
371 } = project_state.current_prediction.as_ref()?;
372
373 if prediction.targets_buffer(buffer.read(cx), cx) {
374 Some(BufferEditPrediction::Local { prediction })
375 } else if *requested_by_buffer_id == buffer.entity_id() {
376 Some(BufferEditPrediction::Jump { prediction })
377 } else {
378 None
379 }
380 }
381
382 fn accept_current_prediction(&mut self, project: &Entity<Project>) {
383 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
384 project_state.current_prediction.take();
385 };
386 // TODO report accepted
387 }
388
389 fn discard_current_prediction(&mut self, project: &Entity<Project>) {
390 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
391 project_state.current_prediction.take();
392 };
393 }
394
395 pub fn refresh_prediction(
396 &mut self,
397 project: &Entity<Project>,
398 buffer: &Entity<Buffer>,
399 position: language::Anchor,
400 cx: &mut Context<Self>,
401 ) -> Task<Result<()>> {
402 let request_task = self.request_prediction(project, buffer, position, cx);
403 let buffer = buffer.clone();
404 let project = project.clone();
405
406 cx.spawn(async move |this, cx| {
407 if let Some(prediction) = request_task.await? {
408 this.update(cx, |this, cx| {
409 let project_state = this
410 .projects
411 .get_mut(&project.entity_id())
412 .context("Project not found")?;
413
414 let new_prediction = CurrentEditPrediction {
415 requested_by_buffer_id: buffer.entity_id(),
416 prediction: prediction,
417 };
418
419 if project_state
420 .current_prediction
421 .as_ref()
422 .is_none_or(|old_prediction| {
423 new_prediction
424 .should_replace_prediction(&old_prediction, buffer.read(cx))
425 })
426 {
427 project_state.current_prediction = Some(new_prediction);
428 }
429 anyhow::Ok(())
430 })??;
431 }
432 Ok(())
433 })
434 }
435
436 fn request_prediction(
437 &mut self,
438 project: &Entity<Project>,
439 buffer: &Entity<Buffer>,
440 position: language::Anchor,
441 cx: &mut Context<Self>,
442 ) -> Task<Result<Option<EditPrediction>>> {
443 let project_state = self.projects.get(&project.entity_id());
444
445 let index_state = project_state.map(|state| {
446 state
447 .syntax_index
448 .read_with(cx, |index, _cx| index.state().clone())
449 });
450 let options = self.options.clone();
451 let snapshot = buffer.read(cx).snapshot();
452 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx).into()) else {
453 return Task::ready(Err(anyhow!("No file path for excerpt")));
454 };
455 let client = self.client.clone();
456 let llm_token = self.llm_token.clone();
457 let app_version = AppVersion::global(cx);
458 let worktree_snapshots = project
459 .read(cx)
460 .worktrees(cx)
461 .map(|worktree| worktree.read(cx).snapshot())
462 .collect::<Vec<_>>();
463 let debug_tx = self.debug_tx.clone();
464
465 let events = project_state
466 .map(|state| {
467 state
468 .events
469 .iter()
470 .filter_map(|event| match event {
471 Event::BufferChange {
472 old_snapshot,
473 new_snapshot,
474 ..
475 } => {
476 let path = new_snapshot.file().map(|f| f.full_path(cx));
477
478 let old_path = old_snapshot.file().and_then(|f| {
479 let old_path = f.full_path(cx);
480 if Some(&old_path) != path.as_ref() {
481 Some(old_path)
482 } else {
483 None
484 }
485 });
486
487 // TODO [zeta2] move to bg?
488 let diff =
489 language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
490
491 if path == old_path && diff.is_empty() {
492 None
493 } else {
494 Some(predict_edits_v3::Event::BufferChange {
495 old_path,
496 path,
497 diff,
498 //todo: Actually detect if this edit was predicted or not
499 predicted: false,
500 })
501 }
502 }
503 })
504 .collect::<Vec<_>>()
505 })
506 .unwrap_or_default();
507
508 let diagnostics = snapshot.diagnostic_sets().clone();
509
510 let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
511 let mut path = f.worktree.read(cx).absolutize(&f.path);
512 if path.pop() { Some(path) } else { None }
513 });
514
515 let request_task = cx.background_spawn({
516 let snapshot = snapshot.clone();
517 let buffer = buffer.clone();
518 async move {
519 let index_state = if let Some(index_state) = index_state {
520 Some(index_state.lock_owned().await)
521 } else {
522 None
523 };
524
525 let cursor_offset = position.to_offset(&snapshot);
526 let cursor_point = cursor_offset.to_point(&snapshot);
527
528 let before_retrieval = chrono::Utc::now();
529
530 let Some(context) = EditPredictionContext::gather_context(
531 cursor_point,
532 &snapshot,
533 parent_abs_path.as_deref(),
534 &options.context,
535 index_state.as_deref(),
536 ) else {
537 return Ok(None);
538 };
539
540 let retrieval_time = chrono::Utc::now() - before_retrieval;
541
542 let debug_response_tx = if let Some(debug_tx) = debug_tx {
543 let (response_tx, response_rx) = oneshot::channel();
544 let context = context.clone();
545
546 debug_tx
547 .unbounded_send(PredictionDebugInfo {
548 context,
549 retrieval_time,
550 buffer: buffer.downgrade(),
551 position,
552 response_rx,
553 })
554 .ok();
555 Some(response_tx)
556 } else {
557 None
558 };
559
560 let (diagnostic_groups, diagnostic_groups_truncated) =
561 Self::gather_nearby_diagnostics(
562 cursor_offset,
563 &diagnostics,
564 &snapshot,
565 options.max_diagnostic_bytes,
566 );
567
568 let request = make_cloud_request(
569 excerpt_path,
570 context,
571 events,
572 // TODO data collection
573 false,
574 diagnostic_groups,
575 diagnostic_groups_truncated,
576 None,
577 debug_response_tx.is_some(),
578 &worktree_snapshots,
579 index_state.as_deref(),
580 Some(options.max_prompt_bytes),
581 options.prompt_format,
582 );
583
584 let response = Self::perform_request(client, llm_token, app_version, request).await;
585
586 if let Some(debug_response_tx) = debug_response_tx {
587 debug_response_tx
588 .send(response.as_ref().map_err(|err| err.to_string()).and_then(
589 |response| match some_or_debug_panic(response.0.debug_info.clone()) {
590 Some(debug_info) => Ok(debug_info),
591 None => Err("Missing debug info".to_string()),
592 },
593 ))
594 .ok();
595 }
596
597 anyhow::Ok(Some(response?))
598 }
599 });
600
601 let buffer = buffer.clone();
602
603 cx.spawn({
604 let project = project.clone();
605 async move |this, cx| {
606 match request_task.await {
607 Ok(Some((response, usage))) => {
608 if let Some(usage) = usage {
609 this.update(cx, |this, cx| {
610 this.user_store.update(cx, |user_store, cx| {
611 user_store.update_edit_prediction_usage(usage, cx);
612 });
613 })
614 .ok();
615 }
616
617 let prediction = EditPrediction::from_response(
618 response, &snapshot, &buffer, &project, cx,
619 )
620 .await;
621
622 // TODO telemetry: duration, etc
623 Ok(prediction)
624 }
625 Ok(None) => Ok(None),
626 Err(err) => {
627 if err.is::<ZedUpdateRequiredError>() {
628 cx.update(|cx| {
629 this.update(cx, |this, _cx| {
630 this.update_required = true;
631 })
632 .ok();
633
634 let error_message: SharedString = err.to_string().into();
635 show_app_notification(
636 NotificationId::unique::<ZedUpdateRequiredError>(),
637 cx,
638 move |cx| {
639 cx.new(|cx| {
640 ErrorMessagePrompt::new(error_message.clone(), cx)
641 .with_link_button(
642 "Update Zed",
643 "https://zed.dev/releases",
644 )
645 })
646 },
647 );
648 })
649 .ok();
650 }
651
652 Err(err)
653 }
654 }
655 }
656 })
657 }
658
659 async fn perform_request(
660 client: Arc<Client>,
661 llm_token: LlmApiToken,
662 app_version: SemanticVersion,
663 request: predict_edits_v3::PredictEditsRequest,
664 ) -> Result<(
665 predict_edits_v3::PredictEditsResponse,
666 Option<EditPredictionUsage>,
667 )> {
668 let http_client = client.http_client();
669 let mut token = llm_token.acquire(&client).await?;
670 let mut did_retry = false;
671
672 loop {
673 let request_builder = http_client::Request::builder().method(Method::POST);
674 let request_builder =
675 if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
676 request_builder.uri(predict_edits_url)
677 } else {
678 request_builder.uri(
679 http_client
680 .build_zed_llm_url("/predict_edits/v3", &[])?
681 .as_ref(),
682 )
683 };
684 let request = request_builder
685 .header("Content-Type", "application/json")
686 .header("Authorization", format!("Bearer {}", token))
687 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
688 .body(serde_json::to_string(&request)?.into())?;
689
690 let mut response = http_client.send(request).await?;
691
692 if let Some(minimum_required_version) = response
693 .headers()
694 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
695 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
696 {
697 anyhow::ensure!(
698 app_version >= minimum_required_version,
699 ZedUpdateRequiredError {
700 minimum_version: minimum_required_version
701 }
702 );
703 }
704
705 if response.status().is_success() {
706 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
707
708 let mut body = Vec::new();
709 response.body_mut().read_to_end(&mut body).await?;
710 return Ok((serde_json::from_slice(&body)?, usage));
711 } else if !did_retry
712 && response
713 .headers()
714 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
715 .is_some()
716 {
717 did_retry = true;
718 token = llm_token.refresh(&client).await?;
719 } else {
720 let mut body = String::new();
721 response.body_mut().read_to_string(&mut body).await?;
722 anyhow::bail!(
723 "error predicting edits.\nStatus: {:?}\nBody: {}",
724 response.status(),
725 body
726 );
727 }
728 }
729 }
730
731 fn gather_nearby_diagnostics(
732 cursor_offset: usize,
733 diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
734 snapshot: &BufferSnapshot,
735 max_diagnostics_bytes: usize,
736 ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
737 // TODO: Could make this more efficient
738 let mut diagnostic_groups = Vec::new();
739 for (language_server_id, diagnostics) in diagnostic_sets {
740 let mut groups = Vec::new();
741 diagnostics.groups(*language_server_id, &mut groups, &snapshot);
742 diagnostic_groups.extend(
743 groups
744 .into_iter()
745 .map(|(_, group)| group.resolve::<usize>(&snapshot)),
746 );
747 }
748
749 // sort by proximity to cursor
750 diagnostic_groups.sort_by_key(|group| {
751 let range = &group.entries[group.primary_ix].range;
752 if range.start >= cursor_offset {
753 range.start - cursor_offset
754 } else if cursor_offset >= range.end {
755 cursor_offset - range.end
756 } else {
757 (cursor_offset - range.start).min(range.end - cursor_offset)
758 }
759 });
760
761 let mut results = Vec::new();
762 let mut diagnostic_groups_truncated = false;
763 let mut diagnostics_byte_count = 0;
764 for group in diagnostic_groups {
765 let raw_value = serde_json::value::to_raw_value(&group).unwrap();
766 diagnostics_byte_count += raw_value.get().len();
767 if diagnostics_byte_count > max_diagnostics_bytes {
768 diagnostic_groups_truncated = true;
769 break;
770 }
771 results.push(predict_edits_v3::DiagnosticGroup(raw_value));
772 }
773
774 (results, diagnostic_groups_truncated)
775 }
776
777 // TODO: Dedupe with similar code in request_prediction?
778 pub fn cloud_request_for_zeta_cli(
779 &mut self,
780 project: &Entity<Project>,
781 buffer: &Entity<Buffer>,
782 position: language::Anchor,
783 cx: &mut Context<Self>,
784 ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
785 let project_state = self.projects.get(&project.entity_id());
786
787 let index_state = project_state.map(|state| {
788 state
789 .syntax_index
790 .read_with(cx, |index, _cx| index.state().clone())
791 });
792 let options = self.options.clone();
793 let snapshot = buffer.read(cx).snapshot();
794 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
795 return Task::ready(Err(anyhow!("No file path for excerpt")));
796 };
797 let worktree_snapshots = project
798 .read(cx)
799 .worktrees(cx)
800 .map(|worktree| worktree.read(cx).snapshot())
801 .collect::<Vec<_>>();
802
803 let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
804 let mut path = f.worktree.read(cx).absolutize(&f.path);
805 if path.pop() { Some(path) } else { None }
806 });
807
808 cx.background_spawn(async move {
809 let index_state = if let Some(index_state) = index_state {
810 Some(index_state.lock_owned().await)
811 } else {
812 None
813 };
814
815 let cursor_point = position.to_point(&snapshot);
816
817 let debug_info = true;
818 EditPredictionContext::gather_context(
819 cursor_point,
820 &snapshot,
821 parent_abs_path.as_deref(),
822 &options.context,
823 index_state.as_deref(),
824 )
825 .context("Failed to select excerpt")
826 .map(|context| {
827 make_cloud_request(
828 excerpt_path.into(),
829 context,
830 // TODO pass everything
831 Vec::new(),
832 false,
833 Vec::new(),
834 false,
835 None,
836 debug_info,
837 &worktree_snapshots,
838 index_state.as_deref(),
839 Some(options.max_prompt_bytes),
840 options.prompt_format,
841 )
842 })
843 })
844 }
845
846 pub fn wait_for_initial_indexing(
847 &mut self,
848 project: &Entity<Project>,
849 cx: &mut App,
850 ) -> Task<Result<()>> {
851 let zeta_project = self.get_or_init_zeta_project(project, cx);
852 zeta_project
853 .syntax_index
854 .read(cx)
855 .wait_for_initial_file_indexing(cx)
856 }
857}
858
859#[derive(Error, Debug)]
860#[error(
861 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
862)]
863pub struct ZedUpdateRequiredError {
864 minimum_version: SemanticVersion,
865}
866
867fn make_cloud_request(
868 excerpt_path: Arc<Path>,
869 context: EditPredictionContext,
870 events: Vec<predict_edits_v3::Event>,
871 can_collect_data: bool,
872 diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
873 diagnostic_groups_truncated: bool,
874 git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
875 debug_info: bool,
876 worktrees: &Vec<worktree::Snapshot>,
877 index_state: Option<&SyntaxIndexState>,
878 prompt_max_bytes: Option<usize>,
879 prompt_format: PromptFormat,
880) -> predict_edits_v3::PredictEditsRequest {
881 let mut signatures = Vec::new();
882 let mut declaration_to_signature_index = HashMap::default();
883 let mut referenced_declarations = Vec::new();
884
885 for snippet in context.declarations {
886 let project_entry_id = snippet.declaration.project_entry_id();
887 let Some(path) = worktrees.iter().find_map(|worktree| {
888 worktree.entry_for_id(project_entry_id).map(|entry| {
889 let mut full_path = RelPathBuf::new();
890 full_path.push(worktree.root_name());
891 full_path.push(&entry.path);
892 full_path
893 })
894 }) else {
895 continue;
896 };
897
898 let parent_index = index_state.and_then(|index_state| {
899 snippet.declaration.parent().and_then(|parent| {
900 add_signature(
901 parent,
902 &mut declaration_to_signature_index,
903 &mut signatures,
904 index_state,
905 )
906 })
907 });
908
909 let (text, text_is_truncated) = snippet.declaration.item_text();
910 referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
911 path: path.as_std_path().into(),
912 text: text.into(),
913 range: snippet.declaration.item_range(),
914 text_is_truncated,
915 signature_range: snippet.declaration.signature_range_in_item_text(),
916 parent_index,
917 signature_score: snippet.score(DeclarationStyle::Signature),
918 declaration_score: snippet.score(DeclarationStyle::Declaration),
919 score_components: snippet.components,
920 });
921 }
922
923 let excerpt_parent = index_state.and_then(|index_state| {
924 context
925 .excerpt
926 .parent_declarations
927 .last()
928 .and_then(|(parent, _)| {
929 add_signature(
930 *parent,
931 &mut declaration_to_signature_index,
932 &mut signatures,
933 index_state,
934 )
935 })
936 });
937
938 predict_edits_v3::PredictEditsRequest {
939 excerpt_path,
940 excerpt: context.excerpt_text.body,
941 excerpt_range: context.excerpt.range,
942 cursor_offset: context.cursor_offset_in_excerpt,
943 referenced_declarations,
944 signatures,
945 excerpt_parent,
946 events,
947 can_collect_data,
948 diagnostic_groups,
949 diagnostic_groups_truncated,
950 git_info,
951 debug_info,
952 prompt_max_bytes,
953 prompt_format,
954 }
955}
956
957fn add_signature(
958 declaration_id: DeclarationId,
959 declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
960 signatures: &mut Vec<Signature>,
961 index: &SyntaxIndexState,
962) -> Option<usize> {
963 if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
964 return Some(*signature_index);
965 }
966 let Some(parent_declaration) = index.declaration(declaration_id) else {
967 log::error!("bug: missing parent declaration");
968 return None;
969 };
970 let parent_index = parent_declaration.parent().and_then(|parent| {
971 add_signature(parent, declaration_to_signature_index, signatures, index)
972 });
973 let (text, text_is_truncated) = parent_declaration.signature_text();
974 let signature_index = signatures.len();
975 signatures.push(Signature {
976 text: text.into(),
977 text_is_truncated,
978 parent_index,
979 range: parent_declaration.signature_range(),
980 });
981 declaration_to_signature_index.insert(declaration_id, signature_index);
982 Some(signature_index)
983}
984
985#[cfg(test)]
986mod tests {
987 use std::{
988 path::{Path, PathBuf},
989 sync::Arc,
990 };
991
992 use client::UserStore;
993 use clock::FakeSystemClock;
994 use cloud_llm_client::predict_edits_v3;
995 use futures::{
996 AsyncReadExt, StreamExt,
997 channel::{mpsc, oneshot},
998 };
999 use gpui::{
1000 Entity, TestAppContext,
1001 http_client::{FakeHttpClient, Response},
1002 prelude::*,
1003 };
1004 use indoc::indoc;
1005 use language::{LanguageServerId, OffsetRangeExt as _};
1006 use pretty_assertions::{assert_eq, assert_matches};
1007 use project::{FakeFs, Project};
1008 use serde_json::json;
1009 use settings::SettingsStore;
1010 use util::path;
1011 use uuid::Uuid;
1012
1013 use crate::{BufferEditPrediction, Zeta};
1014
1015 #[gpui::test]
1016 async fn test_current_state(cx: &mut TestAppContext) {
1017 let (zeta, mut req_rx) = init_test(cx);
1018 let fs = FakeFs::new(cx.executor());
1019 fs.insert_tree(
1020 "/root",
1021 json!({
1022 "1.txt": "Hello!\nHow\nBye",
1023 "2.txt": "Hola!\nComo\nAdios"
1024 }),
1025 )
1026 .await;
1027 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1028
1029 zeta.update(cx, |zeta, cx| {
1030 zeta.register_project(&project, cx);
1031 });
1032
1033 let buffer1 = project
1034 .update(cx, |project, cx| {
1035 let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
1036 project.open_buffer(path, cx)
1037 })
1038 .await
1039 .unwrap();
1040 let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
1041 let position = snapshot1.anchor_before(language::Point::new(1, 3));
1042
1043 // Prediction for current file
1044
1045 let prediction_task = zeta.update(cx, |zeta, cx| {
1046 zeta.refresh_prediction(&project, &buffer1, position, cx)
1047 });
1048 let (_request, respond_tx) = req_rx.next().await.unwrap();
1049 respond_tx
1050 .send(predict_edits_v3::PredictEditsResponse {
1051 request_id: Uuid::new_v4(),
1052 edits: vec![predict_edits_v3::Edit {
1053 path: Path::new(path!("root/1.txt")).into(),
1054 range: 0..snapshot1.len(),
1055 content: "Hello!\nHow are you?\nBye".into(),
1056 }],
1057 debug_info: None,
1058 })
1059 .unwrap();
1060 prediction_task.await.unwrap();
1061
1062 zeta.read_with(cx, |zeta, cx| {
1063 let prediction = zeta
1064 .current_prediction_for_buffer(&buffer1, &project, cx)
1065 .unwrap();
1066 assert_matches!(prediction, BufferEditPrediction::Local { .. });
1067 });
1068
1069 // Prediction for another file
1070
1071 let prediction_task = zeta.update(cx, |zeta, cx| {
1072 zeta.refresh_prediction(&project, &buffer1, position, cx)
1073 });
1074 let (_request, respond_tx) = req_rx.next().await.unwrap();
1075 respond_tx
1076 .send(predict_edits_v3::PredictEditsResponse {
1077 request_id: Uuid::new_v4(),
1078 edits: vec![predict_edits_v3::Edit {
1079 path: Path::new(path!("root/2.txt")).into(),
1080 range: 0..snapshot1.len(),
1081 content: "Hola!\nComo estas?\nAdios".into(),
1082 }],
1083 debug_info: None,
1084 })
1085 .unwrap();
1086 prediction_task.await.unwrap();
1087
1088 zeta.read_with(cx, |zeta, cx| {
1089 let prediction = zeta
1090 .current_prediction_for_buffer(&buffer1, &project, cx)
1091 .unwrap();
1092 assert_matches!(
1093 prediction,
1094 BufferEditPrediction::Jump { prediction } if prediction.path.as_ref() == Path::new(path!("root/2.txt"))
1095 );
1096 });
1097
1098 let buffer2 = project
1099 .update(cx, |project, cx| {
1100 let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
1101 project.open_buffer(path, cx)
1102 })
1103 .await
1104 .unwrap();
1105
1106 zeta.read_with(cx, |zeta, cx| {
1107 let prediction = zeta
1108 .current_prediction_for_buffer(&buffer2, &project, cx)
1109 .unwrap();
1110 assert_matches!(prediction, BufferEditPrediction::Local { .. });
1111 });
1112 }
1113
1114 #[gpui::test]
1115 async fn test_simple_request(cx: &mut TestAppContext) {
1116 let (zeta, mut req_rx) = init_test(cx);
1117 let fs = FakeFs::new(cx.executor());
1118 fs.insert_tree(
1119 "/root",
1120 json!({
1121 "foo.md": "Hello!\nHow\nBye"
1122 }),
1123 )
1124 .await;
1125 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1126
1127 let buffer = project
1128 .update(cx, |project, cx| {
1129 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1130 project.open_buffer(path, cx)
1131 })
1132 .await
1133 .unwrap();
1134 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1135 let position = snapshot.anchor_before(language::Point::new(1, 3));
1136
1137 let prediction_task = zeta.update(cx, |zeta, cx| {
1138 zeta.request_prediction(&project, &buffer, position, cx)
1139 });
1140
1141 let (request, respond_tx) = req_rx.next().await.unwrap();
1142 assert_eq!(
1143 request.excerpt_path.as_ref(),
1144 Path::new(path!("root/foo.md"))
1145 );
1146 assert_eq!(request.cursor_offset, 10);
1147
1148 respond_tx
1149 .send(predict_edits_v3::PredictEditsResponse {
1150 request_id: Uuid::new_v4(),
1151 edits: vec![predict_edits_v3::Edit {
1152 path: Path::new(path!("root/foo.md")).into(),
1153 range: 0..snapshot.len(),
1154 content: "Hello!\nHow are you?\nBye".into(),
1155 }],
1156 debug_info: None,
1157 })
1158 .unwrap();
1159
1160 let prediction = prediction_task.await.unwrap().unwrap();
1161
1162 assert_eq!(prediction.edits.len(), 1);
1163 assert_eq!(
1164 prediction.edits[0].0.to_point(&snapshot).start,
1165 language::Point::new(1, 3)
1166 );
1167 assert_eq!(prediction.edits[0].1, " are you?");
1168 }
1169
1170 #[gpui::test]
1171 async fn test_request_events(cx: &mut TestAppContext) {
1172 let (zeta, mut req_rx) = init_test(cx);
1173 let fs = FakeFs::new(cx.executor());
1174 fs.insert_tree(
1175 "/root",
1176 json!({
1177 "foo.md": "Hello!\n\nBye"
1178 }),
1179 )
1180 .await;
1181 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1182
1183 let buffer = project
1184 .update(cx, |project, cx| {
1185 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1186 project.open_buffer(path, cx)
1187 })
1188 .await
1189 .unwrap();
1190
1191 zeta.update(cx, |zeta, cx| {
1192 zeta.register_buffer(&buffer, &project, cx);
1193 });
1194
1195 buffer.update(cx, |buffer, cx| {
1196 buffer.edit(vec![(7..7, "How")], None, cx);
1197 });
1198
1199 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1200 let position = snapshot.anchor_before(language::Point::new(1, 3));
1201
1202 let prediction_task = zeta.update(cx, |zeta, cx| {
1203 zeta.request_prediction(&project, &buffer, position, cx)
1204 });
1205
1206 let (request, respond_tx) = req_rx.next().await.unwrap();
1207
1208 assert_eq!(request.events.len(), 1);
1209 assert_eq!(
1210 request.events[0],
1211 predict_edits_v3::Event::BufferChange {
1212 path: Some(PathBuf::from(path!("root/foo.md"))),
1213 old_path: None,
1214 diff: indoc! {"
1215 @@ -1,3 +1,3 @@
1216 Hello!
1217 -
1218 +How
1219 Bye
1220 "}
1221 .to_string(),
1222 predicted: false
1223 }
1224 );
1225
1226 respond_tx
1227 .send(predict_edits_v3::PredictEditsResponse {
1228 request_id: Uuid::new_v4(),
1229 edits: vec![predict_edits_v3::Edit {
1230 path: Path::new(path!("root/foo.md")).into(),
1231 range: 0..snapshot.len(),
1232 content: "Hello!\nHow are you?\nBye".into(),
1233 }],
1234 debug_info: None,
1235 })
1236 .unwrap();
1237
1238 let prediction = prediction_task.await.unwrap().unwrap();
1239
1240 assert_eq!(prediction.edits.len(), 1);
1241 assert_eq!(
1242 prediction.edits[0].0.to_point(&snapshot).start,
1243 language::Point::new(1, 3)
1244 );
1245 assert_eq!(prediction.edits[0].1, " are you?");
1246 }
1247
1248 #[gpui::test]
1249 async fn test_request_diagnostics(cx: &mut TestAppContext) {
1250 let (zeta, mut req_rx) = init_test(cx);
1251 let fs = FakeFs::new(cx.executor());
1252 fs.insert_tree(
1253 "/root",
1254 json!({
1255 "foo.md": "Hello!\nBye"
1256 }),
1257 )
1258 .await;
1259 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1260
1261 let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1262 let diagnostic = lsp::Diagnostic {
1263 range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1264 severity: Some(lsp::DiagnosticSeverity::ERROR),
1265 message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1266 ..Default::default()
1267 };
1268
1269 project.update(cx, |project, cx| {
1270 project.lsp_store().update(cx, |lsp_store, cx| {
1271 // Create some diagnostics
1272 lsp_store
1273 .update_diagnostics(
1274 LanguageServerId(0),
1275 lsp::PublishDiagnosticsParams {
1276 uri: path_to_buffer_uri.clone(),
1277 diagnostics: vec![diagnostic],
1278 version: None,
1279 },
1280 None,
1281 language::DiagnosticSourceKind::Pushed,
1282 &[],
1283 cx,
1284 )
1285 .unwrap();
1286 });
1287 });
1288
1289 let buffer = project
1290 .update(cx, |project, cx| {
1291 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1292 project.open_buffer(path, cx)
1293 })
1294 .await
1295 .unwrap();
1296
1297 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1298 let position = snapshot.anchor_before(language::Point::new(0, 0));
1299
1300 let _prediction_task = zeta.update(cx, |zeta, cx| {
1301 zeta.request_prediction(&project, &buffer, position, cx)
1302 });
1303
1304 let (request, _respond_tx) = req_rx.next().await.unwrap();
1305
1306 assert_eq!(request.diagnostic_groups.len(), 1);
1307 let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1308 .unwrap();
1309 // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1310 assert_eq!(
1311 value,
1312 json!({
1313 "entries": [{
1314 "range": {
1315 "start": 8,
1316 "end": 10
1317 },
1318 "diagnostic": {
1319 "source": null,
1320 "code": null,
1321 "code_description": null,
1322 "severity": 1,
1323 "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1324 "markdown": null,
1325 "group_id": 0,
1326 "is_primary": true,
1327 "is_disk_based": false,
1328 "is_unnecessary": false,
1329 "source_kind": "Pushed",
1330 "data": null,
1331 "underline": true
1332 }
1333 }],
1334 "primary_ix": 0
1335 })
1336 );
1337 }
1338
1339 fn init_test(
1340 cx: &mut TestAppContext,
1341 ) -> (
1342 Entity<Zeta>,
1343 mpsc::UnboundedReceiver<(
1344 predict_edits_v3::PredictEditsRequest,
1345 oneshot::Sender<predict_edits_v3::PredictEditsResponse>,
1346 )>,
1347 ) {
1348 cx.update(move |cx| {
1349 let settings_store = SettingsStore::test(cx);
1350 cx.set_global(settings_store);
1351 language::init(cx);
1352 Project::init_settings(cx);
1353
1354 let (req_tx, req_rx) = mpsc::unbounded();
1355
1356 let http_client = FakeHttpClient::create({
1357 move |req| {
1358 let uri = req.uri().path().to_string();
1359 let mut body = req.into_body();
1360 let req_tx = req_tx.clone();
1361 async move {
1362 let resp = match uri.as_str() {
1363 "/client/llm_tokens" => serde_json::to_string(&json!({
1364 "token": "test"
1365 }))
1366 .unwrap(),
1367 "/predict_edits/v3" => {
1368 let mut buf = Vec::new();
1369 body.read_to_end(&mut buf).await.ok();
1370 let req = serde_json::from_slice(&buf).unwrap();
1371
1372 let (res_tx, res_rx) = oneshot::channel();
1373 req_tx.unbounded_send((req, res_tx)).unwrap();
1374 serde_json::to_string(&res_rx.await.unwrap()).unwrap()
1375 }
1376 _ => {
1377 panic!("Unexpected path: {}", uri)
1378 }
1379 };
1380
1381 Ok(Response::builder().body(resp.into()).unwrap())
1382 }
1383 }
1384 });
1385
1386 let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1387 client.cloud_client().set_credentials(1, "test".into());
1388
1389 language_model::init(client.clone(), cx);
1390
1391 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1392 let zeta = Zeta::global(&client, &user_store, cx);
1393
1394 (zeta, req_rx)
1395 })
1396 }
1397}