1use anyhow::{Context as _, Result, anyhow};
2use chrono::TimeDelta;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
5use cloud_llm_client::{
6 EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, ZED_VERSION_HEADER_NAME,
7};
8use cloud_zeta2_prompt::DEFAULT_MAX_PROMPT_BYTES;
9use edit_prediction_context::{
10 DeclarationId, EditPredictionContext, EditPredictionExcerptOptions, SyntaxIndex,
11 SyntaxIndexState,
12};
13use futures::AsyncReadExt as _;
14use futures::channel::mpsc;
15use gpui::http_client::Method;
16use gpui::{
17 App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
18 http_client, prelude::*,
19};
20use language::{Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
21use language::{BufferSnapshot, TextBufferSnapshot};
22use language_model::{LlmApiToken, RefreshLlmTokenListener};
23use project::Project;
24use release_channel::AppVersion;
25use std::collections::{HashMap, VecDeque, hash_map};
26use std::path::Path;
27use std::str::FromStr as _;
28use std::sync::Arc;
29use std::time::{Duration, Instant};
30use thiserror::Error;
31use util::rel_path::RelPathBuf;
32use util::some_or_debug_panic;
33use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
34
35mod prediction;
36mod provider;
37
38use crate::prediction::EditPrediction;
39pub use provider::ZetaEditPredictionProvider;
40
41const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
42
43/// Maximum number of events to track.
44const MAX_EVENT_COUNT: usize = 16;
45
46pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
47 max_bytes: 512,
48 min_bytes: 128,
49 target_before_cursor_over_total_bytes: 0.5,
50};
51
52pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
53 excerpt: DEFAULT_EXCERPT_OPTIONS,
54 max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
55 max_diagnostic_bytes: 2048,
56 prompt_format: PromptFormat::DEFAULT,
57};
58
59#[derive(Clone)]
60struct ZetaGlobal(Entity<Zeta>);
61
62impl Global for ZetaGlobal {}
63
64pub struct Zeta {
65 client: Arc<Client>,
66 user_store: Entity<UserStore>,
67 llm_token: LlmApiToken,
68 _llm_token_subscription: Subscription,
69 projects: HashMap<EntityId, ZetaProject>,
70 options: ZetaOptions,
71 update_required: bool,
72 debug_tx: Option<mpsc::UnboundedSender<Result<PredictionDebugInfo, String>>>,
73}
74
75#[derive(Debug, Clone, PartialEq)]
76pub struct ZetaOptions {
77 pub excerpt: EditPredictionExcerptOptions,
78 pub max_prompt_bytes: usize,
79 pub max_diagnostic_bytes: usize,
80 pub prompt_format: predict_edits_v3::PromptFormat,
81}
82
83pub struct PredictionDebugInfo {
84 pub context: EditPredictionContext,
85 pub retrieval_time: TimeDelta,
86 pub request: RequestDebugInfo,
87 pub buffer: WeakEntity<Buffer>,
88 pub position: language::Anchor,
89}
90
91pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
92
93struct ZetaProject {
94 syntax_index: Entity<SyntaxIndex>,
95 events: VecDeque<Event>,
96 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
97 current_prediction: Option<CurrentEditPrediction>,
98}
99
100#[derive(Clone)]
101struct CurrentEditPrediction {
102 pub requested_by_buffer_id: EntityId,
103 pub prediction: EditPrediction,
104}
105
106impl CurrentEditPrediction {
107 fn should_replace_prediction(
108 &self,
109 old_prediction: &Self,
110 snapshot: &TextBufferSnapshot,
111 ) -> bool {
112 if self.requested_by_buffer_id != old_prediction.requested_by_buffer_id {
113 return true;
114 }
115
116 let Some(old_edits) = old_prediction.prediction.interpolate(snapshot) else {
117 return true;
118 };
119
120 let Some(new_edits) = self.prediction.interpolate(snapshot) else {
121 return false;
122 };
123 if old_edits.len() == 1 && new_edits.len() == 1 {
124 let (old_range, old_text) = &old_edits[0];
125 let (new_range, new_text) = &new_edits[0];
126 new_range == old_range && new_text.starts_with(old_text)
127 } else {
128 true
129 }
130 }
131}
132
133/// A prediction from the perspective of a buffer.
134#[derive(Debug)]
135enum BufferEditPrediction<'a> {
136 Local { prediction: &'a EditPrediction },
137 Jump { prediction: &'a EditPrediction },
138}
139
140struct RegisteredBuffer {
141 snapshot: BufferSnapshot,
142 _subscriptions: [gpui::Subscription; 2],
143}
144
145#[derive(Clone)]
146pub enum Event {
147 BufferChange {
148 old_snapshot: BufferSnapshot,
149 new_snapshot: BufferSnapshot,
150 timestamp: Instant,
151 },
152}
153
154impl Zeta {
155 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
156 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
157 }
158
159 pub fn global(
160 client: &Arc<Client>,
161 user_store: &Entity<UserStore>,
162 cx: &mut App,
163 ) -> Entity<Self> {
164 cx.try_global::<ZetaGlobal>()
165 .map(|global| global.0.clone())
166 .unwrap_or_else(|| {
167 let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
168 cx.set_global(ZetaGlobal(zeta.clone()));
169 zeta
170 })
171 }
172
173 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
174 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
175
176 Self {
177 projects: HashMap::new(),
178 client,
179 user_store,
180 options: DEFAULT_OPTIONS,
181 llm_token: LlmApiToken::default(),
182 _llm_token_subscription: cx.subscribe(
183 &refresh_llm_token_listener,
184 |this, _listener, _event, cx| {
185 let client = this.client.clone();
186 let llm_token = this.llm_token.clone();
187 cx.spawn(async move |_this, _cx| {
188 llm_token.refresh(&client).await?;
189 anyhow::Ok(())
190 })
191 .detach_and_log_err(cx);
192 },
193 ),
194 update_required: false,
195 debug_tx: None,
196 }
197 }
198
199 pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<Result<PredictionDebugInfo, String>> {
200 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
201 self.debug_tx = Some(debug_watch_tx);
202 debug_watch_rx
203 }
204
205 pub fn options(&self) -> &ZetaOptions {
206 &self.options
207 }
208
209 pub fn set_options(&mut self, options: ZetaOptions) {
210 self.options = options;
211 }
212
213 pub fn clear_history(&mut self) {
214 for zeta_project in self.projects.values_mut() {
215 zeta_project.events.clear();
216 }
217 }
218
219 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
220 self.user_store.read(cx).edit_prediction_usage()
221 }
222
223 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
224 self.get_or_init_zeta_project(project, cx);
225 }
226
227 pub fn register_buffer(
228 &mut self,
229 buffer: &Entity<Buffer>,
230 project: &Entity<Project>,
231 cx: &mut Context<Self>,
232 ) {
233 let zeta_project = self.get_or_init_zeta_project(project, cx);
234 Self::register_buffer_impl(zeta_project, buffer, project, cx);
235 }
236
237 fn get_or_init_zeta_project(
238 &mut self,
239 project: &Entity<Project>,
240 cx: &mut App,
241 ) -> &mut ZetaProject {
242 self.projects
243 .entry(project.entity_id())
244 .or_insert_with(|| ZetaProject {
245 syntax_index: cx.new(|cx| SyntaxIndex::new(project, cx)),
246 events: VecDeque::new(),
247 registered_buffers: HashMap::new(),
248 current_prediction: None,
249 })
250 }
251
252 fn register_buffer_impl<'a>(
253 zeta_project: &'a mut ZetaProject,
254 buffer: &Entity<Buffer>,
255 project: &Entity<Project>,
256 cx: &mut Context<Self>,
257 ) -> &'a mut RegisteredBuffer {
258 let buffer_id = buffer.entity_id();
259 match zeta_project.registered_buffers.entry(buffer_id) {
260 hash_map::Entry::Occupied(entry) => entry.into_mut(),
261 hash_map::Entry::Vacant(entry) => {
262 let snapshot = buffer.read(cx).snapshot();
263 let project_entity_id = project.entity_id();
264 entry.insert(RegisteredBuffer {
265 snapshot,
266 _subscriptions: [
267 cx.subscribe(buffer, {
268 let project = project.downgrade();
269 move |this, buffer, event, cx| {
270 if let language::BufferEvent::Edited = event
271 && let Some(project) = project.upgrade()
272 {
273 this.report_changes_for_buffer(&buffer, &project, cx);
274 }
275 }
276 }),
277 cx.observe_release(buffer, move |this, _buffer, _cx| {
278 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
279 else {
280 return;
281 };
282 zeta_project.registered_buffers.remove(&buffer_id);
283 }),
284 ],
285 })
286 }
287 }
288 }
289
290 fn report_changes_for_buffer(
291 &mut self,
292 buffer: &Entity<Buffer>,
293 project: &Entity<Project>,
294 cx: &mut Context<Self>,
295 ) -> BufferSnapshot {
296 let zeta_project = self.get_or_init_zeta_project(project, cx);
297 let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
298
299 let new_snapshot = buffer.read(cx).snapshot();
300 if new_snapshot.version != registered_buffer.snapshot.version {
301 let old_snapshot =
302 std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
303 Self::push_event(
304 zeta_project,
305 Event::BufferChange {
306 old_snapshot,
307 new_snapshot: new_snapshot.clone(),
308 timestamp: Instant::now(),
309 },
310 );
311 }
312
313 new_snapshot
314 }
315
316 fn push_event(zeta_project: &mut ZetaProject, event: Event) {
317 let events = &mut zeta_project.events;
318
319 if let Some(Event::BufferChange {
320 new_snapshot: last_new_snapshot,
321 timestamp: last_timestamp,
322 ..
323 }) = events.back_mut()
324 {
325 // Coalesce edits for the same buffer when they happen one after the other.
326 let Event::BufferChange {
327 old_snapshot,
328 new_snapshot,
329 timestamp,
330 } = &event;
331
332 if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
333 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
334 && old_snapshot.version == last_new_snapshot.version
335 {
336 *last_new_snapshot = new_snapshot.clone();
337 *last_timestamp = *timestamp;
338 return;
339 }
340 }
341
342 if events.len() >= MAX_EVENT_COUNT {
343 // These are halved instead of popping to improve prompt caching.
344 events.drain(..MAX_EVENT_COUNT / 2);
345 }
346
347 events.push_back(event);
348 }
349
350 fn current_prediction_for_buffer(
351 &self,
352 buffer: &Entity<Buffer>,
353 project: &Entity<Project>,
354 cx: &App,
355 ) -> Option<BufferEditPrediction<'_>> {
356 let project_state = self.projects.get(&project.entity_id())?;
357
358 let CurrentEditPrediction {
359 requested_by_buffer_id,
360 prediction,
361 } = project_state.current_prediction.as_ref()?;
362
363 if prediction.targets_buffer(buffer.read(cx), cx) {
364 Some(BufferEditPrediction::Local { prediction })
365 } else if *requested_by_buffer_id == buffer.entity_id() {
366 Some(BufferEditPrediction::Jump { prediction })
367 } else {
368 None
369 }
370 }
371
372 fn accept_current_prediction(&mut self, project: &Entity<Project>) {
373 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
374 project_state.current_prediction.take();
375 };
376 // TODO report accepted
377 }
378
379 fn discard_current_prediction(&mut self, project: &Entity<Project>) {
380 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
381 project_state.current_prediction.take();
382 };
383 }
384
385 pub fn refresh_prediction(
386 &mut self,
387 project: &Entity<Project>,
388 buffer: &Entity<Buffer>,
389 position: language::Anchor,
390 cx: &mut Context<Self>,
391 ) -> Task<Result<()>> {
392 let request_task = self.request_prediction(project, buffer, position, cx);
393 let buffer = buffer.clone();
394 let project = project.clone();
395
396 cx.spawn(async move |this, cx| {
397 if let Some(prediction) = request_task.await? {
398 this.update(cx, |this, cx| {
399 let project_state = this
400 .projects
401 .get_mut(&project.entity_id())
402 .context("Project not found")?;
403
404 let new_prediction = CurrentEditPrediction {
405 requested_by_buffer_id: buffer.entity_id(),
406 prediction: prediction,
407 };
408
409 if project_state
410 .current_prediction
411 .as_ref()
412 .is_none_or(|old_prediction| {
413 new_prediction
414 .should_replace_prediction(&old_prediction, buffer.read(cx))
415 })
416 {
417 project_state.current_prediction = Some(new_prediction);
418 }
419 anyhow::Ok(())
420 })??;
421 }
422 Ok(())
423 })
424 }
425
426 fn request_prediction(
427 &mut self,
428 project: &Entity<Project>,
429 buffer: &Entity<Buffer>,
430 position: language::Anchor,
431 cx: &mut Context<Self>,
432 ) -> Task<Result<Option<EditPrediction>>> {
433 let project_state = self.projects.get(&project.entity_id());
434
435 let index_state = project_state.map(|state| {
436 state
437 .syntax_index
438 .read_with(cx, |index, _cx| index.state().clone())
439 });
440 let options = self.options.clone();
441 let snapshot = buffer.read(cx).snapshot();
442 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx).into()) else {
443 return Task::ready(Err(anyhow!("No file path for excerpt")));
444 };
445 let client = self.client.clone();
446 let llm_token = self.llm_token.clone();
447 let app_version = AppVersion::global(cx);
448 let worktree_snapshots = project
449 .read(cx)
450 .worktrees(cx)
451 .map(|worktree| worktree.read(cx).snapshot())
452 .collect::<Vec<_>>();
453 let debug_tx = self.debug_tx.clone();
454
455 let events = project_state
456 .map(|state| {
457 state
458 .events
459 .iter()
460 .filter_map(|event| match event {
461 Event::BufferChange {
462 old_snapshot,
463 new_snapshot,
464 ..
465 } => {
466 let path = new_snapshot.file().map(|f| f.full_path(cx));
467
468 let old_path = old_snapshot.file().and_then(|f| {
469 let old_path = f.full_path(cx);
470 if Some(&old_path) != path.as_ref() {
471 Some(old_path)
472 } else {
473 None
474 }
475 });
476
477 // TODO [zeta2] move to bg?
478 let diff =
479 language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
480
481 if path == old_path && diff.is_empty() {
482 None
483 } else {
484 Some(predict_edits_v3::Event::BufferChange {
485 old_path,
486 path,
487 diff,
488 //todo: Actually detect if this edit was predicted or not
489 predicted: false,
490 })
491 }
492 }
493 })
494 .collect::<Vec<_>>()
495 })
496 .unwrap_or_default();
497
498 let diagnostics = snapshot.diagnostic_sets().clone();
499
500 let request_task = cx.background_spawn({
501 let snapshot = snapshot.clone();
502 let buffer = buffer.clone();
503 async move {
504 let index_state = if let Some(index_state) = index_state {
505 Some(index_state.lock_owned().await)
506 } else {
507 None
508 };
509
510 let cursor_offset = position.to_offset(&snapshot);
511 let cursor_point = cursor_offset.to_point(&snapshot);
512
513 let before_retrieval = chrono::Utc::now();
514
515 let Some(context) = EditPredictionContext::gather_context(
516 cursor_point,
517 &snapshot,
518 &options.excerpt,
519 index_state.as_deref(),
520 ) else {
521 return Ok(None);
522 };
523
524 let debug_context = if let Some(debug_tx) = debug_tx {
525 Some((debug_tx, context.clone()))
526 } else {
527 None
528 };
529
530 let (diagnostic_groups, diagnostic_groups_truncated) =
531 Self::gather_nearby_diagnostics(
532 cursor_offset,
533 &diagnostics,
534 &snapshot,
535 options.max_diagnostic_bytes,
536 );
537
538 let request = make_cloud_request(
539 excerpt_path,
540 context,
541 events,
542 // TODO data collection
543 false,
544 diagnostic_groups,
545 diagnostic_groups_truncated,
546 None,
547 debug_context.is_some(),
548 &worktree_snapshots,
549 index_state.as_deref(),
550 Some(options.max_prompt_bytes),
551 options.prompt_format,
552 );
553
554 let retrieval_time = chrono::Utc::now() - before_retrieval;
555 let response = Self::perform_request(client, llm_token, app_version, request).await;
556
557 if let Some((debug_tx, context)) = debug_context {
558 debug_tx
559 .unbounded_send(response.as_ref().map_err(|err| err.to_string()).and_then(
560 |response| {
561 let Some(request) =
562 some_or_debug_panic(response.0.debug_info.clone())
563 else {
564 return Err("Missing debug info".to_string());
565 };
566 Ok(PredictionDebugInfo {
567 context,
568 request,
569 retrieval_time,
570 buffer: buffer.downgrade(),
571 position,
572 })
573 },
574 ))
575 .ok();
576 }
577
578 anyhow::Ok(Some(response?))
579 }
580 });
581
582 let buffer = buffer.clone();
583
584 cx.spawn({
585 let project = project.clone();
586 async move |this, cx| {
587 match request_task.await {
588 Ok(Some((response, usage))) => {
589 if let Some(usage) = usage {
590 this.update(cx, |this, cx| {
591 this.user_store.update(cx, |user_store, cx| {
592 user_store.update_edit_prediction_usage(usage, cx);
593 });
594 })
595 .ok();
596 }
597
598 let prediction = EditPrediction::from_response(
599 response, &snapshot, &buffer, &project, cx,
600 )
601 .await;
602
603 // TODO telemetry: duration, etc
604 Ok(prediction)
605 }
606 Ok(None) => Ok(None),
607 Err(err) => {
608 if err.is::<ZedUpdateRequiredError>() {
609 cx.update(|cx| {
610 this.update(cx, |this, _cx| {
611 this.update_required = true;
612 })
613 .ok();
614
615 let error_message: SharedString = err.to_string().into();
616 show_app_notification(
617 NotificationId::unique::<ZedUpdateRequiredError>(),
618 cx,
619 move |cx| {
620 cx.new(|cx| {
621 ErrorMessagePrompt::new(error_message.clone(), cx)
622 .with_link_button(
623 "Update Zed",
624 "https://zed.dev/releases",
625 )
626 })
627 },
628 );
629 })
630 .ok();
631 }
632
633 Err(err)
634 }
635 }
636 }
637 })
638 }
639
640 async fn perform_request(
641 client: Arc<Client>,
642 llm_token: LlmApiToken,
643 app_version: SemanticVersion,
644 request: predict_edits_v3::PredictEditsRequest,
645 ) -> Result<(
646 predict_edits_v3::PredictEditsResponse,
647 Option<EditPredictionUsage>,
648 )> {
649 let http_client = client.http_client();
650 let mut token = llm_token.acquire(&client).await?;
651 let mut did_retry = false;
652
653 loop {
654 let request_builder = http_client::Request::builder().method(Method::POST);
655 let request_builder =
656 if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
657 request_builder.uri(predict_edits_url)
658 } else {
659 request_builder.uri(
660 http_client
661 .build_zed_llm_url("/predict_edits/v3", &[])?
662 .as_ref(),
663 )
664 };
665 let request = request_builder
666 .header("Content-Type", "application/json")
667 .header("Authorization", format!("Bearer {}", token))
668 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
669 .body(serde_json::to_string(&request)?.into())?;
670
671 let mut response = http_client.send(request).await?;
672
673 if let Some(minimum_required_version) = response
674 .headers()
675 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
676 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
677 {
678 anyhow::ensure!(
679 app_version >= minimum_required_version,
680 ZedUpdateRequiredError {
681 minimum_version: minimum_required_version
682 }
683 );
684 }
685
686 if response.status().is_success() {
687 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
688
689 let mut body = Vec::new();
690 response.body_mut().read_to_end(&mut body).await?;
691 return Ok((serde_json::from_slice(&body)?, usage));
692 } else if !did_retry
693 && response
694 .headers()
695 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
696 .is_some()
697 {
698 did_retry = true;
699 token = llm_token.refresh(&client).await?;
700 } else {
701 let mut body = String::new();
702 response.body_mut().read_to_string(&mut body).await?;
703 anyhow::bail!(
704 "error predicting edits.\nStatus: {:?}\nBody: {}",
705 response.status(),
706 body
707 );
708 }
709 }
710 }
711
712 fn gather_nearby_diagnostics(
713 cursor_offset: usize,
714 diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
715 snapshot: &BufferSnapshot,
716 max_diagnostics_bytes: usize,
717 ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
718 // TODO: Could make this more efficient
719 let mut diagnostic_groups = Vec::new();
720 for (language_server_id, diagnostics) in diagnostic_sets {
721 let mut groups = Vec::new();
722 diagnostics.groups(*language_server_id, &mut groups, &snapshot);
723 diagnostic_groups.extend(
724 groups
725 .into_iter()
726 .map(|(_, group)| group.resolve::<usize>(&snapshot)),
727 );
728 }
729
730 // sort by proximity to cursor
731 diagnostic_groups.sort_by_key(|group| {
732 let range = &group.entries[group.primary_ix].range;
733 if range.start >= cursor_offset {
734 range.start - cursor_offset
735 } else if cursor_offset >= range.end {
736 cursor_offset - range.end
737 } else {
738 (cursor_offset - range.start).min(range.end - cursor_offset)
739 }
740 });
741
742 let mut results = Vec::new();
743 let mut diagnostic_groups_truncated = false;
744 let mut diagnostics_byte_count = 0;
745 for group in diagnostic_groups {
746 let raw_value = serde_json::value::to_raw_value(&group).unwrap();
747 diagnostics_byte_count += raw_value.get().len();
748 if diagnostics_byte_count > max_diagnostics_bytes {
749 diagnostic_groups_truncated = true;
750 break;
751 }
752 results.push(predict_edits_v3::DiagnosticGroup(raw_value));
753 }
754
755 (results, diagnostic_groups_truncated)
756 }
757
758 // TODO: Dedupe with similar code in request_prediction?
759 pub fn cloud_request_for_zeta_cli(
760 &mut self,
761 project: &Entity<Project>,
762 buffer: &Entity<Buffer>,
763 position: language::Anchor,
764 cx: &mut Context<Self>,
765 ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
766 let project_state = self.projects.get(&project.entity_id());
767
768 let index_state = project_state.map(|state| {
769 state
770 .syntax_index
771 .read_with(cx, |index, _cx| index.state().clone())
772 });
773 let options = self.options.clone();
774 let snapshot = buffer.read(cx).snapshot();
775 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
776 return Task::ready(Err(anyhow!("No file path for excerpt")));
777 };
778 let worktree_snapshots = project
779 .read(cx)
780 .worktrees(cx)
781 .map(|worktree| worktree.read(cx).snapshot())
782 .collect::<Vec<_>>();
783
784 cx.background_spawn(async move {
785 let index_state = if let Some(index_state) = index_state {
786 Some(index_state.lock_owned().await)
787 } else {
788 None
789 };
790
791 let cursor_point = position.to_point(&snapshot);
792
793 let debug_info = true;
794 EditPredictionContext::gather_context(
795 cursor_point,
796 &snapshot,
797 &options.excerpt,
798 index_state.as_deref(),
799 )
800 .context("Failed to select excerpt")
801 .map(|context| {
802 make_cloud_request(
803 excerpt_path.into(),
804 context,
805 // TODO pass everything
806 Vec::new(),
807 false,
808 Vec::new(),
809 false,
810 None,
811 debug_info,
812 &worktree_snapshots,
813 index_state.as_deref(),
814 Some(options.max_prompt_bytes),
815 options.prompt_format,
816 )
817 })
818 })
819 }
820}
821
822#[derive(Error, Debug)]
823#[error(
824 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
825)]
826pub struct ZedUpdateRequiredError {
827 minimum_version: SemanticVersion,
828}
829
830fn make_cloud_request(
831 excerpt_path: Arc<Path>,
832 context: EditPredictionContext,
833 events: Vec<predict_edits_v3::Event>,
834 can_collect_data: bool,
835 diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
836 diagnostic_groups_truncated: bool,
837 git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
838 debug_info: bool,
839 worktrees: &Vec<worktree::Snapshot>,
840 index_state: Option<&SyntaxIndexState>,
841 prompt_max_bytes: Option<usize>,
842 prompt_format: PromptFormat,
843) -> predict_edits_v3::PredictEditsRequest {
844 let mut signatures = Vec::new();
845 let mut declaration_to_signature_index = HashMap::default();
846 let mut referenced_declarations = Vec::new();
847
848 for snippet in context.declarations {
849 let project_entry_id = snippet.declaration.project_entry_id();
850 let Some(path) = worktrees.iter().find_map(|worktree| {
851 worktree.entry_for_id(project_entry_id).map(|entry| {
852 let mut full_path = RelPathBuf::new();
853 full_path.push(worktree.root_name());
854 full_path.push(&entry.path);
855 full_path
856 })
857 }) else {
858 continue;
859 };
860
861 let parent_index = index_state.and_then(|index_state| {
862 snippet.declaration.parent().and_then(|parent| {
863 add_signature(
864 parent,
865 &mut declaration_to_signature_index,
866 &mut signatures,
867 index_state,
868 )
869 })
870 });
871
872 let (text, text_is_truncated) = snippet.declaration.item_text();
873 referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
874 path: path.as_std_path().into(),
875 text: text.into(),
876 range: snippet.declaration.item_range(),
877 text_is_truncated,
878 signature_range: snippet.declaration.signature_range_in_item_text(),
879 parent_index,
880 score_components: snippet.score_components,
881 signature_score: snippet.scores.signature,
882 declaration_score: snippet.scores.declaration,
883 });
884 }
885
886 let excerpt_parent = index_state.and_then(|index_state| {
887 context
888 .excerpt
889 .parent_declarations
890 .last()
891 .and_then(|(parent, _)| {
892 add_signature(
893 *parent,
894 &mut declaration_to_signature_index,
895 &mut signatures,
896 index_state,
897 )
898 })
899 });
900
901 predict_edits_v3::PredictEditsRequest {
902 excerpt_path,
903 excerpt: context.excerpt_text.body,
904 excerpt_range: context.excerpt.range,
905 cursor_offset: context.cursor_offset_in_excerpt,
906 referenced_declarations,
907 signatures,
908 excerpt_parent,
909 events,
910 can_collect_data,
911 diagnostic_groups,
912 diagnostic_groups_truncated,
913 git_info,
914 debug_info,
915 prompt_max_bytes,
916 prompt_format,
917 }
918}
919
920fn add_signature(
921 declaration_id: DeclarationId,
922 declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
923 signatures: &mut Vec<Signature>,
924 index: &SyntaxIndexState,
925) -> Option<usize> {
926 if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
927 return Some(*signature_index);
928 }
929 let Some(parent_declaration) = index.declaration(declaration_id) else {
930 log::error!("bug: missing parent declaration");
931 return None;
932 };
933 let parent_index = parent_declaration.parent().and_then(|parent| {
934 add_signature(parent, declaration_to_signature_index, signatures, index)
935 });
936 let (text, text_is_truncated) = parent_declaration.signature_text();
937 let signature_index = signatures.len();
938 signatures.push(Signature {
939 text: text.into(),
940 text_is_truncated,
941 parent_index,
942 range: parent_declaration.signature_range(),
943 });
944 declaration_to_signature_index.insert(declaration_id, signature_index);
945 Some(signature_index)
946}
947
948#[cfg(test)]
949mod tests {
950 use std::{
951 path::{Path, PathBuf},
952 sync::Arc,
953 };
954
955 use client::UserStore;
956 use clock::FakeSystemClock;
957 use cloud_llm_client::predict_edits_v3;
958 use futures::{
959 AsyncReadExt, StreamExt,
960 channel::{mpsc, oneshot},
961 };
962 use gpui::{
963 Entity, TestAppContext,
964 http_client::{FakeHttpClient, Response},
965 prelude::*,
966 };
967 use indoc::indoc;
968 use language::{LanguageServerId, OffsetRangeExt as _};
969 use pretty_assertions::{assert_eq, assert_matches};
970 use project::{FakeFs, Project};
971 use serde_json::json;
972 use settings::SettingsStore;
973 use util::path;
974 use uuid::Uuid;
975
976 use crate::{BufferEditPrediction, Zeta};
977
978 #[gpui::test]
979 async fn test_current_state(cx: &mut TestAppContext) {
980 let (zeta, mut req_rx) = init_test(cx);
981 let fs = FakeFs::new(cx.executor());
982 fs.insert_tree(
983 "/root",
984 json!({
985 "1.txt": "Hello!\nHow\nBye",
986 "2.txt": "Hola!\nComo\nAdios"
987 }),
988 )
989 .await;
990 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
991
992 zeta.update(cx, |zeta, cx| {
993 zeta.register_project(&project, cx);
994 });
995
996 let buffer1 = project
997 .update(cx, |project, cx| {
998 let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
999 project.open_buffer(path, cx)
1000 })
1001 .await
1002 .unwrap();
1003 let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
1004 let position = snapshot1.anchor_before(language::Point::new(1, 3));
1005
1006 // Prediction for current file
1007
1008 let prediction_task = zeta.update(cx, |zeta, cx| {
1009 zeta.refresh_prediction(&project, &buffer1, position, cx)
1010 });
1011 let (_request, respond_tx) = req_rx.next().await.unwrap();
1012 respond_tx
1013 .send(predict_edits_v3::PredictEditsResponse {
1014 request_id: Uuid::new_v4(),
1015 edits: vec![predict_edits_v3::Edit {
1016 path: Path::new(path!("root/1.txt")).into(),
1017 range: 0..snapshot1.len(),
1018 content: "Hello!\nHow are you?\nBye".into(),
1019 }],
1020 debug_info: None,
1021 })
1022 .unwrap();
1023 prediction_task.await.unwrap();
1024
1025 zeta.read_with(cx, |zeta, cx| {
1026 let prediction = zeta
1027 .current_prediction_for_buffer(&buffer1, &project, cx)
1028 .unwrap();
1029 assert_matches!(prediction, BufferEditPrediction::Local { .. });
1030 });
1031
1032 // Prediction for another file
1033
1034 let prediction_task = zeta.update(cx, |zeta, cx| {
1035 zeta.refresh_prediction(&project, &buffer1, position, cx)
1036 });
1037 let (_request, respond_tx) = req_rx.next().await.unwrap();
1038 respond_tx
1039 .send(predict_edits_v3::PredictEditsResponse {
1040 request_id: Uuid::new_v4(),
1041 edits: vec![predict_edits_v3::Edit {
1042 path: Path::new(path!("root/2.txt")).into(),
1043 range: 0..snapshot1.len(),
1044 content: "Hola!\nComo estas?\nAdios".into(),
1045 }],
1046 debug_info: None,
1047 })
1048 .unwrap();
1049 prediction_task.await.unwrap();
1050
1051 zeta.read_with(cx, |zeta, cx| {
1052 let prediction = zeta
1053 .current_prediction_for_buffer(&buffer1, &project, cx)
1054 .unwrap();
1055 assert_matches!(
1056 prediction,
1057 BufferEditPrediction::Jump { prediction } if prediction.path.as_ref() == Path::new(path!("root/2.txt"))
1058 );
1059 });
1060
1061 let buffer2 = project
1062 .update(cx, |project, cx| {
1063 let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
1064 project.open_buffer(path, cx)
1065 })
1066 .await
1067 .unwrap();
1068
1069 zeta.read_with(cx, |zeta, cx| {
1070 let prediction = zeta
1071 .current_prediction_for_buffer(&buffer2, &project, cx)
1072 .unwrap();
1073 assert_matches!(prediction, BufferEditPrediction::Local { .. });
1074 });
1075 }
1076
1077 #[gpui::test]
1078 async fn test_simple_request(cx: &mut TestAppContext) {
1079 let (zeta, mut req_rx) = init_test(cx);
1080 let fs = FakeFs::new(cx.executor());
1081 fs.insert_tree(
1082 "/root",
1083 json!({
1084 "foo.md": "Hello!\nHow\nBye"
1085 }),
1086 )
1087 .await;
1088 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1089
1090 let buffer = project
1091 .update(cx, |project, cx| {
1092 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1093 project.open_buffer(path, cx)
1094 })
1095 .await
1096 .unwrap();
1097 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1098 let position = snapshot.anchor_before(language::Point::new(1, 3));
1099
1100 let prediction_task = zeta.update(cx, |zeta, cx| {
1101 zeta.request_prediction(&project, &buffer, position, cx)
1102 });
1103
1104 let (request, respond_tx) = req_rx.next().await.unwrap();
1105 assert_eq!(
1106 request.excerpt_path.as_ref(),
1107 Path::new(path!("root/foo.md"))
1108 );
1109 assert_eq!(request.cursor_offset, 10);
1110
1111 respond_tx
1112 .send(predict_edits_v3::PredictEditsResponse {
1113 request_id: Uuid::new_v4(),
1114 edits: vec![predict_edits_v3::Edit {
1115 path: Path::new(path!("root/foo.md")).into(),
1116 range: 0..snapshot.len(),
1117 content: "Hello!\nHow are you?\nBye".into(),
1118 }],
1119 debug_info: None,
1120 })
1121 .unwrap();
1122
1123 let prediction = prediction_task.await.unwrap().unwrap();
1124
1125 assert_eq!(prediction.edits.len(), 1);
1126 assert_eq!(
1127 prediction.edits[0].0.to_point(&snapshot).start,
1128 language::Point::new(1, 3)
1129 );
1130 assert_eq!(prediction.edits[0].1, " are you?");
1131 }
1132
1133 #[gpui::test]
1134 async fn test_request_events(cx: &mut TestAppContext) {
1135 let (zeta, mut req_rx) = init_test(cx);
1136 let fs = FakeFs::new(cx.executor());
1137 fs.insert_tree(
1138 "/root",
1139 json!({
1140 "foo.md": "Hello!\n\nBye"
1141 }),
1142 )
1143 .await;
1144 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1145
1146 let buffer = project
1147 .update(cx, |project, cx| {
1148 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1149 project.open_buffer(path, cx)
1150 })
1151 .await
1152 .unwrap();
1153
1154 zeta.update(cx, |zeta, cx| {
1155 zeta.register_buffer(&buffer, &project, cx);
1156 });
1157
1158 buffer.update(cx, |buffer, cx| {
1159 buffer.edit(vec![(7..7, "How")], None, cx);
1160 });
1161
1162 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1163 let position = snapshot.anchor_before(language::Point::new(1, 3));
1164
1165 let prediction_task = zeta.update(cx, |zeta, cx| {
1166 zeta.request_prediction(&project, &buffer, position, cx)
1167 });
1168
1169 let (request, respond_tx) = req_rx.next().await.unwrap();
1170
1171 assert_eq!(request.events.len(), 1);
1172 assert_eq!(
1173 request.events[0],
1174 predict_edits_v3::Event::BufferChange {
1175 path: Some(PathBuf::from(path!("root/foo.md"))),
1176 old_path: None,
1177 diff: indoc! {"
1178 @@ -1,3 +1,3 @@
1179 Hello!
1180 -
1181 +How
1182 Bye
1183 "}
1184 .to_string(),
1185 predicted: false
1186 }
1187 );
1188
1189 respond_tx
1190 .send(predict_edits_v3::PredictEditsResponse {
1191 request_id: Uuid::new_v4(),
1192 edits: vec![predict_edits_v3::Edit {
1193 path: Path::new(path!("root/foo.md")).into(),
1194 range: 0..snapshot.len(),
1195 content: "Hello!\nHow are you?\nBye".into(),
1196 }],
1197 debug_info: None,
1198 })
1199 .unwrap();
1200
1201 let prediction = prediction_task.await.unwrap().unwrap();
1202
1203 assert_eq!(prediction.edits.len(), 1);
1204 assert_eq!(
1205 prediction.edits[0].0.to_point(&snapshot).start,
1206 language::Point::new(1, 3)
1207 );
1208 assert_eq!(prediction.edits[0].1, " are you?");
1209 }
1210
1211 #[gpui::test]
1212 async fn test_request_diagnostics(cx: &mut TestAppContext) {
1213 let (zeta, mut req_rx) = init_test(cx);
1214 let fs = FakeFs::new(cx.executor());
1215 fs.insert_tree(
1216 "/root",
1217 json!({
1218 "foo.md": "Hello!\nBye"
1219 }),
1220 )
1221 .await;
1222 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1223
1224 let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1225 let diagnostic = lsp::Diagnostic {
1226 range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1227 severity: Some(lsp::DiagnosticSeverity::ERROR),
1228 message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1229 ..Default::default()
1230 };
1231
1232 project.update(cx, |project, cx| {
1233 project.lsp_store().update(cx, |lsp_store, cx| {
1234 // Create some diagnostics
1235 lsp_store
1236 .update_diagnostics(
1237 LanguageServerId(0),
1238 lsp::PublishDiagnosticsParams {
1239 uri: path_to_buffer_uri.clone(),
1240 diagnostics: vec![diagnostic],
1241 version: None,
1242 },
1243 None,
1244 language::DiagnosticSourceKind::Pushed,
1245 &[],
1246 cx,
1247 )
1248 .unwrap();
1249 });
1250 });
1251
1252 let buffer = project
1253 .update(cx, |project, cx| {
1254 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1255 project.open_buffer(path, cx)
1256 })
1257 .await
1258 .unwrap();
1259
1260 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1261 let position = snapshot.anchor_before(language::Point::new(0, 0));
1262
1263 let _prediction_task = zeta.update(cx, |zeta, cx| {
1264 zeta.request_prediction(&project, &buffer, position, cx)
1265 });
1266
1267 let (request, _respond_tx) = req_rx.next().await.unwrap();
1268
1269 assert_eq!(request.diagnostic_groups.len(), 1);
1270 let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1271 .unwrap();
1272 // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1273 assert_eq!(
1274 value,
1275 json!({
1276 "entries": [{
1277 "range": {
1278 "start": 8,
1279 "end": 10
1280 },
1281 "diagnostic": {
1282 "source": null,
1283 "code": null,
1284 "code_description": null,
1285 "severity": 1,
1286 "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1287 "markdown": null,
1288 "group_id": 0,
1289 "is_primary": true,
1290 "is_disk_based": false,
1291 "is_unnecessary": false,
1292 "source_kind": "Pushed",
1293 "data": null,
1294 "underline": true
1295 }
1296 }],
1297 "primary_ix": 0
1298 })
1299 );
1300 }
1301
1302 fn init_test(
1303 cx: &mut TestAppContext,
1304 ) -> (
1305 Entity<Zeta>,
1306 mpsc::UnboundedReceiver<(
1307 predict_edits_v3::PredictEditsRequest,
1308 oneshot::Sender<predict_edits_v3::PredictEditsResponse>,
1309 )>,
1310 ) {
1311 cx.update(move |cx| {
1312 let settings_store = SettingsStore::test(cx);
1313 cx.set_global(settings_store);
1314 language::init(cx);
1315 Project::init_settings(cx);
1316
1317 let (req_tx, req_rx) = mpsc::unbounded();
1318
1319 let http_client = FakeHttpClient::create({
1320 move |req| {
1321 let uri = req.uri().path().to_string();
1322 let mut body = req.into_body();
1323 let req_tx = req_tx.clone();
1324 async move {
1325 let resp = match uri.as_str() {
1326 "/client/llm_tokens" => serde_json::to_string(&json!({
1327 "token": "test"
1328 }))
1329 .unwrap(),
1330 "/predict_edits/v3" => {
1331 let mut buf = Vec::new();
1332 body.read_to_end(&mut buf).await.ok();
1333 let req = serde_json::from_slice(&buf).unwrap();
1334
1335 let (res_tx, res_rx) = oneshot::channel();
1336 req_tx.unbounded_send((req, res_tx)).unwrap();
1337 serde_json::to_string(&res_rx.await.unwrap()).unwrap()
1338 }
1339 _ => {
1340 panic!("Unexpected path: {}", uri)
1341 }
1342 };
1343
1344 Ok(Response::builder().body(resp.into()).unwrap())
1345 }
1346 }
1347 });
1348
1349 let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1350 client.cloud_client().set_credentials(1, "test".into());
1351
1352 language_model::init(client.clone(), cx);
1353
1354 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1355 let zeta = Zeta::global(&client, &user_store, cx);
1356
1357 (zeta, req_rx)
1358 })
1359 }
1360}