1use anyhow::{Context as _, Result, anyhow};
2use chrono::TimeDelta;
3use client::{Client, EditPredictionUsage, UserStore};
4use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
5use cloud_llm_client::{
6 EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, ZED_VERSION_HEADER_NAME,
7};
8use cloud_zeta2_prompt::DEFAULT_MAX_PROMPT_BYTES;
9use edit_prediction_context::{
10 DeclarationId, EditPredictionContext, EditPredictionExcerptOptions, SyntaxIndex,
11 SyntaxIndexState,
12};
13use futures::AsyncReadExt as _;
14use futures::channel::mpsc;
15use gpui::http_client::Method;
16use gpui::{
17 App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
18 http_client, prelude::*,
19};
20use language::BufferSnapshot;
21use language::{Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
22use language_model::{LlmApiToken, RefreshLlmTokenListener};
23use project::Project;
24use release_channel::AppVersion;
25use std::collections::{HashMap, VecDeque, hash_map};
26use std::path::Path;
27use std::str::FromStr as _;
28use std::sync::Arc;
29use std::time::{Duration, Instant};
30use thiserror::Error;
31use util::rel_path::RelPathBuf;
32use util::some_or_debug_panic;
33use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
34
35mod prediction;
36mod provider;
37
38use crate::prediction::{EditPrediction, edits_from_response, interpolate_edits};
39pub use provider::ZetaEditPredictionProvider;
40
41const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
42
43/// Maximum number of events to track.
44const MAX_EVENT_COUNT: usize = 16;
45
46pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
47 max_bytes: 512,
48 min_bytes: 128,
49 target_before_cursor_over_total_bytes: 0.5,
50};
51
52pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
53 excerpt: DEFAULT_EXCERPT_OPTIONS,
54 max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
55 max_diagnostic_bytes: 2048,
56 prompt_format: PromptFormat::MarkedExcerpt,
57};
58
59#[derive(Clone)]
60struct ZetaGlobal(Entity<Zeta>);
61
62impl Global for ZetaGlobal {}
63
64pub struct Zeta {
65 client: Arc<Client>,
66 user_store: Entity<UserStore>,
67 llm_token: LlmApiToken,
68 _llm_token_subscription: Subscription,
69 projects: HashMap<EntityId, ZetaProject>,
70 options: ZetaOptions,
71 update_required: bool,
72 debug_tx: Option<mpsc::UnboundedSender<Result<PredictionDebugInfo, String>>>,
73}
74
75#[derive(Debug, Clone, PartialEq)]
76pub struct ZetaOptions {
77 pub excerpt: EditPredictionExcerptOptions,
78 pub max_prompt_bytes: usize,
79 pub max_diagnostic_bytes: usize,
80 pub prompt_format: predict_edits_v3::PromptFormat,
81}
82
83pub struct PredictionDebugInfo {
84 pub context: EditPredictionContext,
85 pub retrieval_time: TimeDelta,
86 pub request: RequestDebugInfo,
87 pub buffer: WeakEntity<Buffer>,
88 pub position: language::Anchor,
89}
90
91pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
92
93struct ZetaProject {
94 syntax_index: Entity<SyntaxIndex>,
95 events: VecDeque<Event>,
96 registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
97}
98
99struct RegisteredBuffer {
100 snapshot: BufferSnapshot,
101 _subscriptions: [gpui::Subscription; 2],
102}
103
104#[derive(Clone)]
105pub enum Event {
106 BufferChange {
107 old_snapshot: BufferSnapshot,
108 new_snapshot: BufferSnapshot,
109 timestamp: Instant,
110 },
111}
112
113impl Zeta {
114 pub fn try_global(cx: &App) -> Option<Entity<Self>> {
115 cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
116 }
117
118 pub fn global(
119 client: &Arc<Client>,
120 user_store: &Entity<UserStore>,
121 cx: &mut App,
122 ) -> Entity<Self> {
123 cx.try_global::<ZetaGlobal>()
124 .map(|global| global.0.clone())
125 .unwrap_or_else(|| {
126 let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
127 cx.set_global(ZetaGlobal(zeta.clone()));
128 zeta
129 })
130 }
131
132 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
133 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
134
135 Self {
136 projects: HashMap::new(),
137 client,
138 user_store,
139 options: DEFAULT_OPTIONS,
140 llm_token: LlmApiToken::default(),
141 _llm_token_subscription: cx.subscribe(
142 &refresh_llm_token_listener,
143 |this, _listener, _event, cx| {
144 let client = this.client.clone();
145 let llm_token = this.llm_token.clone();
146 cx.spawn(async move |_this, _cx| {
147 llm_token.refresh(&client).await?;
148 anyhow::Ok(())
149 })
150 .detach_and_log_err(cx);
151 },
152 ),
153 update_required: false,
154 debug_tx: None,
155 }
156 }
157
158 pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<Result<PredictionDebugInfo, String>> {
159 let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
160 self.debug_tx = Some(debug_watch_tx);
161 debug_watch_rx
162 }
163
164 pub fn options(&self) -> &ZetaOptions {
165 &self.options
166 }
167
168 pub fn set_options(&mut self, options: ZetaOptions) {
169 self.options = options;
170 }
171
172 pub fn clear_history(&mut self) {
173 for zeta_project in self.projects.values_mut() {
174 zeta_project.events.clear();
175 }
176 }
177
178 pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
179 self.user_store.read(cx).edit_prediction_usage()
180 }
181
182 pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
183 self.get_or_init_zeta_project(project, cx);
184 }
185
186 pub fn register_buffer(
187 &mut self,
188 buffer: &Entity<Buffer>,
189 project: &Entity<Project>,
190 cx: &mut Context<Self>,
191 ) {
192 let zeta_project = self.get_or_init_zeta_project(project, cx);
193 Self::register_buffer_impl(zeta_project, buffer, project, cx);
194 }
195
196 fn get_or_init_zeta_project(
197 &mut self,
198 project: &Entity<Project>,
199 cx: &mut App,
200 ) -> &mut ZetaProject {
201 self.projects
202 .entry(project.entity_id())
203 .or_insert_with(|| ZetaProject {
204 syntax_index: cx.new(|cx| SyntaxIndex::new(project, cx)),
205 events: VecDeque::new(),
206 registered_buffers: HashMap::new(),
207 })
208 }
209
210 fn register_buffer_impl<'a>(
211 zeta_project: &'a mut ZetaProject,
212 buffer: &Entity<Buffer>,
213 project: &Entity<Project>,
214 cx: &mut Context<Self>,
215 ) -> &'a mut RegisteredBuffer {
216 let buffer_id = buffer.entity_id();
217 match zeta_project.registered_buffers.entry(buffer_id) {
218 hash_map::Entry::Occupied(entry) => entry.into_mut(),
219 hash_map::Entry::Vacant(entry) => {
220 let snapshot = buffer.read(cx).snapshot();
221 let project_entity_id = project.entity_id();
222 entry.insert(RegisteredBuffer {
223 snapshot,
224 _subscriptions: [
225 cx.subscribe(buffer, {
226 let project = project.downgrade();
227 move |this, buffer, event, cx| {
228 if let language::BufferEvent::Edited = event
229 && let Some(project) = project.upgrade()
230 {
231 this.report_changes_for_buffer(&buffer, &project, cx);
232 }
233 }
234 }),
235 cx.observe_release(buffer, move |this, _buffer, _cx| {
236 let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
237 else {
238 return;
239 };
240 zeta_project.registered_buffers.remove(&buffer_id);
241 }),
242 ],
243 })
244 }
245 }
246 }
247
248 fn report_changes_for_buffer(
249 &mut self,
250 buffer: &Entity<Buffer>,
251 project: &Entity<Project>,
252 cx: &mut Context<Self>,
253 ) -> BufferSnapshot {
254 let zeta_project = self.get_or_init_zeta_project(project, cx);
255 let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
256
257 let new_snapshot = buffer.read(cx).snapshot();
258 if new_snapshot.version != registered_buffer.snapshot.version {
259 let old_snapshot =
260 std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
261 Self::push_event(
262 zeta_project,
263 Event::BufferChange {
264 old_snapshot,
265 new_snapshot: new_snapshot.clone(),
266 timestamp: Instant::now(),
267 },
268 );
269 }
270
271 new_snapshot
272 }
273
274 fn push_event(zeta_project: &mut ZetaProject, event: Event) {
275 let events = &mut zeta_project.events;
276
277 if let Some(Event::BufferChange {
278 new_snapshot: last_new_snapshot,
279 timestamp: last_timestamp,
280 ..
281 }) = events.back_mut()
282 {
283 // Coalesce edits for the same buffer when they happen one after the other.
284 let Event::BufferChange {
285 old_snapshot,
286 new_snapshot,
287 timestamp,
288 } = &event;
289
290 if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
291 && old_snapshot.remote_id() == last_new_snapshot.remote_id()
292 && old_snapshot.version == last_new_snapshot.version
293 {
294 *last_new_snapshot = new_snapshot.clone();
295 *last_timestamp = *timestamp;
296 return;
297 }
298 }
299
300 if events.len() >= MAX_EVENT_COUNT {
301 // These are halved instead of popping to improve prompt caching.
302 events.drain(..MAX_EVENT_COUNT / 2);
303 }
304
305 events.push_back(event);
306 }
307
308 pub fn request_prediction(
309 &mut self,
310 project: &Entity<Project>,
311 buffer: &Entity<Buffer>,
312 position: language::Anchor,
313 cx: &mut Context<Self>,
314 ) -> Task<Result<Option<EditPrediction>>> {
315 let project_state = self.projects.get(&project.entity_id());
316
317 let index_state = project_state.map(|state| {
318 state
319 .syntax_index
320 .read_with(cx, |index, _cx| index.state().clone())
321 });
322 let options = self.options.clone();
323 let snapshot = buffer.read(cx).snapshot();
324 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx).into()) else {
325 return Task::ready(Err(anyhow!("No file path for excerpt")));
326 };
327 let client = self.client.clone();
328 let llm_token = self.llm_token.clone();
329 let app_version = AppVersion::global(cx);
330 let worktree_snapshots = project
331 .read(cx)
332 .worktrees(cx)
333 .map(|worktree| worktree.read(cx).snapshot())
334 .collect::<Vec<_>>();
335 let debug_tx = self.debug_tx.clone();
336
337 let events = project_state
338 .map(|state| {
339 state
340 .events
341 .iter()
342 .map(|event| match event {
343 Event::BufferChange {
344 old_snapshot,
345 new_snapshot,
346 ..
347 } => {
348 let path = new_snapshot.file().map(|f| f.path().clone());
349
350 let old_path = old_snapshot.file().and_then(|f| {
351 let old_path = f.path();
352 if Some(old_path) != path.as_ref() {
353 Some(old_path.clone())
354 } else {
355 None
356 }
357 });
358
359 predict_edits_v3::Event::BufferChange {
360 old_path: old_path
361 .map(|old_path| old_path.as_std_path().to_path_buf()),
362 path: path.map(|path| path.as_std_path().to_path_buf()),
363 diff: language::unified_diff(
364 &old_snapshot.text(),
365 &new_snapshot.text(),
366 ),
367 //todo: Actually detect if this edit was predicted or not
368 predicted: false,
369 }
370 }
371 })
372 .collect::<Vec<_>>()
373 })
374 .unwrap_or_default();
375
376 let diagnostics = snapshot.diagnostic_sets().clone();
377
378 let request_task = cx.background_spawn({
379 let snapshot = snapshot.clone();
380 let buffer = buffer.clone();
381 async move {
382 let index_state = if let Some(index_state) = index_state {
383 Some(index_state.lock_owned().await)
384 } else {
385 None
386 };
387
388 let cursor_offset = position.to_offset(&snapshot);
389 let cursor_point = cursor_offset.to_point(&snapshot);
390
391 let before_retrieval = chrono::Utc::now();
392
393 let Some(context) = EditPredictionContext::gather_context(
394 cursor_point,
395 &snapshot,
396 &options.excerpt,
397 index_state.as_deref(),
398 ) else {
399 return Ok(None);
400 };
401
402 let debug_context = if let Some(debug_tx) = debug_tx {
403 Some((debug_tx, context.clone()))
404 } else {
405 None
406 };
407
408 let (diagnostic_groups, diagnostic_groups_truncated) =
409 Self::gather_nearby_diagnostics(
410 cursor_offset,
411 &diagnostics,
412 &snapshot,
413 options.max_diagnostic_bytes,
414 );
415
416 let request = make_cloud_request(
417 excerpt_path,
418 context,
419 events,
420 // TODO data collection
421 false,
422 diagnostic_groups,
423 diagnostic_groups_truncated,
424 None,
425 debug_context.is_some(),
426 &worktree_snapshots,
427 index_state.as_deref(),
428 Some(options.max_prompt_bytes),
429 options.prompt_format,
430 );
431
432 let retrieval_time = chrono::Utc::now() - before_retrieval;
433 let response = Self::perform_request(client, llm_token, app_version, request).await;
434
435 if let Some((debug_tx, context)) = debug_context {
436 debug_tx
437 .unbounded_send(response.as_ref().map_err(|err| err.to_string()).and_then(
438 |response| {
439 let Some(request) =
440 some_or_debug_panic(response.0.debug_info.clone())
441 else {
442 return Err("Missing debug info".to_string());
443 };
444 Ok(PredictionDebugInfo {
445 context,
446 request,
447 retrieval_time,
448 buffer: buffer.downgrade(),
449 position,
450 })
451 },
452 ))
453 .ok();
454 }
455
456 let (response, usage) = response?;
457 let edits = edits_from_response(&response.edits, &snapshot);
458
459 anyhow::Ok(Some((response.request_id, edits, usage)))
460 }
461 });
462
463 let buffer = buffer.clone();
464
465 cx.spawn(async move |this, cx| {
466 match request_task.await {
467 Ok(Some((id, edits, usage))) => {
468 if let Some(usage) = usage {
469 this.update(cx, |this, cx| {
470 this.user_store.update(cx, |user_store, cx| {
471 user_store.update_edit_prediction_usage(usage, cx);
472 });
473 })
474 .ok();
475 }
476
477 // TODO telemetry: duration, etc
478 let Some((edits, snapshot, edit_preview_task)) =
479 buffer.read_with(cx, |buffer, cx| {
480 let new_snapshot = buffer.snapshot();
481 let edits: Arc<[_]> =
482 interpolate_edits(&snapshot, &new_snapshot, edits)?.into();
483 Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
484 })?
485 else {
486 return Ok(None);
487 };
488
489 Ok(Some(EditPrediction {
490 id: id.into(),
491 edits,
492 snapshot,
493 edit_preview: edit_preview_task.await,
494 }))
495 }
496 Ok(None) => Ok(None),
497 Err(err) => {
498 if err.is::<ZedUpdateRequiredError>() {
499 cx.update(|cx| {
500 this.update(cx, |this, _cx| {
501 this.update_required = true;
502 })
503 .ok();
504
505 let error_message: SharedString = err.to_string().into();
506 show_app_notification(
507 NotificationId::unique::<ZedUpdateRequiredError>(),
508 cx,
509 move |cx| {
510 cx.new(|cx| {
511 ErrorMessagePrompt::new(error_message.clone(), cx)
512 .with_link_button(
513 "Update Zed",
514 "https://zed.dev/releases",
515 )
516 })
517 },
518 );
519 })
520 .ok();
521 }
522
523 Err(err)
524 }
525 }
526 })
527 }
528
529 async fn perform_request(
530 client: Arc<Client>,
531 llm_token: LlmApiToken,
532 app_version: SemanticVersion,
533 request: predict_edits_v3::PredictEditsRequest,
534 ) -> Result<(
535 predict_edits_v3::PredictEditsResponse,
536 Option<EditPredictionUsage>,
537 )> {
538 let http_client = client.http_client();
539 let mut token = llm_token.acquire(&client).await?;
540 let mut did_retry = false;
541
542 loop {
543 let request_builder = http_client::Request::builder().method(Method::POST);
544 let request_builder =
545 if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
546 request_builder.uri(predict_edits_url)
547 } else {
548 request_builder.uri(
549 http_client
550 .build_zed_llm_url("/predict_edits/v3", &[])?
551 .as_ref(),
552 )
553 };
554 let request = request_builder
555 .header("Content-Type", "application/json")
556 .header("Authorization", format!("Bearer {}", token))
557 .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
558 .body(serde_json::to_string(&request)?.into())?;
559
560 let mut response = http_client.send(request).await?;
561
562 if let Some(minimum_required_version) = response
563 .headers()
564 .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
565 .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
566 {
567 anyhow::ensure!(
568 app_version >= minimum_required_version,
569 ZedUpdateRequiredError {
570 minimum_version: minimum_required_version
571 }
572 );
573 }
574
575 if response.status().is_success() {
576 let usage = EditPredictionUsage::from_headers(response.headers()).ok();
577
578 let mut body = Vec::new();
579 response.body_mut().read_to_end(&mut body).await?;
580 return Ok((serde_json::from_slice(&body)?, usage));
581 } else if !did_retry
582 && response
583 .headers()
584 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
585 .is_some()
586 {
587 did_retry = true;
588 token = llm_token.refresh(&client).await?;
589 } else {
590 let mut body = String::new();
591 response.body_mut().read_to_string(&mut body).await?;
592 anyhow::bail!(
593 "error predicting edits.\nStatus: {:?}\nBody: {}",
594 response.status(),
595 body
596 );
597 }
598 }
599 }
600
601 fn gather_nearby_diagnostics(
602 cursor_offset: usize,
603 diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
604 snapshot: &BufferSnapshot,
605 max_diagnostics_bytes: usize,
606 ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
607 // TODO: Could make this more efficient
608 let mut diagnostic_groups = Vec::new();
609 for (language_server_id, diagnostics) in diagnostic_sets {
610 let mut groups = Vec::new();
611 diagnostics.groups(*language_server_id, &mut groups, &snapshot);
612 diagnostic_groups.extend(
613 groups
614 .into_iter()
615 .map(|(_, group)| group.resolve::<usize>(&snapshot)),
616 );
617 }
618
619 // sort by proximity to cursor
620 diagnostic_groups.sort_by_key(|group| {
621 let range = &group.entries[group.primary_ix].range;
622 if range.start >= cursor_offset {
623 range.start - cursor_offset
624 } else if cursor_offset >= range.end {
625 cursor_offset - range.end
626 } else {
627 (cursor_offset - range.start).min(range.end - cursor_offset)
628 }
629 });
630
631 let mut results = Vec::new();
632 let mut diagnostic_groups_truncated = false;
633 let mut diagnostics_byte_count = 0;
634 for group in diagnostic_groups {
635 let raw_value = serde_json::value::to_raw_value(&group).unwrap();
636 diagnostics_byte_count += raw_value.get().len();
637 if diagnostics_byte_count > max_diagnostics_bytes {
638 diagnostic_groups_truncated = true;
639 break;
640 }
641 results.push(predict_edits_v3::DiagnosticGroup(raw_value));
642 }
643
644 (results, diagnostic_groups_truncated)
645 }
646
647 // TODO: Dedupe with similar code in request_prediction?
648 pub fn cloud_request_for_zeta_cli(
649 &mut self,
650 project: &Entity<Project>,
651 buffer: &Entity<Buffer>,
652 position: language::Anchor,
653 cx: &mut Context<Self>,
654 ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
655 let project_state = self.projects.get(&project.entity_id());
656
657 let index_state = project_state.map(|state| {
658 state
659 .syntax_index
660 .read_with(cx, |index, _cx| index.state().clone())
661 });
662 let options = self.options.clone();
663 let snapshot = buffer.read(cx).snapshot();
664 let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
665 return Task::ready(Err(anyhow!("No file path for excerpt")));
666 };
667 let worktree_snapshots = project
668 .read(cx)
669 .worktrees(cx)
670 .map(|worktree| worktree.read(cx).snapshot())
671 .collect::<Vec<_>>();
672
673 cx.background_spawn(async move {
674 let index_state = if let Some(index_state) = index_state {
675 Some(index_state.lock_owned().await)
676 } else {
677 None
678 };
679
680 let cursor_point = position.to_point(&snapshot);
681
682 let debug_info = true;
683 EditPredictionContext::gather_context(
684 cursor_point,
685 &snapshot,
686 &options.excerpt,
687 index_state.as_deref(),
688 )
689 .context("Failed to select excerpt")
690 .map(|context| {
691 make_cloud_request(
692 excerpt_path.into(),
693 context,
694 // TODO pass everything
695 Vec::new(),
696 false,
697 Vec::new(),
698 false,
699 None,
700 debug_info,
701 &worktree_snapshots,
702 index_state.as_deref(),
703 Some(options.max_prompt_bytes),
704 options.prompt_format,
705 )
706 })
707 })
708 }
709}
710
711#[derive(Error, Debug)]
712#[error(
713 "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
714)]
715pub struct ZedUpdateRequiredError {
716 minimum_version: SemanticVersion,
717}
718
719fn make_cloud_request(
720 excerpt_path: Arc<Path>,
721 context: EditPredictionContext,
722 events: Vec<predict_edits_v3::Event>,
723 can_collect_data: bool,
724 diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
725 diagnostic_groups_truncated: bool,
726 git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
727 debug_info: bool,
728 worktrees: &Vec<worktree::Snapshot>,
729 index_state: Option<&SyntaxIndexState>,
730 prompt_max_bytes: Option<usize>,
731 prompt_format: PromptFormat,
732) -> predict_edits_v3::PredictEditsRequest {
733 let mut signatures = Vec::new();
734 let mut declaration_to_signature_index = HashMap::default();
735 let mut referenced_declarations = Vec::new();
736
737 for snippet in context.snippets {
738 let project_entry_id = snippet.declaration.project_entry_id();
739 let Some(path) = worktrees.iter().find_map(|worktree| {
740 worktree.entry_for_id(project_entry_id).map(|entry| {
741 let mut full_path = RelPathBuf::new();
742 full_path.push(worktree.root_name());
743 full_path.push(&entry.path);
744 full_path
745 })
746 }) else {
747 continue;
748 };
749
750 let parent_index = index_state.and_then(|index_state| {
751 snippet.declaration.parent().and_then(|parent| {
752 add_signature(
753 parent,
754 &mut declaration_to_signature_index,
755 &mut signatures,
756 index_state,
757 )
758 })
759 });
760
761 let (text, text_is_truncated) = snippet.declaration.item_text();
762 referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
763 path: path.as_std_path().into(),
764 text: text.into(),
765 range: snippet.declaration.item_range(),
766 text_is_truncated,
767 signature_range: snippet.declaration.signature_range_in_item_text(),
768 parent_index,
769 score_components: snippet.score_components,
770 signature_score: snippet.scores.signature,
771 declaration_score: snippet.scores.declaration,
772 });
773 }
774
775 let excerpt_parent = index_state.and_then(|index_state| {
776 context
777 .excerpt
778 .parent_declarations
779 .last()
780 .and_then(|(parent, _)| {
781 add_signature(
782 *parent,
783 &mut declaration_to_signature_index,
784 &mut signatures,
785 index_state,
786 )
787 })
788 });
789
790 predict_edits_v3::PredictEditsRequest {
791 excerpt_path,
792 excerpt: context.excerpt_text.body,
793 excerpt_range: context.excerpt.range,
794 cursor_offset: context.cursor_offset_in_excerpt,
795 referenced_declarations,
796 signatures,
797 excerpt_parent,
798 events,
799 can_collect_data,
800 diagnostic_groups,
801 diagnostic_groups_truncated,
802 git_info,
803 debug_info,
804 prompt_max_bytes,
805 prompt_format,
806 }
807}
808
809fn add_signature(
810 declaration_id: DeclarationId,
811 declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
812 signatures: &mut Vec<Signature>,
813 index: &SyntaxIndexState,
814) -> Option<usize> {
815 if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
816 return Some(*signature_index);
817 }
818 let Some(parent_declaration) = index.declaration(declaration_id) else {
819 log::error!("bug: missing parent declaration");
820 return None;
821 };
822 let parent_index = parent_declaration.parent().and_then(|parent| {
823 add_signature(parent, declaration_to_signature_index, signatures, index)
824 });
825 let (text, text_is_truncated) = parent_declaration.signature_text();
826 let signature_index = signatures.len();
827 signatures.push(Signature {
828 text: text.into(),
829 text_is_truncated,
830 parent_index,
831 range: parent_declaration.signature_range(),
832 });
833 declaration_to_signature_index.insert(declaration_id, signature_index);
834 Some(signature_index)
835}