1use anyhow::{Context as _, Result, anyhow};
2use chrono::TimeDelta;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
5use cloud_llm_client::{
6 EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, ZED_VERSION_HEADER_NAME,
7};
8use cloud_zeta2_prompt::DEFAULT_MAX_PROMPT_BYTES;
9use edit_prediction_context::{
10 DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
11 EditPredictionExcerptOptions, EditPredictionScoreOptions, SyntaxIndex, SyntaxIndexState,
12};
13use futures::AsyncReadExt as _;
14use futures::channel::mpsc;
15use gpui::http_client::Method;
16use gpui::{
17 App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
18 http_client, prelude::*,
19};
20use language::{Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
21use language::{BufferSnapshot, TextBufferSnapshot};
22use language_model::{LlmApiToken, RefreshLlmTokenListener};
23use project::Project;
24use release_channel::AppVersion;
25use std::collections::{HashMap, VecDeque, hash_map};
26use std::path::Path;
27use std::str::FromStr as _;
28use std::sync::Arc;
29use std::time::{Duration, Instant};
30use thiserror::Error;
31use util::rel_path::RelPathBuf;
32use util::some_or_debug_panic;
33use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
34
35mod prediction;
36mod provider;
37
38use crate::prediction::EditPrediction;
39pub use provider::ZetaEditPredictionProvider;
40
41const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
42
43/// Maximum number of events to track.
44const MAX_EVENT_COUNT: usize = 16;
45
46pub const DEFAULT_CONTEXT_OPTIONS: EditPredictionContextOptions = EditPredictionContextOptions {
47 use_imports: true,
48 excerpt: EditPredictionExcerptOptions {
49 max_bytes: 512,
50 min_bytes: 128,
51 target_before_cursor_over_total_bytes: 0.5,
52 },
53 score: EditPredictionScoreOptions {
54 omit_excerpt_overlaps: true,
55 },
56};
57
58pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
59 context: DEFAULT_CONTEXT_OPTIONS,
60 max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
61 max_diagnostic_bytes: 2048,
62 prompt_format: PromptFormat::DEFAULT,
63 file_indexing_parallelism: 1,
64};
65
66#[derive(Clone)]
67struct ZetaGlobal(Entity<Zeta>);
68
69impl Global for ZetaGlobal {}
70
71pub struct Zeta {
72 client: Arc<Client>,
73 user_store: Entity<UserStore>,
74 llm_token: LlmApiToken,
75 _llm_token_subscription: Subscription,
76 projects: HashMap<EntityId, ZetaProject>,
77 options: ZetaOptions,
78 update_required: bool,
79 debug_tx: Option<mpsc::UnboundedSender<Result<PredictionDebugInfo, String>>>,
80}
81
82#[derive(Debug, Clone, PartialEq)]
83pub struct ZetaOptions {
84 pub context: EditPredictionContextOptions,
85 pub max_prompt_bytes: usize,
86 pub max_diagnostic_bytes: usize,
87 pub prompt_format: predict_edits_v3::PromptFormat,
88 pub file_indexing_parallelism: usize,
89}
90
91pub struct PredictionDebugInfo {
92 pub context: EditPredictionContext,
93 pub retrieval_time: TimeDelta,
94 pub request: RequestDebugInfo,
95 pub buffer: WeakEntity<Buffer>,
96 pub position: language::Anchor,
97}
98
99pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
100
101struct ZetaProject {
102 syntax_index: Entity<SyntaxIndex>,
103 events: VecDeque<Event>,
104 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
105 current_prediction: Option<CurrentEditPrediction>,
106}
107
108#[derive(Clone)]
109struct CurrentEditPrediction {
110 pub requested_by_buffer_id: EntityId,
111 pub prediction: EditPrediction,
112}
113
114impl CurrentEditPrediction {
115 fn should_replace_prediction(
116 &self,
117 old_prediction: &Self,
118 snapshot: &TextBufferSnapshot,
119 ) -> bool {
120 if self.requested_by_buffer_id != old_prediction.requested_by_buffer_id {
121 return true;
122 }
123
124 let Some(old_edits) = old_prediction.prediction.interpolate(snapshot) else {
125 return true;
126 };
127
128 let Some(new_edits) = self.prediction.interpolate(snapshot) else {
129 return false;
130 };
131 if old_edits.len() == 1 && new_edits.len() == 1 {
132 let (old_range, old_text) = &old_edits[0];
133 let (new_range, new_text) = &new_edits[0];
134 new_range == old_range && new_text.starts_with(old_text)
135 } else {
136 true
137 }
138 }
139}
140
141/// A prediction from the perspective of a buffer.
142#[derive(Debug)]
143enum BufferEditPrediction<'a> {
144 Local { prediction: &'a EditPrediction },
145 Jump { prediction: &'a EditPrediction },
146}
147
148struct RegisteredBuffer {
149 snapshot: BufferSnapshot,
150 _subscriptions: [gpui::Subscription; 2],
151}
152
153#[derive(Clone)]
154pub enum Event {
155 BufferChange {
156 old_snapshot: BufferSnapshot,
157 new_snapshot: BufferSnapshot,
158 timestamp: Instant,
159 },
160}
161
162impl Zeta {
163 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
164 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
165 }
166
167 pub fn global(
168 client: &Arc<Client>,
169 user_store: &Entity<UserStore>,
170 cx: &mut App,
171 ) -> Entity<Self> {
172 cx.try_global::<ZetaGlobal>()
173 .map(|global| global.0.clone())
174 .unwrap_or_else(|| {
175 let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
176 cx.set_global(ZetaGlobal(zeta.clone()));
177 zeta
178 })
179 }
180
181 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
182 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
183
184 Self {
185 projects: HashMap::new(),
186 client,
187 user_store,
188 options: DEFAULT_OPTIONS,
189 llm_token: LlmApiToken::default(),
190 _llm_token_subscription: cx.subscribe(
191 &refresh_llm_token_listener,
192 |this, _listener, _event, cx| {
193 let client = this.client.clone();
194 let llm_token = this.llm_token.clone();
195 cx.spawn(async move |_this, _cx| {
196 llm_token.refresh(&client).await?;
197 anyhow::Ok(())
198 })
199 .detach_and_log_err(cx);
200 },
201 ),
202 update_required: false,
203 debug_tx: None,
204 }
205 }
206
207 pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<Result<PredictionDebugInfo, String>> {
208 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
209 self.debug_tx = Some(debug_watch_tx);
210 debug_watch_rx
211 }
212
213 pub fn options(&self) -> &ZetaOptions {
214 &self.options
215 }
216
217 pub fn set_options(&mut self, options: ZetaOptions) {
218 self.options = options;
219 }
220
221 pub fn clear_history(&mut self) {
222 for zeta_project in self.projects.values_mut() {
223 zeta_project.events.clear();
224 }
225 }
226
227 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
228 self.user_store.read(cx).edit_prediction_usage()
229 }
230
231 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
232 self.get_or_init_zeta_project(project, cx);
233 }
234
235 pub fn register_buffer(
236 &mut self,
237 buffer: &Entity<Buffer>,
238 project: &Entity<Project>,
239 cx: &mut Context<Self>,
240 ) {
241 let zeta_project = self.get_or_init_zeta_project(project, cx);
242 Self::register_buffer_impl(zeta_project, buffer, project, cx);
243 }
244
245 fn get_or_init_zeta_project(
246 &mut self,
247 project: &Entity<Project>,
248 cx: &mut App,
249 ) -> &mut ZetaProject {
250 self.projects
251 .entry(project.entity_id())
252 .or_insert_with(|| ZetaProject {
253 syntax_index: cx.new(|cx| {
254 SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
255 }),
256 events: VecDeque::new(),
257 registered_buffers: HashMap::new(),
258 current_prediction: None,
259 })
260 }
261
262 fn register_buffer_impl<'a>(
263 zeta_project: &'a mut ZetaProject,
264 buffer: &Entity<Buffer>,
265 project: &Entity<Project>,
266 cx: &mut Context<Self>,
267 ) -> &'a mut RegisteredBuffer {
268 let buffer_id = buffer.entity_id();
269 match zeta_project.registered_buffers.entry(buffer_id) {
270 hash_map::Entry::Occupied(entry) => entry.into_mut(),
271 hash_map::Entry::Vacant(entry) => {
272 let snapshot = buffer.read(cx).snapshot();
273 let project_entity_id = project.entity_id();
274 entry.insert(RegisteredBuffer {
275 snapshot,
276 _subscriptions: [
277 cx.subscribe(buffer, {
278 let project = project.downgrade();
279 move |this, buffer, event, cx| {
280 if let language::BufferEvent::Edited = event
281 && let Some(project) = project.upgrade()
282 {
283 this.report_changes_for_buffer(&buffer, &project, cx);
284 }
285 }
286 }),
287 cx.observe_release(buffer, move |this, _buffer, _cx| {
288 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
289 else {
290 return;
291 };
292 zeta_project.registered_buffers.remove(&buffer_id);
293 }),
294 ],
295 })
296 }
297 }
298 }
299
300 fn report_changes_for_buffer(
301 &mut self,
302 buffer: &Entity<Buffer>,
303 project: &Entity<Project>,
304 cx: &mut Context<Self>,
305 ) -> BufferSnapshot {
306 let zeta_project = self.get_or_init_zeta_project(project, cx);
307 let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
308
309 let new_snapshot = buffer.read(cx).snapshot();
310 if new_snapshot.version != registered_buffer.snapshot.version {
311 let old_snapshot =
312 std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
313 Self::push_event(
314 zeta_project,
315 Event::BufferChange {
316 old_snapshot,
317 new_snapshot: new_snapshot.clone(),
318 timestamp: Instant::now(),
319 },
320 );
321 }
322
323 new_snapshot
324 }
325
326 fn push_event(zeta_project: &mut ZetaProject, event: Event) {
327 let events = &mut zeta_project.events;
328
329 if let Some(Event::BufferChange {
330 new_snapshot: last_new_snapshot,
331 timestamp: last_timestamp,
332 ..
333 }) = events.back_mut()
334 {
335 // Coalesce edits for the same buffer when they happen one after the other.
336 let Event::BufferChange {
337 old_snapshot,
338 new_snapshot,
339 timestamp,
340 } = &event;
341
342 if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
343 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
344 && old_snapshot.version == last_new_snapshot.version
345 {
346 *last_new_snapshot = new_snapshot.clone();
347 *last_timestamp = *timestamp;
348 return;
349 }
350 }
351
352 if events.len() >= MAX_EVENT_COUNT {
353 // These are halved instead of popping to improve prompt caching.
354 events.drain(..MAX_EVENT_COUNT / 2);
355 }
356
357 events.push_back(event);
358 }
359
360 fn current_prediction_for_buffer(
361 &self,
362 buffer: &Entity<Buffer>,
363 project: &Entity<Project>,
364 cx: &App,
365 ) -> Option<BufferEditPrediction<'_>> {
366 let project_state = self.projects.get(&project.entity_id())?;
367
368 let CurrentEditPrediction {
369 requested_by_buffer_id,
370 prediction,
371 } = project_state.current_prediction.as_ref()?;
372
373 if prediction.targets_buffer(buffer.read(cx), cx) {
374 Some(BufferEditPrediction::Local { prediction })
375 } else if *requested_by_buffer_id == buffer.entity_id() {
376 Some(BufferEditPrediction::Jump { prediction })
377 } else {
378 None
379 }
380 }
381
382 fn accept_current_prediction(&mut self, project: &Entity<Project>) {
383 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
384 project_state.current_prediction.take();
385 };
386 // TODO report accepted
387 }
388
389 fn discard_current_prediction(&mut self, project: &Entity<Project>) {
390 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
391 project_state.current_prediction.take();
392 };
393 }
394
395 pub fn refresh_prediction(
396 &mut self,
397 project: &Entity<Project>,
398 buffer: &Entity<Buffer>,
399 position: language::Anchor,
400 cx: &mut Context<Self>,
401 ) -> Task<Result<()>> {
402 let request_task = self.request_prediction(project, buffer, position, cx);
403 let buffer = buffer.clone();
404 let project = project.clone();
405
406 cx.spawn(async move |this, cx| {
407 if let Some(prediction) = request_task.await? {
408 this.update(cx, |this, cx| {
409 let project_state = this
410 .projects
411 .get_mut(&project.entity_id())
412 .context("Project not found")?;
413
414 let new_prediction = CurrentEditPrediction {
415 requested_by_buffer_id: buffer.entity_id(),
416 prediction: prediction,
417 };
418
419 if project_state
420 .current_prediction
421 .as_ref()
422 .is_none_or(|old_prediction| {
423 new_prediction
424 .should_replace_prediction(&old_prediction, buffer.read(cx))
425 })
426 {
427 project_state.current_prediction = Some(new_prediction);
428 }
429 anyhow::Ok(())
430 })??;
431 }
432 Ok(())
433 })
434 }
435
436 fn request_prediction(
437 &mut self,
438 project: &Entity<Project>,
439 buffer: &Entity<Buffer>,
440 position: language::Anchor,
441 cx: &mut Context<Self>,
442 ) -> Task<Result<Option<EditPrediction>>> {
443 let project_state = self.projects.get(&project.entity_id());
444
445 let index_state = project_state.map(|state| {
446 state
447 .syntax_index
448 .read_with(cx, |index, _cx| index.state().clone())
449 });
450 let options = self.options.clone();
451 let snapshot = buffer.read(cx).snapshot();
452 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx).into()) else {
453 return Task::ready(Err(anyhow!("No file path for excerpt")));
454 };
455 let client = self.client.clone();
456 let llm_token = self.llm_token.clone();
457 let app_version = AppVersion::global(cx);
458 let worktree_snapshots = project
459 .read(cx)
460 .worktrees(cx)
461 .map(|worktree| worktree.read(cx).snapshot())
462 .collect::<Vec<_>>();
463 let debug_tx = self.debug_tx.clone();
464
465 let events = project_state
466 .map(|state| {
467 state
468 .events
469 .iter()
470 .filter_map(|event| match event {
471 Event::BufferChange {
472 old_snapshot,
473 new_snapshot,
474 ..
475 } => {
476 let path = new_snapshot.file().map(|f| f.full_path(cx));
477
478 let old_path = old_snapshot.file().and_then(|f| {
479 let old_path = f.full_path(cx);
480 if Some(&old_path) != path.as_ref() {
481 Some(old_path)
482 } else {
483 None
484 }
485 });
486
487 // TODO [zeta2] move to bg?
488 let diff =
489 language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
490
491 if path == old_path && diff.is_empty() {
492 None
493 } else {
494 Some(predict_edits_v3::Event::BufferChange {
495 old_path,
496 path,
497 diff,
498 //todo: Actually detect if this edit was predicted or not
499 predicted: false,
500 })
501 }
502 }
503 })
504 .collect::<Vec<_>>()
505 })
506 .unwrap_or_default();
507
508 let diagnostics = snapshot.diagnostic_sets().clone();
509
510 let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
511 let mut path = f.worktree.read(cx).absolutize(&f.path);
512 if path.pop() { Some(path) } else { None }
513 });
514
515 let request_task = cx.background_spawn({
516 let snapshot = snapshot.clone();
517 let buffer = buffer.clone();
518 async move {
519 let index_state = if let Some(index_state) = index_state {
520 Some(index_state.lock_owned().await)
521 } else {
522 None
523 };
524
525 let cursor_offset = position.to_offset(&snapshot);
526 let cursor_point = cursor_offset.to_point(&snapshot);
527
528 let before_retrieval = chrono::Utc::now();
529
530 let Some(context) = EditPredictionContext::gather_context(
531 cursor_point,
532 &snapshot,
533 parent_abs_path.as_deref(),
534 &options.context,
535 index_state.as_deref(),
536 ) else {
537 return Ok(None);
538 };
539
540 let debug_context = if let Some(debug_tx) = debug_tx {
541 Some((debug_tx, context.clone()))
542 } else {
543 None
544 };
545
546 let (diagnostic_groups, diagnostic_groups_truncated) =
547 Self::gather_nearby_diagnostics(
548 cursor_offset,
549 &diagnostics,
550 &snapshot,
551 options.max_diagnostic_bytes,
552 );
553
554 let request = make_cloud_request(
555 excerpt_path,
556 context,
557 events,
558 // TODO data collection
559 false,
560 diagnostic_groups,
561 diagnostic_groups_truncated,
562 None,
563 debug_context.is_some(),
564 &worktree_snapshots,
565 index_state.as_deref(),
566 Some(options.max_prompt_bytes),
567 options.prompt_format,
568 );
569
570 let retrieval_time = chrono::Utc::now() - before_retrieval;
571 let response = Self::perform_request(client, llm_token, app_version, request).await;
572
573 if let Some((debug_tx, context)) = debug_context {
574 debug_tx
575 .unbounded_send(response.as_ref().map_err(|err| err.to_string()).and_then(
576 |response| {
577 let Some(request) =
578 some_or_debug_panic(response.0.debug_info.clone())
579 else {
580 return Err("Missing debug info".to_string());
581 };
582 Ok(PredictionDebugInfo {
583 context,
584 request,
585 retrieval_time,
586 buffer: buffer.downgrade(),
587 position,
588 })
589 },
590 ))
591 .ok();
592 }
593
594 anyhow::Ok(Some(response?))
595 }
596 });
597
598 let buffer = buffer.clone();
599
600 cx.spawn({
601 let project = project.clone();
602 async move |this, cx| {
603 match request_task.await {
604 Ok(Some((response, usage))) => {
605 if let Some(usage) = usage {
606 this.update(cx, |this, cx| {
607 this.user_store.update(cx, |user_store, cx| {
608 user_store.update_edit_prediction_usage(usage, cx);
609 });
610 })
611 .ok();
612 }
613
614 let prediction = EditPrediction::from_response(
615 response, &snapshot, &buffer, &project, cx,
616 )
617 .await;
618
619 // TODO telemetry: duration, etc
620 Ok(prediction)
621 }
622 Ok(None) => Ok(None),
623 Err(err) => {
624 if err.is::<ZedUpdateRequiredError>() {
625 cx.update(|cx| {
626 this.update(cx, |this, _cx| {
627 this.update_required = true;
628 })
629 .ok();
630
631 let error_message: SharedString = err.to_string().into();
632 show_app_notification(
633 NotificationId::unique::<ZedUpdateRequiredError>(),
634 cx,
635 move |cx| {
636 cx.new(|cx| {
637 ErrorMessagePrompt::new(error_message.clone(), cx)
638 .with_link_button(
639 "Update Zed",
640 "https://zed.dev/releases",
641 )
642 })
643 },
644 );
645 })
646 .ok();
647 }
648
649 Err(err)
650 }
651 }
652 }
653 })
654 }
655
656 async fn perform_request(
657 client: Arc<Client>,
658 llm_token: LlmApiToken,
659 app_version: SemanticVersion,
660 request: predict_edits_v3::PredictEditsRequest,
661 ) -> Result<(
662 predict_edits_v3::PredictEditsResponse,
663 Option<EditPredictionUsage>,
664 )> {
665 let http_client = client.http_client();
666 let mut token = llm_token.acquire(&client).await?;
667 let mut did_retry = false;
668
669 loop {
670 let request_builder = http_client::Request::builder().method(Method::POST);
671 let request_builder =
672 if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
673 request_builder.uri(predict_edits_url)
674 } else {
675 request_builder.uri(
676 http_client
677 .build_zed_llm_url("/predict_edits/v3", &[])?
678 .as_ref(),
679 )
680 };
681 let request = request_builder
682 .header("Content-Type", "application/json")
683 .header("Authorization", format!("Bearer {}", token))
684 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
685 .body(serde_json::to_string(&request)?.into())?;
686
687 let mut response = http_client.send(request).await?;
688
689 if let Some(minimum_required_version) = response
690 .headers()
691 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
692 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
693 {
694 anyhow::ensure!(
695 app_version >= minimum_required_version,
696 ZedUpdateRequiredError {
697 minimum_version: minimum_required_version
698 }
699 );
700 }
701
702 if response.status().is_success() {
703 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
704
705 let mut body = Vec::new();
706 response.body_mut().read_to_end(&mut body).await?;
707 return Ok((serde_json::from_slice(&body)?, usage));
708 } else if !did_retry
709 && response
710 .headers()
711 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
712 .is_some()
713 {
714 did_retry = true;
715 token = llm_token.refresh(&client).await?;
716 } else {
717 let mut body = String::new();
718 response.body_mut().read_to_string(&mut body).await?;
719 anyhow::bail!(
720 "error predicting edits.\nStatus: {:?}\nBody: {}",
721 response.status(),
722 body
723 );
724 }
725 }
726 }
727
728 fn gather_nearby_diagnostics(
729 cursor_offset: usize,
730 diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
731 snapshot: &BufferSnapshot,
732 max_diagnostics_bytes: usize,
733 ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
734 // TODO: Could make this more efficient
735 let mut diagnostic_groups = Vec::new();
736 for (language_server_id, diagnostics) in diagnostic_sets {
737 let mut groups = Vec::new();
738 diagnostics.groups(*language_server_id, &mut groups, &snapshot);
739 diagnostic_groups.extend(
740 groups
741 .into_iter()
742 .map(|(_, group)| group.resolve::<usize>(&snapshot)),
743 );
744 }
745
746 // sort by proximity to cursor
747 diagnostic_groups.sort_by_key(|group| {
748 let range = &group.entries[group.primary_ix].range;
749 if range.start >= cursor_offset {
750 range.start - cursor_offset
751 } else if cursor_offset >= range.end {
752 cursor_offset - range.end
753 } else {
754 (cursor_offset - range.start).min(range.end - cursor_offset)
755 }
756 });
757
758 let mut results = Vec::new();
759 let mut diagnostic_groups_truncated = false;
760 let mut diagnostics_byte_count = 0;
761 for group in diagnostic_groups {
762 let raw_value = serde_json::value::to_raw_value(&group).unwrap();
763 diagnostics_byte_count += raw_value.get().len();
764 if diagnostics_byte_count > max_diagnostics_bytes {
765 diagnostic_groups_truncated = true;
766 break;
767 }
768 results.push(predict_edits_v3::DiagnosticGroup(raw_value));
769 }
770
771 (results, diagnostic_groups_truncated)
772 }
773
774 // TODO: Dedupe with similar code in request_prediction?
775 pub fn cloud_request_for_zeta_cli(
776 &mut self,
777 project: &Entity<Project>,
778 buffer: &Entity<Buffer>,
779 position: language::Anchor,
780 cx: &mut Context<Self>,
781 ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
782 let project_state = self.projects.get(&project.entity_id());
783
784 let index_state = project_state.map(|state| {
785 state
786 .syntax_index
787 .read_with(cx, |index, _cx| index.state().clone())
788 });
789 let options = self.options.clone();
790 let snapshot = buffer.read(cx).snapshot();
791 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
792 return Task::ready(Err(anyhow!("No file path for excerpt")));
793 };
794 let worktree_snapshots = project
795 .read(cx)
796 .worktrees(cx)
797 .map(|worktree| worktree.read(cx).snapshot())
798 .collect::<Vec<_>>();
799
800 let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
801 let mut path = f.worktree.read(cx).absolutize(&f.path);
802 if path.pop() { Some(path) } else { None }
803 });
804
805 cx.background_spawn(async move {
806 let index_state = if let Some(index_state) = index_state {
807 Some(index_state.lock_owned().await)
808 } else {
809 None
810 };
811
812 let cursor_point = position.to_point(&snapshot);
813
814 let debug_info = true;
815 EditPredictionContext::gather_context(
816 cursor_point,
817 &snapshot,
818 parent_abs_path.as_deref(),
819 &options.context,
820 index_state.as_deref(),
821 )
822 .context("Failed to select excerpt")
823 .map(|context| {
824 make_cloud_request(
825 excerpt_path.into(),
826 context,
827 // TODO pass everything
828 Vec::new(),
829 false,
830 Vec::new(),
831 false,
832 None,
833 debug_info,
834 &worktree_snapshots,
835 index_state.as_deref(),
836 Some(options.max_prompt_bytes),
837 options.prompt_format,
838 )
839 })
840 })
841 }
842
843 pub fn wait_for_initial_indexing(
844 &mut self,
845 project: &Entity<Project>,
846 cx: &mut App,
847 ) -> Task<Result<()>> {
848 let zeta_project = self.get_or_init_zeta_project(project, cx);
849 zeta_project
850 .syntax_index
851 .read(cx)
852 .wait_for_initial_file_indexing(cx)
853 }
854}
855
856#[derive(Error, Debug)]
857#[error(
858 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
859)]
860pub struct ZedUpdateRequiredError {
861 minimum_version: SemanticVersion,
862}
863
864fn make_cloud_request(
865 excerpt_path: Arc<Path>,
866 context: EditPredictionContext,
867 events: Vec<predict_edits_v3::Event>,
868 can_collect_data: bool,
869 diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
870 diagnostic_groups_truncated: bool,
871 git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
872 debug_info: bool,
873 worktrees: &Vec<worktree::Snapshot>,
874 index_state: Option<&SyntaxIndexState>,
875 prompt_max_bytes: Option<usize>,
876 prompt_format: PromptFormat,
877) -> predict_edits_v3::PredictEditsRequest {
878 let mut signatures = Vec::new();
879 let mut declaration_to_signature_index = HashMap::default();
880 let mut referenced_declarations = Vec::new();
881
882 for snippet in context.declarations {
883 let project_entry_id = snippet.declaration.project_entry_id();
884 let Some(path) = worktrees.iter().find_map(|worktree| {
885 worktree.entry_for_id(project_entry_id).map(|entry| {
886 let mut full_path = RelPathBuf::new();
887 full_path.push(worktree.root_name());
888 full_path.push(&entry.path);
889 full_path
890 })
891 }) else {
892 continue;
893 };
894
895 let parent_index = index_state.and_then(|index_state| {
896 snippet.declaration.parent().and_then(|parent| {
897 add_signature(
898 parent,
899 &mut declaration_to_signature_index,
900 &mut signatures,
901 index_state,
902 )
903 })
904 });
905
906 let (text, text_is_truncated) = snippet.declaration.item_text();
907 referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
908 path: path.as_std_path().into(),
909 text: text.into(),
910 range: snippet.declaration.item_range(),
911 text_is_truncated,
912 signature_range: snippet.declaration.signature_range_in_item_text(),
913 parent_index,
914 signature_score: snippet.score(DeclarationStyle::Signature),
915 declaration_score: snippet.score(DeclarationStyle::Declaration),
916 score_components: snippet.components,
917 });
918 }
919
920 let excerpt_parent = index_state.and_then(|index_state| {
921 context
922 .excerpt
923 .parent_declarations
924 .last()
925 .and_then(|(parent, _)| {
926 add_signature(
927 *parent,
928 &mut declaration_to_signature_index,
929 &mut signatures,
930 index_state,
931 )
932 })
933 });
934
935 predict_edits_v3::PredictEditsRequest {
936 excerpt_path,
937 excerpt: context.excerpt_text.body,
938 excerpt_range: context.excerpt.range,
939 cursor_offset: context.cursor_offset_in_excerpt,
940 referenced_declarations,
941 signatures,
942 excerpt_parent,
943 events,
944 can_collect_data,
945 diagnostic_groups,
946 diagnostic_groups_truncated,
947 git_info,
948 debug_info,
949 prompt_max_bytes,
950 prompt_format,
951 }
952}
953
954fn add_signature(
955 declaration_id: DeclarationId,
956 declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
957 signatures: &mut Vec<Signature>,
958 index: &SyntaxIndexState,
959) -> Option<usize> {
960 if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
961 return Some(*signature_index);
962 }
963 let Some(parent_declaration) = index.declaration(declaration_id) else {
964 log::error!("bug: missing parent declaration");
965 return None;
966 };
967 let parent_index = parent_declaration.parent().and_then(|parent| {
968 add_signature(parent, declaration_to_signature_index, signatures, index)
969 });
970 let (text, text_is_truncated) = parent_declaration.signature_text();
971 let signature_index = signatures.len();
972 signatures.push(Signature {
973 text: text.into(),
974 text_is_truncated,
975 parent_index,
976 range: parent_declaration.signature_range(),
977 });
978 declaration_to_signature_index.insert(declaration_id, signature_index);
979 Some(signature_index)
980}
981
982#[cfg(test)]
983mod tests {
984 use std::{
985 path::{Path, PathBuf},
986 sync::Arc,
987 };
988
989 use client::UserStore;
990 use clock::FakeSystemClock;
991 use cloud_llm_client::predict_edits_v3;
992 use futures::{
993 AsyncReadExt, StreamExt,
994 channel::{mpsc, oneshot},
995 };
996 use gpui::{
997 Entity, TestAppContext,
998 http_client::{FakeHttpClient, Response},
999 prelude::*,
1000 };
1001 use indoc::indoc;
1002 use language::{LanguageServerId, OffsetRangeExt as _};
1003 use pretty_assertions::{assert_eq, assert_matches};
1004 use project::{FakeFs, Project};
1005 use serde_json::json;
1006 use settings::SettingsStore;
1007 use util::path;
1008 use uuid::Uuid;
1009
1010 use crate::{BufferEditPrediction, Zeta};
1011
1012 #[gpui::test]
1013 async fn test_current_state(cx: &mut TestAppContext) {
1014 let (zeta, mut req_rx) = init_test(cx);
1015 let fs = FakeFs::new(cx.executor());
1016 fs.insert_tree(
1017 "/root",
1018 json!({
1019 "1.txt": "Hello!\nHow\nBye",
1020 "2.txt": "Hola!\nComo\nAdios"
1021 }),
1022 )
1023 .await;
1024 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1025
1026 zeta.update(cx, |zeta, cx| {
1027 zeta.register_project(&project, cx);
1028 });
1029
1030 let buffer1 = project
1031 .update(cx, |project, cx| {
1032 let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
1033 project.open_buffer(path, cx)
1034 })
1035 .await
1036 .unwrap();
1037 let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
1038 let position = snapshot1.anchor_before(language::Point::new(1, 3));
1039
1040 // Prediction for current file
1041
1042 let prediction_task = zeta.update(cx, |zeta, cx| {
1043 zeta.refresh_prediction(&project, &buffer1, position, cx)
1044 });
1045 let (_request, respond_tx) = req_rx.next().await.unwrap();
1046 respond_tx
1047 .send(predict_edits_v3::PredictEditsResponse {
1048 request_id: Uuid::new_v4(),
1049 edits: vec![predict_edits_v3::Edit {
1050 path: Path::new(path!("root/1.txt")).into(),
1051 range: 0..snapshot1.len(),
1052 content: "Hello!\nHow are you?\nBye".into(),
1053 }],
1054 debug_info: None,
1055 })
1056 .unwrap();
1057 prediction_task.await.unwrap();
1058
1059 zeta.read_with(cx, |zeta, cx| {
1060 let prediction = zeta
1061 .current_prediction_for_buffer(&buffer1, &project, cx)
1062 .unwrap();
1063 assert_matches!(prediction, BufferEditPrediction::Local { .. });
1064 });
1065
1066 // Prediction for another file
1067
1068 let prediction_task = zeta.update(cx, |zeta, cx| {
1069 zeta.refresh_prediction(&project, &buffer1, position, cx)
1070 });
1071 let (_request, respond_tx) = req_rx.next().await.unwrap();
1072 respond_tx
1073 .send(predict_edits_v3::PredictEditsResponse {
1074 request_id: Uuid::new_v4(),
1075 edits: vec![predict_edits_v3::Edit {
1076 path: Path::new(path!("root/2.txt")).into(),
1077 range: 0..snapshot1.len(),
1078 content: "Hola!\nComo estas?\nAdios".into(),
1079 }],
1080 debug_info: None,
1081 })
1082 .unwrap();
1083 prediction_task.await.unwrap();
1084
1085 zeta.read_with(cx, |zeta, cx| {
1086 let prediction = zeta
1087 .current_prediction_for_buffer(&buffer1, &project, cx)
1088 .unwrap();
1089 assert_matches!(
1090 prediction,
1091 BufferEditPrediction::Jump { prediction } if prediction.path.as_ref() == Path::new(path!("root/2.txt"))
1092 );
1093 });
1094
1095 let buffer2 = project
1096 .update(cx, |project, cx| {
1097 let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
1098 project.open_buffer(path, cx)
1099 })
1100 .await
1101 .unwrap();
1102
1103 zeta.read_with(cx, |zeta, cx| {
1104 let prediction = zeta
1105 .current_prediction_for_buffer(&buffer2, &project, cx)
1106 .unwrap();
1107 assert_matches!(prediction, BufferEditPrediction::Local { .. });
1108 });
1109 }
1110
1111 #[gpui::test]
1112 async fn test_simple_request(cx: &mut TestAppContext) {
1113 let (zeta, mut req_rx) = init_test(cx);
1114 let fs = FakeFs::new(cx.executor());
1115 fs.insert_tree(
1116 "/root",
1117 json!({
1118 "foo.md": "Hello!\nHow\nBye"
1119 }),
1120 )
1121 .await;
1122 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1123
1124 let buffer = project
1125 .update(cx, |project, cx| {
1126 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1127 project.open_buffer(path, cx)
1128 })
1129 .await
1130 .unwrap();
1131 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1132 let position = snapshot.anchor_before(language::Point::new(1, 3));
1133
1134 let prediction_task = zeta.update(cx, |zeta, cx| {
1135 zeta.request_prediction(&project, &buffer, position, cx)
1136 });
1137
1138 let (request, respond_tx) = req_rx.next().await.unwrap();
1139 assert_eq!(
1140 request.excerpt_path.as_ref(),
1141 Path::new(path!("root/foo.md"))
1142 );
1143 assert_eq!(request.cursor_offset, 10);
1144
1145 respond_tx
1146 .send(predict_edits_v3::PredictEditsResponse {
1147 request_id: Uuid::new_v4(),
1148 edits: vec![predict_edits_v3::Edit {
1149 path: Path::new(path!("root/foo.md")).into(),
1150 range: 0..snapshot.len(),
1151 content: "Hello!\nHow are you?\nBye".into(),
1152 }],
1153 debug_info: None,
1154 })
1155 .unwrap();
1156
1157 let prediction = prediction_task.await.unwrap().unwrap();
1158
1159 assert_eq!(prediction.edits.len(), 1);
1160 assert_eq!(
1161 prediction.edits[0].0.to_point(&snapshot).start,
1162 language::Point::new(1, 3)
1163 );
1164 assert_eq!(prediction.edits[0].1, " are you?");
1165 }
1166
1167 #[gpui::test]
1168 async fn test_request_events(cx: &mut TestAppContext) {
1169 let (zeta, mut req_rx) = init_test(cx);
1170 let fs = FakeFs::new(cx.executor());
1171 fs.insert_tree(
1172 "/root",
1173 json!({
1174 "foo.md": "Hello!\n\nBye"
1175 }),
1176 )
1177 .await;
1178 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1179
1180 let buffer = project
1181 .update(cx, |project, cx| {
1182 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1183 project.open_buffer(path, cx)
1184 })
1185 .await
1186 .unwrap();
1187
1188 zeta.update(cx, |zeta, cx| {
1189 zeta.register_buffer(&buffer, &project, cx);
1190 });
1191
1192 buffer.update(cx, |buffer, cx| {
1193 buffer.edit(vec![(7..7, "How")], None, cx);
1194 });
1195
1196 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1197 let position = snapshot.anchor_before(language::Point::new(1, 3));
1198
1199 let prediction_task = zeta.update(cx, |zeta, cx| {
1200 zeta.request_prediction(&project, &buffer, position, cx)
1201 });
1202
1203 let (request, respond_tx) = req_rx.next().await.unwrap();
1204
1205 assert_eq!(request.events.len(), 1);
1206 assert_eq!(
1207 request.events[0],
1208 predict_edits_v3::Event::BufferChange {
1209 path: Some(PathBuf::from(path!("root/foo.md"))),
1210 old_path: None,
1211 diff: indoc! {"
1212 @@ -1,3 +1,3 @@
1213 Hello!
1214 -
1215 +How
1216 Bye
1217 "}
1218 .to_string(),
1219 predicted: false
1220 }
1221 );
1222
1223 respond_tx
1224 .send(predict_edits_v3::PredictEditsResponse {
1225 request_id: Uuid::new_v4(),
1226 edits: vec![predict_edits_v3::Edit {
1227 path: Path::new(path!("root/foo.md")).into(),
1228 range: 0..snapshot.len(),
1229 content: "Hello!\nHow are you?\nBye".into(),
1230 }],
1231 debug_info: None,
1232 })
1233 .unwrap();
1234
1235 let prediction = prediction_task.await.unwrap().unwrap();
1236
1237 assert_eq!(prediction.edits.len(), 1);
1238 assert_eq!(
1239 prediction.edits[0].0.to_point(&snapshot).start,
1240 language::Point::new(1, 3)
1241 );
1242 assert_eq!(prediction.edits[0].1, " are you?");
1243 }
1244
1245 #[gpui::test]
1246 async fn test_request_diagnostics(cx: &mut TestAppContext) {
1247 let (zeta, mut req_rx) = init_test(cx);
1248 let fs = FakeFs::new(cx.executor());
1249 fs.insert_tree(
1250 "/root",
1251 json!({
1252 "foo.md": "Hello!\nBye"
1253 }),
1254 )
1255 .await;
1256 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1257
1258 let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1259 let diagnostic = lsp::Diagnostic {
1260 range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1261 severity: Some(lsp::DiagnosticSeverity::ERROR),
1262 message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1263 ..Default::default()
1264 };
1265
1266 project.update(cx, |project, cx| {
1267 project.lsp_store().update(cx, |lsp_store, cx| {
1268 // Create some diagnostics
1269 lsp_store
1270 .update_diagnostics(
1271 LanguageServerId(0),
1272 lsp::PublishDiagnosticsParams {
1273 uri: path_to_buffer_uri.clone(),
1274 diagnostics: vec![diagnostic],
1275 version: None,
1276 },
1277 None,
1278 language::DiagnosticSourceKind::Pushed,
1279 &[],
1280 cx,
1281 )
1282 .unwrap();
1283 });
1284 });
1285
1286 let buffer = project
1287 .update(cx, |project, cx| {
1288 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1289 project.open_buffer(path, cx)
1290 })
1291 .await
1292 .unwrap();
1293
1294 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1295 let position = snapshot.anchor_before(language::Point::new(0, 0));
1296
1297 let _prediction_task = zeta.update(cx, |zeta, cx| {
1298 zeta.request_prediction(&project, &buffer, position, cx)
1299 });
1300
1301 let (request, _respond_tx) = req_rx.next().await.unwrap();
1302
1303 assert_eq!(request.diagnostic_groups.len(), 1);
1304 let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1305 .unwrap();
1306 // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1307 assert_eq!(
1308 value,
1309 json!({
1310 "entries": [{
1311 "range": {
1312 "start": 8,
1313 "end": 10
1314 },
1315 "diagnostic": {
1316 "source": null,
1317 "code": null,
1318 "code_description": null,
1319 "severity": 1,
1320 "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1321 "markdown": null,
1322 "group_id": 0,
1323 "is_primary": true,
1324 "is_disk_based": false,
1325 "is_unnecessary": false,
1326 "source_kind": "Pushed",
1327 "data": null,
1328 "underline": true
1329 }
1330 }],
1331 "primary_ix": 0
1332 })
1333 );
1334 }
1335
1336 fn init_test(
1337 cx: &mut TestAppContext,
1338 ) -> (
1339 Entity<Zeta>,
1340 mpsc::UnboundedReceiver<(
1341 predict_edits_v3::PredictEditsRequest,
1342 oneshot::Sender<predict_edits_v3::PredictEditsResponse>,
1343 )>,
1344 ) {
1345 cx.update(move |cx| {
1346 let settings_store = SettingsStore::test(cx);
1347 cx.set_global(settings_store);
1348 language::init(cx);
1349 Project::init_settings(cx);
1350
1351 let (req_tx, req_rx) = mpsc::unbounded();
1352
1353 let http_client = FakeHttpClient::create({
1354 move |req| {
1355 let uri = req.uri().path().to_string();
1356 let mut body = req.into_body();
1357 let req_tx = req_tx.clone();
1358 async move {
1359 let resp = match uri.as_str() {
1360 "/client/llm_tokens" => serde_json::to_string(&json!({
1361 "token": "test"
1362 }))
1363 .unwrap(),
1364 "/predict_edits/v3" => {
1365 let mut buf = Vec::new();
1366 body.read_to_end(&mut buf).await.ok();
1367 let req = serde_json::from_slice(&buf).unwrap();
1368
1369 let (res_tx, res_rx) = oneshot::channel();
1370 req_tx.unbounded_send((req, res_tx)).unwrap();
1371 serde_json::to_string(&res_rx.await.unwrap()).unwrap()
1372 }
1373 _ => {
1374 panic!("Unexpected path: {}", uri)
1375 }
1376 };
1377
1378 Ok(Response::builder().body(resp.into()).unwrap())
1379 }
1380 }
1381 });
1382
1383 let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1384 client.cloud_client().set_credentials(1, "test".into());
1385
1386 language_model::init(client.clone(), cx);
1387
1388 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1389 let zeta = Zeta::global(&client, &user_store, cx);
1390
1391 (zeta, req_rx)
1392 })
1393 }
1394}