1use anyhow::{Context as _, Result, anyhow};
2use arrayvec::ArrayVec;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{self, Signature};
5use cloud_llm_client::{
6 EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, ZED_VERSION_HEADER_NAME,
7};
8use edit_prediction::{DataCollectionState, Direction, EditPrediction, EditPredictionProvider};
9use edit_prediction_context::{
10 DeclarationId, EditPredictionContext, EditPredictionExcerptOptions, SyntaxIndex,
11 SyntaxIndexState,
12};
13use futures::AsyncReadExt as _;
14use gpui::http_client::Method;
15use gpui::{
16 App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, http_client,
17 prelude::*,
18};
19use language::BufferSnapshot;
20use language::{Anchor, Buffer, OffsetRangeExt as _, ToPoint};
21use language_model::{LlmApiToken, RefreshLlmTokenListener};
22use project::Project;
23use release_channel::AppVersion;
24use std::collections::HashMap;
25use std::path::PathBuf;
26use std::str::FromStr as _;
27use std::time::{Duration, Instant};
28use std::{ops::Range, sync::Arc};
29use thiserror::Error;
30use util::ResultExt as _;
31use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
32
33#[derive(Clone)]
34struct ZetaGlobal(Entity<Zeta>);
35
36impl Global for ZetaGlobal {}
37
38pub struct Zeta {
39 client: Arc<Client>,
40 user_store: Entity<UserStore>,
41 llm_token: LlmApiToken,
42 _llm_token_subscription: Subscription,
43 projects: HashMap<EntityId, RegisteredProject>,
44 excerpt_options: EditPredictionExcerptOptions,
45 update_required: bool,
46}
47
48struct RegisteredProject {
49 syntax_index: Entity<SyntaxIndex>,
50}
51
52impl Zeta {
53 pub fn global(
54 client: &Arc<Client>,
55 user_store: &Entity<UserStore>,
56 cx: &mut App,
57 ) -> Entity<Self> {
58 cx.try_global::<ZetaGlobal>()
59 .map(|global| global.0.clone())
60 .unwrap_or_else(|| {
61 let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
62 cx.set_global(ZetaGlobal(zeta.clone()));
63 zeta
64 })
65 }
66
67 fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
68 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
69
70 Self {
71 projects: HashMap::new(),
72 client,
73 user_store,
74 excerpt_options: EditPredictionExcerptOptions {
75 max_bytes: 512,
76 min_bytes: 128,
77 target_before_cursor_over_total_bytes: 0.5,
78 },
79 llm_token: LlmApiToken::default(),
80 _llm_token_subscription: cx.subscribe(
81 &refresh_llm_token_listener,
82 |this, _listener, _event, cx| {
83 let client = this.client.clone();
84 let llm_token = this.llm_token.clone();
85 cx.spawn(async move |_this, _cx| {
86 llm_token.refresh(&client).await?;
87 anyhow::Ok(())
88 })
89 .detach_and_log_err(cx);
90 },
91 ),
92 update_required: false,
93 }
94 }
95
96 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
97 self.user_store.read(cx).edit_prediction_usage()
98 }
99
100 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
101 self.projects
102 .entry(project.entity_id())
103 .or_insert_with(|| RegisteredProject {
104 syntax_index: cx.new(|cx| SyntaxIndex::new(project, cx)),
105 });
106 }
107
108 pub fn request_prediction(
109 &mut self,
110 project: &Entity<Project>,
111 buffer: &Entity<Buffer>,
112 position: language::Anchor,
113 cx: &mut Context<Self>,
114 ) -> Task<Result<Option<EditPrediction>>> {
115 let project_state = self.projects.get(&project.entity_id());
116
117 let index_state = project_state.map(|state| {
118 state
119 .syntax_index
120 .read_with(cx, |index, _cx| index.state().clone())
121 });
122 let excerpt_options = self.excerpt_options.clone();
123 let snapshot = buffer.read(cx).snapshot();
124 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
125 return Task::ready(Err(anyhow!("No file path for excerpt")));
126 };
127 let client = self.client.clone();
128 let llm_token = self.llm_token.clone();
129 let app_version = AppVersion::global(cx);
130 let worktree_snapshots = project
131 .read(cx)
132 .worktrees(cx)
133 .map(|worktree| worktree.read(cx).snapshot())
134 .collect::<Vec<_>>();
135
136 let request_task = cx.background_spawn({
137 let snapshot = snapshot.clone();
138 async move {
139 let index_state = if let Some(index_state) = index_state {
140 Some(index_state.lock_owned().await)
141 } else {
142 None
143 };
144
145 let cursor_point = position.to_point(&snapshot);
146
147 // TODO: make this only true if debug view is open
148 let debug_info = true;
149
150 let Some(request) = EditPredictionContext::gather_context(
151 cursor_point,
152 &snapshot,
153 &excerpt_options,
154 index_state.as_deref(),
155 )
156 .map(|context| {
157 make_cloud_request(
158 excerpt_path.clone(),
159 context,
160 // TODO pass everything
161 Vec::new(),
162 false,
163 Vec::new(),
164 None,
165 debug_info,
166 &worktree_snapshots,
167 index_state.as_deref(),
168 )
169 }) else {
170 return Ok(None);
171 };
172
173 anyhow::Ok(Some(
174 Self::perform_request(client, llm_token, app_version, request).await?,
175 ))
176 }
177 });
178
179 let buffer = buffer.clone();
180
181 cx.spawn(async move |this, cx| {
182 match request_task.await {
183 Ok(Some((response, usage))) => {
184 log::debug!("predicted edits: {:?}", &response.edits);
185
186 if let Some(usage) = usage {
187 this.update(cx, |this, cx| {
188 this.user_store.update(cx, |user_store, cx| {
189 user_store.update_edit_prediction_usage(usage, cx);
190 });
191 })
192 .ok();
193 }
194
195 // TODO telemetry: duration, etc
196
197 let edits = response
198 .edits
199 .into_iter()
200 .map(|edit| {
201 // TODO edits to different files
202 (
203 snapshot.anchor_before(edit.range.start)
204 ..snapshot.anchor_before(edit.range.end),
205 edit.content,
206 )
207 })
208 .collect::<Vec<_>>()
209 .into();
210
211 let Some((edits, edit_preview_task)) = buffer.read_with(cx, |buffer, cx| {
212 let new_snapshot = buffer.snapshot();
213 let edits: Arc<[_]> = interpolate(&snapshot, &new_snapshot, edits)?.into();
214 Some((edits.clone().to_vec(), buffer.preview_edits(edits, cx)))
215 })?
216 else {
217 return Ok(None);
218 };
219
220 Ok(Some(EditPrediction {
221 // todo!
222 id: None,
223 edits,
224 edit_preview: Some(edit_preview_task.await),
225 }))
226 }
227 Ok(None) => Ok(None),
228 Err(err) => {
229 if err.is::<ZedUpdateRequiredError>() {
230 cx.update(|cx| {
231 this.update(cx, |this, _cx| {
232 this.update_required = true;
233 })
234 .ok();
235
236 let error_message: SharedString = err.to_string().into();
237 show_app_notification(
238 NotificationId::unique::<ZedUpdateRequiredError>(),
239 cx,
240 move |cx| {
241 cx.new(|cx| {
242 ErrorMessagePrompt::new(error_message.clone(), cx)
243 .with_link_button(
244 "Update Zed",
245 "https://zed.dev/releases",
246 )
247 })
248 },
249 );
250 })
251 .ok();
252 }
253
254 Err(err)
255 }
256 }
257 })
258 }
259
260 async fn perform_request(
261 client: Arc<Client>,
262 llm_token: LlmApiToken,
263 app_version: SemanticVersion,
264 request: predict_edits_v3::PredictEditsRequest,
265 ) -> Result<(
266 predict_edits_v3::PredictEditsResponse,
267 Option<EditPredictionUsage>,
268 )> {
269 let http_client = client.http_client();
270 let mut token = llm_token.acquire(&client).await?;
271 let mut did_retry = false;
272
273 loop {
274 let request_builder = http_client::Request::builder().method(Method::POST);
275 let request_builder =
276 if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
277 request_builder.uri(predict_edits_url)
278 } else {
279 request_builder.uri(
280 http_client
281 .build_zed_llm_url("/predict_edits/v3", &[])?
282 .as_ref(),
283 )
284 };
285 let request = request_builder
286 .header("Content-Type", "application/json")
287 .header("Authorization", format!("Bearer {}", token))
288 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
289 .body(serde_json::to_string(&request)?.into())?;
290
291 let mut response = http_client.send(request).await?;
292
293 if let Some(minimum_required_version) = response
294 .headers()
295 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
296 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
297 {
298 anyhow::ensure!(
299 app_version >= minimum_required_version,
300 ZedUpdateRequiredError {
301 minimum_version: minimum_required_version
302 }
303 );
304 }
305
306 if response.status().is_success() {
307 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
308
309 let mut body = Vec::new();
310 response.body_mut().read_to_end(&mut body).await?;
311 return Ok((serde_json::from_slice(&body)?, usage));
312 } else if !did_retry
313 && response
314 .headers()
315 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
316 .is_some()
317 {
318 did_retry = true;
319 token = llm_token.refresh(&client).await?;
320 } else {
321 let mut body = String::new();
322 response.body_mut().read_to_string(&mut body).await?;
323 anyhow::bail!(
324 "error predicting edits.\nStatus: {:?}\nBody: {}",
325 response.status(),
326 body
327 );
328 }
329 }
330 }
331}
332
333#[derive(Error, Debug)]
334#[error(
335 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
336)]
337pub struct ZedUpdateRequiredError {
338 minimum_version: SemanticVersion,
339}
340
341pub struct ZetaEditPredictionProvider {
342 zeta: Entity<Zeta>,
343 current_prediction: Option<CurrentEditPrediction>,
344 next_pending_prediction_id: usize,
345 pending_predictions: ArrayVec<PendingPrediction, 2>,
346 last_request_timestamp: Instant,
347}
348
349impl ZetaEditPredictionProvider {
350 pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
351
352 pub fn new(
353 project: Option<&Entity<Project>>,
354 client: &Arc<Client>,
355 user_store: &Entity<UserStore>,
356 cx: &mut App,
357 ) -> Self {
358 let zeta = Zeta::global(client, user_store, cx);
359 if let Some(project) = project {
360 zeta.update(cx, |zeta, cx| {
361 zeta.register_project(project, cx);
362 });
363 }
364
365 Self {
366 zeta,
367 current_prediction: None,
368 next_pending_prediction_id: 0,
369 pending_predictions: ArrayVec::new(),
370 last_request_timestamp: Instant::now(),
371 }
372 }
373}
374
375#[derive(Clone)]
376struct CurrentEditPrediction {
377 buffer_id: EntityId,
378 prediction: EditPrediction,
379}
380
381impl CurrentEditPrediction {
382 fn should_replace_prediction(
383 &self,
384 _old_completion: &Self,
385 _snapshot: &BufferSnapshot,
386 ) -> bool {
387 true
388 // TODO
389 // if self.buffer_id != old_completion.buffer_id {
390 // return true;
391 // }
392
393 // let Some(old_edits) = old_completion.completion.interpolate(snapshot) else {
394 // return true;
395 // };
396 // let Some(new_edits) = self.completion.interpolate(snapshot) else {
397 // return false;
398 // };
399
400 // if old_edits.len() == 1 && new_edits.len() == 1 {
401 // let (old_range, old_text) = &old_edits[0];
402 // let (new_range, new_text) = &new_edits[0];
403 // new_range == old_range && new_text.starts_with(old_text)
404 // } else {
405 // true
406 // }
407 }
408}
409
410struct PendingPrediction {
411 id: usize,
412 _task: Task<()>,
413}
414
415impl EditPredictionProvider for ZetaEditPredictionProvider {
416 fn name() -> &'static str {
417 // TODO [zeta2]
418 "zed-predict2"
419 }
420
421 fn display_name() -> &'static str {
422 "Zed's Edit Predictions 2"
423 }
424
425 fn show_completions_in_menu() -> bool {
426 true
427 }
428
429 fn show_tab_accept_marker() -> bool {
430 true
431 }
432
433 fn data_collection_state(&self, _cx: &App) -> DataCollectionState {
434 // TODO [zeta2]
435 DataCollectionState::Unsupported
436 }
437
438 fn toggle_data_collection(&mut self, _cx: &mut App) {
439 // TODO [zeta2]
440 }
441
442 fn usage(&self, cx: &App) -> Option<client::EditPredictionUsage> {
443 self.zeta.read(cx).usage(cx)
444 }
445
446 fn is_enabled(
447 &self,
448 _buffer: &Entity<language::Buffer>,
449 _cursor_position: language::Anchor,
450 _cx: &App,
451 ) -> bool {
452 true
453 }
454
455 fn is_refreshing(&self) -> bool {
456 !self.pending_predictions.is_empty()
457 }
458
459 fn refresh(
460 &mut self,
461 project: Option<Entity<project::Project>>,
462 buffer: Entity<language::Buffer>,
463 cursor_position: language::Anchor,
464 _debounce: bool,
465 cx: &mut Context<Self>,
466 ) {
467 let Some(project) = project else {
468 return;
469 };
470
471 // TODO [zeta2] check account
472 // if self
473 // .zeta
474 // .read(cx)
475 // .user_store
476 // .read_with(cx, |user_store, _cx| {
477 // user_store.account_too_young() || user_store.has_overdue_invoices()
478 // })
479 // {
480 // return;
481 // }
482
483 // TODO [zeta2] try to interpolate current request
484
485 let pending_prediction_id = self.next_pending_prediction_id;
486 self.next_pending_prediction_id += 1;
487 let last_request_timestamp = self.last_request_timestamp;
488
489 let task = cx.spawn(async move |this, cx| {
490 if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
491 .checked_duration_since(Instant::now())
492 {
493 cx.background_executor().timer(timeout).await;
494 }
495
496 let prediction_request = this.update(cx, |this, cx| {
497 this.last_request_timestamp = Instant::now();
498 this.zeta.update(cx, |zeta, cx| {
499 zeta.request_prediction(&project, &buffer, cursor_position, cx)
500 })
501 });
502
503 let prediction = match prediction_request {
504 Ok(prediction_request) => {
505 let prediction_request = prediction_request.await;
506 prediction_request.map(|c| {
507 c.map(|prediction| CurrentEditPrediction {
508 buffer_id: buffer.entity_id(),
509 prediction,
510 })
511 })
512 }
513 Err(error) => Err(error),
514 };
515
516 this.update(cx, |this, cx| {
517 if this.pending_predictions[0].id == pending_prediction_id {
518 this.pending_predictions.remove(0);
519 } else {
520 this.pending_predictions.clear();
521 }
522
523 let Some(new_prediction) = prediction
524 .context("edit prediction failed")
525 .log_err()
526 .flatten()
527 else {
528 cx.notify();
529 return;
530 };
531
532 if let Some(old_prediction) = this.current_prediction.as_ref() {
533 let snapshot = buffer.read(cx).snapshot();
534 if new_prediction.should_replace_prediction(old_prediction, &snapshot) {
535 this.current_prediction = Some(new_prediction);
536 }
537 } else {
538 this.current_prediction = Some(new_prediction);
539 }
540
541 cx.notify();
542 })
543 .ok();
544 });
545
546 // We always maintain at most two pending predictions. When we already
547 // have two, we replace the newest one.
548 if self.pending_predictions.len() <= 1 {
549 self.pending_predictions.push(PendingPrediction {
550 id: pending_prediction_id,
551 _task: task,
552 });
553 } else if self.pending_predictions.len() == 2 {
554 self.pending_predictions.pop();
555 self.pending_predictions.push(PendingPrediction {
556 id: pending_prediction_id,
557 _task: task,
558 });
559 }
560
561 cx.notify();
562 }
563
564 fn cycle(
565 &mut self,
566 _buffer: Entity<language::Buffer>,
567 _cursor_position: language::Anchor,
568 _direction: Direction,
569 _cx: &mut Context<Self>,
570 ) {
571 }
572
573 fn accept(&mut self, _cx: &mut Context<Self>) {
574 // TODO [zeta2] report accept
575 self.current_prediction.take();
576 self.pending_predictions.clear();
577 }
578
579 fn discard(&mut self, _cx: &mut Context<Self>) {
580 self.pending_predictions.clear();
581 self.current_prediction.take();
582 }
583
584 fn suggest(
585 &mut self,
586 buffer: &Entity<language::Buffer>,
587 _cursor_position: language::Anchor,
588 _cx: &mut Context<Self>,
589 ) -> Option<EditPrediction> {
590 let current_prediction = self.current_prediction.take()?;
591
592 if current_prediction.buffer_id != buffer.entity_id() {
593 return None;
594 }
595
596 // TODO [zeta2] interpolate
597
598 Some(current_prediction.prediction)
599 }
600}
601
602fn make_cloud_request(
603 excerpt_path: PathBuf,
604 context: EditPredictionContext,
605 events: Vec<predict_edits_v3::Event>,
606 can_collect_data: bool,
607 diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
608 git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
609 debug_info: bool,
610 worktrees: &Vec<worktree::Snapshot>,
611 index_state: Option<&SyntaxIndexState>,
612) -> predict_edits_v3::PredictEditsRequest {
613 let mut signatures = Vec::new();
614 let mut declaration_to_signature_index = HashMap::default();
615 let mut referenced_declarations = Vec::new();
616
617 for snippet in context.snippets {
618 let project_entry_id = snippet.declaration.project_entry_id();
619 // TODO: Use full paths (worktree rooted) - need to move full_path method to the snapshot.
620 // Note that currently full_path is currently being used for excerpt_path.
621 let Some(path) = worktrees.iter().find_map(|worktree| {
622 let abs_path = worktree.abs_path();
623 worktree
624 .entry_for_id(project_entry_id)
625 .map(|e| abs_path.join(&e.path))
626 }) else {
627 continue;
628 };
629
630 let parent_index = index_state.and_then(|index_state| {
631 snippet.declaration.parent().and_then(|parent| {
632 add_signature(
633 parent,
634 &mut declaration_to_signature_index,
635 &mut signatures,
636 index_state,
637 )
638 })
639 });
640
641 let (text, text_is_truncated) = snippet.declaration.item_text();
642 referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
643 path,
644 text: text.into(),
645 range: snippet.declaration.item_range(),
646 text_is_truncated,
647 signature_range: snippet.declaration.signature_range_in_item_text(),
648 parent_index,
649 score_components: snippet.score_components,
650 signature_score: snippet.scores.signature,
651 declaration_score: snippet.scores.declaration,
652 });
653 }
654
655 let excerpt_parent = index_state.and_then(|index_state| {
656 context
657 .excerpt
658 .parent_declarations
659 .last()
660 .and_then(|(parent, _)| {
661 add_signature(
662 *parent,
663 &mut declaration_to_signature_index,
664 &mut signatures,
665 index_state,
666 )
667 })
668 });
669
670 predict_edits_v3::PredictEditsRequest {
671 excerpt_path,
672 excerpt: context.excerpt_text.body,
673 excerpt_range: context.excerpt.range,
674 cursor_offset: context.cursor_offset_in_excerpt,
675 referenced_declarations,
676 signatures,
677 excerpt_parent,
678 // todo!
679 events,
680 can_collect_data,
681 diagnostic_groups,
682 git_info,
683 debug_info,
684 }
685}
686
687fn add_signature(
688 declaration_id: DeclarationId,
689 declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
690 signatures: &mut Vec<Signature>,
691 index: &SyntaxIndexState,
692) -> Option<usize> {
693 if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
694 return Some(*signature_index);
695 }
696 let Some(parent_declaration) = index.declaration(declaration_id) else {
697 log::error!("bug: missing parent declaration");
698 return None;
699 };
700 let parent_index = parent_declaration.parent().and_then(|parent| {
701 add_signature(parent, declaration_to_signature_index, signatures, index)
702 });
703 let (text, text_is_truncated) = parent_declaration.signature_text();
704 let signature_index = signatures.len();
705 signatures.push(Signature {
706 text: text.into(),
707 text_is_truncated,
708 parent_index,
709 });
710 declaration_to_signature_index.insert(declaration_id, signature_index);
711 Some(signature_index)
712}
713
714fn interpolate(
715 old_snapshot: &BufferSnapshot,
716 new_snapshot: &BufferSnapshot,
717 current_edits: Arc<[(Range<Anchor>, String)]>,
718) -> Option<Vec<(Range<Anchor>, String)>> {
719 let mut edits = Vec::new();
720
721 let mut model_edits = current_edits.iter().peekable();
722 for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
723 while let Some((model_old_range, _)) = model_edits.peek() {
724 let model_old_range = model_old_range.to_offset(old_snapshot);
725 if model_old_range.end < user_edit.old.start {
726 let (model_old_range, model_new_text) = model_edits.next().unwrap();
727 edits.push((model_old_range.clone(), model_new_text.clone()));
728 } else {
729 break;
730 }
731 }
732
733 if let Some((model_old_range, model_new_text)) = model_edits.peek() {
734 let model_old_offset_range = model_old_range.to_offset(old_snapshot);
735 if user_edit.old == model_old_offset_range {
736 let user_new_text = new_snapshot
737 .text_for_range(user_edit.new.clone())
738 .collect::<String>();
739
740 if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
741 if !model_suffix.is_empty() {
742 let anchor = old_snapshot.anchor_after(user_edit.old.end);
743 edits.push((anchor..anchor, model_suffix.to_string()));
744 }
745
746 model_edits.next();
747 continue;
748 }
749 }
750 }
751
752 return None;
753 }
754
755 edits.extend(model_edits.cloned());
756
757 if edits.is_empty() { None } else { Some(edits) }
758}