1use anyhow::{Context as _, Result, anyhow};
2use chrono::TimeDelta;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{self, Signature};
5use cloud_llm_client::{
6 EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, ZED_VERSION_HEADER_NAME,
7};
8use cloud_zeta2_prompt::DEFAULT_MAX_PROMPT_BYTES;
9use edit_prediction_context::{
10 DeclarationId, EditPredictionContext, EditPredictionExcerptOptions, SyntaxIndex,
11 SyntaxIndexState,
12};
13use futures::AsyncReadExt as _;
14use futures::channel::mpsc;
15use gpui::http_client::Method;
16use gpui::{
17 App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
18 http_client, prelude::*,
19};
20use language::BufferSnapshot;
21use language::{Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
22use language_model::{LlmApiToken, RefreshLlmTokenListener};
23use project::Project;
24use release_channel::AppVersion;
25use std::collections::{HashMap, VecDeque, hash_map};
26use std::path::PathBuf;
27use std::str::FromStr as _;
28use std::sync::Arc;
29use std::time::{Duration, Instant};
30use thiserror::Error;
31use util::some_or_debug_panic;
32use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
33
34mod prediction;
35mod provider;
36
37use crate::prediction::{EditPrediction, edits_from_response, interpolate_edits};
38pub use provider::ZetaEditPredictionProvider;
39
40const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
41
42/// Maximum number of events to track.
43const MAX_EVENT_COUNT: usize = 16;
44
45pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
46 max_bytes: 512,
47 min_bytes: 128,
48 target_before_cursor_over_total_bytes: 0.5,
49};
50
51pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
52 excerpt: DEFAULT_EXCERPT_OPTIONS,
53 max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
54 max_diagnostic_bytes: 2048,
55};
56
57#[derive(Clone)]
58struct ZetaGlobal(Entity<Zeta>);
59
60impl Global for ZetaGlobal {}
61
62pub struct Zeta {
63 client: Arc<Client>,
64 user_store: Entity<UserStore>,
65 llm_token: LlmApiToken,
66 _llm_token_subscription: Subscription,
67 projects: HashMap<EntityId, ZetaProject>,
68 options: ZetaOptions,
69 update_required: bool,
70 debug_tx: Option<mpsc::UnboundedSender<Result<PredictionDebugInfo, String>>>,
71}
72
73#[derive(Debug, Clone, PartialEq)]
74pub struct ZetaOptions {
75 pub excerpt: EditPredictionExcerptOptions,
76 pub max_prompt_bytes: usize,
77 pub max_diagnostic_bytes: usize,
78}
79
80pub struct PredictionDebugInfo {
81 pub context: EditPredictionContext,
82 pub retrieval_time: TimeDelta,
83 pub request: RequestDebugInfo,
84 pub buffer: WeakEntity<Buffer>,
85 pub position: language::Anchor,
86}
87
88pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
89
90struct ZetaProject {
91 syntax_index: Entity<SyntaxIndex>,
92 events: VecDeque<Event>,
93 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
94}
95
96struct RegisteredBuffer {
97 snapshot: BufferSnapshot,
98 _subscriptions: [gpui::Subscription; 2],
99}
100
101#[derive(Clone)]
102pub enum Event {
103 BufferChange {
104 old_snapshot: BufferSnapshot,
105 new_snapshot: BufferSnapshot,
106 timestamp: Instant,
107 },
108}
109
110impl Zeta {
111 pub fn global(
112 client: &Arc<Client>,
113 user_store: &Entity<UserStore>,
114 cx: &mut App,
115 ) -> Entity<Self> {
116 cx.try_global::<ZetaGlobal>()
117 .map(|global| global.0.clone())
118 .unwrap_or_else(|| {
119 let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
120 cx.set_global(ZetaGlobal(zeta.clone()));
121 zeta
122 })
123 }
124
125 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
126 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
127
128 Self {
129 projects: HashMap::new(),
130 client,
131 user_store,
132 options: DEFAULT_OPTIONS,
133 llm_token: LlmApiToken::default(),
134 _llm_token_subscription: cx.subscribe(
135 &refresh_llm_token_listener,
136 |this, _listener, _event, cx| {
137 let client = this.client.clone();
138 let llm_token = this.llm_token.clone();
139 cx.spawn(async move |_this, _cx| {
140 llm_token.refresh(&client).await?;
141 anyhow::Ok(())
142 })
143 .detach_and_log_err(cx);
144 },
145 ),
146 update_required: false,
147 debug_tx: None,
148 }
149 }
150
151 pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<Result<PredictionDebugInfo, String>> {
152 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
153 self.debug_tx = Some(debug_watch_tx);
154 debug_watch_rx
155 }
156
157 pub fn options(&self) -> &ZetaOptions {
158 &self.options
159 }
160
161 pub fn set_options(&mut self, options: ZetaOptions) {
162 self.options = options;
163 }
164
165 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
166 self.user_store.read(cx).edit_prediction_usage()
167 }
168
169 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
170 self.get_or_init_zeta_project(project, cx);
171 }
172
173 pub fn register_buffer(
174 &mut self,
175 buffer: &Entity<Buffer>,
176 project: &Entity<Project>,
177 cx: &mut Context<Self>,
178 ) {
179 let zeta_project = self.get_or_init_zeta_project(project, cx);
180 Self::register_buffer_impl(zeta_project, buffer, project, cx);
181 }
182
183 fn get_or_init_zeta_project(
184 &mut self,
185 project: &Entity<Project>,
186 cx: &mut App,
187 ) -> &mut ZetaProject {
188 self.projects
189 .entry(project.entity_id())
190 .or_insert_with(|| ZetaProject {
191 syntax_index: cx.new(|cx| SyntaxIndex::new(project, cx)),
192 events: VecDeque::new(),
193 registered_buffers: HashMap::new(),
194 })
195 }
196
197 fn register_buffer_impl<'a>(
198 zeta_project: &'a mut ZetaProject,
199 buffer: &Entity<Buffer>,
200 project: &Entity<Project>,
201 cx: &mut Context<Self>,
202 ) -> &'a mut RegisteredBuffer {
203 let buffer_id = buffer.entity_id();
204 match zeta_project.registered_buffers.entry(buffer_id) {
205 hash_map::Entry::Occupied(entry) => entry.into_mut(),
206 hash_map::Entry::Vacant(entry) => {
207 let snapshot = buffer.read(cx).snapshot();
208 let project_entity_id = project.entity_id();
209 entry.insert(RegisteredBuffer {
210 snapshot,
211 _subscriptions: [
212 cx.subscribe(buffer, {
213 let project = project.downgrade();
214 move |this, buffer, event, cx| {
215 if let language::BufferEvent::Edited = event
216 && let Some(project) = project.upgrade()
217 {
218 this.report_changes_for_buffer(&buffer, &project, cx);
219 }
220 }
221 }),
222 cx.observe_release(buffer, move |this, _buffer, _cx| {
223 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
224 else {
225 return;
226 };
227 zeta_project.registered_buffers.remove(&buffer_id);
228 }),
229 ],
230 })
231 }
232 }
233 }
234
235 fn report_changes_for_buffer(
236 &mut self,
237 buffer: &Entity<Buffer>,
238 project: &Entity<Project>,
239 cx: &mut Context<Self>,
240 ) -> BufferSnapshot {
241 let zeta_project = self.get_or_init_zeta_project(project, cx);
242 let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
243
244 let new_snapshot = buffer.read(cx).snapshot();
245 if new_snapshot.version != registered_buffer.snapshot.version {
246 let old_snapshot =
247 std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
248 Self::push_event(
249 zeta_project,
250 Event::BufferChange {
251 old_snapshot,
252 new_snapshot: new_snapshot.clone(),
253 timestamp: Instant::now(),
254 },
255 );
256 }
257
258 new_snapshot
259 }
260
261 fn push_event(zeta_project: &mut ZetaProject, event: Event) {
262 let events = &mut zeta_project.events;
263
264 if let Some(Event::BufferChange {
265 new_snapshot: last_new_snapshot,
266 timestamp: last_timestamp,
267 ..
268 }) = events.back_mut()
269 {
270 // Coalesce edits for the same buffer when they happen one after the other.
271 let Event::BufferChange {
272 old_snapshot,
273 new_snapshot,
274 timestamp,
275 } = &event;
276
277 if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
278 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
279 && old_snapshot.version == last_new_snapshot.version
280 {
281 *last_new_snapshot = new_snapshot.clone();
282 *last_timestamp = *timestamp;
283 return;
284 }
285 }
286
287 if events.len() >= MAX_EVENT_COUNT {
288 // These are halved instead of popping to improve prompt caching.
289 events.drain(..MAX_EVENT_COUNT / 2);
290 }
291
292 events.push_back(event);
293 }
294
295 pub fn request_prediction(
296 &mut self,
297 project: &Entity<Project>,
298 buffer: &Entity<Buffer>,
299 position: language::Anchor,
300 cx: &mut Context<Self>,
301 ) -> Task<Result<Option<EditPrediction>>> {
302 let project_state = self.projects.get(&project.entity_id());
303
304 let index_state = project_state.map(|state| {
305 state
306 .syntax_index
307 .read_with(cx, |index, _cx| index.state().clone())
308 });
309 let options = self.options.clone();
310 let snapshot = buffer.read(cx).snapshot();
311 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
312 return Task::ready(Err(anyhow!("No file path for excerpt")));
313 };
314 let client = self.client.clone();
315 let llm_token = self.llm_token.clone();
316 let app_version = AppVersion::global(cx);
317 let worktree_snapshots = project
318 .read(cx)
319 .worktrees(cx)
320 .map(|worktree| worktree.read(cx).snapshot())
321 .collect::<Vec<_>>();
322 let debug_tx = self.debug_tx.clone();
323
324 let events = project_state
325 .map(|state| {
326 state
327 .events
328 .iter()
329 .map(|event| match event {
330 Event::BufferChange {
331 old_snapshot,
332 new_snapshot,
333 ..
334 } => {
335 let path = new_snapshot.file().map(|f| f.path().to_path_buf());
336
337 let old_path = old_snapshot.file().and_then(|f| {
338 let old_path = f.path().as_ref();
339 if Some(old_path) != path.as_deref() {
340 Some(old_path.to_path_buf())
341 } else {
342 None
343 }
344 });
345
346 predict_edits_v3::Event::BufferChange {
347 old_path,
348 path,
349 diff: language::unified_diff(
350 &old_snapshot.text(),
351 &new_snapshot.text(),
352 ),
353 //todo: Actually detect if this edit was predicted or not
354 predicted: false,
355 }
356 }
357 })
358 .collect::<Vec<_>>()
359 })
360 .unwrap_or_default();
361
362 let diagnostics = snapshot.diagnostic_sets().clone();
363
364 let request_task = cx.background_spawn({
365 let snapshot = snapshot.clone();
366 let buffer = buffer.clone();
367 async move {
368 let index_state = if let Some(index_state) = index_state {
369 Some(index_state.lock_owned().await)
370 } else {
371 None
372 };
373
374 let cursor_offset = position.to_offset(&snapshot);
375 let cursor_point = cursor_offset.to_point(&snapshot);
376
377 let before_retrieval = chrono::Utc::now();
378
379 let Some(context) = EditPredictionContext::gather_context(
380 cursor_point,
381 &snapshot,
382 &options.excerpt,
383 index_state.as_deref(),
384 ) else {
385 return Ok(None);
386 };
387
388 let debug_context = if let Some(debug_tx) = debug_tx {
389 Some((debug_tx, context.clone()))
390 } else {
391 None
392 };
393
394 let (diagnostic_groups, diagnostic_groups_truncated) =
395 Self::gather_nearby_diagnostics(
396 cursor_offset,
397 &diagnostics,
398 &snapshot,
399 options.max_diagnostic_bytes,
400 );
401
402 let request = make_cloud_request(
403 excerpt_path.clone(),
404 context,
405 events,
406 // TODO data collection
407 false,
408 diagnostic_groups,
409 diagnostic_groups_truncated,
410 None,
411 debug_context.is_some(),
412 &worktree_snapshots,
413 index_state.as_deref(),
414 Some(options.max_prompt_bytes),
415 );
416
417 let retrieval_time = chrono::Utc::now() - before_retrieval;
418 let response = Self::perform_request(client, llm_token, app_version, request).await;
419
420 if let Some((debug_tx, context)) = debug_context {
421 debug_tx
422 .unbounded_send(response.as_ref().map_err(|err| err.to_string()).and_then(
423 |response| {
424 let Some(request) =
425 some_or_debug_panic(response.0.debug_info.clone())
426 else {
427 return Err("Missing debug info".to_string());
428 };
429 Ok(PredictionDebugInfo {
430 context,
431 request,
432 retrieval_time,
433 buffer: buffer.downgrade(),
434 position,
435 })
436 },
437 ))
438 .ok();
439 }
440
441 let (response, usage) = response?;
442 let edits = edits_from_response(&response.edits, &snapshot);
443
444 anyhow::Ok(Some((response.request_id, edits, usage)))
445 }
446 });
447
448 let buffer = buffer.clone();
449
450 cx.spawn(async move |this, cx| {
451 match request_task.await {
452 Ok(Some((id, edits, usage))) => {
453 if let Some(usage) = usage {
454 this.update(cx, |this, cx| {
455 this.user_store.update(cx, |user_store, cx| {
456 user_store.update_edit_prediction_usage(usage, cx);
457 });
458 })
459 .ok();
460 }
461
462 // TODO telemetry: duration, etc
463 let Some((edits, snapshot, edit_preview_task)) =
464 buffer.read_with(cx, |buffer, cx| {
465 let new_snapshot = buffer.snapshot();
466 let edits: Arc<[_]> =
467 interpolate_edits(&snapshot, &new_snapshot, edits)?.into();
468 Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
469 })?
470 else {
471 return Ok(None);
472 };
473
474 Ok(Some(EditPrediction {
475 id: id.into(),
476 edits,
477 snapshot,
478 edit_preview: edit_preview_task.await,
479 }))
480 }
481 Ok(None) => Ok(None),
482 Err(err) => {
483 if err.is::<ZedUpdateRequiredError>() {
484 cx.update(|cx| {
485 this.update(cx, |this, _cx| {
486 this.update_required = true;
487 })
488 .ok();
489
490 let error_message: SharedString = err.to_string().into();
491 show_app_notification(
492 NotificationId::unique::<ZedUpdateRequiredError>(),
493 cx,
494 move |cx| {
495 cx.new(|cx| {
496 ErrorMessagePrompt::new(error_message.clone(), cx)
497 .with_link_button(
498 "Update Zed",
499 "https://zed.dev/releases",
500 )
501 })
502 },
503 );
504 })
505 .ok();
506 }
507
508 Err(err)
509 }
510 }
511 })
512 }
513
514 async fn perform_request(
515 client: Arc<Client>,
516 llm_token: LlmApiToken,
517 app_version: SemanticVersion,
518 request: predict_edits_v3::PredictEditsRequest,
519 ) -> Result<(
520 predict_edits_v3::PredictEditsResponse,
521 Option<EditPredictionUsage>,
522 )> {
523 let http_client = client.http_client();
524 let mut token = llm_token.acquire(&client).await?;
525 let mut did_retry = false;
526
527 loop {
528 let request_builder = http_client::Request::builder().method(Method::POST);
529 let request_builder =
530 if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
531 request_builder.uri(predict_edits_url)
532 } else {
533 request_builder.uri(
534 http_client
535 .build_zed_llm_url("/predict_edits/v3", &[])?
536 .as_ref(),
537 )
538 };
539 let request = request_builder
540 .header("Content-Type", "application/json")
541 .header("Authorization", format!("Bearer {}", token))
542 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
543 .body(serde_json::to_string(&request)?.into())?;
544
545 let mut response = http_client.send(request).await?;
546
547 if let Some(minimum_required_version) = response
548 .headers()
549 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
550 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
551 {
552 anyhow::ensure!(
553 app_version >= minimum_required_version,
554 ZedUpdateRequiredError {
555 minimum_version: minimum_required_version
556 }
557 );
558 }
559
560 if response.status().is_success() {
561 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
562
563 let mut body = Vec::new();
564 response.body_mut().read_to_end(&mut body).await?;
565 return Ok((serde_json::from_slice(&body)?, usage));
566 } else if !did_retry
567 && response
568 .headers()
569 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
570 .is_some()
571 {
572 did_retry = true;
573 token = llm_token.refresh(&client).await?;
574 } else {
575 let mut body = String::new();
576 response.body_mut().read_to_string(&mut body).await?;
577 anyhow::bail!(
578 "error predicting edits.\nStatus: {:?}\nBody: {}",
579 response.status(),
580 body
581 );
582 }
583 }
584 }
585
586 fn gather_nearby_diagnostics(
587 cursor_offset: usize,
588 diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
589 snapshot: &BufferSnapshot,
590 max_diagnostics_bytes: usize,
591 ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
592 // TODO: Could make this more efficient
593 let mut diagnostic_groups = Vec::new();
594 for (language_server_id, diagnostics) in diagnostic_sets {
595 let mut groups = Vec::new();
596 diagnostics.groups(*language_server_id, &mut groups, &snapshot);
597 diagnostic_groups.extend(
598 groups
599 .into_iter()
600 .map(|(_, group)| group.resolve::<usize>(&snapshot)),
601 );
602 }
603
604 // sort by proximity to cursor
605 diagnostic_groups.sort_by_key(|group| {
606 let range = &group.entries[group.primary_ix].range;
607 if range.start >= cursor_offset {
608 range.start - cursor_offset
609 } else if cursor_offset >= range.end {
610 cursor_offset - range.end
611 } else {
612 (cursor_offset - range.start).min(range.end - cursor_offset)
613 }
614 });
615
616 let mut results = Vec::new();
617 let mut diagnostic_groups_truncated = false;
618 let mut diagnostics_byte_count = 0;
619 for group in diagnostic_groups {
620 let raw_value = serde_json::value::to_raw_value(&group).unwrap();
621 diagnostics_byte_count += raw_value.get().len();
622 if diagnostics_byte_count > max_diagnostics_bytes {
623 diagnostic_groups_truncated = true;
624 break;
625 }
626 results.push(predict_edits_v3::DiagnosticGroup(raw_value));
627 }
628
629 (results, diagnostic_groups_truncated)
630 }
631
632 // TODO: Dedupe with similar code in request_prediction?
633 pub fn cloud_request_for_zeta_cli(
634 &mut self,
635 project: &Entity<Project>,
636 buffer: &Entity<Buffer>,
637 position: language::Anchor,
638 cx: &mut Context<Self>,
639 ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
640 let project_state = self.projects.get(&project.entity_id());
641
642 let index_state = project_state.map(|state| {
643 state
644 .syntax_index
645 .read_with(cx, |index, _cx| index.state().clone())
646 });
647 let options = self.options.clone();
648 let snapshot = buffer.read(cx).snapshot();
649 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
650 return Task::ready(Err(anyhow!("No file path for excerpt")));
651 };
652 let worktree_snapshots = project
653 .read(cx)
654 .worktrees(cx)
655 .map(|worktree| worktree.read(cx).snapshot())
656 .collect::<Vec<_>>();
657
658 cx.background_spawn(async move {
659 let index_state = if let Some(index_state) = index_state {
660 Some(index_state.lock_owned().await)
661 } else {
662 None
663 };
664
665 let cursor_point = position.to_point(&snapshot);
666
667 let debug_info = true;
668 EditPredictionContext::gather_context(
669 cursor_point,
670 &snapshot,
671 &options.excerpt,
672 index_state.as_deref(),
673 )
674 .context("Failed to select excerpt")
675 .map(|context| {
676 make_cloud_request(
677 excerpt_path.clone(),
678 context,
679 // TODO pass everything
680 Vec::new(),
681 false,
682 Vec::new(),
683 false,
684 None,
685 debug_info,
686 &worktree_snapshots,
687 index_state.as_deref(),
688 Some(options.max_prompt_bytes),
689 )
690 })
691 })
692 }
693}
694
695#[derive(Error, Debug)]
696#[error(
697 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
698)]
699pub struct ZedUpdateRequiredError {
700 minimum_version: SemanticVersion,
701}
702
703fn make_cloud_request(
704 excerpt_path: PathBuf,
705 context: EditPredictionContext,
706 events: Vec<predict_edits_v3::Event>,
707 can_collect_data: bool,
708 diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
709 diagnostic_groups_truncated: bool,
710 git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
711 debug_info: bool,
712 worktrees: &Vec<worktree::Snapshot>,
713 index_state: Option<&SyntaxIndexState>,
714 prompt_max_bytes: Option<usize>,
715) -> predict_edits_v3::PredictEditsRequest {
716 let mut signatures = Vec::new();
717 let mut declaration_to_signature_index = HashMap::default();
718 let mut referenced_declarations = Vec::new();
719
720 for snippet in context.snippets {
721 let project_entry_id = snippet.declaration.project_entry_id();
722 let Some(path) = worktrees.iter().find_map(|worktree| {
723 worktree.entry_for_id(project_entry_id).map(|entry| {
724 let mut full_path = PathBuf::new();
725 full_path.push(worktree.root_name());
726 full_path.push(&entry.path);
727 full_path
728 })
729 }) else {
730 continue;
731 };
732
733 let parent_index = index_state.and_then(|index_state| {
734 snippet.declaration.parent().and_then(|parent| {
735 add_signature(
736 parent,
737 &mut declaration_to_signature_index,
738 &mut signatures,
739 index_state,
740 )
741 })
742 });
743
744 let (text, text_is_truncated) = snippet.declaration.item_text();
745 referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
746 path,
747 text: text.into(),
748 range: snippet.declaration.item_range(),
749 text_is_truncated,
750 signature_range: snippet.declaration.signature_range_in_item_text(),
751 parent_index,
752 score_components: snippet.score_components,
753 signature_score: snippet.scores.signature,
754 declaration_score: snippet.scores.declaration,
755 });
756 }
757
758 let excerpt_parent = index_state.and_then(|index_state| {
759 context
760 .excerpt
761 .parent_declarations
762 .last()
763 .and_then(|(parent, _)| {
764 add_signature(
765 *parent,
766 &mut declaration_to_signature_index,
767 &mut signatures,
768 index_state,
769 )
770 })
771 });
772
773 predict_edits_v3::PredictEditsRequest {
774 excerpt_path,
775 excerpt: context.excerpt_text.body,
776 excerpt_range: context.excerpt.range,
777 cursor_offset: context.cursor_offset_in_excerpt,
778 referenced_declarations,
779 signatures,
780 excerpt_parent,
781 events,
782 can_collect_data,
783 diagnostic_groups,
784 diagnostic_groups_truncated,
785 git_info,
786 debug_info,
787 prompt_max_bytes,
788 }
789}
790
791fn add_signature(
792 declaration_id: DeclarationId,
793 declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
794 signatures: &mut Vec<Signature>,
795 index: &SyntaxIndexState,
796) -> Option<usize> {
797 if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
798 return Some(*signature_index);
799 }
800 let Some(parent_declaration) = index.declaration(declaration_id) else {
801 log::error!("bug: missing parent declaration");
802 return None;
803 };
804 let parent_index = parent_declaration.parent().and_then(|parent| {
805 add_signature(parent, declaration_to_signature_index, signatures, index)
806 });
807 let (text, text_is_truncated) = parent_declaration.signature_text();
808 let signature_index = signatures.len();
809 signatures.push(Signature {
810 text: text.into(),
811 text_is_truncated,
812 parent_index,
813 range: parent_declaration.signature_range(),
814 });
815 declaration_to_signature_index.insert(declaration_id, signature_index);
816 Some(signature_index)
817}