1use anyhow::{Context as _, Result, anyhow};
2use chrono::TimeDelta;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
5use cloud_llm_client::{
6 EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, ZED_VERSION_HEADER_NAME,
7};
8use cloud_zeta2_prompt::{DEFAULT_MAX_PROMPT_BYTES, PlannedPrompt};
9use edit_prediction_context::{
10 DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
11 EditPredictionExcerptOptions, EditPredictionScoreOptions, SyntaxIndex, SyntaxIndexState,
12};
13use futures::AsyncReadExt as _;
14use futures::channel::{mpsc, oneshot};
15use gpui::http_client::Method;
16use gpui::{
17 App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
18 http_client, prelude::*,
19};
20use language::BufferSnapshot;
21use language::{Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
22use language_model::{LlmApiToken, RefreshLlmTokenListener};
23use project::Project;
24use release_channel::AppVersion;
25use std::collections::{HashMap, VecDeque, hash_map};
26use std::path::Path;
27use std::str::FromStr as _;
28use std::sync::Arc;
29use std::time::{Duration, Instant};
30use thiserror::Error;
31use util::rel_path::RelPathBuf;
32use util::some_or_debug_panic;
33use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
34
35mod prediction;
36mod provider;
37
38use crate::prediction::EditPrediction;
39pub use provider::ZetaEditPredictionProvider;
40
41const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
42
43/// Maximum number of events to track.
44const MAX_EVENT_COUNT: usize = 16;
45
46pub const DEFAULT_CONTEXT_OPTIONS: EditPredictionContextOptions = EditPredictionContextOptions {
47 use_imports: true,
48 excerpt: EditPredictionExcerptOptions {
49 max_bytes: 512,
50 min_bytes: 128,
51 target_before_cursor_over_total_bytes: 0.5,
52 },
53 score: EditPredictionScoreOptions {
54 omit_excerpt_overlaps: true,
55 },
56};
57
58pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
59 context: DEFAULT_CONTEXT_OPTIONS,
60 max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
61 max_diagnostic_bytes: 2048,
62 prompt_format: PromptFormat::DEFAULT,
63 file_indexing_parallelism: 1,
64};
65
66#[derive(Clone)]
67struct ZetaGlobal(Entity<Zeta>);
68
69impl Global for ZetaGlobal {}
70
71pub struct Zeta {
72 client: Arc<Client>,
73 user_store: Entity<UserStore>,
74 llm_token: LlmApiToken,
75 _llm_token_subscription: Subscription,
76 projects: HashMap<EntityId, ZetaProject>,
77 options: ZetaOptions,
78 update_required: bool,
79 debug_tx: Option<mpsc::UnboundedSender<PredictionDebugInfo>>,
80}
81
82#[derive(Debug, Clone, PartialEq)]
83pub struct ZetaOptions {
84 pub context: EditPredictionContextOptions,
85 pub max_prompt_bytes: usize,
86 pub max_diagnostic_bytes: usize,
87 pub prompt_format: predict_edits_v3::PromptFormat,
88 pub file_indexing_parallelism: usize,
89}
90
91pub struct PredictionDebugInfo {
92 pub context: EditPredictionContext,
93 pub retrieval_time: TimeDelta,
94 pub buffer: WeakEntity<Buffer>,
95 pub position: language::Anchor,
96 pub local_prompt: Result<String, String>,
97 pub response_rx: oneshot::Receiver<Result<RequestDebugInfo, String>>,
98}
99
100pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
101
102struct ZetaProject {
103 syntax_index: Entity<SyntaxIndex>,
104 events: VecDeque<Event>,
105 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
106 current_prediction: Option<CurrentEditPrediction>,
107}
108
109#[derive(Debug, Clone)]
110struct CurrentEditPrediction {
111 pub requested_by_buffer_id: EntityId,
112 pub prediction: EditPrediction,
113}
114
115impl CurrentEditPrediction {
116 fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
117 let Some(new_edits) = self
118 .prediction
119 .interpolate(&self.prediction.buffer.read(cx))
120 else {
121 return false;
122 };
123
124 if self.prediction.buffer != old_prediction.prediction.buffer {
125 return true;
126 }
127
128 let Some(old_edits) = old_prediction
129 .prediction
130 .interpolate(&old_prediction.prediction.buffer.read(cx))
131 else {
132 return true;
133 };
134
135 // This reduces the occurrence of UI thrash from replacing edits
136 //
137 // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
138 if self.requested_by_buffer_id == self.prediction.buffer.entity_id()
139 && self.requested_by_buffer_id == old_prediction.prediction.buffer.entity_id()
140 && old_edits.len() == 1
141 && new_edits.len() == 1
142 {
143 let (old_range, old_text) = &old_edits[0];
144 let (new_range, new_text) = &new_edits[0];
145 new_range == old_range && new_text.starts_with(old_text)
146 } else {
147 true
148 }
149 }
150}
151
152/// A prediction from the perspective of a buffer.
153#[derive(Debug)]
154enum BufferEditPrediction<'a> {
155 Local { prediction: &'a EditPrediction },
156 Jump { prediction: &'a EditPrediction },
157}
158
159struct RegisteredBuffer {
160 snapshot: BufferSnapshot,
161 _subscriptions: [gpui::Subscription; 2],
162}
163
164#[derive(Clone)]
165pub enum Event {
166 BufferChange {
167 old_snapshot: BufferSnapshot,
168 new_snapshot: BufferSnapshot,
169 timestamp: Instant,
170 },
171}
172
173impl Zeta {
174 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
175 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
176 }
177
178 pub fn global(
179 client: &Arc<Client>,
180 user_store: &Entity<UserStore>,
181 cx: &mut App,
182 ) -> Entity<Self> {
183 cx.try_global::<ZetaGlobal>()
184 .map(|global| global.0.clone())
185 .unwrap_or_else(|| {
186 let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
187 cx.set_global(ZetaGlobal(zeta.clone()));
188 zeta
189 })
190 }
191
192 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
193 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
194
195 Self {
196 projects: HashMap::new(),
197 client,
198 user_store,
199 options: DEFAULT_OPTIONS,
200 llm_token: LlmApiToken::default(),
201 _llm_token_subscription: cx.subscribe(
202 &refresh_llm_token_listener,
203 |this, _listener, _event, cx| {
204 let client = this.client.clone();
205 let llm_token = this.llm_token.clone();
206 cx.spawn(async move |_this, _cx| {
207 llm_token.refresh(&client).await?;
208 anyhow::Ok(())
209 })
210 .detach_and_log_err(cx);
211 },
212 ),
213 update_required: false,
214 debug_tx: None,
215 }
216 }
217
218 pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<PredictionDebugInfo> {
219 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
220 self.debug_tx = Some(debug_watch_tx);
221 debug_watch_rx
222 }
223
224 pub fn options(&self) -> &ZetaOptions {
225 &self.options
226 }
227
228 pub fn set_options(&mut self, options: ZetaOptions) {
229 self.options = options;
230 }
231
232 pub fn clear_history(&mut self) {
233 for zeta_project in self.projects.values_mut() {
234 zeta_project.events.clear();
235 }
236 }
237
238 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
239 self.user_store.read(cx).edit_prediction_usage()
240 }
241
242 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
243 self.get_or_init_zeta_project(project, cx);
244 }
245
246 pub fn register_buffer(
247 &mut self,
248 buffer: &Entity<Buffer>,
249 project: &Entity<Project>,
250 cx: &mut Context<Self>,
251 ) {
252 let zeta_project = self.get_or_init_zeta_project(project, cx);
253 Self::register_buffer_impl(zeta_project, buffer, project, cx);
254 }
255
256 fn get_or_init_zeta_project(
257 &mut self,
258 project: &Entity<Project>,
259 cx: &mut App,
260 ) -> &mut ZetaProject {
261 self.projects
262 .entry(project.entity_id())
263 .or_insert_with(|| ZetaProject {
264 syntax_index: cx.new(|cx| {
265 SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
266 }),
267 events: VecDeque::new(),
268 registered_buffers: HashMap::new(),
269 current_prediction: None,
270 })
271 }
272
273 fn register_buffer_impl<'a>(
274 zeta_project: &'a mut ZetaProject,
275 buffer: &Entity<Buffer>,
276 project: &Entity<Project>,
277 cx: &mut Context<Self>,
278 ) -> &'a mut RegisteredBuffer {
279 let buffer_id = buffer.entity_id();
280 match zeta_project.registered_buffers.entry(buffer_id) {
281 hash_map::Entry::Occupied(entry) => entry.into_mut(),
282 hash_map::Entry::Vacant(entry) => {
283 let snapshot = buffer.read(cx).snapshot();
284 let project_entity_id = project.entity_id();
285 entry.insert(RegisteredBuffer {
286 snapshot,
287 _subscriptions: [
288 cx.subscribe(buffer, {
289 let project = project.downgrade();
290 move |this, buffer, event, cx| {
291 if let language::BufferEvent::Edited = event
292 && let Some(project) = project.upgrade()
293 {
294 this.report_changes_for_buffer(&buffer, &project, cx);
295 }
296 }
297 }),
298 cx.observe_release(buffer, move |this, _buffer, _cx| {
299 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
300 else {
301 return;
302 };
303 zeta_project.registered_buffers.remove(&buffer_id);
304 }),
305 ],
306 })
307 }
308 }
309 }
310
311 fn report_changes_for_buffer(
312 &mut self,
313 buffer: &Entity<Buffer>,
314 project: &Entity<Project>,
315 cx: &mut Context<Self>,
316 ) -> BufferSnapshot {
317 let zeta_project = self.get_or_init_zeta_project(project, cx);
318 let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
319
320 let new_snapshot = buffer.read(cx).snapshot();
321 if new_snapshot.version != registered_buffer.snapshot.version {
322 let old_snapshot =
323 std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
324 Self::push_event(
325 zeta_project,
326 Event::BufferChange {
327 old_snapshot,
328 new_snapshot: new_snapshot.clone(),
329 timestamp: Instant::now(),
330 },
331 );
332 }
333
334 new_snapshot
335 }
336
337 fn push_event(zeta_project: &mut ZetaProject, event: Event) {
338 let events = &mut zeta_project.events;
339
340 if let Some(Event::BufferChange {
341 new_snapshot: last_new_snapshot,
342 timestamp: last_timestamp,
343 ..
344 }) = events.back_mut()
345 {
346 // Coalesce edits for the same buffer when they happen one after the other.
347 let Event::BufferChange {
348 old_snapshot,
349 new_snapshot,
350 timestamp,
351 } = &event;
352
353 if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
354 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
355 && old_snapshot.version == last_new_snapshot.version
356 {
357 *last_new_snapshot = new_snapshot.clone();
358 *last_timestamp = *timestamp;
359 return;
360 }
361 }
362
363 if events.len() >= MAX_EVENT_COUNT {
364 // These are halved instead of popping to improve prompt caching.
365 events.drain(..MAX_EVENT_COUNT / 2);
366 }
367
368 events.push_back(event);
369 }
370
371 fn current_prediction_for_buffer(
372 &self,
373 buffer: &Entity<Buffer>,
374 project: &Entity<Project>,
375 cx: &App,
376 ) -> Option<BufferEditPrediction<'_>> {
377 let project_state = self.projects.get(&project.entity_id())?;
378
379 let CurrentEditPrediction {
380 requested_by_buffer_id,
381 prediction,
382 } = project_state.current_prediction.as_ref()?;
383
384 if prediction.targets_buffer(buffer.read(cx), cx) {
385 Some(BufferEditPrediction::Local { prediction })
386 } else if *requested_by_buffer_id == buffer.entity_id() {
387 Some(BufferEditPrediction::Jump { prediction })
388 } else {
389 None
390 }
391 }
392
393 fn accept_current_prediction(&mut self, project: &Entity<Project>) {
394 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
395 project_state.current_prediction.take();
396 };
397 // TODO report accepted
398 }
399
400 fn discard_current_prediction(&mut self, project: &Entity<Project>) {
401 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
402 project_state.current_prediction.take();
403 };
404 }
405
406 pub fn refresh_prediction(
407 &mut self,
408 project: &Entity<Project>,
409 buffer: &Entity<Buffer>,
410 position: language::Anchor,
411 cx: &mut Context<Self>,
412 ) -> Task<Result<()>> {
413 let request_task = self.request_prediction(project, buffer, position, cx);
414 let buffer = buffer.clone();
415 let project = project.clone();
416
417 cx.spawn(async move |this, cx| {
418 if let Some(prediction) = request_task.await? {
419 this.update(cx, |this, cx| {
420 let project_state = this
421 .projects
422 .get_mut(&project.entity_id())
423 .context("Project not found")?;
424
425 let new_prediction = CurrentEditPrediction {
426 requested_by_buffer_id: buffer.entity_id(),
427 prediction: prediction,
428 };
429
430 if project_state
431 .current_prediction
432 .as_ref()
433 .is_none_or(|old_prediction| {
434 new_prediction.should_replace_prediction(&old_prediction, cx)
435 })
436 {
437 project_state.current_prediction = Some(new_prediction);
438 }
439 anyhow::Ok(())
440 })??;
441 }
442 Ok(())
443 })
444 }
445
446 fn request_prediction(
447 &mut self,
448 project: &Entity<Project>,
449 buffer: &Entity<Buffer>,
450 position: language::Anchor,
451 cx: &mut Context<Self>,
452 ) -> Task<Result<Option<EditPrediction>>> {
453 let project_state = self.projects.get(&project.entity_id());
454
455 let index_state = project_state.map(|state| {
456 state
457 .syntax_index
458 .read_with(cx, |index, _cx| index.state().clone())
459 });
460 let options = self.options.clone();
461 let snapshot = buffer.read(cx).snapshot();
462 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx).into()) else {
463 return Task::ready(Err(anyhow!("No file path for excerpt")));
464 };
465 let client = self.client.clone();
466 let llm_token = self.llm_token.clone();
467 let app_version = AppVersion::global(cx);
468 let worktree_snapshots = project
469 .read(cx)
470 .worktrees(cx)
471 .map(|worktree| worktree.read(cx).snapshot())
472 .collect::<Vec<_>>();
473 let debug_tx = self.debug_tx.clone();
474
475 let events = project_state
476 .map(|state| {
477 state
478 .events
479 .iter()
480 .filter_map(|event| match event {
481 Event::BufferChange {
482 old_snapshot,
483 new_snapshot,
484 ..
485 } => {
486 let path = new_snapshot.file().map(|f| f.full_path(cx));
487
488 let old_path = old_snapshot.file().and_then(|f| {
489 let old_path = f.full_path(cx);
490 if Some(&old_path) != path.as_ref() {
491 Some(old_path)
492 } else {
493 None
494 }
495 });
496
497 // TODO [zeta2] move to bg?
498 let diff =
499 language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
500
501 if path == old_path && diff.is_empty() {
502 None
503 } else {
504 Some(predict_edits_v3::Event::BufferChange {
505 old_path,
506 path,
507 diff,
508 //todo: Actually detect if this edit was predicted or not
509 predicted: false,
510 })
511 }
512 }
513 })
514 .collect::<Vec<_>>()
515 })
516 .unwrap_or_default();
517
518 let diagnostics = snapshot.diagnostic_sets().clone();
519
520 let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
521 let mut path = f.worktree.read(cx).absolutize(&f.path);
522 if path.pop() { Some(path) } else { None }
523 });
524
525 let request_task = cx.background_spawn({
526 let snapshot = snapshot.clone();
527 let buffer = buffer.clone();
528 async move {
529 let index_state = if let Some(index_state) = index_state {
530 Some(index_state.lock_owned().await)
531 } else {
532 None
533 };
534
535 let cursor_offset = position.to_offset(&snapshot);
536 let cursor_point = cursor_offset.to_point(&snapshot);
537
538 let before_retrieval = chrono::Utc::now();
539
540 let Some(context) = EditPredictionContext::gather_context(
541 cursor_point,
542 &snapshot,
543 parent_abs_path.as_deref(),
544 &options.context,
545 index_state.as_deref(),
546 ) else {
547 return Ok(None);
548 };
549
550 let retrieval_time = chrono::Utc::now() - before_retrieval;
551
552 let (diagnostic_groups, diagnostic_groups_truncated) =
553 Self::gather_nearby_diagnostics(
554 cursor_offset,
555 &diagnostics,
556 &snapshot,
557 options.max_diagnostic_bytes,
558 );
559
560 let debug_context = debug_tx.map(|tx| (tx, context.clone()));
561
562 let request = make_cloud_request(
563 excerpt_path,
564 context,
565 events,
566 // TODO data collection
567 false,
568 diagnostic_groups,
569 diagnostic_groups_truncated,
570 None,
571 debug_context.is_some(),
572 &worktree_snapshots,
573 index_state.as_deref(),
574 Some(options.max_prompt_bytes),
575 options.prompt_format,
576 );
577
578 let debug_response_tx = if let Some((debug_tx, context)) = debug_context {
579 let (response_tx, response_rx) = oneshot::channel();
580
581 let local_prompt = PlannedPrompt::populate(&request)
582 .and_then(|p| p.to_prompt_string().map(|p| p.0))
583 .map_err(|err| err.to_string());
584
585 debug_tx
586 .unbounded_send(PredictionDebugInfo {
587 context,
588 retrieval_time,
589 buffer: buffer.downgrade(),
590 local_prompt,
591 position,
592 response_rx,
593 })
594 .ok();
595 Some(response_tx)
596 } else {
597 None
598 };
599
600 if cfg!(debug_assertions) && std::env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
601 if let Some(debug_response_tx) = debug_response_tx {
602 debug_response_tx
603 .send(Err("Request skipped".to_string()))
604 .ok();
605 }
606 anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
607 }
608
609 let response = Self::perform_request(client, llm_token, app_version, request).await;
610
611 if let Some(debug_response_tx) = debug_response_tx {
612 debug_response_tx
613 .send(response.as_ref().map_err(|err| err.to_string()).and_then(
614 |response| match some_or_debug_panic(response.0.debug_info.clone()) {
615 Some(debug_info) => Ok(debug_info),
616 None => Err("Missing debug info".to_string()),
617 },
618 ))
619 .ok();
620 }
621
622 anyhow::Ok(Some(response?))
623 }
624 });
625
626 let buffer = buffer.clone();
627
628 cx.spawn({
629 let project = project.clone();
630 async move |this, cx| {
631 match request_task.await {
632 Ok(Some((response, usage))) => {
633 if let Some(usage) = usage {
634 this.update(cx, |this, cx| {
635 this.user_store.update(cx, |user_store, cx| {
636 user_store.update_edit_prediction_usage(usage, cx);
637 });
638 })
639 .ok();
640 }
641
642 let prediction = EditPrediction::from_response(
643 response, &snapshot, &buffer, &project, cx,
644 )
645 .await;
646
647 // TODO telemetry: duration, etc
648 Ok(prediction)
649 }
650 Ok(None) => Ok(None),
651 Err(err) => {
652 if err.is::<ZedUpdateRequiredError>() {
653 cx.update(|cx| {
654 this.update(cx, |this, _cx| {
655 this.update_required = true;
656 })
657 .ok();
658
659 let error_message: SharedString = err.to_string().into();
660 show_app_notification(
661 NotificationId::unique::<ZedUpdateRequiredError>(),
662 cx,
663 move |cx| {
664 cx.new(|cx| {
665 ErrorMessagePrompt::new(error_message.clone(), cx)
666 .with_link_button(
667 "Update Zed",
668 "https://zed.dev/releases",
669 )
670 })
671 },
672 );
673 })
674 .ok();
675 }
676
677 Err(err)
678 }
679 }
680 }
681 })
682 }
683
684 async fn perform_request(
685 client: Arc<Client>,
686 llm_token: LlmApiToken,
687 app_version: SemanticVersion,
688 request: predict_edits_v3::PredictEditsRequest,
689 ) -> Result<(
690 predict_edits_v3::PredictEditsResponse,
691 Option<EditPredictionUsage>,
692 )> {
693 let http_client = client.http_client();
694 let mut token = llm_token.acquire(&client).await?;
695 let mut did_retry = false;
696
697 loop {
698 let request_builder = http_client::Request::builder().method(Method::POST);
699 let request_builder =
700 if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
701 request_builder.uri(predict_edits_url)
702 } else {
703 request_builder.uri(
704 http_client
705 .build_zed_llm_url("/predict_edits/v3", &[])?
706 .as_ref(),
707 )
708 };
709 let request = request_builder
710 .header("Content-Type", "application/json")
711 .header("Authorization", format!("Bearer {}", token))
712 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
713 .body(serde_json::to_string(&request)?.into())?;
714
715 let mut response = http_client.send(request).await?;
716
717 if let Some(minimum_required_version) = response
718 .headers()
719 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
720 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
721 {
722 anyhow::ensure!(
723 app_version >= minimum_required_version,
724 ZedUpdateRequiredError {
725 minimum_version: minimum_required_version
726 }
727 );
728 }
729
730 if response.status().is_success() {
731 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
732
733 let mut body = Vec::new();
734 response.body_mut().read_to_end(&mut body).await?;
735 return Ok((serde_json::from_slice(&body)?, usage));
736 } else if !did_retry
737 && response
738 .headers()
739 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
740 .is_some()
741 {
742 did_retry = true;
743 token = llm_token.refresh(&client).await?;
744 } else {
745 let mut body = String::new();
746 response.body_mut().read_to_string(&mut body).await?;
747 anyhow::bail!(
748 "error predicting edits.\nStatus: {:?}\nBody: {}",
749 response.status(),
750 body
751 );
752 }
753 }
754 }
755
756 fn gather_nearby_diagnostics(
757 cursor_offset: usize,
758 diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
759 snapshot: &BufferSnapshot,
760 max_diagnostics_bytes: usize,
761 ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
762 // TODO: Could make this more efficient
763 let mut diagnostic_groups = Vec::new();
764 for (language_server_id, diagnostics) in diagnostic_sets {
765 let mut groups = Vec::new();
766 diagnostics.groups(*language_server_id, &mut groups, &snapshot);
767 diagnostic_groups.extend(
768 groups
769 .into_iter()
770 .map(|(_, group)| group.resolve::<usize>(&snapshot)),
771 );
772 }
773
774 // sort by proximity to cursor
775 diagnostic_groups.sort_by_key(|group| {
776 let range = &group.entries[group.primary_ix].range;
777 if range.start >= cursor_offset {
778 range.start - cursor_offset
779 } else if cursor_offset >= range.end {
780 cursor_offset - range.end
781 } else {
782 (cursor_offset - range.start).min(range.end - cursor_offset)
783 }
784 });
785
786 let mut results = Vec::new();
787 let mut diagnostic_groups_truncated = false;
788 let mut diagnostics_byte_count = 0;
789 for group in diagnostic_groups {
790 let raw_value = serde_json::value::to_raw_value(&group).unwrap();
791 diagnostics_byte_count += raw_value.get().len();
792 if diagnostics_byte_count > max_diagnostics_bytes {
793 diagnostic_groups_truncated = true;
794 break;
795 }
796 results.push(predict_edits_v3::DiagnosticGroup(raw_value));
797 }
798
799 (results, diagnostic_groups_truncated)
800 }
801
802 // TODO: Dedupe with similar code in request_prediction?
803 pub fn cloud_request_for_zeta_cli(
804 &mut self,
805 project: &Entity<Project>,
806 buffer: &Entity<Buffer>,
807 position: language::Anchor,
808 cx: &mut Context<Self>,
809 ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
810 let project_state = self.projects.get(&project.entity_id());
811
812 let index_state = project_state.map(|state| {
813 state
814 .syntax_index
815 .read_with(cx, |index, _cx| index.state().clone())
816 });
817 let options = self.options.clone();
818 let snapshot = buffer.read(cx).snapshot();
819 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
820 return Task::ready(Err(anyhow!("No file path for excerpt")));
821 };
822 let worktree_snapshots = project
823 .read(cx)
824 .worktrees(cx)
825 .map(|worktree| worktree.read(cx).snapshot())
826 .collect::<Vec<_>>();
827
828 let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
829 let mut path = f.worktree.read(cx).absolutize(&f.path);
830 if path.pop() { Some(path) } else { None }
831 });
832
833 cx.background_spawn(async move {
834 let index_state = if let Some(index_state) = index_state {
835 Some(index_state.lock_owned().await)
836 } else {
837 None
838 };
839
840 let cursor_point = position.to_point(&snapshot);
841
842 let debug_info = true;
843 EditPredictionContext::gather_context(
844 cursor_point,
845 &snapshot,
846 parent_abs_path.as_deref(),
847 &options.context,
848 index_state.as_deref(),
849 )
850 .context("Failed to select excerpt")
851 .map(|context| {
852 make_cloud_request(
853 excerpt_path.into(),
854 context,
855 // TODO pass everything
856 Vec::new(),
857 false,
858 Vec::new(),
859 false,
860 None,
861 debug_info,
862 &worktree_snapshots,
863 index_state.as_deref(),
864 Some(options.max_prompt_bytes),
865 options.prompt_format,
866 )
867 })
868 })
869 }
870
871 pub fn wait_for_initial_indexing(
872 &mut self,
873 project: &Entity<Project>,
874 cx: &mut App,
875 ) -> Task<Result<()>> {
876 let zeta_project = self.get_or_init_zeta_project(project, cx);
877 zeta_project
878 .syntax_index
879 .read(cx)
880 .wait_for_initial_file_indexing(cx)
881 }
882}
883
884#[derive(Error, Debug)]
885#[error(
886 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
887)]
888pub struct ZedUpdateRequiredError {
889 minimum_version: SemanticVersion,
890}
891
892fn make_cloud_request(
893 excerpt_path: Arc<Path>,
894 context: EditPredictionContext,
895 events: Vec<predict_edits_v3::Event>,
896 can_collect_data: bool,
897 diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
898 diagnostic_groups_truncated: bool,
899 git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
900 debug_info: bool,
901 worktrees: &Vec<worktree::Snapshot>,
902 index_state: Option<&SyntaxIndexState>,
903 prompt_max_bytes: Option<usize>,
904 prompt_format: PromptFormat,
905) -> predict_edits_v3::PredictEditsRequest {
906 let mut signatures = Vec::new();
907 let mut declaration_to_signature_index = HashMap::default();
908 let mut referenced_declarations = Vec::new();
909
910 for snippet in context.declarations {
911 let project_entry_id = snippet.declaration.project_entry_id();
912 let Some(path) = worktrees.iter().find_map(|worktree| {
913 worktree.entry_for_id(project_entry_id).map(|entry| {
914 let mut full_path = RelPathBuf::new();
915 full_path.push(worktree.root_name());
916 full_path.push(&entry.path);
917 full_path
918 })
919 }) else {
920 continue;
921 };
922
923 let parent_index = index_state.and_then(|index_state| {
924 snippet.declaration.parent().and_then(|parent| {
925 add_signature(
926 parent,
927 &mut declaration_to_signature_index,
928 &mut signatures,
929 index_state,
930 )
931 })
932 });
933
934 let (text, text_is_truncated) = snippet.declaration.item_text();
935 referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
936 path: path.as_std_path().into(),
937 text: text.into(),
938 range: snippet.declaration.item_line_range(),
939 text_is_truncated,
940 signature_range: snippet.declaration.signature_range_in_item_text(),
941 parent_index,
942 signature_score: snippet.score(DeclarationStyle::Signature),
943 declaration_score: snippet.score(DeclarationStyle::Declaration),
944 score_components: snippet.components,
945 });
946 }
947
948 let excerpt_parent = index_state.and_then(|index_state| {
949 context
950 .excerpt
951 .parent_declarations
952 .last()
953 .and_then(|(parent, _)| {
954 add_signature(
955 *parent,
956 &mut declaration_to_signature_index,
957 &mut signatures,
958 index_state,
959 )
960 })
961 });
962
963 predict_edits_v3::PredictEditsRequest {
964 excerpt_path,
965 excerpt: context.excerpt_text.body,
966 excerpt_line_range: context.excerpt.line_range,
967 excerpt_range: context.excerpt.range,
968 cursor_point: predict_edits_v3::Point {
969 line: predict_edits_v3::Line(context.cursor_point.row),
970 column: context.cursor_point.column,
971 },
972 referenced_declarations,
973 signatures,
974 excerpt_parent,
975 events,
976 can_collect_data,
977 diagnostic_groups,
978 diagnostic_groups_truncated,
979 git_info,
980 debug_info,
981 prompt_max_bytes,
982 prompt_format,
983 }
984}
985
986fn add_signature(
987 declaration_id: DeclarationId,
988 declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
989 signatures: &mut Vec<Signature>,
990 index: &SyntaxIndexState,
991) -> Option<usize> {
992 if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
993 return Some(*signature_index);
994 }
995 let Some(parent_declaration) = index.declaration(declaration_id) else {
996 log::error!("bug: missing parent declaration");
997 return None;
998 };
999 let parent_index = parent_declaration.parent().and_then(|parent| {
1000 add_signature(parent, declaration_to_signature_index, signatures, index)
1001 });
1002 let (text, text_is_truncated) = parent_declaration.signature_text();
1003 let signature_index = signatures.len();
1004 signatures.push(Signature {
1005 text: text.into(),
1006 text_is_truncated,
1007 parent_index,
1008 range: parent_declaration.signature_line_range(),
1009 });
1010 declaration_to_signature_index.insert(declaration_id, signature_index);
1011 Some(signature_index)
1012}
1013
1014#[cfg(test)]
1015mod tests {
1016 use std::{
1017 path::{Path, PathBuf},
1018 sync::Arc,
1019 };
1020
1021 use client::UserStore;
1022 use clock::FakeSystemClock;
1023 use cloud_llm_client::predict_edits_v3::{self, Point};
1024 use edit_prediction_context::Line;
1025 use futures::{
1026 AsyncReadExt, StreamExt,
1027 channel::{mpsc, oneshot},
1028 };
1029 use gpui::{
1030 Entity, TestAppContext,
1031 http_client::{FakeHttpClient, Response},
1032 prelude::*,
1033 };
1034 use indoc::indoc;
1035 use language::{LanguageServerId, OffsetRangeExt as _};
1036 use pretty_assertions::{assert_eq, assert_matches};
1037 use project::{FakeFs, Project};
1038 use serde_json::json;
1039 use settings::SettingsStore;
1040 use util::path;
1041 use uuid::Uuid;
1042
1043 use crate::{BufferEditPrediction, Zeta};
1044
1045 #[gpui::test]
1046 async fn test_current_state(cx: &mut TestAppContext) {
1047 let (zeta, mut req_rx) = init_test(cx);
1048 let fs = FakeFs::new(cx.executor());
1049 fs.insert_tree(
1050 "/root",
1051 json!({
1052 "1.txt": "Hello!\nHow\nBye",
1053 "2.txt": "Hola!\nComo\nAdios"
1054 }),
1055 )
1056 .await;
1057 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1058
1059 zeta.update(cx, |zeta, cx| {
1060 zeta.register_project(&project, cx);
1061 });
1062
1063 let buffer1 = project
1064 .update(cx, |project, cx| {
1065 let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
1066 project.open_buffer(path, cx)
1067 })
1068 .await
1069 .unwrap();
1070 let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
1071 let position = snapshot1.anchor_before(language::Point::new(1, 3));
1072
1073 // Prediction for current file
1074
1075 let prediction_task = zeta.update(cx, |zeta, cx| {
1076 zeta.refresh_prediction(&project, &buffer1, position, cx)
1077 });
1078 let (_request, respond_tx) = req_rx.next().await.unwrap();
1079 respond_tx
1080 .send(predict_edits_v3::PredictEditsResponse {
1081 request_id: Uuid::new_v4(),
1082 edits: vec![predict_edits_v3::Edit {
1083 path: Path::new(path!("root/1.txt")).into(),
1084 range: Line(0)..Line(snapshot1.max_point().row + 1),
1085 content: "Hello!\nHow are you?\nBye".into(),
1086 }],
1087 debug_info: None,
1088 })
1089 .unwrap();
1090 prediction_task.await.unwrap();
1091
1092 zeta.read_with(cx, |zeta, cx| {
1093 let prediction = zeta
1094 .current_prediction_for_buffer(&buffer1, &project, cx)
1095 .unwrap();
1096 assert_matches!(prediction, BufferEditPrediction::Local { .. });
1097 });
1098
1099 // Prediction for another file
1100 let prediction_task = zeta.update(cx, |zeta, cx| {
1101 zeta.refresh_prediction(&project, &buffer1, position, cx)
1102 });
1103 let (_request, respond_tx) = req_rx.next().await.unwrap();
1104 respond_tx
1105 .send(predict_edits_v3::PredictEditsResponse {
1106 request_id: Uuid::new_v4(),
1107 edits: vec![predict_edits_v3::Edit {
1108 path: Path::new(path!("root/2.txt")).into(),
1109 range: Line(0)..Line(snapshot1.max_point().row + 1),
1110 content: "Hola!\nComo estas?\nAdios".into(),
1111 }],
1112 debug_info: None,
1113 })
1114 .unwrap();
1115 prediction_task.await.unwrap();
1116 zeta.read_with(cx, |zeta, cx| {
1117 let prediction = zeta
1118 .current_prediction_for_buffer(&buffer1, &project, cx)
1119 .unwrap();
1120 assert_matches!(
1121 prediction,
1122 BufferEditPrediction::Jump { prediction } if prediction.path.as_ref() == Path::new(path!("root/2.txt"))
1123 );
1124 });
1125
1126 let buffer2 = project
1127 .update(cx, |project, cx| {
1128 let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
1129 project.open_buffer(path, cx)
1130 })
1131 .await
1132 .unwrap();
1133
1134 zeta.read_with(cx, |zeta, cx| {
1135 let prediction = zeta
1136 .current_prediction_for_buffer(&buffer2, &project, cx)
1137 .unwrap();
1138 assert_matches!(prediction, BufferEditPrediction::Local { .. });
1139 });
1140 }
1141
1142 #[gpui::test]
1143 async fn test_simple_request(cx: &mut TestAppContext) {
1144 let (zeta, mut req_rx) = init_test(cx);
1145 let fs = FakeFs::new(cx.executor());
1146 fs.insert_tree(
1147 "/root",
1148 json!({
1149 "foo.md": "Hello!\nHow\nBye"
1150 }),
1151 )
1152 .await;
1153 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1154
1155 let buffer = project
1156 .update(cx, |project, cx| {
1157 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1158 project.open_buffer(path, cx)
1159 })
1160 .await
1161 .unwrap();
1162 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1163 let position = snapshot.anchor_before(language::Point::new(1, 3));
1164
1165 let prediction_task = zeta.update(cx, |zeta, cx| {
1166 zeta.request_prediction(&project, &buffer, position, cx)
1167 });
1168
1169 let (request, respond_tx) = req_rx.next().await.unwrap();
1170 assert_eq!(
1171 request.excerpt_path.as_ref(),
1172 Path::new(path!("root/foo.md"))
1173 );
1174 assert_eq!(
1175 request.cursor_point,
1176 Point {
1177 line: Line(1),
1178 column: 3
1179 }
1180 );
1181
1182 respond_tx
1183 .send(predict_edits_v3::PredictEditsResponse {
1184 request_id: Uuid::new_v4(),
1185 edits: vec![predict_edits_v3::Edit {
1186 path: Path::new(path!("root/foo.md")).into(),
1187 range: Line(0)..Line(snapshot.max_point().row + 1),
1188 content: "Hello!\nHow are you?\nBye".into(),
1189 }],
1190 debug_info: None,
1191 })
1192 .unwrap();
1193
1194 let prediction = prediction_task.await.unwrap().unwrap();
1195
1196 assert_eq!(prediction.edits.len(), 1);
1197 assert_eq!(
1198 prediction.edits[0].0.to_point(&snapshot).start,
1199 language::Point::new(1, 3)
1200 );
1201 assert_eq!(prediction.edits[0].1, " are you?");
1202 }
1203
1204 #[gpui::test]
1205 async fn test_request_events(cx: &mut TestAppContext) {
1206 let (zeta, mut req_rx) = init_test(cx);
1207 let fs = FakeFs::new(cx.executor());
1208 fs.insert_tree(
1209 "/root",
1210 json!({
1211 "foo.md": "Hello!\n\nBye"
1212 }),
1213 )
1214 .await;
1215 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1216
1217 let buffer = project
1218 .update(cx, |project, cx| {
1219 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1220 project.open_buffer(path, cx)
1221 })
1222 .await
1223 .unwrap();
1224
1225 zeta.update(cx, |zeta, cx| {
1226 zeta.register_buffer(&buffer, &project, cx);
1227 });
1228
1229 buffer.update(cx, |buffer, cx| {
1230 buffer.edit(vec![(7..7, "How")], None, cx);
1231 });
1232
1233 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1234 let position = snapshot.anchor_before(language::Point::new(1, 3));
1235
1236 let prediction_task = zeta.update(cx, |zeta, cx| {
1237 zeta.request_prediction(&project, &buffer, position, cx)
1238 });
1239
1240 let (request, respond_tx) = req_rx.next().await.unwrap();
1241
1242 assert_eq!(request.events.len(), 1);
1243 assert_eq!(
1244 request.events[0],
1245 predict_edits_v3::Event::BufferChange {
1246 path: Some(PathBuf::from(path!("root/foo.md"))),
1247 old_path: None,
1248 diff: indoc! {"
1249 @@ -1,3 +1,3 @@
1250 Hello!
1251 -
1252 +How
1253 Bye
1254 "}
1255 .to_string(),
1256 predicted: false
1257 }
1258 );
1259
1260 respond_tx
1261 .send(predict_edits_v3::PredictEditsResponse {
1262 request_id: Uuid::new_v4(),
1263 edits: vec![predict_edits_v3::Edit {
1264 path: Path::new(path!("root/foo.md")).into(),
1265 range: Line(0)..Line(snapshot.max_point().row + 1),
1266 content: "Hello!\nHow are you?\nBye".into(),
1267 }],
1268 debug_info: None,
1269 })
1270 .unwrap();
1271
1272 let prediction = prediction_task.await.unwrap().unwrap();
1273
1274 assert_eq!(prediction.edits.len(), 1);
1275 assert_eq!(
1276 prediction.edits[0].0.to_point(&snapshot).start,
1277 language::Point::new(1, 3)
1278 );
1279 assert_eq!(prediction.edits[0].1, " are you?");
1280 }
1281
1282 #[gpui::test]
1283 async fn test_request_diagnostics(cx: &mut TestAppContext) {
1284 let (zeta, mut req_rx) = init_test(cx);
1285 let fs = FakeFs::new(cx.executor());
1286 fs.insert_tree(
1287 "/root",
1288 json!({
1289 "foo.md": "Hello!\nBye"
1290 }),
1291 )
1292 .await;
1293 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1294
1295 let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1296 let diagnostic = lsp::Diagnostic {
1297 range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1298 severity: Some(lsp::DiagnosticSeverity::ERROR),
1299 message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1300 ..Default::default()
1301 };
1302
1303 project.update(cx, |project, cx| {
1304 project.lsp_store().update(cx, |lsp_store, cx| {
1305 // Create some diagnostics
1306 lsp_store
1307 .update_diagnostics(
1308 LanguageServerId(0),
1309 lsp::PublishDiagnosticsParams {
1310 uri: path_to_buffer_uri.clone(),
1311 diagnostics: vec![diagnostic],
1312 version: None,
1313 },
1314 None,
1315 language::DiagnosticSourceKind::Pushed,
1316 &[],
1317 cx,
1318 )
1319 .unwrap();
1320 });
1321 });
1322
1323 let buffer = project
1324 .update(cx, |project, cx| {
1325 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1326 project.open_buffer(path, cx)
1327 })
1328 .await
1329 .unwrap();
1330
1331 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1332 let position = snapshot.anchor_before(language::Point::new(0, 0));
1333
1334 let _prediction_task = zeta.update(cx, |zeta, cx| {
1335 zeta.request_prediction(&project, &buffer, position, cx)
1336 });
1337
1338 let (request, _respond_tx) = req_rx.next().await.unwrap();
1339
1340 assert_eq!(request.diagnostic_groups.len(), 1);
1341 let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1342 .unwrap();
1343 // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1344 assert_eq!(
1345 value,
1346 json!({
1347 "entries": [{
1348 "range": {
1349 "start": 8,
1350 "end": 10
1351 },
1352 "diagnostic": {
1353 "source": null,
1354 "code": null,
1355 "code_description": null,
1356 "severity": 1,
1357 "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1358 "markdown": null,
1359 "group_id": 0,
1360 "is_primary": true,
1361 "is_disk_based": false,
1362 "is_unnecessary": false,
1363 "source_kind": "Pushed",
1364 "data": null,
1365 "underline": true
1366 }
1367 }],
1368 "primary_ix": 0
1369 })
1370 );
1371 }
1372
1373 fn init_test(
1374 cx: &mut TestAppContext,
1375 ) -> (
1376 Entity<Zeta>,
1377 mpsc::UnboundedReceiver<(
1378 predict_edits_v3::PredictEditsRequest,
1379 oneshot::Sender<predict_edits_v3::PredictEditsResponse>,
1380 )>,
1381 ) {
1382 cx.update(move |cx| {
1383 let settings_store = SettingsStore::test(cx);
1384 cx.set_global(settings_store);
1385 language::init(cx);
1386 Project::init_settings(cx);
1387
1388 let (req_tx, req_rx) = mpsc::unbounded();
1389
1390 let http_client = FakeHttpClient::create({
1391 move |req| {
1392 let uri = req.uri().path().to_string();
1393 let mut body = req.into_body();
1394 let req_tx = req_tx.clone();
1395 async move {
1396 let resp = match uri.as_str() {
1397 "/client/llm_tokens" => serde_json::to_string(&json!({
1398 "token": "test"
1399 }))
1400 .unwrap(),
1401 "/predict_edits/v3" => {
1402 let mut buf = Vec::new();
1403 body.read_to_end(&mut buf).await.ok();
1404 let req = serde_json::from_slice(&buf).unwrap();
1405
1406 let (res_tx, res_rx) = oneshot::channel();
1407 req_tx.unbounded_send((req, res_tx)).unwrap();
1408 serde_json::to_string(&res_rx.await?).unwrap()
1409 }
1410 _ => {
1411 panic!("Unexpected path: {}", uri)
1412 }
1413 };
1414
1415 Ok(Response::builder().body(resp.into()).unwrap())
1416 }
1417 }
1418 });
1419
1420 let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1421 client.cloud_client().set_credentials(1, "test".into());
1422
1423 language_model::init(client.clone(), cx);
1424
1425 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1426 let zeta = Zeta::global(&client, &user_store, cx);
1427
1428 (zeta, req_rx)
1429 })
1430 }
1431}