1use anyhow::{Context as _, Result, anyhow};
2use chrono::TimeDelta;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
5use cloud_llm_client::{
6 EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, ZED_VERSION_HEADER_NAME,
7};
8use cloud_zeta2_prompt::{DEFAULT_MAX_PROMPT_BYTES, PlannedPrompt};
9use edit_prediction_context::{
10 DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
11 EditPredictionExcerptOptions, EditPredictionScoreOptions, SyntaxIndex, SyntaxIndexState,
12};
13use futures::AsyncReadExt as _;
14use futures::channel::{mpsc, oneshot};
15use gpui::http_client::Method;
16use gpui::{
17 App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
18 http_client, prelude::*,
19};
20use language::{Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
21use language::{BufferSnapshot, TextBufferSnapshot};
22use language_model::{LlmApiToken, RefreshLlmTokenListener};
23use project::Project;
24use release_channel::AppVersion;
25use std::collections::{HashMap, VecDeque, hash_map};
26use std::path::Path;
27use std::str::FromStr as _;
28use std::sync::Arc;
29use std::time::{Duration, Instant};
30use thiserror::Error;
31use util::rel_path::RelPathBuf;
32use util::some_or_debug_panic;
33use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
34
35mod prediction;
36mod provider;
37
38use crate::prediction::EditPrediction;
39pub use provider::ZetaEditPredictionProvider;
40
41const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
42
43/// Maximum number of events to track.
44const MAX_EVENT_COUNT: usize = 16;
45
46pub const DEFAULT_CONTEXT_OPTIONS: EditPredictionContextOptions = EditPredictionContextOptions {
47 use_imports: true,
48 excerpt: EditPredictionExcerptOptions {
49 max_bytes: 512,
50 min_bytes: 128,
51 target_before_cursor_over_total_bytes: 0.5,
52 },
53 score: EditPredictionScoreOptions {
54 omit_excerpt_overlaps: true,
55 },
56};
57
58pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
59 context: DEFAULT_CONTEXT_OPTIONS,
60 max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
61 max_diagnostic_bytes: 2048,
62 prompt_format: PromptFormat::DEFAULT,
63 file_indexing_parallelism: 1,
64};
65
66#[derive(Clone)]
67struct ZetaGlobal(Entity<Zeta>);
68
69impl Global for ZetaGlobal {}
70
71pub struct Zeta {
72 client: Arc<Client>,
73 user_store: Entity<UserStore>,
74 llm_token: LlmApiToken,
75 _llm_token_subscription: Subscription,
76 projects: HashMap<EntityId, ZetaProject>,
77 options: ZetaOptions,
78 update_required: bool,
79 debug_tx: Option<mpsc::UnboundedSender<PredictionDebugInfo>>,
80}
81
82#[derive(Debug, Clone, PartialEq)]
83pub struct ZetaOptions {
84 pub context: EditPredictionContextOptions,
85 pub max_prompt_bytes: usize,
86 pub max_diagnostic_bytes: usize,
87 pub prompt_format: predict_edits_v3::PromptFormat,
88 pub file_indexing_parallelism: usize,
89}
90
91pub struct PredictionDebugInfo {
92 pub context: EditPredictionContext,
93 pub retrieval_time: TimeDelta,
94 pub buffer: WeakEntity<Buffer>,
95 pub position: language::Anchor,
96 pub local_prompt: Result<String, String>,
97 pub response_rx: oneshot::Receiver<Result<RequestDebugInfo, String>>,
98}
99
100pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
101
102struct ZetaProject {
103 syntax_index: Entity<SyntaxIndex>,
104 events: VecDeque<Event>,
105 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
106 current_prediction: Option<CurrentEditPrediction>,
107}
108
109#[derive(Clone)]
110struct CurrentEditPrediction {
111 pub requested_by_buffer_id: EntityId,
112 pub prediction: EditPrediction,
113}
114
115impl CurrentEditPrediction {
116 fn should_replace_prediction(
117 &self,
118 old_prediction: &Self,
119 snapshot: &TextBufferSnapshot,
120 ) -> bool {
121 if self.requested_by_buffer_id != old_prediction.requested_by_buffer_id {
122 return true;
123 }
124
125 let Some(old_edits) = old_prediction.prediction.interpolate(snapshot) else {
126 return true;
127 };
128
129 let Some(new_edits) = self.prediction.interpolate(snapshot) else {
130 return false;
131 };
132 if old_edits.len() == 1 && new_edits.len() == 1 {
133 let (old_range, old_text) = &old_edits[0];
134 let (new_range, new_text) = &new_edits[0];
135 new_range == old_range && new_text.starts_with(old_text)
136 } else {
137 true
138 }
139 }
140}
141
142/// A prediction from the perspective of a buffer.
143#[derive(Debug)]
144enum BufferEditPrediction<'a> {
145 Local { prediction: &'a EditPrediction },
146 Jump { prediction: &'a EditPrediction },
147}
148
149struct RegisteredBuffer {
150 snapshot: BufferSnapshot,
151 _subscriptions: [gpui::Subscription; 2],
152}
153
154#[derive(Clone)]
155pub enum Event {
156 BufferChange {
157 old_snapshot: BufferSnapshot,
158 new_snapshot: BufferSnapshot,
159 timestamp: Instant,
160 },
161}
162
163impl Zeta {
164 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
165 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
166 }
167
168 pub fn global(
169 client: &Arc<Client>,
170 user_store: &Entity<UserStore>,
171 cx: &mut App,
172 ) -> Entity<Self> {
173 cx.try_global::<ZetaGlobal>()
174 .map(|global| global.0.clone())
175 .unwrap_or_else(|| {
176 let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
177 cx.set_global(ZetaGlobal(zeta.clone()));
178 zeta
179 })
180 }
181
182 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
183 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
184
185 Self {
186 projects: HashMap::new(),
187 client,
188 user_store,
189 options: DEFAULT_OPTIONS,
190 llm_token: LlmApiToken::default(),
191 _llm_token_subscription: cx.subscribe(
192 &refresh_llm_token_listener,
193 |this, _listener, _event, cx| {
194 let client = this.client.clone();
195 let llm_token = this.llm_token.clone();
196 cx.spawn(async move |_this, _cx| {
197 llm_token.refresh(&client).await?;
198 anyhow::Ok(())
199 })
200 .detach_and_log_err(cx);
201 },
202 ),
203 update_required: false,
204 debug_tx: None,
205 }
206 }
207
208 pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<PredictionDebugInfo> {
209 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
210 self.debug_tx = Some(debug_watch_tx);
211 debug_watch_rx
212 }
213
214 pub fn options(&self) -> &ZetaOptions {
215 &self.options
216 }
217
218 pub fn set_options(&mut self, options: ZetaOptions) {
219 self.options = options;
220 }
221
222 pub fn clear_history(&mut self) {
223 for zeta_project in self.projects.values_mut() {
224 zeta_project.events.clear();
225 }
226 }
227
228 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
229 self.user_store.read(cx).edit_prediction_usage()
230 }
231
232 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
233 self.get_or_init_zeta_project(project, cx);
234 }
235
236 pub fn register_buffer(
237 &mut self,
238 buffer: &Entity<Buffer>,
239 project: &Entity<Project>,
240 cx: &mut Context<Self>,
241 ) {
242 let zeta_project = self.get_or_init_zeta_project(project, cx);
243 Self::register_buffer_impl(zeta_project, buffer, project, cx);
244 }
245
246 fn get_or_init_zeta_project(
247 &mut self,
248 project: &Entity<Project>,
249 cx: &mut App,
250 ) -> &mut ZetaProject {
251 self.projects
252 .entry(project.entity_id())
253 .or_insert_with(|| ZetaProject {
254 syntax_index: cx.new(|cx| {
255 SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
256 }),
257 events: VecDeque::new(),
258 registered_buffers: HashMap::new(),
259 current_prediction: None,
260 })
261 }
262
263 fn register_buffer_impl<'a>(
264 zeta_project: &'a mut ZetaProject,
265 buffer: &Entity<Buffer>,
266 project: &Entity<Project>,
267 cx: &mut Context<Self>,
268 ) -> &'a mut RegisteredBuffer {
269 let buffer_id = buffer.entity_id();
270 match zeta_project.registered_buffers.entry(buffer_id) {
271 hash_map::Entry::Occupied(entry) => entry.into_mut(),
272 hash_map::Entry::Vacant(entry) => {
273 let snapshot = buffer.read(cx).snapshot();
274 let project_entity_id = project.entity_id();
275 entry.insert(RegisteredBuffer {
276 snapshot,
277 _subscriptions: [
278 cx.subscribe(buffer, {
279 let project = project.downgrade();
280 move |this, buffer, event, cx| {
281 if let language::BufferEvent::Edited = event
282 && let Some(project) = project.upgrade()
283 {
284 this.report_changes_for_buffer(&buffer, &project, cx);
285 }
286 }
287 }),
288 cx.observe_release(buffer, move |this, _buffer, _cx| {
289 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
290 else {
291 return;
292 };
293 zeta_project.registered_buffers.remove(&buffer_id);
294 }),
295 ],
296 })
297 }
298 }
299 }
300
301 fn report_changes_for_buffer(
302 &mut self,
303 buffer: &Entity<Buffer>,
304 project: &Entity<Project>,
305 cx: &mut Context<Self>,
306 ) -> BufferSnapshot {
307 let zeta_project = self.get_or_init_zeta_project(project, cx);
308 let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
309
310 let new_snapshot = buffer.read(cx).snapshot();
311 if new_snapshot.version != registered_buffer.snapshot.version {
312 let old_snapshot =
313 std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
314 Self::push_event(
315 zeta_project,
316 Event::BufferChange {
317 old_snapshot,
318 new_snapshot: new_snapshot.clone(),
319 timestamp: Instant::now(),
320 },
321 );
322 }
323
324 new_snapshot
325 }
326
327 fn push_event(zeta_project: &mut ZetaProject, event: Event) {
328 let events = &mut zeta_project.events;
329
330 if let Some(Event::BufferChange {
331 new_snapshot: last_new_snapshot,
332 timestamp: last_timestamp,
333 ..
334 }) = events.back_mut()
335 {
336 // Coalesce edits for the same buffer when they happen one after the other.
337 let Event::BufferChange {
338 old_snapshot,
339 new_snapshot,
340 timestamp,
341 } = &event;
342
343 if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
344 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
345 && old_snapshot.version == last_new_snapshot.version
346 {
347 *last_new_snapshot = new_snapshot.clone();
348 *last_timestamp = *timestamp;
349 return;
350 }
351 }
352
353 if events.len() >= MAX_EVENT_COUNT {
354 // These are halved instead of popping to improve prompt caching.
355 events.drain(..MAX_EVENT_COUNT / 2);
356 }
357
358 events.push_back(event);
359 }
360
361 fn current_prediction_for_buffer(
362 &self,
363 buffer: &Entity<Buffer>,
364 project: &Entity<Project>,
365 cx: &App,
366 ) -> Option<BufferEditPrediction<'_>> {
367 let project_state = self.projects.get(&project.entity_id())?;
368
369 let CurrentEditPrediction {
370 requested_by_buffer_id,
371 prediction,
372 } = project_state.current_prediction.as_ref()?;
373
374 if prediction.targets_buffer(buffer.read(cx), cx) {
375 Some(BufferEditPrediction::Local { prediction })
376 } else if *requested_by_buffer_id == buffer.entity_id() {
377 Some(BufferEditPrediction::Jump { prediction })
378 } else {
379 None
380 }
381 }
382
383 fn accept_current_prediction(&mut self, project: &Entity<Project>) {
384 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
385 project_state.current_prediction.take();
386 };
387 // TODO report accepted
388 }
389
390 fn discard_current_prediction(&mut self, project: &Entity<Project>) {
391 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
392 project_state.current_prediction.take();
393 };
394 }
395
396 pub fn refresh_prediction(
397 &mut self,
398 project: &Entity<Project>,
399 buffer: &Entity<Buffer>,
400 position: language::Anchor,
401 cx: &mut Context<Self>,
402 ) -> Task<Result<()>> {
403 let request_task = self.request_prediction(project, buffer, position, cx);
404 let buffer = buffer.clone();
405 let project = project.clone();
406
407 cx.spawn(async move |this, cx| {
408 if let Some(prediction) = request_task.await? {
409 this.update(cx, |this, cx| {
410 let project_state = this
411 .projects
412 .get_mut(&project.entity_id())
413 .context("Project not found")?;
414
415 let new_prediction = CurrentEditPrediction {
416 requested_by_buffer_id: buffer.entity_id(),
417 prediction: prediction,
418 };
419
420 if project_state
421 .current_prediction
422 .as_ref()
423 .is_none_or(|old_prediction| {
424 new_prediction
425 .should_replace_prediction(&old_prediction, buffer.read(cx))
426 })
427 {
428 project_state.current_prediction = Some(new_prediction);
429 }
430 anyhow::Ok(())
431 })??;
432 }
433 Ok(())
434 })
435 }
436
437 fn request_prediction(
438 &mut self,
439 project: &Entity<Project>,
440 buffer: &Entity<Buffer>,
441 position: language::Anchor,
442 cx: &mut Context<Self>,
443 ) -> Task<Result<Option<EditPrediction>>> {
444 let project_state = self.projects.get(&project.entity_id());
445
446 let index_state = project_state.map(|state| {
447 state
448 .syntax_index
449 .read_with(cx, |index, _cx| index.state().clone())
450 });
451 let options = self.options.clone();
452 let snapshot = buffer.read(cx).snapshot();
453 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx).into()) else {
454 return Task::ready(Err(anyhow!("No file path for excerpt")));
455 };
456 let client = self.client.clone();
457 let llm_token = self.llm_token.clone();
458 let app_version = AppVersion::global(cx);
459 let worktree_snapshots = project
460 .read(cx)
461 .worktrees(cx)
462 .map(|worktree| worktree.read(cx).snapshot())
463 .collect::<Vec<_>>();
464 let debug_tx = self.debug_tx.clone();
465
466 let events = project_state
467 .map(|state| {
468 state
469 .events
470 .iter()
471 .filter_map(|event| match event {
472 Event::BufferChange {
473 old_snapshot,
474 new_snapshot,
475 ..
476 } => {
477 let path = new_snapshot.file().map(|f| f.full_path(cx));
478
479 let old_path = old_snapshot.file().and_then(|f| {
480 let old_path = f.full_path(cx);
481 if Some(&old_path) != path.as_ref() {
482 Some(old_path)
483 } else {
484 None
485 }
486 });
487
488 // TODO [zeta2] move to bg?
489 let diff =
490 language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
491
492 if path == old_path && diff.is_empty() {
493 None
494 } else {
495 Some(predict_edits_v3::Event::BufferChange {
496 old_path,
497 path,
498 diff,
499 //todo: Actually detect if this edit was predicted or not
500 predicted: false,
501 })
502 }
503 }
504 })
505 .collect::<Vec<_>>()
506 })
507 .unwrap_or_default();
508
509 let diagnostics = snapshot.diagnostic_sets().clone();
510
511 let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
512 let mut path = f.worktree.read(cx).absolutize(&f.path);
513 if path.pop() { Some(path) } else { None }
514 });
515
516 let request_task = cx.background_spawn({
517 let snapshot = snapshot.clone();
518 let buffer = buffer.clone();
519 async move {
520 let index_state = if let Some(index_state) = index_state {
521 Some(index_state.lock_owned().await)
522 } else {
523 None
524 };
525
526 let cursor_offset = position.to_offset(&snapshot);
527 let cursor_point = cursor_offset.to_point(&snapshot);
528
529 let before_retrieval = chrono::Utc::now();
530
531 let Some(context) = EditPredictionContext::gather_context(
532 cursor_point,
533 &snapshot,
534 parent_abs_path.as_deref(),
535 &options.context,
536 index_state.as_deref(),
537 ) else {
538 return Ok(None);
539 };
540
541 let retrieval_time = chrono::Utc::now() - before_retrieval;
542
543 let (diagnostic_groups, diagnostic_groups_truncated) =
544 Self::gather_nearby_diagnostics(
545 cursor_offset,
546 &diagnostics,
547 &snapshot,
548 options.max_diagnostic_bytes,
549 );
550
551 let debug_context = debug_tx.map(|tx| (tx, context.clone()));
552
553 let request = make_cloud_request(
554 excerpt_path,
555 context,
556 events,
557 // TODO data collection
558 false,
559 diagnostic_groups,
560 diagnostic_groups_truncated,
561 None,
562 debug_context.is_some(),
563 &worktree_snapshots,
564 index_state.as_deref(),
565 Some(options.max_prompt_bytes),
566 options.prompt_format,
567 );
568
569 let debug_response_tx = if let Some((debug_tx, context)) = debug_context {
570 let (response_tx, response_rx) = oneshot::channel();
571
572 let local_prompt = PlannedPrompt::populate(&request)
573 .and_then(|p| p.to_prompt_string().map(|p| p.0))
574 .map_err(|err| err.to_string());
575
576 debug_tx
577 .unbounded_send(PredictionDebugInfo {
578 context,
579 retrieval_time,
580 buffer: buffer.downgrade(),
581 local_prompt,
582 position,
583 response_rx,
584 })
585 .ok();
586 Some(response_tx)
587 } else {
588 None
589 };
590
591 if cfg!(debug_assertions) && std::env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
592 if let Some(debug_response_tx) = debug_response_tx {
593 debug_response_tx
594 .send(Err("Request skipped".to_string()))
595 .ok();
596 }
597 anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
598 }
599
600 let response = Self::perform_request(client, llm_token, app_version, request).await;
601
602 if let Some(debug_response_tx) = debug_response_tx {
603 debug_response_tx
604 .send(response.as_ref().map_err(|err| err.to_string()).and_then(
605 |response| match some_or_debug_panic(response.0.debug_info.clone()) {
606 Some(debug_info) => Ok(debug_info),
607 None => Err("Missing debug info".to_string()),
608 },
609 ))
610 .ok();
611 }
612
613 anyhow::Ok(Some(response?))
614 }
615 });
616
617 let buffer = buffer.clone();
618
619 cx.spawn({
620 let project = project.clone();
621 async move |this, cx| {
622 match request_task.await {
623 Ok(Some((response, usage))) => {
624 if let Some(usage) = usage {
625 this.update(cx, |this, cx| {
626 this.user_store.update(cx, |user_store, cx| {
627 user_store.update_edit_prediction_usage(usage, cx);
628 });
629 })
630 .ok();
631 }
632
633 let prediction = EditPrediction::from_response(
634 response, &snapshot, &buffer, &project, cx,
635 )
636 .await;
637
638 // TODO telemetry: duration, etc
639 Ok(prediction)
640 }
641 Ok(None) => Ok(None),
642 Err(err) => {
643 if err.is::<ZedUpdateRequiredError>() {
644 cx.update(|cx| {
645 this.update(cx, |this, _cx| {
646 this.update_required = true;
647 })
648 .ok();
649
650 let error_message: SharedString = err.to_string().into();
651 show_app_notification(
652 NotificationId::unique::<ZedUpdateRequiredError>(),
653 cx,
654 move |cx| {
655 cx.new(|cx| {
656 ErrorMessagePrompt::new(error_message.clone(), cx)
657 .with_link_button(
658 "Update Zed",
659 "https://zed.dev/releases",
660 )
661 })
662 },
663 );
664 })
665 .ok();
666 }
667
668 Err(err)
669 }
670 }
671 }
672 })
673 }
674
675 async fn perform_request(
676 client: Arc<Client>,
677 llm_token: LlmApiToken,
678 app_version: SemanticVersion,
679 request: predict_edits_v3::PredictEditsRequest,
680 ) -> Result<(
681 predict_edits_v3::PredictEditsResponse,
682 Option<EditPredictionUsage>,
683 )> {
684 let http_client = client.http_client();
685 let mut token = llm_token.acquire(&client).await?;
686 let mut did_retry = false;
687
688 loop {
689 let request_builder = http_client::Request::builder().method(Method::POST);
690 let request_builder =
691 if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
692 request_builder.uri(predict_edits_url)
693 } else {
694 request_builder.uri(
695 http_client
696 .build_zed_llm_url("/predict_edits/v3", &[])?
697 .as_ref(),
698 )
699 };
700 let request = request_builder
701 .header("Content-Type", "application/json")
702 .header("Authorization", format!("Bearer {}", token))
703 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
704 .body(serde_json::to_string(&request)?.into())?;
705
706 let mut response = http_client.send(request).await?;
707
708 if let Some(minimum_required_version) = response
709 .headers()
710 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
711 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
712 {
713 anyhow::ensure!(
714 app_version >= minimum_required_version,
715 ZedUpdateRequiredError {
716 minimum_version: minimum_required_version
717 }
718 );
719 }
720
721 if response.status().is_success() {
722 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
723
724 let mut body = Vec::new();
725 response.body_mut().read_to_end(&mut body).await?;
726 return Ok((serde_json::from_slice(&body)?, usage));
727 } else if !did_retry
728 && response
729 .headers()
730 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
731 .is_some()
732 {
733 did_retry = true;
734 token = llm_token.refresh(&client).await?;
735 } else {
736 let mut body = String::new();
737 response.body_mut().read_to_string(&mut body).await?;
738 anyhow::bail!(
739 "error predicting edits.\nStatus: {:?}\nBody: {}",
740 response.status(),
741 body
742 );
743 }
744 }
745 }
746
747 fn gather_nearby_diagnostics(
748 cursor_offset: usize,
749 diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
750 snapshot: &BufferSnapshot,
751 max_diagnostics_bytes: usize,
752 ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
753 // TODO: Could make this more efficient
754 let mut diagnostic_groups = Vec::new();
755 for (language_server_id, diagnostics) in diagnostic_sets {
756 let mut groups = Vec::new();
757 diagnostics.groups(*language_server_id, &mut groups, &snapshot);
758 diagnostic_groups.extend(
759 groups
760 .into_iter()
761 .map(|(_, group)| group.resolve::<usize>(&snapshot)),
762 );
763 }
764
765 // sort by proximity to cursor
766 diagnostic_groups.sort_by_key(|group| {
767 let range = &group.entries[group.primary_ix].range;
768 if range.start >= cursor_offset {
769 range.start - cursor_offset
770 } else if cursor_offset >= range.end {
771 cursor_offset - range.end
772 } else {
773 (cursor_offset - range.start).min(range.end - cursor_offset)
774 }
775 });
776
777 let mut results = Vec::new();
778 let mut diagnostic_groups_truncated = false;
779 let mut diagnostics_byte_count = 0;
780 for group in diagnostic_groups {
781 let raw_value = serde_json::value::to_raw_value(&group).unwrap();
782 diagnostics_byte_count += raw_value.get().len();
783 if diagnostics_byte_count > max_diagnostics_bytes {
784 diagnostic_groups_truncated = true;
785 break;
786 }
787 results.push(predict_edits_v3::DiagnosticGroup(raw_value));
788 }
789
790 (results, diagnostic_groups_truncated)
791 }
792
793 // TODO: Dedupe with similar code in request_prediction?
794 pub fn cloud_request_for_zeta_cli(
795 &mut self,
796 project: &Entity<Project>,
797 buffer: &Entity<Buffer>,
798 position: language::Anchor,
799 cx: &mut Context<Self>,
800 ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
801 let project_state = self.projects.get(&project.entity_id());
802
803 let index_state = project_state.map(|state| {
804 state
805 .syntax_index
806 .read_with(cx, |index, _cx| index.state().clone())
807 });
808 let options = self.options.clone();
809 let snapshot = buffer.read(cx).snapshot();
810 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
811 return Task::ready(Err(anyhow!("No file path for excerpt")));
812 };
813 let worktree_snapshots = project
814 .read(cx)
815 .worktrees(cx)
816 .map(|worktree| worktree.read(cx).snapshot())
817 .collect::<Vec<_>>();
818
819 let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
820 let mut path = f.worktree.read(cx).absolutize(&f.path);
821 if path.pop() { Some(path) } else { None }
822 });
823
824 cx.background_spawn(async move {
825 let index_state = if let Some(index_state) = index_state {
826 Some(index_state.lock_owned().await)
827 } else {
828 None
829 };
830
831 let cursor_point = position.to_point(&snapshot);
832
833 let debug_info = true;
834 EditPredictionContext::gather_context(
835 cursor_point,
836 &snapshot,
837 parent_abs_path.as_deref(),
838 &options.context,
839 index_state.as_deref(),
840 )
841 .context("Failed to select excerpt")
842 .map(|context| {
843 make_cloud_request(
844 excerpt_path.into(),
845 context,
846 // TODO pass everything
847 Vec::new(),
848 false,
849 Vec::new(),
850 false,
851 None,
852 debug_info,
853 &worktree_snapshots,
854 index_state.as_deref(),
855 Some(options.max_prompt_bytes),
856 options.prompt_format,
857 )
858 })
859 })
860 }
861
862 pub fn wait_for_initial_indexing(
863 &mut self,
864 project: &Entity<Project>,
865 cx: &mut App,
866 ) -> Task<Result<()>> {
867 let zeta_project = self.get_or_init_zeta_project(project, cx);
868 zeta_project
869 .syntax_index
870 .read(cx)
871 .wait_for_initial_file_indexing(cx)
872 }
873}
874
875#[derive(Error, Debug)]
876#[error(
877 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
878)]
879pub struct ZedUpdateRequiredError {
880 minimum_version: SemanticVersion,
881}
882
883fn make_cloud_request(
884 excerpt_path: Arc<Path>,
885 context: EditPredictionContext,
886 events: Vec<predict_edits_v3::Event>,
887 can_collect_data: bool,
888 diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
889 diagnostic_groups_truncated: bool,
890 git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
891 debug_info: bool,
892 worktrees: &Vec<worktree::Snapshot>,
893 index_state: Option<&SyntaxIndexState>,
894 prompt_max_bytes: Option<usize>,
895 prompt_format: PromptFormat,
896) -> predict_edits_v3::PredictEditsRequest {
897 let mut signatures = Vec::new();
898 let mut declaration_to_signature_index = HashMap::default();
899 let mut referenced_declarations = Vec::new();
900
901 for snippet in context.declarations {
902 let project_entry_id = snippet.declaration.project_entry_id();
903 let Some(path) = worktrees.iter().find_map(|worktree| {
904 worktree.entry_for_id(project_entry_id).map(|entry| {
905 let mut full_path = RelPathBuf::new();
906 full_path.push(worktree.root_name());
907 full_path.push(&entry.path);
908 full_path
909 })
910 }) else {
911 continue;
912 };
913
914 let parent_index = index_state.and_then(|index_state| {
915 snippet.declaration.parent().and_then(|parent| {
916 add_signature(
917 parent,
918 &mut declaration_to_signature_index,
919 &mut signatures,
920 index_state,
921 )
922 })
923 });
924
925 let (text, text_is_truncated) = snippet.declaration.item_text();
926 referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
927 path: path.as_std_path().into(),
928 text: text.into(),
929 range: snippet.declaration.item_range(),
930 text_is_truncated,
931 signature_range: snippet.declaration.signature_range_in_item_text(),
932 parent_index,
933 signature_score: snippet.score(DeclarationStyle::Signature),
934 declaration_score: snippet.score(DeclarationStyle::Declaration),
935 score_components: snippet.components,
936 });
937 }
938
939 let excerpt_parent = index_state.and_then(|index_state| {
940 context
941 .excerpt
942 .parent_declarations
943 .last()
944 .and_then(|(parent, _)| {
945 add_signature(
946 *parent,
947 &mut declaration_to_signature_index,
948 &mut signatures,
949 index_state,
950 )
951 })
952 });
953
954 predict_edits_v3::PredictEditsRequest {
955 excerpt_path,
956 excerpt: context.excerpt_text.body,
957 excerpt_range: context.excerpt.range,
958 cursor_offset: context.cursor_offset_in_excerpt,
959 referenced_declarations,
960 signatures,
961 excerpt_parent,
962 events,
963 can_collect_data,
964 diagnostic_groups,
965 diagnostic_groups_truncated,
966 git_info,
967 debug_info,
968 prompt_max_bytes,
969 prompt_format,
970 }
971}
972
973fn add_signature(
974 declaration_id: DeclarationId,
975 declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
976 signatures: &mut Vec<Signature>,
977 index: &SyntaxIndexState,
978) -> Option<usize> {
979 if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
980 return Some(*signature_index);
981 }
982 let Some(parent_declaration) = index.declaration(declaration_id) else {
983 log::error!("bug: missing parent declaration");
984 return None;
985 };
986 let parent_index = parent_declaration.parent().and_then(|parent| {
987 add_signature(parent, declaration_to_signature_index, signatures, index)
988 });
989 let (text, text_is_truncated) = parent_declaration.signature_text();
990 let signature_index = signatures.len();
991 signatures.push(Signature {
992 text: text.into(),
993 text_is_truncated,
994 parent_index,
995 range: parent_declaration.signature_range(),
996 });
997 declaration_to_signature_index.insert(declaration_id, signature_index);
998 Some(signature_index)
999}
1000
1001#[cfg(test)]
1002mod tests {
1003 use std::{
1004 path::{Path, PathBuf},
1005 sync::Arc,
1006 };
1007
1008 use client::UserStore;
1009 use clock::FakeSystemClock;
1010 use cloud_llm_client::predict_edits_v3;
1011 use futures::{
1012 AsyncReadExt, StreamExt,
1013 channel::{mpsc, oneshot},
1014 };
1015 use gpui::{
1016 Entity, TestAppContext,
1017 http_client::{FakeHttpClient, Response},
1018 prelude::*,
1019 };
1020 use indoc::indoc;
1021 use language::{LanguageServerId, OffsetRangeExt as _};
1022 use pretty_assertions::{assert_eq, assert_matches};
1023 use project::{FakeFs, Project};
1024 use serde_json::json;
1025 use settings::SettingsStore;
1026 use util::path;
1027 use uuid::Uuid;
1028
1029 use crate::{BufferEditPrediction, Zeta};
1030
1031 #[gpui::test]
1032 async fn test_current_state(cx: &mut TestAppContext) {
1033 let (zeta, mut req_rx) = init_test(cx);
1034 let fs = FakeFs::new(cx.executor());
1035 fs.insert_tree(
1036 "/root",
1037 json!({
1038 "1.txt": "Hello!\nHow\nBye",
1039 "2.txt": "Hola!\nComo\nAdios"
1040 }),
1041 )
1042 .await;
1043 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1044
1045 zeta.update(cx, |zeta, cx| {
1046 zeta.register_project(&project, cx);
1047 });
1048
1049 let buffer1 = project
1050 .update(cx, |project, cx| {
1051 let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
1052 project.open_buffer(path, cx)
1053 })
1054 .await
1055 .unwrap();
1056 let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
1057 let position = snapshot1.anchor_before(language::Point::new(1, 3));
1058
1059 // Prediction for current file
1060
1061 let prediction_task = zeta.update(cx, |zeta, cx| {
1062 zeta.refresh_prediction(&project, &buffer1, position, cx)
1063 });
1064 let (_request, respond_tx) = req_rx.next().await.unwrap();
1065 respond_tx
1066 .send(predict_edits_v3::PredictEditsResponse {
1067 request_id: Uuid::new_v4(),
1068 edits: vec![predict_edits_v3::Edit {
1069 path: Path::new(path!("root/1.txt")).into(),
1070 range: 0..snapshot1.len(),
1071 content: "Hello!\nHow are you?\nBye".into(),
1072 }],
1073 debug_info: None,
1074 })
1075 .unwrap();
1076 prediction_task.await.unwrap();
1077
1078 zeta.read_with(cx, |zeta, cx| {
1079 let prediction = zeta
1080 .current_prediction_for_buffer(&buffer1, &project, cx)
1081 .unwrap();
1082 assert_matches!(prediction, BufferEditPrediction::Local { .. });
1083 });
1084
1085 // Prediction for another file
1086
1087 let prediction_task = zeta.update(cx, |zeta, cx| {
1088 zeta.refresh_prediction(&project, &buffer1, position, cx)
1089 });
1090 let (_request, respond_tx) = req_rx.next().await.unwrap();
1091 respond_tx
1092 .send(predict_edits_v3::PredictEditsResponse {
1093 request_id: Uuid::new_v4(),
1094 edits: vec![predict_edits_v3::Edit {
1095 path: Path::new(path!("root/2.txt")).into(),
1096 range: 0..snapshot1.len(),
1097 content: "Hola!\nComo estas?\nAdios".into(),
1098 }],
1099 debug_info: None,
1100 })
1101 .unwrap();
1102 prediction_task.await.unwrap();
1103
1104 zeta.read_with(cx, |zeta, cx| {
1105 let prediction = zeta
1106 .current_prediction_for_buffer(&buffer1, &project, cx)
1107 .unwrap();
1108 assert_matches!(
1109 prediction,
1110 BufferEditPrediction::Jump { prediction } if prediction.path.as_ref() == Path::new(path!("root/2.txt"))
1111 );
1112 });
1113
1114 let buffer2 = project
1115 .update(cx, |project, cx| {
1116 let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
1117 project.open_buffer(path, cx)
1118 })
1119 .await
1120 .unwrap();
1121
1122 zeta.read_with(cx, |zeta, cx| {
1123 let prediction = zeta
1124 .current_prediction_for_buffer(&buffer2, &project, cx)
1125 .unwrap();
1126 assert_matches!(prediction, BufferEditPrediction::Local { .. });
1127 });
1128 }
1129
1130 #[gpui::test]
1131 async fn test_simple_request(cx: &mut TestAppContext) {
1132 let (zeta, mut req_rx) = init_test(cx);
1133 let fs = FakeFs::new(cx.executor());
1134 fs.insert_tree(
1135 "/root",
1136 json!({
1137 "foo.md": "Hello!\nHow\nBye"
1138 }),
1139 )
1140 .await;
1141 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1142
1143 let buffer = project
1144 .update(cx, |project, cx| {
1145 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1146 project.open_buffer(path, cx)
1147 })
1148 .await
1149 .unwrap();
1150 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1151 let position = snapshot.anchor_before(language::Point::new(1, 3));
1152
1153 let prediction_task = zeta.update(cx, |zeta, cx| {
1154 zeta.request_prediction(&project, &buffer, position, cx)
1155 });
1156
1157 let (request, respond_tx) = req_rx.next().await.unwrap();
1158 assert_eq!(
1159 request.excerpt_path.as_ref(),
1160 Path::new(path!("root/foo.md"))
1161 );
1162 assert_eq!(request.cursor_offset, 10);
1163
1164 respond_tx
1165 .send(predict_edits_v3::PredictEditsResponse {
1166 request_id: Uuid::new_v4(),
1167 edits: vec![predict_edits_v3::Edit {
1168 path: Path::new(path!("root/foo.md")).into(),
1169 range: 0..snapshot.len(),
1170 content: "Hello!\nHow are you?\nBye".into(),
1171 }],
1172 debug_info: None,
1173 })
1174 .unwrap();
1175
1176 let prediction = prediction_task.await.unwrap().unwrap();
1177
1178 assert_eq!(prediction.edits.len(), 1);
1179 assert_eq!(
1180 prediction.edits[0].0.to_point(&snapshot).start,
1181 language::Point::new(1, 3)
1182 );
1183 assert_eq!(prediction.edits[0].1, " are you?");
1184 }
1185
1186 #[gpui::test]
1187 async fn test_request_events(cx: &mut TestAppContext) {
1188 let (zeta, mut req_rx) = init_test(cx);
1189 let fs = FakeFs::new(cx.executor());
1190 fs.insert_tree(
1191 "/root",
1192 json!({
1193 "foo.md": "Hello!\n\nBye"
1194 }),
1195 )
1196 .await;
1197 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1198
1199 let buffer = project
1200 .update(cx, |project, cx| {
1201 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1202 project.open_buffer(path, cx)
1203 })
1204 .await
1205 .unwrap();
1206
1207 zeta.update(cx, |zeta, cx| {
1208 zeta.register_buffer(&buffer, &project, cx);
1209 });
1210
1211 buffer.update(cx, |buffer, cx| {
1212 buffer.edit(vec![(7..7, "How")], None, cx);
1213 });
1214
1215 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1216 let position = snapshot.anchor_before(language::Point::new(1, 3));
1217
1218 let prediction_task = zeta.update(cx, |zeta, cx| {
1219 zeta.request_prediction(&project, &buffer, position, cx)
1220 });
1221
1222 let (request, respond_tx) = req_rx.next().await.unwrap();
1223
1224 assert_eq!(request.events.len(), 1);
1225 assert_eq!(
1226 request.events[0],
1227 predict_edits_v3::Event::BufferChange {
1228 path: Some(PathBuf::from(path!("root/foo.md"))),
1229 old_path: None,
1230 diff: indoc! {"
1231 @@ -1,3 +1,3 @@
1232 Hello!
1233 -
1234 +How
1235 Bye
1236 "}
1237 .to_string(),
1238 predicted: false
1239 }
1240 );
1241
1242 respond_tx
1243 .send(predict_edits_v3::PredictEditsResponse {
1244 request_id: Uuid::new_v4(),
1245 edits: vec![predict_edits_v3::Edit {
1246 path: Path::new(path!("root/foo.md")).into(),
1247 range: 0..snapshot.len(),
1248 content: "Hello!\nHow are you?\nBye".into(),
1249 }],
1250 debug_info: None,
1251 })
1252 .unwrap();
1253
1254 let prediction = prediction_task.await.unwrap().unwrap();
1255
1256 assert_eq!(prediction.edits.len(), 1);
1257 assert_eq!(
1258 prediction.edits[0].0.to_point(&snapshot).start,
1259 language::Point::new(1, 3)
1260 );
1261 assert_eq!(prediction.edits[0].1, " are you?");
1262 }
1263
1264 #[gpui::test]
1265 async fn test_request_diagnostics(cx: &mut TestAppContext) {
1266 let (zeta, mut req_rx) = init_test(cx);
1267 let fs = FakeFs::new(cx.executor());
1268 fs.insert_tree(
1269 "/root",
1270 json!({
1271 "foo.md": "Hello!\nBye"
1272 }),
1273 )
1274 .await;
1275 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1276
1277 let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1278 let diagnostic = lsp::Diagnostic {
1279 range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1280 severity: Some(lsp::DiagnosticSeverity::ERROR),
1281 message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1282 ..Default::default()
1283 };
1284
1285 project.update(cx, |project, cx| {
1286 project.lsp_store().update(cx, |lsp_store, cx| {
1287 // Create some diagnostics
1288 lsp_store
1289 .update_diagnostics(
1290 LanguageServerId(0),
1291 lsp::PublishDiagnosticsParams {
1292 uri: path_to_buffer_uri.clone(),
1293 diagnostics: vec![diagnostic],
1294 version: None,
1295 },
1296 None,
1297 language::DiagnosticSourceKind::Pushed,
1298 &[],
1299 cx,
1300 )
1301 .unwrap();
1302 });
1303 });
1304
1305 let buffer = project
1306 .update(cx, |project, cx| {
1307 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1308 project.open_buffer(path, cx)
1309 })
1310 .await
1311 .unwrap();
1312
1313 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1314 let position = snapshot.anchor_before(language::Point::new(0, 0));
1315
1316 let _prediction_task = zeta.update(cx, |zeta, cx| {
1317 zeta.request_prediction(&project, &buffer, position, cx)
1318 });
1319
1320 let (request, _respond_tx) = req_rx.next().await.unwrap();
1321
1322 assert_eq!(request.diagnostic_groups.len(), 1);
1323 let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1324 .unwrap();
1325 // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1326 assert_eq!(
1327 value,
1328 json!({
1329 "entries": [{
1330 "range": {
1331 "start": 8,
1332 "end": 10
1333 },
1334 "diagnostic": {
1335 "source": null,
1336 "code": null,
1337 "code_description": null,
1338 "severity": 1,
1339 "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1340 "markdown": null,
1341 "group_id": 0,
1342 "is_primary": true,
1343 "is_disk_based": false,
1344 "is_unnecessary": false,
1345 "source_kind": "Pushed",
1346 "data": null,
1347 "underline": true
1348 }
1349 }],
1350 "primary_ix": 0
1351 })
1352 );
1353 }
1354
1355 fn init_test(
1356 cx: &mut TestAppContext,
1357 ) -> (
1358 Entity<Zeta>,
1359 mpsc::UnboundedReceiver<(
1360 predict_edits_v3::PredictEditsRequest,
1361 oneshot::Sender<predict_edits_v3::PredictEditsResponse>,
1362 )>,
1363 ) {
1364 cx.update(move |cx| {
1365 let settings_store = SettingsStore::test(cx);
1366 cx.set_global(settings_store);
1367 language::init(cx);
1368 Project::init_settings(cx);
1369
1370 let (req_tx, req_rx) = mpsc::unbounded();
1371
1372 let http_client = FakeHttpClient::create({
1373 move |req| {
1374 let uri = req.uri().path().to_string();
1375 let mut body = req.into_body();
1376 let req_tx = req_tx.clone();
1377 async move {
1378 let resp = match uri.as_str() {
1379 "/client/llm_tokens" => serde_json::to_string(&json!({
1380 "token": "test"
1381 }))
1382 .unwrap(),
1383 "/predict_edits/v3" => {
1384 let mut buf = Vec::new();
1385 body.read_to_end(&mut buf).await.ok();
1386 let req = serde_json::from_slice(&buf).unwrap();
1387
1388 let (res_tx, res_rx) = oneshot::channel();
1389 req_tx.unbounded_send((req, res_tx)).unwrap();
1390 serde_json::to_string(&res_rx.await?).unwrap()
1391 }
1392 _ => {
1393 panic!("Unexpected path: {}", uri)
1394 }
1395 };
1396
1397 Ok(Response::builder().body(resp.into()).unwrap())
1398 }
1399 }
1400 });
1401
1402 let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1403 client.cloud_client().set_credentials(1, "test".into());
1404
1405 language_model::init(client.clone(), cx);
1406
1407 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1408 let zeta = Zeta::global(&client, &user_store, cx);
1409
1410 (zeta, req_rx)
1411 })
1412 }
1413}