1use anyhow::{Context as _, Result, anyhow};
2use chrono::TimeDelta;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{self, Signature};
5use cloud_llm_client::{
6 EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, ZED_VERSION_HEADER_NAME,
7};
8use cloud_zeta2_prompt::DEFAULT_MAX_PROMPT_BYTES;
9use edit_prediction_context::{
10 DeclarationId, EditPredictionContext, EditPredictionExcerptOptions, SyntaxIndex,
11 SyntaxIndexState,
12};
13use futures::AsyncReadExt as _;
14use futures::channel::mpsc;
15use gpui::http_client::Method;
16use gpui::{
17 App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
18 http_client, prelude::*,
19};
20use language::BufferSnapshot;
21use language::{Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
22use language_model::{LlmApiToken, RefreshLlmTokenListener};
23use project::Project;
24use release_channel::AppVersion;
25use std::collections::{HashMap, VecDeque, hash_map};
26use std::path::PathBuf;
27use std::str::FromStr as _;
28use std::sync::Arc;
29use std::time::{Duration, Instant};
30use thiserror::Error;
31use util::some_or_debug_panic;
32use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
33
34mod prediction;
35mod provider;
36
37use crate::prediction::{EditPrediction, edits_from_response, interpolate_edits};
38pub use provider::ZetaEditPredictionProvider;
39
40const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
41
42/// Maximum number of events to track.
43const MAX_EVENT_COUNT: usize = 16;
44
45pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
46 max_bytes: 512,
47 min_bytes: 128,
48 target_before_cursor_over_total_bytes: 0.5,
49};
50
51pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
52 excerpt: DEFAULT_EXCERPT_OPTIONS,
53 max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
54 max_diagnostic_bytes: 2048,
55};
56
57#[derive(Clone)]
58struct ZetaGlobal(Entity<Zeta>);
59
60impl Global for ZetaGlobal {}
61
62pub struct Zeta {
63 client: Arc<Client>,
64 user_store: Entity<UserStore>,
65 llm_token: LlmApiToken,
66 _llm_token_subscription: Subscription,
67 projects: HashMap<EntityId, ZetaProject>,
68 options: ZetaOptions,
69 update_required: bool,
70 debug_tx: Option<mpsc::UnboundedSender<Result<PredictionDebugInfo, String>>>,
71}
72
73#[derive(Debug, Clone, PartialEq)]
74pub struct ZetaOptions {
75 pub excerpt: EditPredictionExcerptOptions,
76 pub max_prompt_bytes: usize,
77 pub max_diagnostic_bytes: usize,
78}
79
80pub struct PredictionDebugInfo {
81 pub context: EditPredictionContext,
82 pub retrieval_time: TimeDelta,
83 pub request: RequestDebugInfo,
84 pub buffer: WeakEntity<Buffer>,
85 pub position: language::Anchor,
86}
87
88pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
89
90struct ZetaProject {
91 syntax_index: Entity<SyntaxIndex>,
92 events: VecDeque<Event>,
93 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
94}
95
96struct RegisteredBuffer {
97 snapshot: BufferSnapshot,
98 _subscriptions: [gpui::Subscription; 2],
99}
100
101#[derive(Clone)]
102pub enum Event {
103 BufferChange {
104 old_snapshot: BufferSnapshot,
105 new_snapshot: BufferSnapshot,
106 timestamp: Instant,
107 },
108}
109
110impl Zeta {
111 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
112 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
113 }
114
115 pub fn global(
116 client: &Arc<Client>,
117 user_store: &Entity<UserStore>,
118 cx: &mut App,
119 ) -> Entity<Self> {
120 cx.try_global::<ZetaGlobal>()
121 .map(|global| global.0.clone())
122 .unwrap_or_else(|| {
123 let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
124 cx.set_global(ZetaGlobal(zeta.clone()));
125 zeta
126 })
127 }
128
129 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
130 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
131
132 Self {
133 projects: HashMap::new(),
134 client,
135 user_store,
136 options: DEFAULT_OPTIONS,
137 llm_token: LlmApiToken::default(),
138 _llm_token_subscription: cx.subscribe(
139 &refresh_llm_token_listener,
140 |this, _listener, _event, cx| {
141 let client = this.client.clone();
142 let llm_token = this.llm_token.clone();
143 cx.spawn(async move |_this, _cx| {
144 llm_token.refresh(&client).await?;
145 anyhow::Ok(())
146 })
147 .detach_and_log_err(cx);
148 },
149 ),
150 update_required: false,
151 debug_tx: None,
152 }
153 }
154
155 pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<Result<PredictionDebugInfo, String>> {
156 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
157 self.debug_tx = Some(debug_watch_tx);
158 debug_watch_rx
159 }
160
161 pub fn options(&self) -> &ZetaOptions {
162 &self.options
163 }
164
165 pub fn set_options(&mut self, options: ZetaOptions) {
166 self.options = options;
167 }
168
169 pub fn clear_history(&mut self) {
170 for zeta_project in self.projects.values_mut() {
171 zeta_project.events.clear();
172 }
173 }
174
175 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
176 self.user_store.read(cx).edit_prediction_usage()
177 }
178
179 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
180 self.get_or_init_zeta_project(project, cx);
181 }
182
183 pub fn register_buffer(
184 &mut self,
185 buffer: &Entity<Buffer>,
186 project: &Entity<Project>,
187 cx: &mut Context<Self>,
188 ) {
189 let zeta_project = self.get_or_init_zeta_project(project, cx);
190 Self::register_buffer_impl(zeta_project, buffer, project, cx);
191 }
192
193 fn get_or_init_zeta_project(
194 &mut self,
195 project: &Entity<Project>,
196 cx: &mut App,
197 ) -> &mut ZetaProject {
198 self.projects
199 .entry(project.entity_id())
200 .or_insert_with(|| ZetaProject {
201 syntax_index: cx.new(|cx| SyntaxIndex::new(project, cx)),
202 events: VecDeque::new(),
203 registered_buffers: HashMap::new(),
204 })
205 }
206
207 fn register_buffer_impl<'a>(
208 zeta_project: &'a mut ZetaProject,
209 buffer: &Entity<Buffer>,
210 project: &Entity<Project>,
211 cx: &mut Context<Self>,
212 ) -> &'a mut RegisteredBuffer {
213 let buffer_id = buffer.entity_id();
214 match zeta_project.registered_buffers.entry(buffer_id) {
215 hash_map::Entry::Occupied(entry) => entry.into_mut(),
216 hash_map::Entry::Vacant(entry) => {
217 let snapshot = buffer.read(cx).snapshot();
218 let project_entity_id = project.entity_id();
219 entry.insert(RegisteredBuffer {
220 snapshot,
221 _subscriptions: [
222 cx.subscribe(buffer, {
223 let project = project.downgrade();
224 move |this, buffer, event, cx| {
225 if let language::BufferEvent::Edited = event
226 && let Some(project) = project.upgrade()
227 {
228 this.report_changes_for_buffer(&buffer, &project, cx);
229 }
230 }
231 }),
232 cx.observe_release(buffer, move |this, _buffer, _cx| {
233 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
234 else {
235 return;
236 };
237 zeta_project.registered_buffers.remove(&buffer_id);
238 }),
239 ],
240 })
241 }
242 }
243 }
244
245 fn report_changes_for_buffer(
246 &mut self,
247 buffer: &Entity<Buffer>,
248 project: &Entity<Project>,
249 cx: &mut Context<Self>,
250 ) -> BufferSnapshot {
251 let zeta_project = self.get_or_init_zeta_project(project, cx);
252 let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
253
254 let new_snapshot = buffer.read(cx).snapshot();
255 if new_snapshot.version != registered_buffer.snapshot.version {
256 let old_snapshot =
257 std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
258 Self::push_event(
259 zeta_project,
260 Event::BufferChange {
261 old_snapshot,
262 new_snapshot: new_snapshot.clone(),
263 timestamp: Instant::now(),
264 },
265 );
266 }
267
268 new_snapshot
269 }
270
271 fn push_event(zeta_project: &mut ZetaProject, event: Event) {
272 let events = &mut zeta_project.events;
273
274 if let Some(Event::BufferChange {
275 new_snapshot: last_new_snapshot,
276 timestamp: last_timestamp,
277 ..
278 }) = events.back_mut()
279 {
280 // Coalesce edits for the same buffer when they happen one after the other.
281 let Event::BufferChange {
282 old_snapshot,
283 new_snapshot,
284 timestamp,
285 } = &event;
286
287 if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
288 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
289 && old_snapshot.version == last_new_snapshot.version
290 {
291 *last_new_snapshot = new_snapshot.clone();
292 *last_timestamp = *timestamp;
293 return;
294 }
295 }
296
297 if events.len() >= MAX_EVENT_COUNT {
298 // These are halved instead of popping to improve prompt caching.
299 events.drain(..MAX_EVENT_COUNT / 2);
300 }
301
302 events.push_back(event);
303 }
304
305 pub fn request_prediction(
306 &mut self,
307 project: &Entity<Project>,
308 buffer: &Entity<Buffer>,
309 position: language::Anchor,
310 cx: &mut Context<Self>,
311 ) -> Task<Result<Option<EditPrediction>>> {
312 let project_state = self.projects.get(&project.entity_id());
313
314 let index_state = project_state.map(|state| {
315 state
316 .syntax_index
317 .read_with(cx, |index, _cx| index.state().clone())
318 });
319 let options = self.options.clone();
320 let snapshot = buffer.read(cx).snapshot();
321 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
322 return Task::ready(Err(anyhow!("No file path for excerpt")));
323 };
324 let client = self.client.clone();
325 let llm_token = self.llm_token.clone();
326 let app_version = AppVersion::global(cx);
327 let worktree_snapshots = project
328 .read(cx)
329 .worktrees(cx)
330 .map(|worktree| worktree.read(cx).snapshot())
331 .collect::<Vec<_>>();
332 let debug_tx = self.debug_tx.clone();
333
334 let events = project_state
335 .map(|state| {
336 state
337 .events
338 .iter()
339 .map(|event| match event {
340 Event::BufferChange {
341 old_snapshot,
342 new_snapshot,
343 ..
344 } => {
345 let path = new_snapshot.file().map(|f| f.path().to_path_buf());
346
347 let old_path = old_snapshot.file().and_then(|f| {
348 let old_path = f.path().as_ref();
349 if Some(old_path) != path.as_deref() {
350 Some(old_path.to_path_buf())
351 } else {
352 None
353 }
354 });
355
356 predict_edits_v3::Event::BufferChange {
357 old_path,
358 path,
359 diff: language::unified_diff(
360 &old_snapshot.text(),
361 &new_snapshot.text(),
362 ),
363 //todo: Actually detect if this edit was predicted or not
364 predicted: false,
365 }
366 }
367 })
368 .collect::<Vec<_>>()
369 })
370 .unwrap_or_default();
371
372 let diagnostics = snapshot.diagnostic_sets().clone();
373
374 let request_task = cx.background_spawn({
375 let snapshot = snapshot.clone();
376 let buffer = buffer.clone();
377 async move {
378 let index_state = if let Some(index_state) = index_state {
379 Some(index_state.lock_owned().await)
380 } else {
381 None
382 };
383
384 let cursor_offset = position.to_offset(&snapshot);
385 let cursor_point = cursor_offset.to_point(&snapshot);
386
387 let before_retrieval = chrono::Utc::now();
388
389 let Some(context) = EditPredictionContext::gather_context(
390 cursor_point,
391 &snapshot,
392 &options.excerpt,
393 index_state.as_deref(),
394 ) else {
395 return Ok(None);
396 };
397
398 let debug_context = if let Some(debug_tx) = debug_tx {
399 Some((debug_tx, context.clone()))
400 } else {
401 None
402 };
403
404 let (diagnostic_groups, diagnostic_groups_truncated) =
405 Self::gather_nearby_diagnostics(
406 cursor_offset,
407 &diagnostics,
408 &snapshot,
409 options.max_diagnostic_bytes,
410 );
411
412 let request = make_cloud_request(
413 excerpt_path.clone(),
414 context,
415 events,
416 // TODO data collection
417 false,
418 diagnostic_groups,
419 diagnostic_groups_truncated,
420 None,
421 debug_context.is_some(),
422 &worktree_snapshots,
423 index_state.as_deref(),
424 Some(options.max_prompt_bytes),
425 );
426
427 let retrieval_time = chrono::Utc::now() - before_retrieval;
428 let response = Self::perform_request(client, llm_token, app_version, request).await;
429
430 if let Some((debug_tx, context)) = debug_context {
431 debug_tx
432 .unbounded_send(response.as_ref().map_err(|err| err.to_string()).and_then(
433 |response| {
434 let Some(request) =
435 some_or_debug_panic(response.0.debug_info.clone())
436 else {
437 return Err("Missing debug info".to_string());
438 };
439 Ok(PredictionDebugInfo {
440 context,
441 request,
442 retrieval_time,
443 buffer: buffer.downgrade(),
444 position,
445 })
446 },
447 ))
448 .ok();
449 }
450
451 let (response, usage) = response?;
452 let edits = edits_from_response(&response.edits, &snapshot);
453
454 anyhow::Ok(Some((response.request_id, edits, usage)))
455 }
456 });
457
458 let buffer = buffer.clone();
459
460 cx.spawn(async move |this, cx| {
461 match request_task.await {
462 Ok(Some((id, edits, usage))) => {
463 if let Some(usage) = usage {
464 this.update(cx, |this, cx| {
465 this.user_store.update(cx, |user_store, cx| {
466 user_store.update_edit_prediction_usage(usage, cx);
467 });
468 })
469 .ok();
470 }
471
472 // TODO telemetry: duration, etc
473 let Some((edits, snapshot, edit_preview_task)) =
474 buffer.read_with(cx, |buffer, cx| {
475 let new_snapshot = buffer.snapshot();
476 let edits: Arc<[_]> =
477 interpolate_edits(&snapshot, &new_snapshot, edits)?.into();
478 Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
479 })?
480 else {
481 return Ok(None);
482 };
483
484 Ok(Some(EditPrediction {
485 id: id.into(),
486 edits,
487 snapshot,
488 edit_preview: edit_preview_task.await,
489 }))
490 }
491 Ok(None) => Ok(None),
492 Err(err) => {
493 if err.is::<ZedUpdateRequiredError>() {
494 cx.update(|cx| {
495 this.update(cx, |this, _cx| {
496 this.update_required = true;
497 })
498 .ok();
499
500 let error_message: SharedString = err.to_string().into();
501 show_app_notification(
502 NotificationId::unique::<ZedUpdateRequiredError>(),
503 cx,
504 move |cx| {
505 cx.new(|cx| {
506 ErrorMessagePrompt::new(error_message.clone(), cx)
507 .with_link_button(
508 "Update Zed",
509 "https://zed.dev/releases",
510 )
511 })
512 },
513 );
514 })
515 .ok();
516 }
517
518 Err(err)
519 }
520 }
521 })
522 }
523
524 async fn perform_request(
525 client: Arc<Client>,
526 llm_token: LlmApiToken,
527 app_version: SemanticVersion,
528 request: predict_edits_v3::PredictEditsRequest,
529 ) -> Result<(
530 predict_edits_v3::PredictEditsResponse,
531 Option<EditPredictionUsage>,
532 )> {
533 let http_client = client.http_client();
534 let mut token = llm_token.acquire(&client).await?;
535 let mut did_retry = false;
536
537 loop {
538 let request_builder = http_client::Request::builder().method(Method::POST);
539 let request_builder =
540 if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
541 request_builder.uri(predict_edits_url)
542 } else {
543 request_builder.uri(
544 http_client
545 .build_zed_llm_url("/predict_edits/v3", &[])?
546 .as_ref(),
547 )
548 };
549 let request = request_builder
550 .header("Content-Type", "application/json")
551 .header("Authorization", format!("Bearer {}", token))
552 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
553 .body(serde_json::to_string(&request)?.into())?;
554
555 let mut response = http_client.send(request).await?;
556
557 if let Some(minimum_required_version) = response
558 .headers()
559 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
560 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
561 {
562 anyhow::ensure!(
563 app_version >= minimum_required_version,
564 ZedUpdateRequiredError {
565 minimum_version: minimum_required_version
566 }
567 );
568 }
569
570 if response.status().is_success() {
571 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
572
573 let mut body = Vec::new();
574 response.body_mut().read_to_end(&mut body).await?;
575 return Ok((serde_json::from_slice(&body)?, usage));
576 } else if !did_retry
577 && response
578 .headers()
579 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
580 .is_some()
581 {
582 did_retry = true;
583 token = llm_token.refresh(&client).await?;
584 } else {
585 let mut body = String::new();
586 response.body_mut().read_to_string(&mut body).await?;
587 anyhow::bail!(
588 "error predicting edits.\nStatus: {:?}\nBody: {}",
589 response.status(),
590 body
591 );
592 }
593 }
594 }
595
596 fn gather_nearby_diagnostics(
597 cursor_offset: usize,
598 diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
599 snapshot: &BufferSnapshot,
600 max_diagnostics_bytes: usize,
601 ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
602 // TODO: Could make this more efficient
603 let mut diagnostic_groups = Vec::new();
604 for (language_server_id, diagnostics) in diagnostic_sets {
605 let mut groups = Vec::new();
606 diagnostics.groups(*language_server_id, &mut groups, &snapshot);
607 diagnostic_groups.extend(
608 groups
609 .into_iter()
610 .map(|(_, group)| group.resolve::<usize>(&snapshot)),
611 );
612 }
613
614 // sort by proximity to cursor
615 diagnostic_groups.sort_by_key(|group| {
616 let range = &group.entries[group.primary_ix].range;
617 if range.start >= cursor_offset {
618 range.start - cursor_offset
619 } else if cursor_offset >= range.end {
620 cursor_offset - range.end
621 } else {
622 (cursor_offset - range.start).min(range.end - cursor_offset)
623 }
624 });
625
626 let mut results = Vec::new();
627 let mut diagnostic_groups_truncated = false;
628 let mut diagnostics_byte_count = 0;
629 for group in diagnostic_groups {
630 let raw_value = serde_json::value::to_raw_value(&group).unwrap();
631 diagnostics_byte_count += raw_value.get().len();
632 if diagnostics_byte_count > max_diagnostics_bytes {
633 diagnostic_groups_truncated = true;
634 break;
635 }
636 results.push(predict_edits_v3::DiagnosticGroup(raw_value));
637 }
638
639 (results, diagnostic_groups_truncated)
640 }
641
642 // TODO: Dedupe with similar code in request_prediction?
643 pub fn cloud_request_for_zeta_cli(
644 &mut self,
645 project: &Entity<Project>,
646 buffer: &Entity<Buffer>,
647 position: language::Anchor,
648 cx: &mut Context<Self>,
649 ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
650 let project_state = self.projects.get(&project.entity_id());
651
652 let index_state = project_state.map(|state| {
653 state
654 .syntax_index
655 .read_with(cx, |index, _cx| index.state().clone())
656 });
657 let options = self.options.clone();
658 let snapshot = buffer.read(cx).snapshot();
659 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
660 return Task::ready(Err(anyhow!("No file path for excerpt")));
661 };
662 let worktree_snapshots = project
663 .read(cx)
664 .worktrees(cx)
665 .map(|worktree| worktree.read(cx).snapshot())
666 .collect::<Vec<_>>();
667
668 cx.background_spawn(async move {
669 let index_state = if let Some(index_state) = index_state {
670 Some(index_state.lock_owned().await)
671 } else {
672 None
673 };
674
675 let cursor_point = position.to_point(&snapshot);
676
677 let debug_info = true;
678 EditPredictionContext::gather_context(
679 cursor_point,
680 &snapshot,
681 &options.excerpt,
682 index_state.as_deref(),
683 )
684 .context("Failed to select excerpt")
685 .map(|context| {
686 make_cloud_request(
687 excerpt_path.clone(),
688 context,
689 // TODO pass everything
690 Vec::new(),
691 false,
692 Vec::new(),
693 false,
694 None,
695 debug_info,
696 &worktree_snapshots,
697 index_state.as_deref(),
698 Some(options.max_prompt_bytes),
699 )
700 })
701 })
702 }
703}
704
705#[derive(Error, Debug)]
706#[error(
707 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
708)]
709pub struct ZedUpdateRequiredError {
710 minimum_version: SemanticVersion,
711}
712
713fn make_cloud_request(
714 excerpt_path: PathBuf,
715 context: EditPredictionContext,
716 events: Vec<predict_edits_v3::Event>,
717 can_collect_data: bool,
718 diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
719 diagnostic_groups_truncated: bool,
720 git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
721 debug_info: bool,
722 worktrees: &Vec<worktree::Snapshot>,
723 index_state: Option<&SyntaxIndexState>,
724 prompt_max_bytes: Option<usize>,
725) -> predict_edits_v3::PredictEditsRequest {
726 let mut signatures = Vec::new();
727 let mut declaration_to_signature_index = HashMap::default();
728 let mut referenced_declarations = Vec::new();
729
730 for snippet in context.snippets {
731 let project_entry_id = snippet.declaration.project_entry_id();
732 let Some(path) = worktrees.iter().find_map(|worktree| {
733 worktree.entry_for_id(project_entry_id).map(|entry| {
734 let mut full_path = PathBuf::new();
735 full_path.push(worktree.root_name());
736 full_path.push(&entry.path);
737 full_path
738 })
739 }) else {
740 continue;
741 };
742
743 let parent_index = index_state.and_then(|index_state| {
744 snippet.declaration.parent().and_then(|parent| {
745 add_signature(
746 parent,
747 &mut declaration_to_signature_index,
748 &mut signatures,
749 index_state,
750 )
751 })
752 });
753
754 let (text, text_is_truncated) = snippet.declaration.item_text();
755 referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
756 path,
757 text: text.into(),
758 range: snippet.declaration.item_range(),
759 text_is_truncated,
760 signature_range: snippet.declaration.signature_range_in_item_text(),
761 parent_index,
762 score_components: snippet.score_components,
763 signature_score: snippet.scores.signature,
764 declaration_score: snippet.scores.declaration,
765 });
766 }
767
768 let excerpt_parent = index_state.and_then(|index_state| {
769 context
770 .excerpt
771 .parent_declarations
772 .last()
773 .and_then(|(parent, _)| {
774 add_signature(
775 *parent,
776 &mut declaration_to_signature_index,
777 &mut signatures,
778 index_state,
779 )
780 })
781 });
782
783 predict_edits_v3::PredictEditsRequest {
784 excerpt_path,
785 excerpt: context.excerpt_text.body,
786 excerpt_range: context.excerpt.range,
787 cursor_offset: context.cursor_offset_in_excerpt,
788 referenced_declarations,
789 signatures,
790 excerpt_parent,
791 events,
792 can_collect_data,
793 diagnostic_groups,
794 diagnostic_groups_truncated,
795 git_info,
796 debug_info,
797 prompt_max_bytes,
798 }
799}
800
801fn add_signature(
802 declaration_id: DeclarationId,
803 declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
804 signatures: &mut Vec<Signature>,
805 index: &SyntaxIndexState,
806) -> Option<usize> {
807 if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
808 return Some(*signature_index);
809 }
810 let Some(parent_declaration) = index.declaration(declaration_id) else {
811 log::error!("bug: missing parent declaration");
812 return None;
813 };
814 let parent_index = parent_declaration.parent().and_then(|parent| {
815 add_signature(parent, declaration_to_signature_index, signatures, index)
816 });
817 let (text, text_is_truncated) = parent_declaration.signature_text();
818 let signature_index = signatures.len();
819 signatures.push(Signature {
820 text: text.into(),
821 text_is_truncated,
822 parent_index,
823 range: parent_declaration.signature_range(),
824 });
825 declaration_to_signature_index.insert(declaration_id, signature_index);
826 Some(signature_index)
827}