1use anyhow::{Context as _, Result, anyhow};
2use chrono::TimeDelta;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{self, Signature};
5use cloud_llm_client::{
6 EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, ZED_VERSION_HEADER_NAME,
7};
8use cloud_zeta2_prompt::DEFAULT_MAX_PROMPT_BYTES;
9use edit_prediction_context::{
10 DeclarationId, EditPredictionContext, EditPredictionExcerptOptions, SyntaxIndex,
11 SyntaxIndexState,
12};
13use futures::AsyncReadExt as _;
14use futures::channel::mpsc;
15use gpui::http_client::Method;
16use gpui::{
17 App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
18 http_client, prelude::*,
19};
20use language::BufferSnapshot;
21use language::{Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
22use language_model::{LlmApiToken, RefreshLlmTokenListener};
23use project::Project;
24use release_channel::AppVersion;
25use std::collections::{HashMap, VecDeque, hash_map};
26use std::path::PathBuf;
27use std::str::FromStr as _;
28use std::sync::Arc;
29use std::time::{Duration, Instant};
30use thiserror::Error;
31use util::rel_path::RelPathBuf;
32use util::some_or_debug_panic;
33use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
34
35mod prediction;
36mod provider;
37
38use crate::prediction::{EditPrediction, edits_from_response, interpolate_edits};
39pub use provider::ZetaEditPredictionProvider;
40
41const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
42
43/// Maximum number of events to track.
44const MAX_EVENT_COUNT: usize = 16;
45
46pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
47 max_bytes: 512,
48 min_bytes: 128,
49 target_before_cursor_over_total_bytes: 0.5,
50};
51
52pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
53 excerpt: DEFAULT_EXCERPT_OPTIONS,
54 max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
55 max_diagnostic_bytes: 2048,
56};
57
58#[derive(Clone)]
59struct ZetaGlobal(Entity<Zeta>);
60
61impl Global for ZetaGlobal {}
62
63pub struct Zeta {
64 client: Arc<Client>,
65 user_store: Entity<UserStore>,
66 llm_token: LlmApiToken,
67 _llm_token_subscription: Subscription,
68 projects: HashMap<EntityId, ZetaProject>,
69 options: ZetaOptions,
70 update_required: bool,
71 debug_tx: Option<mpsc::UnboundedSender<Result<PredictionDebugInfo, String>>>,
72}
73
74#[derive(Debug, Clone, PartialEq)]
75pub struct ZetaOptions {
76 pub excerpt: EditPredictionExcerptOptions,
77 pub max_prompt_bytes: usize,
78 pub max_diagnostic_bytes: usize,
79}
80
81pub struct PredictionDebugInfo {
82 pub context: EditPredictionContext,
83 pub retrieval_time: TimeDelta,
84 pub request: RequestDebugInfo,
85 pub buffer: WeakEntity<Buffer>,
86 pub position: language::Anchor,
87}
88
89pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
90
91struct ZetaProject {
92 syntax_index: Entity<SyntaxIndex>,
93 events: VecDeque<Event>,
94 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
95}
96
97struct RegisteredBuffer {
98 snapshot: BufferSnapshot,
99 _subscriptions: [gpui::Subscription; 2],
100}
101
102#[derive(Clone)]
103pub enum Event {
104 BufferChange {
105 old_snapshot: BufferSnapshot,
106 new_snapshot: BufferSnapshot,
107 timestamp: Instant,
108 },
109}
110
111impl Zeta {
112 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
113 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
114 }
115
116 pub fn global(
117 client: &Arc<Client>,
118 user_store: &Entity<UserStore>,
119 cx: &mut App,
120 ) -> Entity<Self> {
121 cx.try_global::<ZetaGlobal>()
122 .map(|global| global.0.clone())
123 .unwrap_or_else(|| {
124 let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
125 cx.set_global(ZetaGlobal(zeta.clone()));
126 zeta
127 })
128 }
129
130 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
131 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
132
133 Self {
134 projects: HashMap::new(),
135 client,
136 user_store,
137 options: DEFAULT_OPTIONS,
138 llm_token: LlmApiToken::default(),
139 _llm_token_subscription: cx.subscribe(
140 &refresh_llm_token_listener,
141 |this, _listener, _event, cx| {
142 let client = this.client.clone();
143 let llm_token = this.llm_token.clone();
144 cx.spawn(async move |_this, _cx| {
145 llm_token.refresh(&client).await?;
146 anyhow::Ok(())
147 })
148 .detach_and_log_err(cx);
149 },
150 ),
151 update_required: false,
152 debug_tx: None,
153 }
154 }
155
156 pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<Result<PredictionDebugInfo, String>> {
157 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
158 self.debug_tx = Some(debug_watch_tx);
159 debug_watch_rx
160 }
161
162 pub fn options(&self) -> &ZetaOptions {
163 &self.options
164 }
165
166 pub fn set_options(&mut self, options: ZetaOptions) {
167 self.options = options;
168 }
169
170 pub fn clear_history(&mut self) {
171 for zeta_project in self.projects.values_mut() {
172 zeta_project.events.clear();
173 }
174 }
175
176 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
177 self.user_store.read(cx).edit_prediction_usage()
178 }
179
180 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
181 self.get_or_init_zeta_project(project, cx);
182 }
183
184 pub fn register_buffer(
185 &mut self,
186 buffer: &Entity<Buffer>,
187 project: &Entity<Project>,
188 cx: &mut Context<Self>,
189 ) {
190 let zeta_project = self.get_or_init_zeta_project(project, cx);
191 Self::register_buffer_impl(zeta_project, buffer, project, cx);
192 }
193
194 fn get_or_init_zeta_project(
195 &mut self,
196 project: &Entity<Project>,
197 cx: &mut App,
198 ) -> &mut ZetaProject {
199 self.projects
200 .entry(project.entity_id())
201 .or_insert_with(|| ZetaProject {
202 syntax_index: cx.new(|cx| SyntaxIndex::new(project, cx)),
203 events: VecDeque::new(),
204 registered_buffers: HashMap::new(),
205 })
206 }
207
208 fn register_buffer_impl<'a>(
209 zeta_project: &'a mut ZetaProject,
210 buffer: &Entity<Buffer>,
211 project: &Entity<Project>,
212 cx: &mut Context<Self>,
213 ) -> &'a mut RegisteredBuffer {
214 let buffer_id = buffer.entity_id();
215 match zeta_project.registered_buffers.entry(buffer_id) {
216 hash_map::Entry::Occupied(entry) => entry.into_mut(),
217 hash_map::Entry::Vacant(entry) => {
218 let snapshot = buffer.read(cx).snapshot();
219 let project_entity_id = project.entity_id();
220 entry.insert(RegisteredBuffer {
221 snapshot,
222 _subscriptions: [
223 cx.subscribe(buffer, {
224 let project = project.downgrade();
225 move |this, buffer, event, cx| {
226 if let language::BufferEvent::Edited = event
227 && let Some(project) = project.upgrade()
228 {
229 this.report_changes_for_buffer(&buffer, &project, cx);
230 }
231 }
232 }),
233 cx.observe_release(buffer, move |this, _buffer, _cx| {
234 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
235 else {
236 return;
237 };
238 zeta_project.registered_buffers.remove(&buffer_id);
239 }),
240 ],
241 })
242 }
243 }
244 }
245
246 fn report_changes_for_buffer(
247 &mut self,
248 buffer: &Entity<Buffer>,
249 project: &Entity<Project>,
250 cx: &mut Context<Self>,
251 ) -> BufferSnapshot {
252 let zeta_project = self.get_or_init_zeta_project(project, cx);
253 let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
254
255 let new_snapshot = buffer.read(cx).snapshot();
256 if new_snapshot.version != registered_buffer.snapshot.version {
257 let old_snapshot =
258 std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
259 Self::push_event(
260 zeta_project,
261 Event::BufferChange {
262 old_snapshot,
263 new_snapshot: new_snapshot.clone(),
264 timestamp: Instant::now(),
265 },
266 );
267 }
268
269 new_snapshot
270 }
271
272 fn push_event(zeta_project: &mut ZetaProject, event: Event) {
273 let events = &mut zeta_project.events;
274
275 if let Some(Event::BufferChange {
276 new_snapshot: last_new_snapshot,
277 timestamp: last_timestamp,
278 ..
279 }) = events.back_mut()
280 {
281 // Coalesce edits for the same buffer when they happen one after the other.
282 let Event::BufferChange {
283 old_snapshot,
284 new_snapshot,
285 timestamp,
286 } = &event;
287
288 if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
289 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
290 && old_snapshot.version == last_new_snapshot.version
291 {
292 *last_new_snapshot = new_snapshot.clone();
293 *last_timestamp = *timestamp;
294 return;
295 }
296 }
297
298 if events.len() >= MAX_EVENT_COUNT {
299 // These are halved instead of popping to improve prompt caching.
300 events.drain(..MAX_EVENT_COUNT / 2);
301 }
302
303 events.push_back(event);
304 }
305
306 pub fn request_prediction(
307 &mut self,
308 project: &Entity<Project>,
309 buffer: &Entity<Buffer>,
310 position: language::Anchor,
311 cx: &mut Context<Self>,
312 ) -> Task<Result<Option<EditPrediction>>> {
313 let project_state = self.projects.get(&project.entity_id());
314
315 let index_state = project_state.map(|state| {
316 state
317 .syntax_index
318 .read_with(cx, |index, _cx| index.state().clone())
319 });
320 let options = self.options.clone();
321 let snapshot = buffer.read(cx).snapshot();
322 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
323 return Task::ready(Err(anyhow!("No file path for excerpt")));
324 };
325 let client = self.client.clone();
326 let llm_token = self.llm_token.clone();
327 let app_version = AppVersion::global(cx);
328 let worktree_snapshots = project
329 .read(cx)
330 .worktrees(cx)
331 .map(|worktree| worktree.read(cx).snapshot())
332 .collect::<Vec<_>>();
333 let debug_tx = self.debug_tx.clone();
334
335 let events = project_state
336 .map(|state| {
337 state
338 .events
339 .iter()
340 .map(|event| match event {
341 Event::BufferChange {
342 old_snapshot,
343 new_snapshot,
344 ..
345 } => {
346 let path = new_snapshot.file().map(|f| f.path().clone());
347
348 let old_path = old_snapshot.file().and_then(|f| {
349 let old_path = f.path();
350 if Some(old_path) != path.as_ref() {
351 Some(old_path.clone())
352 } else {
353 None
354 }
355 });
356
357 predict_edits_v3::Event::BufferChange {
358 old_path: old_path
359 .map(|old_path| old_path.as_std_path().to_path_buf()),
360 path: path.map(|path| path.as_std_path().to_path_buf()),
361 diff: language::unified_diff(
362 &old_snapshot.text(),
363 &new_snapshot.text(),
364 ),
365 //todo: Actually detect if this edit was predicted or not
366 predicted: false,
367 }
368 }
369 })
370 .collect::<Vec<_>>()
371 })
372 .unwrap_or_default();
373
374 let diagnostics = snapshot.diagnostic_sets().clone();
375
376 let request_task = cx.background_spawn({
377 let snapshot = snapshot.clone();
378 let buffer = buffer.clone();
379 async move {
380 let index_state = if let Some(index_state) = index_state {
381 Some(index_state.lock_owned().await)
382 } else {
383 None
384 };
385
386 let cursor_offset = position.to_offset(&snapshot);
387 let cursor_point = cursor_offset.to_point(&snapshot);
388
389 let before_retrieval = chrono::Utc::now();
390
391 let Some(context) = EditPredictionContext::gather_context(
392 cursor_point,
393 &snapshot,
394 &options.excerpt,
395 index_state.as_deref(),
396 ) else {
397 return Ok(None);
398 };
399
400 let debug_context = if let Some(debug_tx) = debug_tx {
401 Some((debug_tx, context.clone()))
402 } else {
403 None
404 };
405
406 let (diagnostic_groups, diagnostic_groups_truncated) =
407 Self::gather_nearby_diagnostics(
408 cursor_offset,
409 &diagnostics,
410 &snapshot,
411 options.max_diagnostic_bytes,
412 );
413
414 let request = make_cloud_request(
415 excerpt_path.clone(),
416 context,
417 events,
418 // TODO data collection
419 false,
420 diagnostic_groups,
421 diagnostic_groups_truncated,
422 None,
423 debug_context.is_some(),
424 &worktree_snapshots,
425 index_state.as_deref(),
426 Some(options.max_prompt_bytes),
427 );
428
429 let retrieval_time = chrono::Utc::now() - before_retrieval;
430 let response = Self::perform_request(client, llm_token, app_version, request).await;
431
432 if let Some((debug_tx, context)) = debug_context {
433 debug_tx
434 .unbounded_send(response.as_ref().map_err(|err| err.to_string()).and_then(
435 |response| {
436 let Some(request) =
437 some_or_debug_panic(response.0.debug_info.clone())
438 else {
439 return Err("Missing debug info".to_string());
440 };
441 Ok(PredictionDebugInfo {
442 context,
443 request,
444 retrieval_time,
445 buffer: buffer.downgrade(),
446 position,
447 })
448 },
449 ))
450 .ok();
451 }
452
453 let (response, usage) = response?;
454 let edits = edits_from_response(&response.edits, &snapshot);
455
456 anyhow::Ok(Some((response.request_id, edits, usage)))
457 }
458 });
459
460 let buffer = buffer.clone();
461
462 cx.spawn(async move |this, cx| {
463 match request_task.await {
464 Ok(Some((id, edits, usage))) => {
465 if let Some(usage) = usage {
466 this.update(cx, |this, cx| {
467 this.user_store.update(cx, |user_store, cx| {
468 user_store.update_edit_prediction_usage(usage, cx);
469 });
470 })
471 .ok();
472 }
473
474 // TODO telemetry: duration, etc
475 let Some((edits, snapshot, edit_preview_task)) =
476 buffer.read_with(cx, |buffer, cx| {
477 let new_snapshot = buffer.snapshot();
478 let edits: Arc<[_]> =
479 interpolate_edits(&snapshot, &new_snapshot, edits)?.into();
480 Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
481 })?
482 else {
483 return Ok(None);
484 };
485
486 Ok(Some(EditPrediction {
487 id: id.into(),
488 edits,
489 snapshot,
490 edit_preview: edit_preview_task.await,
491 }))
492 }
493 Ok(None) => Ok(None),
494 Err(err) => {
495 if err.is::<ZedUpdateRequiredError>() {
496 cx.update(|cx| {
497 this.update(cx, |this, _cx| {
498 this.update_required = true;
499 })
500 .ok();
501
502 let error_message: SharedString = err.to_string().into();
503 show_app_notification(
504 NotificationId::unique::<ZedUpdateRequiredError>(),
505 cx,
506 move |cx| {
507 cx.new(|cx| {
508 ErrorMessagePrompt::new(error_message.clone(), cx)
509 .with_link_button(
510 "Update Zed",
511 "https://zed.dev/releases",
512 )
513 })
514 },
515 );
516 })
517 .ok();
518 }
519
520 Err(err)
521 }
522 }
523 })
524 }
525
526 async fn perform_request(
527 client: Arc<Client>,
528 llm_token: LlmApiToken,
529 app_version: SemanticVersion,
530 request: predict_edits_v3::PredictEditsRequest,
531 ) -> Result<(
532 predict_edits_v3::PredictEditsResponse,
533 Option<EditPredictionUsage>,
534 )> {
535 let http_client = client.http_client();
536 let mut token = llm_token.acquire(&client).await?;
537 let mut did_retry = false;
538
539 loop {
540 let request_builder = http_client::Request::builder().method(Method::POST);
541 let request_builder =
542 if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
543 request_builder.uri(predict_edits_url)
544 } else {
545 request_builder.uri(
546 http_client
547 .build_zed_llm_url("/predict_edits/v3", &[])?
548 .as_ref(),
549 )
550 };
551 let request = request_builder
552 .header("Content-Type", "application/json")
553 .header("Authorization", format!("Bearer {}", token))
554 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
555 .body(serde_json::to_string(&request)?.into())?;
556
557 let mut response = http_client.send(request).await?;
558
559 if let Some(minimum_required_version) = response
560 .headers()
561 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
562 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
563 {
564 anyhow::ensure!(
565 app_version >= minimum_required_version,
566 ZedUpdateRequiredError {
567 minimum_version: minimum_required_version
568 }
569 );
570 }
571
572 if response.status().is_success() {
573 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
574
575 let mut body = Vec::new();
576 response.body_mut().read_to_end(&mut body).await?;
577 return Ok((serde_json::from_slice(&body)?, usage));
578 } else if !did_retry
579 && response
580 .headers()
581 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
582 .is_some()
583 {
584 did_retry = true;
585 token = llm_token.refresh(&client).await?;
586 } else {
587 let mut body = String::new();
588 response.body_mut().read_to_string(&mut body).await?;
589 anyhow::bail!(
590 "error predicting edits.\nStatus: {:?}\nBody: {}",
591 response.status(),
592 body
593 );
594 }
595 }
596 }
597
598 fn gather_nearby_diagnostics(
599 cursor_offset: usize,
600 diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
601 snapshot: &BufferSnapshot,
602 max_diagnostics_bytes: usize,
603 ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
604 // TODO: Could make this more efficient
605 let mut diagnostic_groups = Vec::new();
606 for (language_server_id, diagnostics) in diagnostic_sets {
607 let mut groups = Vec::new();
608 diagnostics.groups(*language_server_id, &mut groups, &snapshot);
609 diagnostic_groups.extend(
610 groups
611 .into_iter()
612 .map(|(_, group)| group.resolve::<usize>(&snapshot)),
613 );
614 }
615
616 // sort by proximity to cursor
617 diagnostic_groups.sort_by_key(|group| {
618 let range = &group.entries[group.primary_ix].range;
619 if range.start >= cursor_offset {
620 range.start - cursor_offset
621 } else if cursor_offset >= range.end {
622 cursor_offset - range.end
623 } else {
624 (cursor_offset - range.start).min(range.end - cursor_offset)
625 }
626 });
627
628 let mut results = Vec::new();
629 let mut diagnostic_groups_truncated = false;
630 let mut diagnostics_byte_count = 0;
631 for group in diagnostic_groups {
632 let raw_value = serde_json::value::to_raw_value(&group).unwrap();
633 diagnostics_byte_count += raw_value.get().len();
634 if diagnostics_byte_count > max_diagnostics_bytes {
635 diagnostic_groups_truncated = true;
636 break;
637 }
638 results.push(predict_edits_v3::DiagnosticGroup(raw_value));
639 }
640
641 (results, diagnostic_groups_truncated)
642 }
643
644 // TODO: Dedupe with similar code in request_prediction?
645 pub fn cloud_request_for_zeta_cli(
646 &mut self,
647 project: &Entity<Project>,
648 buffer: &Entity<Buffer>,
649 position: language::Anchor,
650 cx: &mut Context<Self>,
651 ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
652 let project_state = self.projects.get(&project.entity_id());
653
654 let index_state = project_state.map(|state| {
655 state
656 .syntax_index
657 .read_with(cx, |index, _cx| index.state().clone())
658 });
659 let options = self.options.clone();
660 let snapshot = buffer.read(cx).snapshot();
661 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
662 return Task::ready(Err(anyhow!("No file path for excerpt")));
663 };
664 let worktree_snapshots = project
665 .read(cx)
666 .worktrees(cx)
667 .map(|worktree| worktree.read(cx).snapshot())
668 .collect::<Vec<_>>();
669
670 cx.background_spawn(async move {
671 let index_state = if let Some(index_state) = index_state {
672 Some(index_state.lock_owned().await)
673 } else {
674 None
675 };
676
677 let cursor_point = position.to_point(&snapshot);
678
679 let debug_info = true;
680 EditPredictionContext::gather_context(
681 cursor_point,
682 &snapshot,
683 &options.excerpt,
684 index_state.as_deref(),
685 )
686 .context("Failed to select excerpt")
687 .map(|context| {
688 make_cloud_request(
689 excerpt_path.clone(),
690 context,
691 // TODO pass everything
692 Vec::new(),
693 false,
694 Vec::new(),
695 false,
696 None,
697 debug_info,
698 &worktree_snapshots,
699 index_state.as_deref(),
700 Some(options.max_prompt_bytes),
701 )
702 })
703 })
704 }
705}
706
707#[derive(Error, Debug)]
708#[error(
709 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
710)]
711pub struct ZedUpdateRequiredError {
712 minimum_version: SemanticVersion,
713}
714
715fn make_cloud_request(
716 excerpt_path: PathBuf,
717 context: EditPredictionContext,
718 events: Vec<predict_edits_v3::Event>,
719 can_collect_data: bool,
720 diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
721 diagnostic_groups_truncated: bool,
722 git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
723 debug_info: bool,
724 worktrees: &Vec<worktree::Snapshot>,
725 index_state: Option<&SyntaxIndexState>,
726 prompt_max_bytes: Option<usize>,
727) -> predict_edits_v3::PredictEditsRequest {
728 let mut signatures = Vec::new();
729 let mut declaration_to_signature_index = HashMap::default();
730 let mut referenced_declarations = Vec::new();
731
732 for snippet in context.snippets {
733 let project_entry_id = snippet.declaration.project_entry_id();
734 let Some(path) = worktrees.iter().find_map(|worktree| {
735 worktree.entry_for_id(project_entry_id).map(|entry| {
736 let mut full_path = RelPathBuf::new();
737 full_path.push(worktree.root_name());
738 full_path.push(&entry.path);
739 full_path
740 })
741 }) else {
742 continue;
743 };
744
745 let parent_index = index_state.and_then(|index_state| {
746 snippet.declaration.parent().and_then(|parent| {
747 add_signature(
748 parent,
749 &mut declaration_to_signature_index,
750 &mut signatures,
751 index_state,
752 )
753 })
754 });
755
756 let (text, text_is_truncated) = snippet.declaration.item_text();
757 referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
758 path: path.as_std_path().to_path_buf(),
759 text: text.into(),
760 range: snippet.declaration.item_range(),
761 text_is_truncated,
762 signature_range: snippet.declaration.signature_range_in_item_text(),
763 parent_index,
764 score_components: snippet.score_components,
765 signature_score: snippet.scores.signature,
766 declaration_score: snippet.scores.declaration,
767 });
768 }
769
770 let excerpt_parent = index_state.and_then(|index_state| {
771 context
772 .excerpt
773 .parent_declarations
774 .last()
775 .and_then(|(parent, _)| {
776 add_signature(
777 *parent,
778 &mut declaration_to_signature_index,
779 &mut signatures,
780 index_state,
781 )
782 })
783 });
784
785 predict_edits_v3::PredictEditsRequest {
786 excerpt_path,
787 excerpt: context.excerpt_text.body,
788 excerpt_range: context.excerpt.range,
789 cursor_offset: context.cursor_offset_in_excerpt,
790 referenced_declarations,
791 signatures,
792 excerpt_parent,
793 events,
794 can_collect_data,
795 diagnostic_groups,
796 diagnostic_groups_truncated,
797 git_info,
798 debug_info,
799 prompt_max_bytes,
800 }
801}
802
803fn add_signature(
804 declaration_id: DeclarationId,
805 declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
806 signatures: &mut Vec<Signature>,
807 index: &SyntaxIndexState,
808) -> Option<usize> {
809 if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
810 return Some(*signature_index);
811 }
812 let Some(parent_declaration) = index.declaration(declaration_id) else {
813 log::error!("bug: missing parent declaration");
814 return None;
815 };
816 let parent_index = parent_declaration.parent().and_then(|parent| {
817 add_signature(parent, declaration_to_signature_index, signatures, index)
818 });
819 let (text, text_is_truncated) = parent_declaration.signature_text();
820 let signature_index = signatures.len();
821 signatures.push(Signature {
822 text: text.into(),
823 text_is_truncated,
824 parent_index,
825 range: parent_declaration.signature_range(),
826 });
827 declaration_to_signature_index.insert(declaration_id, signature_index);
828 Some(signature_index)
829}