1use anyhow::{Context as _, Result, anyhow};
2use chrono::TimeDelta;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
5use cloud_llm_client::{
6 EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, ZED_VERSION_HEADER_NAME,
7};
8use cloud_zeta2_prompt::DEFAULT_MAX_PROMPT_BYTES;
9use edit_prediction_context::{
10 DeclarationId, EditPredictionContext, EditPredictionExcerptOptions, SyntaxIndex,
11 SyntaxIndexState,
12};
13use futures::AsyncReadExt as _;
14use futures::channel::mpsc;
15use gpui::http_client::Method;
16use gpui::{
17 App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
18 http_client, prelude::*,
19};
20use language::{Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
21use language::{BufferSnapshot, TextBufferSnapshot};
22use language_model::{LlmApiToken, RefreshLlmTokenListener};
23use project::Project;
24use release_channel::AppVersion;
25use std::collections::{HashMap, VecDeque, hash_map};
26use std::path::Path;
27use std::str::FromStr as _;
28use std::sync::Arc;
29use std::time::{Duration, Instant};
30use thiserror::Error;
31use util::rel_path::RelPathBuf;
32use util::some_or_debug_panic;
33use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
34
35mod prediction;
36mod provider;
37
38use crate::prediction::EditPrediction;
39pub use provider::ZetaEditPredictionProvider;
40
41const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
42
43/// Maximum number of events to track.
44const MAX_EVENT_COUNT: usize = 16;
45
46pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
47 max_bytes: 512,
48 min_bytes: 128,
49 target_before_cursor_over_total_bytes: 0.5,
50};
51
52pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
53 excerpt: DEFAULT_EXCERPT_OPTIONS,
54 max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
55 max_diagnostic_bytes: 2048,
56 prompt_format: PromptFormat::DEFAULT,
57 file_indexing_parallelism: 1,
58};
59
60#[derive(Clone)]
61struct ZetaGlobal(Entity<Zeta>);
62
63impl Global for ZetaGlobal {}
64
65pub struct Zeta {
66 client: Arc<Client>,
67 user_store: Entity<UserStore>,
68 llm_token: LlmApiToken,
69 _llm_token_subscription: Subscription,
70 projects: HashMap<EntityId, ZetaProject>,
71 options: ZetaOptions,
72 update_required: bool,
73 debug_tx: Option<mpsc::UnboundedSender<Result<PredictionDebugInfo, String>>>,
74}
75
76#[derive(Debug, Clone, PartialEq)]
77pub struct ZetaOptions {
78 pub excerpt: EditPredictionExcerptOptions,
79 pub max_prompt_bytes: usize,
80 pub max_diagnostic_bytes: usize,
81 pub prompt_format: predict_edits_v3::PromptFormat,
82 pub file_indexing_parallelism: usize,
83}
84
85pub struct PredictionDebugInfo {
86 pub context: EditPredictionContext,
87 pub retrieval_time: TimeDelta,
88 pub request: RequestDebugInfo,
89 pub buffer: WeakEntity<Buffer>,
90 pub position: language::Anchor,
91}
92
93pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
94
95struct ZetaProject {
96 syntax_index: Entity<SyntaxIndex>,
97 events: VecDeque<Event>,
98 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
99 current_prediction: Option<CurrentEditPrediction>,
100}
101
102#[derive(Clone)]
103struct CurrentEditPrediction {
104 pub requested_by_buffer_id: EntityId,
105 pub prediction: EditPrediction,
106}
107
108impl CurrentEditPrediction {
109 fn should_replace_prediction(
110 &self,
111 old_prediction: &Self,
112 snapshot: &TextBufferSnapshot,
113 ) -> bool {
114 if self.requested_by_buffer_id != old_prediction.requested_by_buffer_id {
115 return true;
116 }
117
118 let Some(old_edits) = old_prediction.prediction.interpolate(snapshot) else {
119 return true;
120 };
121
122 let Some(new_edits) = self.prediction.interpolate(snapshot) else {
123 return false;
124 };
125 if old_edits.len() == 1 && new_edits.len() == 1 {
126 let (old_range, old_text) = &old_edits[0];
127 let (new_range, new_text) = &new_edits[0];
128 new_range == old_range && new_text.starts_with(old_text)
129 } else {
130 true
131 }
132 }
133}
134
135/// A prediction from the perspective of a buffer.
136#[derive(Debug)]
137enum BufferEditPrediction<'a> {
138 Local { prediction: &'a EditPrediction },
139 Jump { prediction: &'a EditPrediction },
140}
141
142struct RegisteredBuffer {
143 snapshot: BufferSnapshot,
144 _subscriptions: [gpui::Subscription; 2],
145}
146
147#[derive(Clone)]
148pub enum Event {
149 BufferChange {
150 old_snapshot: BufferSnapshot,
151 new_snapshot: BufferSnapshot,
152 timestamp: Instant,
153 },
154}
155
156impl Zeta {
157 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
158 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
159 }
160
161 pub fn global(
162 client: &Arc<Client>,
163 user_store: &Entity<UserStore>,
164 cx: &mut App,
165 ) -> Entity<Self> {
166 cx.try_global::<ZetaGlobal>()
167 .map(|global| global.0.clone())
168 .unwrap_or_else(|| {
169 let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
170 cx.set_global(ZetaGlobal(zeta.clone()));
171 zeta
172 })
173 }
174
175 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
176 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
177
178 Self {
179 projects: HashMap::new(),
180 client,
181 user_store,
182 options: DEFAULT_OPTIONS,
183 llm_token: LlmApiToken::default(),
184 _llm_token_subscription: cx.subscribe(
185 &refresh_llm_token_listener,
186 |this, _listener, _event, cx| {
187 let client = this.client.clone();
188 let llm_token = this.llm_token.clone();
189 cx.spawn(async move |_this, _cx| {
190 llm_token.refresh(&client).await?;
191 anyhow::Ok(())
192 })
193 .detach_and_log_err(cx);
194 },
195 ),
196 update_required: false,
197 debug_tx: None,
198 }
199 }
200
201 pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<Result<PredictionDebugInfo, String>> {
202 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
203 self.debug_tx = Some(debug_watch_tx);
204 debug_watch_rx
205 }
206
207 pub fn options(&self) -> &ZetaOptions {
208 &self.options
209 }
210
211 pub fn set_options(&mut self, options: ZetaOptions) {
212 self.options = options;
213 }
214
215 pub fn clear_history(&mut self) {
216 for zeta_project in self.projects.values_mut() {
217 zeta_project.events.clear();
218 }
219 }
220
221 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
222 self.user_store.read(cx).edit_prediction_usage()
223 }
224
225 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
226 self.get_or_init_zeta_project(project, cx);
227 }
228
229 pub fn register_buffer(
230 &mut self,
231 buffer: &Entity<Buffer>,
232 project: &Entity<Project>,
233 cx: &mut Context<Self>,
234 ) {
235 let zeta_project = self.get_or_init_zeta_project(project, cx);
236 Self::register_buffer_impl(zeta_project, buffer, project, cx);
237 }
238
239 fn get_or_init_zeta_project(
240 &mut self,
241 project: &Entity<Project>,
242 cx: &mut App,
243 ) -> &mut ZetaProject {
244 self.projects
245 .entry(project.entity_id())
246 .or_insert_with(|| ZetaProject {
247 syntax_index: cx.new(|cx| {
248 SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
249 }),
250 events: VecDeque::new(),
251 registered_buffers: HashMap::new(),
252 current_prediction: None,
253 })
254 }
255
256 fn register_buffer_impl<'a>(
257 zeta_project: &'a mut ZetaProject,
258 buffer: &Entity<Buffer>,
259 project: &Entity<Project>,
260 cx: &mut Context<Self>,
261 ) -> &'a mut RegisteredBuffer {
262 let buffer_id = buffer.entity_id();
263 match zeta_project.registered_buffers.entry(buffer_id) {
264 hash_map::Entry::Occupied(entry) => entry.into_mut(),
265 hash_map::Entry::Vacant(entry) => {
266 let snapshot = buffer.read(cx).snapshot();
267 let project_entity_id = project.entity_id();
268 entry.insert(RegisteredBuffer {
269 snapshot,
270 _subscriptions: [
271 cx.subscribe(buffer, {
272 let project = project.downgrade();
273 move |this, buffer, event, cx| {
274 if let language::BufferEvent::Edited = event
275 && let Some(project) = project.upgrade()
276 {
277 this.report_changes_for_buffer(&buffer, &project, cx);
278 }
279 }
280 }),
281 cx.observe_release(buffer, move |this, _buffer, _cx| {
282 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
283 else {
284 return;
285 };
286 zeta_project.registered_buffers.remove(&buffer_id);
287 }),
288 ],
289 })
290 }
291 }
292 }
293
294 fn report_changes_for_buffer(
295 &mut self,
296 buffer: &Entity<Buffer>,
297 project: &Entity<Project>,
298 cx: &mut Context<Self>,
299 ) -> BufferSnapshot {
300 let zeta_project = self.get_or_init_zeta_project(project, cx);
301 let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
302
303 let new_snapshot = buffer.read(cx).snapshot();
304 if new_snapshot.version != registered_buffer.snapshot.version {
305 let old_snapshot =
306 std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
307 Self::push_event(
308 zeta_project,
309 Event::BufferChange {
310 old_snapshot,
311 new_snapshot: new_snapshot.clone(),
312 timestamp: Instant::now(),
313 },
314 );
315 }
316
317 new_snapshot
318 }
319
320 fn push_event(zeta_project: &mut ZetaProject, event: Event) {
321 let events = &mut zeta_project.events;
322
323 if let Some(Event::BufferChange {
324 new_snapshot: last_new_snapshot,
325 timestamp: last_timestamp,
326 ..
327 }) = events.back_mut()
328 {
329 // Coalesce edits for the same buffer when they happen one after the other.
330 let Event::BufferChange {
331 old_snapshot,
332 new_snapshot,
333 timestamp,
334 } = &event;
335
336 if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
337 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
338 && old_snapshot.version == last_new_snapshot.version
339 {
340 *last_new_snapshot = new_snapshot.clone();
341 *last_timestamp = *timestamp;
342 return;
343 }
344 }
345
346 if events.len() >= MAX_EVENT_COUNT {
347 // These are halved instead of popping to improve prompt caching.
348 events.drain(..MAX_EVENT_COUNT / 2);
349 }
350
351 events.push_back(event);
352 }
353
354 fn current_prediction_for_buffer(
355 &self,
356 buffer: &Entity<Buffer>,
357 project: &Entity<Project>,
358 cx: &App,
359 ) -> Option<BufferEditPrediction<'_>> {
360 let project_state = self.projects.get(&project.entity_id())?;
361
362 let CurrentEditPrediction {
363 requested_by_buffer_id,
364 prediction,
365 } = project_state.current_prediction.as_ref()?;
366
367 if prediction.targets_buffer(buffer.read(cx), cx) {
368 Some(BufferEditPrediction::Local { prediction })
369 } else if *requested_by_buffer_id == buffer.entity_id() {
370 Some(BufferEditPrediction::Jump { prediction })
371 } else {
372 None
373 }
374 }
375
376 fn accept_current_prediction(&mut self, project: &Entity<Project>) {
377 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
378 project_state.current_prediction.take();
379 };
380 // TODO report accepted
381 }
382
383 fn discard_current_prediction(&mut self, project: &Entity<Project>) {
384 if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
385 project_state.current_prediction.take();
386 };
387 }
388
389 pub fn refresh_prediction(
390 &mut self,
391 project: &Entity<Project>,
392 buffer: &Entity<Buffer>,
393 position: language::Anchor,
394 cx: &mut Context<Self>,
395 ) -> Task<Result<()>> {
396 let request_task = self.request_prediction(project, buffer, position, cx);
397 let buffer = buffer.clone();
398 let project = project.clone();
399
400 cx.spawn(async move |this, cx| {
401 if let Some(prediction) = request_task.await? {
402 this.update(cx, |this, cx| {
403 let project_state = this
404 .projects
405 .get_mut(&project.entity_id())
406 .context("Project not found")?;
407
408 let new_prediction = CurrentEditPrediction {
409 requested_by_buffer_id: buffer.entity_id(),
410 prediction: prediction,
411 };
412
413 if project_state
414 .current_prediction
415 .as_ref()
416 .is_none_or(|old_prediction| {
417 new_prediction
418 .should_replace_prediction(&old_prediction, buffer.read(cx))
419 })
420 {
421 project_state.current_prediction = Some(new_prediction);
422 }
423 anyhow::Ok(())
424 })??;
425 }
426 Ok(())
427 })
428 }
429
430 fn request_prediction(
431 &mut self,
432 project: &Entity<Project>,
433 buffer: &Entity<Buffer>,
434 position: language::Anchor,
435 cx: &mut Context<Self>,
436 ) -> Task<Result<Option<EditPrediction>>> {
437 let project_state = self.projects.get(&project.entity_id());
438
439 let index_state = project_state.map(|state| {
440 state
441 .syntax_index
442 .read_with(cx, |index, _cx| index.state().clone())
443 });
444 let options = self.options.clone();
445 let snapshot = buffer.read(cx).snapshot();
446 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx).into()) else {
447 return Task::ready(Err(anyhow!("No file path for excerpt")));
448 };
449 let client = self.client.clone();
450 let llm_token = self.llm_token.clone();
451 let app_version = AppVersion::global(cx);
452 let worktree_snapshots = project
453 .read(cx)
454 .worktrees(cx)
455 .map(|worktree| worktree.read(cx).snapshot())
456 .collect::<Vec<_>>();
457 let debug_tx = self.debug_tx.clone();
458
459 let events = project_state
460 .map(|state| {
461 state
462 .events
463 .iter()
464 .filter_map(|event| match event {
465 Event::BufferChange {
466 old_snapshot,
467 new_snapshot,
468 ..
469 } => {
470 let path = new_snapshot.file().map(|f| f.full_path(cx));
471
472 let old_path = old_snapshot.file().and_then(|f| {
473 let old_path = f.full_path(cx);
474 if Some(&old_path) != path.as_ref() {
475 Some(old_path)
476 } else {
477 None
478 }
479 });
480
481 // TODO [zeta2] move to bg?
482 let diff =
483 language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
484
485 if path == old_path && diff.is_empty() {
486 None
487 } else {
488 Some(predict_edits_v3::Event::BufferChange {
489 old_path,
490 path,
491 diff,
492 //todo: Actually detect if this edit was predicted or not
493 predicted: false,
494 })
495 }
496 }
497 })
498 .collect::<Vec<_>>()
499 })
500 .unwrap_or_default();
501
502 let diagnostics = snapshot.diagnostic_sets().clone();
503
504 let request_task = cx.background_spawn({
505 let snapshot = snapshot.clone();
506 let buffer = buffer.clone();
507 async move {
508 let index_state = if let Some(index_state) = index_state {
509 Some(index_state.lock_owned().await)
510 } else {
511 None
512 };
513
514 let cursor_offset = position.to_offset(&snapshot);
515 let cursor_point = cursor_offset.to_point(&snapshot);
516
517 let before_retrieval = chrono::Utc::now();
518
519 let Some(context) = EditPredictionContext::gather_context(
520 cursor_point,
521 &snapshot,
522 &options.excerpt,
523 index_state.as_deref(),
524 ) else {
525 return Ok(None);
526 };
527
528 let debug_context = if let Some(debug_tx) = debug_tx {
529 Some((debug_tx, context.clone()))
530 } else {
531 None
532 };
533
534 let (diagnostic_groups, diagnostic_groups_truncated) =
535 Self::gather_nearby_diagnostics(
536 cursor_offset,
537 &diagnostics,
538 &snapshot,
539 options.max_diagnostic_bytes,
540 );
541
542 let request = make_cloud_request(
543 excerpt_path,
544 context,
545 events,
546 // TODO data collection
547 false,
548 diagnostic_groups,
549 diagnostic_groups_truncated,
550 None,
551 debug_context.is_some(),
552 &worktree_snapshots,
553 index_state.as_deref(),
554 Some(options.max_prompt_bytes),
555 options.prompt_format,
556 );
557
558 let retrieval_time = chrono::Utc::now() - before_retrieval;
559 let response = Self::perform_request(client, llm_token, app_version, request).await;
560
561 if let Some((debug_tx, context)) = debug_context {
562 debug_tx
563 .unbounded_send(response.as_ref().map_err(|err| err.to_string()).and_then(
564 |response| {
565 let Some(request) =
566 some_or_debug_panic(response.0.debug_info.clone())
567 else {
568 return Err("Missing debug info".to_string());
569 };
570 Ok(PredictionDebugInfo {
571 context,
572 request,
573 retrieval_time,
574 buffer: buffer.downgrade(),
575 position,
576 })
577 },
578 ))
579 .ok();
580 }
581
582 anyhow::Ok(Some(response?))
583 }
584 });
585
586 let buffer = buffer.clone();
587
588 cx.spawn({
589 let project = project.clone();
590 async move |this, cx| {
591 match request_task.await {
592 Ok(Some((response, usage))) => {
593 if let Some(usage) = usage {
594 this.update(cx, |this, cx| {
595 this.user_store.update(cx, |user_store, cx| {
596 user_store.update_edit_prediction_usage(usage, cx);
597 });
598 })
599 .ok();
600 }
601
602 let prediction = EditPrediction::from_response(
603 response, &snapshot, &buffer, &project, cx,
604 )
605 .await;
606
607 // TODO telemetry: duration, etc
608 Ok(prediction)
609 }
610 Ok(None) => Ok(None),
611 Err(err) => {
612 if err.is::<ZedUpdateRequiredError>() {
613 cx.update(|cx| {
614 this.update(cx, |this, _cx| {
615 this.update_required = true;
616 })
617 .ok();
618
619 let error_message: SharedString = err.to_string().into();
620 show_app_notification(
621 NotificationId::unique::<ZedUpdateRequiredError>(),
622 cx,
623 move |cx| {
624 cx.new(|cx| {
625 ErrorMessagePrompt::new(error_message.clone(), cx)
626 .with_link_button(
627 "Update Zed",
628 "https://zed.dev/releases",
629 )
630 })
631 },
632 );
633 })
634 .ok();
635 }
636
637 Err(err)
638 }
639 }
640 }
641 })
642 }
643
644 async fn perform_request(
645 client: Arc<Client>,
646 llm_token: LlmApiToken,
647 app_version: SemanticVersion,
648 request: predict_edits_v3::PredictEditsRequest,
649 ) -> Result<(
650 predict_edits_v3::PredictEditsResponse,
651 Option<EditPredictionUsage>,
652 )> {
653 let http_client = client.http_client();
654 let mut token = llm_token.acquire(&client).await?;
655 let mut did_retry = false;
656
657 loop {
658 let request_builder = http_client::Request::builder().method(Method::POST);
659 let request_builder =
660 if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
661 request_builder.uri(predict_edits_url)
662 } else {
663 request_builder.uri(
664 http_client
665 .build_zed_llm_url("/predict_edits/v3", &[])?
666 .as_ref(),
667 )
668 };
669 let request = request_builder
670 .header("Content-Type", "application/json")
671 .header("Authorization", format!("Bearer {}", token))
672 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
673 .body(serde_json::to_string(&request)?.into())?;
674
675 let mut response = http_client.send(request).await?;
676
677 if let Some(minimum_required_version) = response
678 .headers()
679 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
680 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
681 {
682 anyhow::ensure!(
683 app_version >= minimum_required_version,
684 ZedUpdateRequiredError {
685 minimum_version: minimum_required_version
686 }
687 );
688 }
689
690 if response.status().is_success() {
691 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
692
693 let mut body = Vec::new();
694 response.body_mut().read_to_end(&mut body).await?;
695 return Ok((serde_json::from_slice(&body)?, usage));
696 } else if !did_retry
697 && response
698 .headers()
699 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
700 .is_some()
701 {
702 did_retry = true;
703 token = llm_token.refresh(&client).await?;
704 } else {
705 let mut body = String::new();
706 response.body_mut().read_to_string(&mut body).await?;
707 anyhow::bail!(
708 "error predicting edits.\nStatus: {:?}\nBody: {}",
709 response.status(),
710 body
711 );
712 }
713 }
714 }
715
716 fn gather_nearby_diagnostics(
717 cursor_offset: usize,
718 diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
719 snapshot: &BufferSnapshot,
720 max_diagnostics_bytes: usize,
721 ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
722 // TODO: Could make this more efficient
723 let mut diagnostic_groups = Vec::new();
724 for (language_server_id, diagnostics) in diagnostic_sets {
725 let mut groups = Vec::new();
726 diagnostics.groups(*language_server_id, &mut groups, &snapshot);
727 diagnostic_groups.extend(
728 groups
729 .into_iter()
730 .map(|(_, group)| group.resolve::<usize>(&snapshot)),
731 );
732 }
733
734 // sort by proximity to cursor
735 diagnostic_groups.sort_by_key(|group| {
736 let range = &group.entries[group.primary_ix].range;
737 if range.start >= cursor_offset {
738 range.start - cursor_offset
739 } else if cursor_offset >= range.end {
740 cursor_offset - range.end
741 } else {
742 (cursor_offset - range.start).min(range.end - cursor_offset)
743 }
744 });
745
746 let mut results = Vec::new();
747 let mut diagnostic_groups_truncated = false;
748 let mut diagnostics_byte_count = 0;
749 for group in diagnostic_groups {
750 let raw_value = serde_json::value::to_raw_value(&group).unwrap();
751 diagnostics_byte_count += raw_value.get().len();
752 if diagnostics_byte_count > max_diagnostics_bytes {
753 diagnostic_groups_truncated = true;
754 break;
755 }
756 results.push(predict_edits_v3::DiagnosticGroup(raw_value));
757 }
758
759 (results, diagnostic_groups_truncated)
760 }
761
762 // TODO: Dedupe with similar code in request_prediction?
763 pub fn cloud_request_for_zeta_cli(
764 &mut self,
765 project: &Entity<Project>,
766 buffer: &Entity<Buffer>,
767 position: language::Anchor,
768 cx: &mut Context<Self>,
769 ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
770 let project_state = self.projects.get(&project.entity_id());
771
772 let index_state = project_state.map(|state| {
773 state
774 .syntax_index
775 .read_with(cx, |index, _cx| index.state().clone())
776 });
777 let options = self.options.clone();
778 let snapshot = buffer.read(cx).snapshot();
779 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
780 return Task::ready(Err(anyhow!("No file path for excerpt")));
781 };
782 let worktree_snapshots = project
783 .read(cx)
784 .worktrees(cx)
785 .map(|worktree| worktree.read(cx).snapshot())
786 .collect::<Vec<_>>();
787
788 cx.background_spawn(async move {
789 let index_state = if let Some(index_state) = index_state {
790 Some(index_state.lock_owned().await)
791 } else {
792 None
793 };
794
795 let cursor_point = position.to_point(&snapshot);
796
797 let debug_info = true;
798 EditPredictionContext::gather_context(
799 cursor_point,
800 &snapshot,
801 &options.excerpt,
802 index_state.as_deref(),
803 )
804 .context("Failed to select excerpt")
805 .map(|context| {
806 make_cloud_request(
807 excerpt_path.into(),
808 context,
809 // TODO pass everything
810 Vec::new(),
811 false,
812 Vec::new(),
813 false,
814 None,
815 debug_info,
816 &worktree_snapshots,
817 index_state.as_deref(),
818 Some(options.max_prompt_bytes),
819 options.prompt_format,
820 )
821 })
822 })
823 }
824
825 pub fn wait_for_initial_indexing(
826 &mut self,
827 project: &Entity<Project>,
828 cx: &mut App,
829 ) -> Task<Result<()>> {
830 let zeta_project = self.get_or_init_zeta_project(project, cx);
831 zeta_project
832 .syntax_index
833 .read(cx)
834 .wait_for_initial_file_indexing(cx)
835 }
836}
837
838#[derive(Error, Debug)]
839#[error(
840 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
841)]
842pub struct ZedUpdateRequiredError {
843 minimum_version: SemanticVersion,
844}
845
846fn make_cloud_request(
847 excerpt_path: Arc<Path>,
848 context: EditPredictionContext,
849 events: Vec<predict_edits_v3::Event>,
850 can_collect_data: bool,
851 diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
852 diagnostic_groups_truncated: bool,
853 git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
854 debug_info: bool,
855 worktrees: &Vec<worktree::Snapshot>,
856 index_state: Option<&SyntaxIndexState>,
857 prompt_max_bytes: Option<usize>,
858 prompt_format: PromptFormat,
859) -> predict_edits_v3::PredictEditsRequest {
860 let mut signatures = Vec::new();
861 let mut declaration_to_signature_index = HashMap::default();
862 let mut referenced_declarations = Vec::new();
863
864 for snippet in context.declarations {
865 let project_entry_id = snippet.declaration.project_entry_id();
866 let Some(path) = worktrees.iter().find_map(|worktree| {
867 worktree.entry_for_id(project_entry_id).map(|entry| {
868 let mut full_path = RelPathBuf::new();
869 full_path.push(worktree.root_name());
870 full_path.push(&entry.path);
871 full_path
872 })
873 }) else {
874 continue;
875 };
876
877 let parent_index = index_state.and_then(|index_state| {
878 snippet.declaration.parent().and_then(|parent| {
879 add_signature(
880 parent,
881 &mut declaration_to_signature_index,
882 &mut signatures,
883 index_state,
884 )
885 })
886 });
887
888 let (text, text_is_truncated) = snippet.declaration.item_text();
889 referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
890 path: path.as_std_path().into(),
891 text: text.into(),
892 range: snippet.declaration.item_range(),
893 text_is_truncated,
894 signature_range: snippet.declaration.signature_range_in_item_text(),
895 parent_index,
896 score_components: snippet.score_components,
897 signature_score: snippet.scores.signature,
898 declaration_score: snippet.scores.declaration,
899 });
900 }
901
902 let excerpt_parent = index_state.and_then(|index_state| {
903 context
904 .excerpt
905 .parent_declarations
906 .last()
907 .and_then(|(parent, _)| {
908 add_signature(
909 *parent,
910 &mut declaration_to_signature_index,
911 &mut signatures,
912 index_state,
913 )
914 })
915 });
916
917 predict_edits_v3::PredictEditsRequest {
918 excerpt_path,
919 excerpt: context.excerpt_text.body,
920 excerpt_range: context.excerpt.range,
921 cursor_offset: context.cursor_offset_in_excerpt,
922 referenced_declarations,
923 signatures,
924 excerpt_parent,
925 events,
926 can_collect_data,
927 diagnostic_groups,
928 diagnostic_groups_truncated,
929 git_info,
930 debug_info,
931 prompt_max_bytes,
932 prompt_format,
933 }
934}
935
936fn add_signature(
937 declaration_id: DeclarationId,
938 declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
939 signatures: &mut Vec<Signature>,
940 index: &SyntaxIndexState,
941) -> Option<usize> {
942 if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
943 return Some(*signature_index);
944 }
945 let Some(parent_declaration) = index.declaration(declaration_id) else {
946 log::error!("bug: missing parent declaration");
947 return None;
948 };
949 let parent_index = parent_declaration.parent().and_then(|parent| {
950 add_signature(parent, declaration_to_signature_index, signatures, index)
951 });
952 let (text, text_is_truncated) = parent_declaration.signature_text();
953 let signature_index = signatures.len();
954 signatures.push(Signature {
955 text: text.into(),
956 text_is_truncated,
957 parent_index,
958 range: parent_declaration.signature_range(),
959 });
960 declaration_to_signature_index.insert(declaration_id, signature_index);
961 Some(signature_index)
962}
963
964#[cfg(test)]
965mod tests {
966 use std::{
967 path::{Path, PathBuf},
968 sync::Arc,
969 };
970
971 use client::UserStore;
972 use clock::FakeSystemClock;
973 use cloud_llm_client::predict_edits_v3;
974 use futures::{
975 AsyncReadExt, StreamExt,
976 channel::{mpsc, oneshot},
977 };
978 use gpui::{
979 Entity, TestAppContext,
980 http_client::{FakeHttpClient, Response},
981 prelude::*,
982 };
983 use indoc::indoc;
984 use language::{LanguageServerId, OffsetRangeExt as _};
985 use pretty_assertions::{assert_eq, assert_matches};
986 use project::{FakeFs, Project};
987 use serde_json::json;
988 use settings::SettingsStore;
989 use util::path;
990 use uuid::Uuid;
991
992 use crate::{BufferEditPrediction, Zeta};
993
994 #[gpui::test]
995 async fn test_current_state(cx: &mut TestAppContext) {
996 let (zeta, mut req_rx) = init_test(cx);
997 let fs = FakeFs::new(cx.executor());
998 fs.insert_tree(
999 "/root",
1000 json!({
1001 "1.txt": "Hello!\nHow\nBye",
1002 "2.txt": "Hola!\nComo\nAdios"
1003 }),
1004 )
1005 .await;
1006 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1007
1008 zeta.update(cx, |zeta, cx| {
1009 zeta.register_project(&project, cx);
1010 });
1011
1012 let buffer1 = project
1013 .update(cx, |project, cx| {
1014 let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
1015 project.open_buffer(path, cx)
1016 })
1017 .await
1018 .unwrap();
1019 let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
1020 let position = snapshot1.anchor_before(language::Point::new(1, 3));
1021
1022 // Prediction for current file
1023
1024 let prediction_task = zeta.update(cx, |zeta, cx| {
1025 zeta.refresh_prediction(&project, &buffer1, position, cx)
1026 });
1027 let (_request, respond_tx) = req_rx.next().await.unwrap();
1028 respond_tx
1029 .send(predict_edits_v3::PredictEditsResponse {
1030 request_id: Uuid::new_v4(),
1031 edits: vec![predict_edits_v3::Edit {
1032 path: Path::new(path!("root/1.txt")).into(),
1033 range: 0..snapshot1.len(),
1034 content: "Hello!\nHow are you?\nBye".into(),
1035 }],
1036 debug_info: None,
1037 })
1038 .unwrap();
1039 prediction_task.await.unwrap();
1040
1041 zeta.read_with(cx, |zeta, cx| {
1042 let prediction = zeta
1043 .current_prediction_for_buffer(&buffer1, &project, cx)
1044 .unwrap();
1045 assert_matches!(prediction, BufferEditPrediction::Local { .. });
1046 });
1047
1048 // Prediction for another file
1049
1050 let prediction_task = zeta.update(cx, |zeta, cx| {
1051 zeta.refresh_prediction(&project, &buffer1, position, cx)
1052 });
1053 let (_request, respond_tx) = req_rx.next().await.unwrap();
1054 respond_tx
1055 .send(predict_edits_v3::PredictEditsResponse {
1056 request_id: Uuid::new_v4(),
1057 edits: vec![predict_edits_v3::Edit {
1058 path: Path::new(path!("root/2.txt")).into(),
1059 range: 0..snapshot1.len(),
1060 content: "Hola!\nComo estas?\nAdios".into(),
1061 }],
1062 debug_info: None,
1063 })
1064 .unwrap();
1065 prediction_task.await.unwrap();
1066
1067 zeta.read_with(cx, |zeta, cx| {
1068 let prediction = zeta
1069 .current_prediction_for_buffer(&buffer1, &project, cx)
1070 .unwrap();
1071 assert_matches!(
1072 prediction,
1073 BufferEditPrediction::Jump { prediction } if prediction.path.as_ref() == Path::new(path!("root/2.txt"))
1074 );
1075 });
1076
1077 let buffer2 = project
1078 .update(cx, |project, cx| {
1079 let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
1080 project.open_buffer(path, cx)
1081 })
1082 .await
1083 .unwrap();
1084
1085 zeta.read_with(cx, |zeta, cx| {
1086 let prediction = zeta
1087 .current_prediction_for_buffer(&buffer2, &project, cx)
1088 .unwrap();
1089 assert_matches!(prediction, BufferEditPrediction::Local { .. });
1090 });
1091 }
1092
1093 #[gpui::test]
1094 async fn test_simple_request(cx: &mut TestAppContext) {
1095 let (zeta, mut req_rx) = init_test(cx);
1096 let fs = FakeFs::new(cx.executor());
1097 fs.insert_tree(
1098 "/root",
1099 json!({
1100 "foo.md": "Hello!\nHow\nBye"
1101 }),
1102 )
1103 .await;
1104 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1105
1106 let buffer = project
1107 .update(cx, |project, cx| {
1108 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1109 project.open_buffer(path, cx)
1110 })
1111 .await
1112 .unwrap();
1113 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1114 let position = snapshot.anchor_before(language::Point::new(1, 3));
1115
1116 let prediction_task = zeta.update(cx, |zeta, cx| {
1117 zeta.request_prediction(&project, &buffer, position, cx)
1118 });
1119
1120 let (request, respond_tx) = req_rx.next().await.unwrap();
1121 assert_eq!(
1122 request.excerpt_path.as_ref(),
1123 Path::new(path!("root/foo.md"))
1124 );
1125 assert_eq!(request.cursor_offset, 10);
1126
1127 respond_tx
1128 .send(predict_edits_v3::PredictEditsResponse {
1129 request_id: Uuid::new_v4(),
1130 edits: vec![predict_edits_v3::Edit {
1131 path: Path::new(path!("root/foo.md")).into(),
1132 range: 0..snapshot.len(),
1133 content: "Hello!\nHow are you?\nBye".into(),
1134 }],
1135 debug_info: None,
1136 })
1137 .unwrap();
1138
1139 let prediction = prediction_task.await.unwrap().unwrap();
1140
1141 assert_eq!(prediction.edits.len(), 1);
1142 assert_eq!(
1143 prediction.edits[0].0.to_point(&snapshot).start,
1144 language::Point::new(1, 3)
1145 );
1146 assert_eq!(prediction.edits[0].1, " are you?");
1147 }
1148
1149 #[gpui::test]
1150 async fn test_request_events(cx: &mut TestAppContext) {
1151 let (zeta, mut req_rx) = init_test(cx);
1152 let fs = FakeFs::new(cx.executor());
1153 fs.insert_tree(
1154 "/root",
1155 json!({
1156 "foo.md": "Hello!\n\nBye"
1157 }),
1158 )
1159 .await;
1160 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1161
1162 let buffer = project
1163 .update(cx, |project, cx| {
1164 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1165 project.open_buffer(path, cx)
1166 })
1167 .await
1168 .unwrap();
1169
1170 zeta.update(cx, |zeta, cx| {
1171 zeta.register_buffer(&buffer, &project, cx);
1172 });
1173
1174 buffer.update(cx, |buffer, cx| {
1175 buffer.edit(vec![(7..7, "How")], None, cx);
1176 });
1177
1178 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1179 let position = snapshot.anchor_before(language::Point::new(1, 3));
1180
1181 let prediction_task = zeta.update(cx, |zeta, cx| {
1182 zeta.request_prediction(&project, &buffer, position, cx)
1183 });
1184
1185 let (request, respond_tx) = req_rx.next().await.unwrap();
1186
1187 assert_eq!(request.events.len(), 1);
1188 assert_eq!(
1189 request.events[0],
1190 predict_edits_v3::Event::BufferChange {
1191 path: Some(PathBuf::from(path!("root/foo.md"))),
1192 old_path: None,
1193 diff: indoc! {"
1194 @@ -1,3 +1,3 @@
1195 Hello!
1196 -
1197 +How
1198 Bye
1199 "}
1200 .to_string(),
1201 predicted: false
1202 }
1203 );
1204
1205 respond_tx
1206 .send(predict_edits_v3::PredictEditsResponse {
1207 request_id: Uuid::new_v4(),
1208 edits: vec![predict_edits_v3::Edit {
1209 path: Path::new(path!("root/foo.md")).into(),
1210 range: 0..snapshot.len(),
1211 content: "Hello!\nHow are you?\nBye".into(),
1212 }],
1213 debug_info: None,
1214 })
1215 .unwrap();
1216
1217 let prediction = prediction_task.await.unwrap().unwrap();
1218
1219 assert_eq!(prediction.edits.len(), 1);
1220 assert_eq!(
1221 prediction.edits[0].0.to_point(&snapshot).start,
1222 language::Point::new(1, 3)
1223 );
1224 assert_eq!(prediction.edits[0].1, " are you?");
1225 }
1226
1227 #[gpui::test]
1228 async fn test_request_diagnostics(cx: &mut TestAppContext) {
1229 let (zeta, mut req_rx) = init_test(cx);
1230 let fs = FakeFs::new(cx.executor());
1231 fs.insert_tree(
1232 "/root",
1233 json!({
1234 "foo.md": "Hello!\nBye"
1235 }),
1236 )
1237 .await;
1238 let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1239
1240 let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1241 let diagnostic = lsp::Diagnostic {
1242 range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1243 severity: Some(lsp::DiagnosticSeverity::ERROR),
1244 message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1245 ..Default::default()
1246 };
1247
1248 project.update(cx, |project, cx| {
1249 project.lsp_store().update(cx, |lsp_store, cx| {
1250 // Create some diagnostics
1251 lsp_store
1252 .update_diagnostics(
1253 LanguageServerId(0),
1254 lsp::PublishDiagnosticsParams {
1255 uri: path_to_buffer_uri.clone(),
1256 diagnostics: vec![diagnostic],
1257 version: None,
1258 },
1259 None,
1260 language::DiagnosticSourceKind::Pushed,
1261 &[],
1262 cx,
1263 )
1264 .unwrap();
1265 });
1266 });
1267
1268 let buffer = project
1269 .update(cx, |project, cx| {
1270 let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1271 project.open_buffer(path, cx)
1272 })
1273 .await
1274 .unwrap();
1275
1276 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1277 let position = snapshot.anchor_before(language::Point::new(0, 0));
1278
1279 let _prediction_task = zeta.update(cx, |zeta, cx| {
1280 zeta.request_prediction(&project, &buffer, position, cx)
1281 });
1282
1283 let (request, _respond_tx) = req_rx.next().await.unwrap();
1284
1285 assert_eq!(request.diagnostic_groups.len(), 1);
1286 let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1287 .unwrap();
1288 // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1289 assert_eq!(
1290 value,
1291 json!({
1292 "entries": [{
1293 "range": {
1294 "start": 8,
1295 "end": 10
1296 },
1297 "diagnostic": {
1298 "source": null,
1299 "code": null,
1300 "code_description": null,
1301 "severity": 1,
1302 "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1303 "markdown": null,
1304 "group_id": 0,
1305 "is_primary": true,
1306 "is_disk_based": false,
1307 "is_unnecessary": false,
1308 "source_kind": "Pushed",
1309 "data": null,
1310 "underline": true
1311 }
1312 }],
1313 "primary_ix": 0
1314 })
1315 );
1316 }
1317
1318 fn init_test(
1319 cx: &mut TestAppContext,
1320 ) -> (
1321 Entity<Zeta>,
1322 mpsc::UnboundedReceiver<(
1323 predict_edits_v3::PredictEditsRequest,
1324 oneshot::Sender<predict_edits_v3::PredictEditsResponse>,
1325 )>,
1326 ) {
1327 cx.update(move |cx| {
1328 let settings_store = SettingsStore::test(cx);
1329 cx.set_global(settings_store);
1330 language::init(cx);
1331 Project::init_settings(cx);
1332
1333 let (req_tx, req_rx) = mpsc::unbounded();
1334
1335 let http_client = FakeHttpClient::create({
1336 move |req| {
1337 let uri = req.uri().path().to_string();
1338 let mut body = req.into_body();
1339 let req_tx = req_tx.clone();
1340 async move {
1341 let resp = match uri.as_str() {
1342 "/client/llm_tokens" => serde_json::to_string(&json!({
1343 "token": "test"
1344 }))
1345 .unwrap(),
1346 "/predict_edits/v3" => {
1347 let mut buf = Vec::new();
1348 body.read_to_end(&mut buf).await.ok();
1349 let req = serde_json::from_slice(&buf).unwrap();
1350
1351 let (res_tx, res_rx) = oneshot::channel();
1352 req_tx.unbounded_send((req, res_tx)).unwrap();
1353 serde_json::to_string(&res_rx.await.unwrap()).unwrap()
1354 }
1355 _ => {
1356 panic!("Unexpected path: {}", uri)
1357 }
1358 };
1359
1360 Ok(Response::builder().body(resp.into()).unwrap())
1361 }
1362 }
1363 });
1364
1365 let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1366 client.cloud_client().set_credentials(1, "test".into());
1367
1368 language_model::init(client.clone(), cx);
1369
1370 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1371 let zeta = Zeta::global(&client, &user_store, cx);
1372
1373 (zeta, req_rx)
1374 })
1375 }
1376}