1use anyhow::{Context as _, Result, anyhow};
2use chrono::TimeDelta;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
5use cloud_llm_client::{
6 EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, ZED_VERSION_HEADER_NAME,
7};
8use cloud_zeta2_prompt::{DEFAULT_MAX_PROMPT_BYTES, PlannedPrompt};
9use edit_prediction_context::{
10 DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
11 EditPredictionExcerptOptions, EditPredictionScoreOptions, SyntaxIndex, SyntaxIndexState,
12};
13use futures::AsyncReadExt as _;
14use futures::channel::{mpsc, oneshot};
15use gpui::http_client::Method;
16use gpui::{
17 App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
18 http_client, prelude::*,
19};
20use language::{Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
21use language::{BufferSnapshot, TextBufferSnapshot};
22use language_model::{LlmApiToken, RefreshLlmTokenListener};
23use project::Project;
24use release_channel::AppVersion;
25use std::collections::{HashMap, VecDeque, hash_map};
26use std::path::Path;
27use std::str::FromStr as _;
28use std::sync::Arc;
29use std::time::{Duration, Instant};
30use thiserror::Error;
31use util::rel_path::RelPathBuf;
32use util::some_or_debug_panic;
33use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
34
35mod prediction;
36mod provider;
37
38use crate::prediction::EditPrediction;
39pub use provider::ZetaEditPredictionProvider;
40
41const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
42
43/// Maximum number of events to track.
44const MAX_EVENT_COUNT: usize = 16;
45
46pub const DEFAULT_CONTEXT_OPTIONS: EditPredictionContextOptions = EditPredictionContextOptions {
47 use_imports: true,
48 excerpt: EditPredictionExcerptOptions {
49 max_bytes: 512,
50 min_bytes: 128,
51 target_before_cursor_over_total_bytes: 0.5,
52 },
53 score: EditPredictionScoreOptions {
54 omit_excerpt_overlaps: true,
55 prefilter_score_ratio: 0.5,
56 },
57};
58
59pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
60 context: DEFAULT_CONTEXT_OPTIONS,
61 max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
62 max_diagnostic_bytes: 2048,
63 prompt_format: PromptFormat::DEFAULT,
64 file_indexing_parallelism: 1,
65};
66
67#[derive(Clone)]
68struct ZetaGlobal(Entity<Zeta>);
69
70impl Global for ZetaGlobal {}
71
72pub struct Zeta {
73 client: Arc<Client>,
74 user_store: Entity<UserStore>,
75 llm_token: LlmApiToken,
76 _llm_token_subscription: Subscription,
77 projects: HashMap<EntityId, ZetaProject>,
78 options: ZetaOptions,
79 update_required: bool,
80 debug_tx: Option<mpsc::UnboundedSender<PredictionDebugInfo>>,
81}
82
83#[derive(Debug, Clone, PartialEq)]
84pub struct ZetaOptions {
85 pub context: EditPredictionContextOptions,
86 pub max_prompt_bytes: usize,
87 pub max_diagnostic_bytes: usize,
88 pub prompt_format: predict_edits_v3::PromptFormat,
89 pub file_indexing_parallelism: usize,
90}
91
92pub struct PredictionDebugInfo {
93 pub context: EditPredictionContext,
94 pub retrieval_time: TimeDelta,
95 pub buffer: WeakEntity<Buffer>,
96 pub position: language::Anchor,
97 pub local_prompt: Result<String, String>,
98 pub response_rx: oneshot::Receiver<Result<RequestDebugInfo, String>>,
99}
100
101pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
102
103struct ZetaProject {
104 syntax_index: Entity<SyntaxIndex>,
105 events: VecDeque<Event>,
106 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
107 current_prediction: Option<CurrentEditPrediction>,
108}
109
110#[derive(Clone)]
111struct CurrentEditPrediction {
112 pub requested_by_buffer_id: EntityId,
113 pub prediction: EditPrediction,
114}
115
116impl CurrentEditPrediction {
117 fn should_replace_prediction(
118 &self,
119 old_prediction: &Self,
120 snapshot: &TextBufferSnapshot,
121 ) -> bool {
122 if self.requested_by_buffer_id != old_prediction.requested_by_buffer_id {
123 return true;
124 }
125
126 let Some(old_edits) = old_prediction.prediction.interpolate(snapshot) else {
127 return true;
128 };
129
130 let Some(new_edits) = self.prediction.interpolate(snapshot) else {
131 return false;
132 };
133 if old_edits.len() == 1 && new_edits.len() == 1 {
134 let (old_range, old_text) = &old_edits[0];
135 let (new_range, new_text) = &new_edits[0];
136 new_range == old_range && new_text.starts_with(old_text)
137 } else {
138 true
139 }
140 }
141}
142
143/// A prediction from the perspective of a buffer.
144#[derive(Debug)]
145enum BufferEditPrediction<'a> {
146 Local { prediction: &'a EditPrediction },
147 Jump { prediction: &'a EditPrediction },
148}
149
150struct RegisteredBuffer {
151 snapshot: BufferSnapshot,
152 _subscriptions: [gpui::Subscription; 2],
153}
154
155#[derive(Clone)]
156pub enum Event {
157 BufferChange {
158 old_snapshot: BufferSnapshot,
159 new_snapshot: BufferSnapshot,
160 timestamp: Instant,
161 },
162}
163
164impl Zeta {
165 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
166 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
167 }
168
169 pub fn global(
170 client: &Arc<Client>,
171 user_store: &Entity<UserStore>,
172 cx: &mut App,
173 ) -> Entity<Self> {
174 cx.try_global::<ZetaGlobal>()
175 .map(|global| global.0.clone())
176 .unwrap_or_else(|| {
177 let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
178 cx.set_global(ZetaGlobal(zeta.clone()));
179 zeta
180 })
181 }
182
183 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
184 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
185
186 Self {
187 projects: HashMap::new(),
188 client,
189 user_store,
190 options: DEFAULT_OPTIONS,
191 llm_token: LlmApiToken::default(),
192 _llm_token_subscription: cx.subscribe(
193 &refresh_llm_token_listener,
194 |this, _listener, _event, cx| {
195 let client = this.client.clone();
196 let llm_token = this.llm_token.clone();
197 cx.spawn(async move |_this, _cx| {
198 llm_token.refresh(&client).await?;
199 anyhow::Ok(())
200 })
201 .detach_and_log_err(cx);
202 },
203 ),
204 update_required: false,
205 debug_tx: None,
206 }
207 }
208
209 pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<PredictionDebugInfo> {
210 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
211 self.debug_tx = Some(debug_watch_tx);
212 debug_watch_rx
213 }
214
215 pub fn options(&self) -> &ZetaOptions {
216 &self.options
217 }
218
219 pub fn set_options(&mut self, options: ZetaOptions) {
220 self.options = options;
221 }
222
223 pub fn clear_history(&mut self) {
224 for zeta_project in self.projects.values_mut() {
225 zeta_project.events.clear();
226 }
227 }
228
229 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
230 self.user_store.read(cx).edit_prediction_usage()
231 }
232
233 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
234 self.get_or_init_zeta_project(project, cx);
235 }
236
237 pub fn register_buffer(
238 &mut self,
239 buffer: &Entity<Buffer>,
240 project: &Entity<Project>,
241 cx: &mut Context<Self>,
242 ) {
243 let zeta_project = self.get_or_init_zeta_project(project, cx);
244 Self::register_buffer_impl(zeta_project, buffer, project, cx);
245 }
246
247 fn get_or_init_zeta_project(
248 &mut self,
249 project: &Entity<Project>,
250 cx: &mut App,
251 ) -> &mut ZetaProject {
252 self.projects
253 .entry(project.entity_id())
254 .or_insert_with(|| ZetaProject {
255 syntax_index: cx.new(|cx| {
256 SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
257 }),
258 events: VecDeque::new(),
259 registered_buffers: HashMap::new(),
260 current_prediction: None,
261 })
262 }
263
264 fn register_buffer_impl<'a>(
265 zeta_project: &'a mut ZetaProject,
266 buffer: &Entity<Buffer>,
267 project: &Entity<Project>,
268 cx: &mut Context<Self>,
269 ) -> &'a mut RegisteredBuffer {
270 let buffer_id = buffer.entity_id();
271 match zeta_project.registered_buffers.entry(buffer_id) {
272 hash_map::Entry::Occupied(entry) => entry.into_mut(),
273 hash_map::Entry::Vacant(entry) => {
274 let snapshot = buffer.read(cx).snapshot();
275 let project_entity_id = project.entity_id();
276 entry.insert(RegisteredBuffer {
277 snapshot,
278 _subscriptions: [
279 cx.subscribe(buffer, {
280 let project = project.downgrade();
281 move |this, buffer, event, cx| {
282 if let language::BufferEvent::Edited = event
283 && let Some(project) = project.upgrade()
284 {
285 this.report_changes_for_buffer(&buffer, &project, cx);
286 }
287 }
288 }),
289 cx.observe_release(buffer, move |this, _buffer, _cx| {
290 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
291 else {
292 return;
293 };
294 zeta_project.registered_buffers.remove(&buffer_id);
295 }),
296 ],
297 })
298 }
299 }
300 }
301
302 fn report_changes_for_buffer(
303 &mut self,
304 buffer: &Entity<Buffer>,
305 project: &Entity<Project>,
306 cx: &mut Context<Self>,
307 ) -> BufferSnapshot {
308 let zeta_project = self.get_or_init_zeta_project(project, cx);
309 let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
310
311 let new_snapshot = buffer.read(cx).snapshot();
312 if new_snapshot.version != registered_buffer.snapshot.version {
313 let old_snapshot =
314 std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
315 Self::push_event(
316 zeta_project,
317 Event::BufferChange {
318 old_snapshot,
319 new_snapshot: new_snapshot.clone(),
320 timestamp: Instant::now(),
321 },
322 );
323 }
324
325 new_snapshot
326 }
327
328 fn push_event(zeta_project: &mut ZetaProject, event: Event) {
329 let events = &mut zeta_project.events;
330
331 if let Some(Event::BufferChange {
332 new_snapshot: last_new_snapshot,
333 timestamp: last_timestamp,
334 ..
335 }) = events.back_mut()
336 {
337 // Coalesce edits for the same buffer when they happen one after the other.
338 let Event::BufferChange {
339 old_snapshot,
340 new_snapshot,
341 timestamp,
342 } = &event;
343
344 if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
345 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
346 && old_snapshot.version == last_new_snapshot.version
347 {
348 *last_new_snapshot = new_snapshot.clone();
349 *last_timestamp = *timestamp;
350 return;
351 }
352 }
353
354 if events.len() >= MAX_EVENT_COUNT {
355 // These are halved instead of popping to improve prompt caching.
356 events.drain(..MAX_EVENT_COUNT / 2);
357 }
358
359 events.push_back(event);
360 }
361
362 fn current_prediction_for_buffer(
363 &self,
364 buffer: &Entity<Buffer>,
365 project: &Entity<Project>,
366 cx: &App,
367 ) -> Option<BufferEditPrediction<'_>> {
368 let project_state = self.projects.get(&project.entity_id())?;
369
370 let CurrentEditPrediction {
371 requested_by_buffer_id,
372 prediction,
373 } = project_state.current_prediction.as_ref()?;
374
375 if prediction.targets_buffer(buffer.read(cx), cx) {
376 Some(BufferEditPrediction::Local { prediction })
377 } else if *requested_by_buffer_id == buffer.entity_id() {
378 Some(BufferEditPrediction::Jump { prediction })
379 } else {
380 None
381 }
382 }
383
384 fn accept_current_prediction(&mut self, project: &Entity<Project>) {
385 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
386 project_state.current_prediction.take();
387 };
388 // TODO report accepted
389 }
390
391 fn discard_current_prediction(&mut self, project: &Entity<Project>) {
392 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
393 project_state.current_prediction.take();
394 };
395 }
396
397 pub fn refresh_prediction(
398 &mut self,
399 project: &Entity<Project>,
400 buffer: &Entity<Buffer>,
401 position: language::Anchor,
402 cx: &mut Context<Self>,
403 ) -> Task<Result<()>> {
404 let request_task = self.request_prediction(project, buffer, position, cx);
405 let buffer = buffer.clone();
406 let project = project.clone();
407
408 cx.spawn(async move |this, cx| {
409 if let Some(prediction) = request_task.await? {
410 this.update(cx, |this, cx| {
411 let project_state = this
412 .projects
413 .get_mut(&project.entity_id())
414 .context("Project not found")?;
415
416 let new_prediction = CurrentEditPrediction {
417 requested_by_buffer_id: buffer.entity_id(),
418 prediction: prediction,
419 };
420
421 if project_state
422 .current_prediction
423 .as_ref()
424 .is_none_or(|old_prediction| {
425 new_prediction
426 .should_replace_prediction(&old_prediction, buffer.read(cx))
427 })
428 {
429 project_state.current_prediction = Some(new_prediction);
430 }
431 anyhow::Ok(())
432 })??;
433 }
434 Ok(())
435 })
436 }
437
438 fn request_prediction(
439 &mut self,
440 project: &Entity<Project>,
441 buffer: &Entity<Buffer>,
442 position: language::Anchor,
443 cx: &mut Context<Self>,
444 ) -> Task<Result<Option<EditPrediction>>> {
445 let project_state = self.projects.get(&project.entity_id());
446
447 let index_state = project_state.map(|state| {
448 state
449 .syntax_index
450 .read_with(cx, |index, _cx| index.state().clone())
451 });
452 let options = self.options.clone();
453 let snapshot = buffer.read(cx).snapshot();
454 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx).into()) else {
455 return Task::ready(Err(anyhow!("No file path for excerpt")));
456 };
457 let client = self.client.clone();
458 let llm_token = self.llm_token.clone();
459 let app_version = AppVersion::global(cx);
460 let worktree_snapshots = project
461 .read(cx)
462 .worktrees(cx)
463 .map(|worktree| worktree.read(cx).snapshot())
464 .collect::<Vec<_>>();
465 let debug_tx = self.debug_tx.clone();
466
467 let events = project_state
468 .map(|state| {
469 state
470 .events
471 .iter()
472 .filter_map(|event| match event {
473 Event::BufferChange {
474 old_snapshot,
475 new_snapshot,
476 ..
477 } => {
478 let path = new_snapshot.file().map(|f| f.full_path(cx));
479
480 let old_path = old_snapshot.file().and_then(|f| {
481 let old_path = f.full_path(cx);
482 if Some(&old_path) != path.as_ref() {
483 Some(old_path)
484 } else {
485 None
486 }
487 });
488
489 // TODO [zeta2] move to bg?
490 let diff =
491 language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
492
493 if path == old_path && diff.is_empty() {
494 None
495 } else {
496 Some(predict_edits_v3::Event::BufferChange {
497 old_path,
498 path,
499 diff,
500 //todo: Actually detect if this edit was predicted or not
501 predicted: false,
502 })
503 }
504 }
505 })
506 .collect::<Vec<_>>()
507 })
508 .unwrap_or_default();
509
510 let diagnostics = snapshot.diagnostic_sets().clone();
511
512 let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
513 let mut path = f.worktree.read(cx).absolutize(&f.path);
514 if path.pop() { Some(path) } else { None }
515 });
516
517 let request_task = cx.background_spawn({
518 let snapshot = snapshot.clone();
519 let buffer = buffer.clone();
520 async move {
521 let index_state = if let Some(index_state) = index_state {
522 Some(index_state.lock_owned().await)
523 } else {
524 None
525 };
526
527 let cursor_offset = position.to_offset(&snapshot);
528 let cursor_point = cursor_offset.to_point(&snapshot);
529
530 let before_retrieval = chrono::Utc::now();
531
532 let Some(context) = EditPredictionContext::gather_context(
533 cursor_point,
534 &snapshot,
535 parent_abs_path.as_deref(),
536 &options.context,
537 index_state.as_deref(),
538 ) else {
539 return Ok(None);
540 };
541
542 let retrieval_time = chrono::Utc::now() - before_retrieval;
543
544 let (diagnostic_groups, diagnostic_groups_truncated) =
545 Self::gather_nearby_diagnostics(
546 cursor_offset,
547 &diagnostics,
548 &snapshot,
549 options.max_diagnostic_bytes,
550 );
551
552 let debug_context = debug_tx.map(|tx| (tx, context.clone()));
553
554 let request = make_cloud_request(
555 excerpt_path,
556 context,
557 events,
558 // TODO data collection
559 false,
560 diagnostic_groups,
561 diagnostic_groups_truncated,
562 None,
563 debug_context.is_some(),
564 &worktree_snapshots,
565 index_state.as_deref(),
566 Some(options.max_prompt_bytes),
567 options.prompt_format,
568 );
569
570 let debug_response_tx = if let Some((debug_tx, context)) = debug_context {
571 let (response_tx, response_rx) = oneshot::channel();
572
573 let local_prompt = PlannedPrompt::populate(&request)
574 .and_then(|p| p.to_prompt_string().map(|p| p.0))
575 .map_err(|err| err.to_string());
576
577 debug_tx
578 .unbounded_send(PredictionDebugInfo {
579 context,
580 retrieval_time,
581 buffer: buffer.downgrade(),
582 local_prompt,
583 position,
584 response_rx,
585 })
586 .ok();
587 Some(response_tx)
588 } else {
589 None
590 };
591
592 if cfg!(debug_assertions) && std::env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
593 if let Some(debug_response_tx) = debug_response_tx {
594 debug_response_tx
595 .send(Err("Request skipped".to_string()))
596 .ok();
597 }
598 anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
599 }
600
601 let response = Self::perform_request(client, llm_token, app_version, request).await;
602
603 if let Some(debug_response_tx) = debug_response_tx {
604 debug_response_tx
605 .send(response.as_ref().map_err(|err| err.to_string()).and_then(
606 |response| match some_or_debug_panic(response.0.debug_info.clone()) {
607 Some(debug_info) => Ok(debug_info),
608 None => Err("Missing debug info".to_string()),
609 },
610 ))
611 .ok();
612 }
613
614 anyhow::Ok(Some(response?))
615 }
616 });
617
618 let buffer = buffer.clone();
619
620 cx.spawn({
621 let project = project.clone();
622 async move |this, cx| {
623 match request_task.await {
624 Ok(Some((response, usage))) => {
625 if let Some(usage) = usage {
626 this.update(cx, |this, cx| {
627 this.user_store.update(cx, |user_store, cx| {
628 user_store.update_edit_prediction_usage(usage, cx);
629 });
630 })
631 .ok();
632 }
633
634 let prediction = EditPrediction::from_response(
635 response, &snapshot, &buffer, &project, cx,
636 )
637 .await;
638
639 // TODO telemetry: duration, etc
640 Ok(prediction)
641 }
642 Ok(None) => Ok(None),
643 Err(err) => {
644 if err.is::<ZedUpdateRequiredError>() {
645 cx.update(|cx| {
646 this.update(cx, |this, _cx| {
647 this.update_required = true;
648 })
649 .ok();
650
651 let error_message: SharedString = err.to_string().into();
652 show_app_notification(
653 NotificationId::unique::<ZedUpdateRequiredError>(),
654 cx,
655 move |cx| {
656 cx.new(|cx| {
657 ErrorMessagePrompt::new(error_message.clone(), cx)
658 .with_link_button(
659 "Update Zed",
660 "https://zed.dev/releases",
661 )
662 })
663 },
664 );
665 })
666 .ok();
667 }
668
669 Err(err)
670 }
671 }
672 }
673 })
674 }
675
676 async fn perform_request(
677 client: Arc<Client>,
678 llm_token: LlmApiToken,
679 app_version: SemanticVersion,
680 request: predict_edits_v3::PredictEditsRequest,
681 ) -> Result<(
682 predict_edits_v3::PredictEditsResponse,
683 Option<EditPredictionUsage>,
684 )> {
685 let http_client = client.http_client();
686 let mut token = llm_token.acquire(&client).await?;
687 let mut did_retry = false;
688
689 loop {
690 let request_builder = http_client::Request::builder().method(Method::POST);
691 let request_builder =
692 if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
693 request_builder.uri(predict_edits_url)
694 } else {
695 request_builder.uri(
696 http_client
697 .build_zed_llm_url("/predict_edits/v3", &[])?
698 .as_ref(),
699 )
700 };
701 let request = request_builder
702 .header("Content-Type", "application/json")
703 .header("Authorization", format!("Bearer {}", token))
704 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
705 .body(serde_json::to_string(&request)?.into())?;
706
707 let mut response = http_client.send(request).await?;
708
709 if let Some(minimum_required_version) = response
710 .headers()
711 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
712 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
713 {
714 anyhow::ensure!(
715 app_version >= minimum_required_version,
716 ZedUpdateRequiredError {
717 minimum_version: minimum_required_version
718 }
719 );
720 }
721
722 if response.status().is_success() {
723 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
724
725 let mut body = Vec::new();
726 response.body_mut().read_to_end(&mut body).await?;
727 return Ok((serde_json::from_slice(&body)?, usage));
728 } else if !did_retry
729 && response
730 .headers()
731 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
732 .is_some()
733 {
734 did_retry = true;
735 token = llm_token.refresh(&client).await?;
736 } else {
737 let mut body = String::new();
738 response.body_mut().read_to_string(&mut body).await?;
739 anyhow::bail!(
740 "error predicting edits.\nStatus: {:?}\nBody: {}",
741 response.status(),
742 body
743 );
744 }
745 }
746 }
747
748 fn gather_nearby_diagnostics(
749 cursor_offset: usize,
750 diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
751 snapshot: &BufferSnapshot,
752 max_diagnostics_bytes: usize,
753 ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
754 // TODO: Could make this more efficient
755 let mut diagnostic_groups = Vec::new();
756 for (language_server_id, diagnostics) in diagnostic_sets {
757 let mut groups = Vec::new();
758 diagnostics.groups(*language_server_id, &mut groups, &snapshot);
759 diagnostic_groups.extend(
760 groups
761 .into_iter()
762 .map(|(_, group)| group.resolve::<usize>(&snapshot)),
763 );
764 }
765
766 // sort by proximity to cursor
767 diagnostic_groups.sort_by_key(|group| {
768 let range = &group.entries[group.primary_ix].range;
769 if range.start >= cursor_offset {
770 range.start - cursor_offset
771 } else if cursor_offset >= range.end {
772 cursor_offset - range.end
773 } else {
774 (cursor_offset - range.start).min(range.end - cursor_offset)
775 }
776 });
777
778 let mut results = Vec::new();
779 let mut diagnostic_groups_truncated = false;
780 let mut diagnostics_byte_count = 0;
781 for group in diagnostic_groups {
782 let raw_value = serde_json::value::to_raw_value(&group).unwrap();
783 diagnostics_byte_count += raw_value.get().len();
784 if diagnostics_byte_count > max_diagnostics_bytes {
785 diagnostic_groups_truncated = true;
786 break;
787 }
788 results.push(predict_edits_v3::DiagnosticGroup(raw_value));
789 }
790
791 (results, diagnostic_groups_truncated)
792 }
793
794 // TODO: Dedupe with similar code in request_prediction?
795 pub fn cloud_request_for_zeta_cli(
796 &mut self,
797 project: &Entity<Project>,
798 buffer: &Entity<Buffer>,
799 position: language::Anchor,
800 cx: &mut Context<Self>,
801 ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
802 let project_state = self.projects.get(&project.entity_id());
803
804 let index_state = project_state.map(|state| {
805 state
806 .syntax_index
807 .read_with(cx, |index, _cx| index.state().clone())
808 });
809 let options = self.options.clone();
810 let snapshot = buffer.read(cx).snapshot();
811 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
812 return Task::ready(Err(anyhow!("No file path for excerpt")));
813 };
814 let worktree_snapshots = project
815 .read(cx)
816 .worktrees(cx)
817 .map(|worktree| worktree.read(cx).snapshot())
818 .collect::<Vec<_>>();
819
820 let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
821 let mut path = f.worktree.read(cx).absolutize(&f.path);
822 if path.pop() { Some(path) } else { None }
823 });
824
825 cx.background_spawn(async move {
826 let index_state = if let Some(index_state) = index_state {
827 Some(index_state.lock_owned().await)
828 } else {
829 None
830 };
831
832 let cursor_point = position.to_point(&snapshot);
833
834 let debug_info = true;
835 EditPredictionContext::gather_context(
836 cursor_point,
837 &snapshot,
838 parent_abs_path.as_deref(),
839 &options.context,
840 index_state.as_deref(),
841 )
842 .context("Failed to select excerpt")
843 .map(|context| {
844 make_cloud_request(
845 excerpt_path.into(),
846 context,
847 // TODO pass everything
848 Vec::new(),
849 false,
850 Vec::new(),
851 false,
852 None,
853 debug_info,
854 &worktree_snapshots,
855 index_state.as_deref(),
856 Some(options.max_prompt_bytes),
857 options.prompt_format,
858 )
859 })
860 })
861 }
862
863 pub fn wait_for_initial_indexing(
864 &mut self,
865 project: &Entity<Project>,
866 cx: &mut App,
867 ) -> Task<Result<()>> {
868 let zeta_project = self.get_or_init_zeta_project(project, cx);
869 zeta_project
870 .syntax_index
871 .read(cx)
872 .wait_for_initial_file_indexing(cx)
873 }
874}
875
876#[derive(Error, Debug)]
877#[error(
878 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
879)]
880pub struct ZedUpdateRequiredError {
881 minimum_version: SemanticVersion,
882}
883
884fn make_cloud_request(
885 excerpt_path: Arc<Path>,
886 context: EditPredictionContext,
887 events: Vec<predict_edits_v3::Event>,
888 can_collect_data: bool,
889 diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
890 diagnostic_groups_truncated: bool,
891 git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
892 debug_info: bool,
893 worktrees: &Vec<worktree::Snapshot>,
894 index_state: Option<&SyntaxIndexState>,
895 prompt_max_bytes: Option<usize>,
896 prompt_format: PromptFormat,
897) -> predict_edits_v3::PredictEditsRequest {
898 let mut signatures = Vec::new();
899 let mut declaration_to_signature_index = HashMap::default();
900 let mut referenced_declarations = Vec::new();
901
902 for snippet in context.declarations {
903 let project_entry_id = snippet.declaration.project_entry_id();
904 let Some(path) = worktrees.iter().find_map(|worktree| {
905 worktree.entry_for_id(project_entry_id).map(|entry| {
906 let mut full_path = RelPathBuf::new();
907 full_path.push(worktree.root_name());
908 full_path.push(&entry.path);
909 full_path
910 })
911 }) else {
912 continue;
913 };
914
915 let parent_index = index_state.and_then(|index_state| {
916 snippet.declaration.parent().and_then(|parent| {
917 add_signature(
918 parent,
919 &mut declaration_to_signature_index,
920 &mut signatures,
921 index_state,
922 )
923 })
924 });
925
926 let (text, text_is_truncated) = snippet.declaration.item_text();
927 referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
928 path: path.as_std_path().into(),
929 text: text.into(),
930 range: snippet.declaration.item_range(),
931 text_is_truncated,
932 signature_range: snippet.declaration.signature_range_in_item_text(),
933 parent_index,
934 signature_score: snippet.score(DeclarationStyle::Signature),
935 declaration_score: snippet.score(DeclarationStyle::Declaration),
936 score_components: snippet.components,
937 });
938 }
939
940 let excerpt_parent = index_state.and_then(|index_state| {
941 context
942 .excerpt
943 .parent_declarations
944 .last()
945 .and_then(|(parent, _)| {
946 add_signature(
947 *parent,
948 &mut declaration_to_signature_index,
949 &mut signatures,
950 index_state,
951 )
952 })
953 });
954
955 predict_edits_v3::PredictEditsRequest {
956 excerpt_path,
957 excerpt: context.excerpt_text.body,
958 excerpt_range: context.excerpt.range,
959 cursor_offset: context.cursor_offset_in_excerpt,
960 referenced_declarations,
961 signatures,
962 excerpt_parent,
963 events,
964 can_collect_data,
965 diagnostic_groups,
966 diagnostic_groups_truncated,
967 git_info,
968 debug_info,
969 prompt_max_bytes,
970 prompt_format,
971 }
972}
973
974fn add_signature(
975 declaration_id: DeclarationId,
976 declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
977 signatures: &mut Vec<Signature>,
978 index: &SyntaxIndexState,
979) -> Option<usize> {
980 if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
981 return Some(*signature_index);
982 }
983 let Some(parent_declaration) = index.declaration(declaration_id) else {
984 log::error!("bug: missing parent declaration");
985 return None;
986 };
987 let parent_index = parent_declaration.parent().and_then(|parent| {
988 add_signature(parent, declaration_to_signature_index, signatures, index)
989 });
990 let (text, text_is_truncated) = parent_declaration.signature_text();
991 let signature_index = signatures.len();
992 signatures.push(Signature {
993 text: text.into(),
994 text_is_truncated,
995 parent_index,
996 range: parent_declaration.signature_range(),
997 });
998 declaration_to_signature_index.insert(declaration_id, signature_index);
999 Some(signature_index)
1000}
1001
1002#[cfg(test)]
1003mod tests {
1004 use std::{
1005 path::{Path, PathBuf},
1006 sync::Arc,
1007 };
1008
1009 use client::UserStore;
1010 use clock::FakeSystemClock;
1011 use cloud_llm_client::predict_edits_v3;
1012 use futures::{
1013 AsyncReadExt, StreamExt,
1014 channel::{mpsc, oneshot},
1015 };
1016 use gpui::{
1017 Entity, TestAppContext,
1018 http_client::{FakeHttpClient, Response},
1019 prelude::*,
1020 };
1021 use indoc::indoc;
1022 use language::{LanguageServerId, OffsetRangeExt as _};
1023 use pretty_assertions::{assert_eq, assert_matches};
1024 use project::{FakeFs, Project};
1025 use serde_json::json;
1026 use settings::SettingsStore;
1027 use util::path;
1028 use uuid::Uuid;
1029
1030 use crate::{BufferEditPrediction, Zeta};
1031
1032 #[gpui::test]
1033 async fn test_current_state(cx: &mut TestAppContext) {
1034 let (zeta, mut req_rx) = init_test(cx);
1035 let fs = FakeFs::new(cx.executor());
1036 fs.insert_tree(
1037 "/root",
1038 json!({
1039 "1.txt": "Hello!\nHow\nBye",
1040 "2.txt": "Hola!\nComo\nAdios"
1041 }),
1042 )
1043 .await;
1044 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1045
1046 zeta.update(cx, |zeta, cx| {
1047 zeta.register_project(&project, cx);
1048 });
1049
1050 let buffer1 = project
1051 .update(cx, |project, cx| {
1052 let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
1053 project.open_buffer(path, cx)
1054 })
1055 .await
1056 .unwrap();
1057 let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
1058 let position = snapshot1.anchor_before(language::Point::new(1, 3));
1059
1060 // Prediction for current file
1061
1062 let prediction_task = zeta.update(cx, |zeta, cx| {
1063 zeta.refresh_prediction(&project, &buffer1, position, cx)
1064 });
1065 let (_request, respond_tx) = req_rx.next().await.unwrap();
1066 respond_tx
1067 .send(predict_edits_v3::PredictEditsResponse {
1068 request_id: Uuid::new_v4(),
1069 edits: vec![predict_edits_v3::Edit {
1070 path: Path::new(path!("root/1.txt")).into(),
1071 range: 0..snapshot1.len(),
1072 content: "Hello!\nHow are you?\nBye".into(),
1073 }],
1074 debug_info: None,
1075 })
1076 .unwrap();
1077 prediction_task.await.unwrap();
1078
1079 zeta.read_with(cx, |zeta, cx| {
1080 let prediction = zeta
1081 .current_prediction_for_buffer(&buffer1, &project, cx)
1082 .unwrap();
1083 assert_matches!(prediction, BufferEditPrediction::Local { .. });
1084 });
1085
1086 // Prediction for another file
1087
1088 let prediction_task = zeta.update(cx, |zeta, cx| {
1089 zeta.refresh_prediction(&project, &buffer1, position, cx)
1090 });
1091 let (_request, respond_tx) = req_rx.next().await.unwrap();
1092 respond_tx
1093 .send(predict_edits_v3::PredictEditsResponse {
1094 request_id: Uuid::new_v4(),
1095 edits: vec![predict_edits_v3::Edit {
1096 path: Path::new(path!("root/2.txt")).into(),
1097 range: 0..snapshot1.len(),
1098 content: "Hola!\nComo estas?\nAdios".into(),
1099 }],
1100 debug_info: None,
1101 })
1102 .unwrap();
1103 prediction_task.await.unwrap();
1104
1105 zeta.read_with(cx, |zeta, cx| {
1106 let prediction = zeta
1107 .current_prediction_for_buffer(&buffer1, &project, cx)
1108 .unwrap();
1109 assert_matches!(
1110 prediction,
1111 BufferEditPrediction::Jump { prediction } if prediction.path.as_ref() == Path::new(path!("root/2.txt"))
1112 );
1113 });
1114
1115 let buffer2 = project
1116 .update(cx, |project, cx| {
1117 let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
1118 project.open_buffer(path, cx)
1119 })
1120 .await
1121 .unwrap();
1122
1123 zeta.read_with(cx, |zeta, cx| {
1124 let prediction = zeta
1125 .current_prediction_for_buffer(&buffer2, &project, cx)
1126 .unwrap();
1127 assert_matches!(prediction, BufferEditPrediction::Local { .. });
1128 });
1129 }
1130
1131 #[gpui::test]
1132 async fn test_simple_request(cx: &mut TestAppContext) {
1133 let (zeta, mut req_rx) = init_test(cx);
1134 let fs = FakeFs::new(cx.executor());
1135 fs.insert_tree(
1136 "/root",
1137 json!({
1138 "foo.md": "Hello!\nHow\nBye"
1139 }),
1140 )
1141 .await;
1142 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1143
1144 let buffer = project
1145 .update(cx, |project, cx| {
1146 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1147 project.open_buffer(path, cx)
1148 })
1149 .await
1150 .unwrap();
1151 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1152 let position = snapshot.anchor_before(language::Point::new(1, 3));
1153
1154 let prediction_task = zeta.update(cx, |zeta, cx| {
1155 zeta.request_prediction(&project, &buffer, position, cx)
1156 });
1157
1158 let (request, respond_tx) = req_rx.next().await.unwrap();
1159 assert_eq!(
1160 request.excerpt_path.as_ref(),
1161 Path::new(path!("root/foo.md"))
1162 );
1163 assert_eq!(request.cursor_offset, 10);
1164
1165 respond_tx
1166 .send(predict_edits_v3::PredictEditsResponse {
1167 request_id: Uuid::new_v4(),
1168 edits: vec![predict_edits_v3::Edit {
1169 path: Path::new(path!("root/foo.md")).into(),
1170 range: 0..snapshot.len(),
1171 content: "Hello!\nHow are you?\nBye".into(),
1172 }],
1173 debug_info: None,
1174 })
1175 .unwrap();
1176
1177 let prediction = prediction_task.await.unwrap().unwrap();
1178
1179 assert_eq!(prediction.edits.len(), 1);
1180 assert_eq!(
1181 prediction.edits[0].0.to_point(&snapshot).start,
1182 language::Point::new(1, 3)
1183 );
1184 assert_eq!(prediction.edits[0].1, " are you?");
1185 }
1186
1187 #[gpui::test]
1188 async fn test_request_events(cx: &mut TestAppContext) {
1189 let (zeta, mut req_rx) = init_test(cx);
1190 let fs = FakeFs::new(cx.executor());
1191 fs.insert_tree(
1192 "/root",
1193 json!({
1194 "foo.md": "Hello!\n\nBye"
1195 }),
1196 )
1197 .await;
1198 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1199
1200 let buffer = project
1201 .update(cx, |project, cx| {
1202 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1203 project.open_buffer(path, cx)
1204 })
1205 .await
1206 .unwrap();
1207
1208 zeta.update(cx, |zeta, cx| {
1209 zeta.register_buffer(&buffer, &project, cx);
1210 });
1211
1212 buffer.update(cx, |buffer, cx| {
1213 buffer.edit(vec![(7..7, "How")], None, cx);
1214 });
1215
1216 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1217 let position = snapshot.anchor_before(language::Point::new(1, 3));
1218
1219 let prediction_task = zeta.update(cx, |zeta, cx| {
1220 zeta.request_prediction(&project, &buffer, position, cx)
1221 });
1222
1223 let (request, respond_tx) = req_rx.next().await.unwrap();
1224
1225 assert_eq!(request.events.len(), 1);
1226 assert_eq!(
1227 request.events[0],
1228 predict_edits_v3::Event::BufferChange {
1229 path: Some(PathBuf::from(path!("root/foo.md"))),
1230 old_path: None,
1231 diff: indoc! {"
1232 @@ -1,3 +1,3 @@
1233 Hello!
1234 -
1235 +How
1236 Bye
1237 "}
1238 .to_string(),
1239 predicted: false
1240 }
1241 );
1242
1243 respond_tx
1244 .send(predict_edits_v3::PredictEditsResponse {
1245 request_id: Uuid::new_v4(),
1246 edits: vec![predict_edits_v3::Edit {
1247 path: Path::new(path!("root/foo.md")).into(),
1248 range: 0..snapshot.len(),
1249 content: "Hello!\nHow are you?\nBye".into(),
1250 }],
1251 debug_info: None,
1252 })
1253 .unwrap();
1254
1255 let prediction = prediction_task.await.unwrap().unwrap();
1256
1257 assert_eq!(prediction.edits.len(), 1);
1258 assert_eq!(
1259 prediction.edits[0].0.to_point(&snapshot).start,
1260 language::Point::new(1, 3)
1261 );
1262 assert_eq!(prediction.edits[0].1, " are you?");
1263 }
1264
1265 #[gpui::test]
1266 async fn test_request_diagnostics(cx: &mut TestAppContext) {
1267 let (zeta, mut req_rx) = init_test(cx);
1268 let fs = FakeFs::new(cx.executor());
1269 fs.insert_tree(
1270 "/root",
1271 json!({
1272 "foo.md": "Hello!\nBye"
1273 }),
1274 )
1275 .await;
1276 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1277
1278 let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1279 let diagnostic = lsp::Diagnostic {
1280 range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1281 severity: Some(lsp::DiagnosticSeverity::ERROR),
1282 message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1283 ..Default::default()
1284 };
1285
1286 project.update(cx, |project, cx| {
1287 project.lsp_store().update(cx, |lsp_store, cx| {
1288 // Create some diagnostics
1289 lsp_store
1290 .update_diagnostics(
1291 LanguageServerId(0),
1292 lsp::PublishDiagnosticsParams {
1293 uri: path_to_buffer_uri.clone(),
1294 diagnostics: vec![diagnostic],
1295 version: None,
1296 },
1297 None,
1298 language::DiagnosticSourceKind::Pushed,
1299 &[],
1300 cx,
1301 )
1302 .unwrap();
1303 });
1304 });
1305
1306 let buffer = project
1307 .update(cx, |project, cx| {
1308 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1309 project.open_buffer(path, cx)
1310 })
1311 .await
1312 .unwrap();
1313
1314 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1315 let position = snapshot.anchor_before(language::Point::new(0, 0));
1316
1317 let _prediction_task = zeta.update(cx, |zeta, cx| {
1318 zeta.request_prediction(&project, &buffer, position, cx)
1319 });
1320
1321 let (request, _respond_tx) = req_rx.next().await.unwrap();
1322
1323 assert_eq!(request.diagnostic_groups.len(), 1);
1324 let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1325 .unwrap();
1326 // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1327 assert_eq!(
1328 value,
1329 json!({
1330 "entries": [{
1331 "range": {
1332 "start": 8,
1333 "end": 10
1334 },
1335 "diagnostic": {
1336 "source": null,
1337 "code": null,
1338 "code_description": null,
1339 "severity": 1,
1340 "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1341 "markdown": null,
1342 "group_id": 0,
1343 "is_primary": true,
1344 "is_disk_based": false,
1345 "is_unnecessary": false,
1346 "source_kind": "Pushed",
1347 "data": null,
1348 "underline": true
1349 }
1350 }],
1351 "primary_ix": 0
1352 })
1353 );
1354 }
1355
1356 fn init_test(
1357 cx: &mut TestAppContext,
1358 ) -> (
1359 Entity<Zeta>,
1360 mpsc::UnboundedReceiver<(
1361 predict_edits_v3::PredictEditsRequest,
1362 oneshot::Sender<predict_edits_v3::PredictEditsResponse>,
1363 )>,
1364 ) {
1365 cx.update(move |cx| {
1366 let settings_store = SettingsStore::test(cx);
1367 cx.set_global(settings_store);
1368 language::init(cx);
1369 Project::init_settings(cx);
1370
1371 let (req_tx, req_rx) = mpsc::unbounded();
1372
1373 let http_client = FakeHttpClient::create({
1374 move |req| {
1375 let uri = req.uri().path().to_string();
1376 let mut body = req.into_body();
1377 let req_tx = req_tx.clone();
1378 async move {
1379 let resp = match uri.as_str() {
1380 "/client/llm_tokens" => serde_json::to_string(&json!({
1381 "token": "test"
1382 }))
1383 .unwrap(),
1384 "/predict_edits/v3" => {
1385 let mut buf = Vec::new();
1386 body.read_to_end(&mut buf).await.ok();
1387 let req = serde_json::from_slice(&buf).unwrap();
1388
1389 let (res_tx, res_rx) = oneshot::channel();
1390 req_tx.unbounded_send((req, res_tx)).unwrap();
1391 serde_json::to_string(&res_rx.await.unwrap()).unwrap()
1392 }
1393 _ => {
1394 panic!("Unexpected path: {}", uri)
1395 }
1396 };
1397
1398 Ok(Response::builder().body(resp.into()).unwrap())
1399 }
1400 }
1401 });
1402
1403 let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1404 client.cloud_client().set_credentials(1, "test".into());
1405
1406 language_model::init(client.clone(), cx);
1407
1408 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1409 let zeta = Zeta::global(&client, &user_store, cx);
1410
1411 (zeta, req_rx)
1412 })
1413 }
1414}