zeta2.rs

  1use anyhow::{Context as _, Result, anyhow};
  2use chrono::TimeDelta;
  3use client::{Client, EditPredictionUsage, UserStore};
  4use cloud_llm_client::predict_edits_v3::{self, Signature};
  5use cloud_llm_client::{
  6    EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, ZED_VERSION_HEADER_NAME,
  7};
  8use cloud_zeta2_prompt::DEFAULT_MAX_PROMPT_BYTES;
  9use edit_prediction_context::{
 10    DeclarationId, EditPredictionContext, EditPredictionExcerptOptions, SyntaxIndex,
 11    SyntaxIndexState,
 12};
 13use futures::AsyncReadExt as _;
 14use futures::channel::mpsc;
 15use gpui::http_client::Method;
 16use gpui::{
 17    App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
 18    http_client, prelude::*,
 19};
 20use language::BufferSnapshot;
 21use language::{Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
 22use language_model::{LlmApiToken, RefreshLlmTokenListener};
 23use project::Project;
 24use release_channel::AppVersion;
 25use std::collections::{HashMap, VecDeque, hash_map};
 26use std::path::PathBuf;
 27use std::str::FromStr as _;
 28use std::sync::Arc;
 29use std::time::{Duration, Instant};
 30use thiserror::Error;
 31use util::rel_path::RelPathBuf;
 32use util::some_or_debug_panic;
 33use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
 34
 35mod prediction;
 36mod provider;
 37
 38use crate::prediction::{EditPrediction, edits_from_response, interpolate_edits};
 39pub use provider::ZetaEditPredictionProvider;
 40
 41const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
 42
 43/// Maximum number of events to track.
 44const MAX_EVENT_COUNT: usize = 16;
 45
 46pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
 47    max_bytes: 512,
 48    min_bytes: 128,
 49    target_before_cursor_over_total_bytes: 0.5,
 50};
 51
 52pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
 53    excerpt: DEFAULT_EXCERPT_OPTIONS,
 54    max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
 55    max_diagnostic_bytes: 2048,
 56};
 57
 58#[derive(Clone)]
 59struct ZetaGlobal(Entity<Zeta>);
 60
 61impl Global for ZetaGlobal {}
 62
 63pub struct Zeta {
 64    client: Arc<Client>,
 65    user_store: Entity<UserStore>,
 66    llm_token: LlmApiToken,
 67    _llm_token_subscription: Subscription,
 68    projects: HashMap<EntityId, ZetaProject>,
 69    options: ZetaOptions,
 70    update_required: bool,
 71    debug_tx: Option<mpsc::UnboundedSender<Result<PredictionDebugInfo, String>>>,
 72}
 73
 74#[derive(Debug, Clone, PartialEq)]
 75pub struct ZetaOptions {
 76    pub excerpt: EditPredictionExcerptOptions,
 77    pub max_prompt_bytes: usize,
 78    pub max_diagnostic_bytes: usize,
 79}
 80
 81pub struct PredictionDebugInfo {
 82    pub context: EditPredictionContext,
 83    pub retrieval_time: TimeDelta,
 84    pub request: RequestDebugInfo,
 85    pub buffer: WeakEntity<Buffer>,
 86    pub position: language::Anchor,
 87}
 88
 89pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
 90
 91struct ZetaProject {
 92    syntax_index: Entity<SyntaxIndex>,
 93    events: VecDeque<Event>,
 94    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 95}
 96
 97struct RegisteredBuffer {
 98    snapshot: BufferSnapshot,
 99    _subscriptions: [gpui::Subscription; 2],
100}
101
102#[derive(Clone)]
103pub enum Event {
104    BufferChange {
105        old_snapshot: BufferSnapshot,
106        new_snapshot: BufferSnapshot,
107        timestamp: Instant,
108    },
109}
110
111impl Zeta {
112    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
113        cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
114    }
115
116    pub fn global(
117        client: &Arc<Client>,
118        user_store: &Entity<UserStore>,
119        cx: &mut App,
120    ) -> Entity<Self> {
121        cx.try_global::<ZetaGlobal>()
122            .map(|global| global.0.clone())
123            .unwrap_or_else(|| {
124                let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
125                cx.set_global(ZetaGlobal(zeta.clone()));
126                zeta
127            })
128    }
129
130    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
131        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
132
133        Self {
134            projects: HashMap::new(),
135            client,
136            user_store,
137            options: DEFAULT_OPTIONS,
138            llm_token: LlmApiToken::default(),
139            _llm_token_subscription: cx.subscribe(
140                &refresh_llm_token_listener,
141                |this, _listener, _event, cx| {
142                    let client = this.client.clone();
143                    let llm_token = this.llm_token.clone();
144                    cx.spawn(async move |_this, _cx| {
145                        llm_token.refresh(&client).await?;
146                        anyhow::Ok(())
147                    })
148                    .detach_and_log_err(cx);
149                },
150            ),
151            update_required: false,
152            debug_tx: None,
153        }
154    }
155
156    pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<Result<PredictionDebugInfo, String>> {
157        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
158        self.debug_tx = Some(debug_watch_tx);
159        debug_watch_rx
160    }
161
162    pub fn options(&self) -> &ZetaOptions {
163        &self.options
164    }
165
166    pub fn set_options(&mut self, options: ZetaOptions) {
167        self.options = options;
168    }
169
170    pub fn clear_history(&mut self) {
171        for zeta_project in self.projects.values_mut() {
172            zeta_project.events.clear();
173        }
174    }
175
176    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
177        self.user_store.read(cx).edit_prediction_usage()
178    }
179
180    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
181        self.get_or_init_zeta_project(project, cx);
182    }
183
184    pub fn register_buffer(
185        &mut self,
186        buffer: &Entity<Buffer>,
187        project: &Entity<Project>,
188        cx: &mut Context<Self>,
189    ) {
190        let zeta_project = self.get_or_init_zeta_project(project, cx);
191        Self::register_buffer_impl(zeta_project, buffer, project, cx);
192    }
193
194    fn get_or_init_zeta_project(
195        &mut self,
196        project: &Entity<Project>,
197        cx: &mut App,
198    ) -> &mut ZetaProject {
199        self.projects
200            .entry(project.entity_id())
201            .or_insert_with(|| ZetaProject {
202                syntax_index: cx.new(|cx| SyntaxIndex::new(project, cx)),
203                events: VecDeque::new(),
204                registered_buffers: HashMap::new(),
205            })
206    }
207
208    fn register_buffer_impl<'a>(
209        zeta_project: &'a mut ZetaProject,
210        buffer: &Entity<Buffer>,
211        project: &Entity<Project>,
212        cx: &mut Context<Self>,
213    ) -> &'a mut RegisteredBuffer {
214        let buffer_id = buffer.entity_id();
215        match zeta_project.registered_buffers.entry(buffer_id) {
216            hash_map::Entry::Occupied(entry) => entry.into_mut(),
217            hash_map::Entry::Vacant(entry) => {
218                let snapshot = buffer.read(cx).snapshot();
219                let project_entity_id = project.entity_id();
220                entry.insert(RegisteredBuffer {
221                    snapshot,
222                    _subscriptions: [
223                        cx.subscribe(buffer, {
224                            let project = project.downgrade();
225                            move |this, buffer, event, cx| {
226                                if let language::BufferEvent::Edited = event
227                                    && let Some(project) = project.upgrade()
228                                {
229                                    this.report_changes_for_buffer(&buffer, &project, cx);
230                                }
231                            }
232                        }),
233                        cx.observe_release(buffer, move |this, _buffer, _cx| {
234                            let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
235                            else {
236                                return;
237                            };
238                            zeta_project.registered_buffers.remove(&buffer_id);
239                        }),
240                    ],
241                })
242            }
243        }
244    }
245
246    fn report_changes_for_buffer(
247        &mut self,
248        buffer: &Entity<Buffer>,
249        project: &Entity<Project>,
250        cx: &mut Context<Self>,
251    ) -> BufferSnapshot {
252        let zeta_project = self.get_or_init_zeta_project(project, cx);
253        let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
254
255        let new_snapshot = buffer.read(cx).snapshot();
256        if new_snapshot.version != registered_buffer.snapshot.version {
257            let old_snapshot =
258                std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
259            Self::push_event(
260                zeta_project,
261                Event::BufferChange {
262                    old_snapshot,
263                    new_snapshot: new_snapshot.clone(),
264                    timestamp: Instant::now(),
265                },
266            );
267        }
268
269        new_snapshot
270    }
271
272    fn push_event(zeta_project: &mut ZetaProject, event: Event) {
273        let events = &mut zeta_project.events;
274
275        if let Some(Event::BufferChange {
276            new_snapshot: last_new_snapshot,
277            timestamp: last_timestamp,
278            ..
279        }) = events.back_mut()
280        {
281            // Coalesce edits for the same buffer when they happen one after the other.
282            let Event::BufferChange {
283                old_snapshot,
284                new_snapshot,
285                timestamp,
286            } = &event;
287
288            if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
289                && old_snapshot.remote_id() == last_new_snapshot.remote_id()
290                && old_snapshot.version == last_new_snapshot.version
291            {
292                *last_new_snapshot = new_snapshot.clone();
293                *last_timestamp = *timestamp;
294                return;
295            }
296        }
297
298        if events.len() >= MAX_EVENT_COUNT {
299            // These are halved instead of popping to improve prompt caching.
300            events.drain(..MAX_EVENT_COUNT / 2);
301        }
302
303        events.push_back(event);
304    }
305
306    pub fn request_prediction(
307        &mut self,
308        project: &Entity<Project>,
309        buffer: &Entity<Buffer>,
310        position: language::Anchor,
311        cx: &mut Context<Self>,
312    ) -> Task<Result<Option<EditPrediction>>> {
313        let project_state = self.projects.get(&project.entity_id());
314
315        let index_state = project_state.map(|state| {
316            state
317                .syntax_index
318                .read_with(cx, |index, _cx| index.state().clone())
319        });
320        let options = self.options.clone();
321        let snapshot = buffer.read(cx).snapshot();
322        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
323            return Task::ready(Err(anyhow!("No file path for excerpt")));
324        };
325        let client = self.client.clone();
326        let llm_token = self.llm_token.clone();
327        let app_version = AppVersion::global(cx);
328        let worktree_snapshots = project
329            .read(cx)
330            .worktrees(cx)
331            .map(|worktree| worktree.read(cx).snapshot())
332            .collect::<Vec<_>>();
333        let debug_tx = self.debug_tx.clone();
334
335        let events = project_state
336            .map(|state| {
337                state
338                    .events
339                    .iter()
340                    .map(|event| match event {
341                        Event::BufferChange {
342                            old_snapshot,
343                            new_snapshot,
344                            ..
345                        } => {
346                            let path = new_snapshot.file().map(|f| f.path().clone());
347
348                            let old_path = old_snapshot.file().and_then(|f| {
349                                let old_path = f.path();
350                                if Some(old_path) != path.as_ref() {
351                                    Some(old_path.clone())
352                                } else {
353                                    None
354                                }
355                            });
356
357                            predict_edits_v3::Event::BufferChange {
358                                old_path: old_path
359                                    .map(|old_path| old_path.as_std_path().to_path_buf()),
360                                path: path.map(|path| path.as_std_path().to_path_buf()),
361                                diff: language::unified_diff(
362                                    &old_snapshot.text(),
363                                    &new_snapshot.text(),
364                                ),
365                                //todo: Actually detect if this edit was predicted or not
366                                predicted: false,
367                            }
368                        }
369                    })
370                    .collect::<Vec<_>>()
371            })
372            .unwrap_or_default();
373
374        let diagnostics = snapshot.diagnostic_sets().clone();
375
376        let request_task = cx.background_spawn({
377            let snapshot = snapshot.clone();
378            let buffer = buffer.clone();
379            async move {
380                let index_state = if let Some(index_state) = index_state {
381                    Some(index_state.lock_owned().await)
382                } else {
383                    None
384                };
385
386                let cursor_offset = position.to_offset(&snapshot);
387                let cursor_point = cursor_offset.to_point(&snapshot);
388
389                let before_retrieval = chrono::Utc::now();
390
391                let Some(context) = EditPredictionContext::gather_context(
392                    cursor_point,
393                    &snapshot,
394                    &options.excerpt,
395                    index_state.as_deref(),
396                ) else {
397                    return Ok(None);
398                };
399
400                let debug_context = if let Some(debug_tx) = debug_tx {
401                    Some((debug_tx, context.clone()))
402                } else {
403                    None
404                };
405
406                let (diagnostic_groups, diagnostic_groups_truncated) =
407                    Self::gather_nearby_diagnostics(
408                        cursor_offset,
409                        &diagnostics,
410                        &snapshot,
411                        options.max_diagnostic_bytes,
412                    );
413
414                let request = make_cloud_request(
415                    excerpt_path.clone(),
416                    context,
417                    events,
418                    // TODO data collection
419                    false,
420                    diagnostic_groups,
421                    diagnostic_groups_truncated,
422                    None,
423                    debug_context.is_some(),
424                    &worktree_snapshots,
425                    index_state.as_deref(),
426                    Some(options.max_prompt_bytes),
427                );
428
429                let retrieval_time = chrono::Utc::now() - before_retrieval;
430                let response = Self::perform_request(client, llm_token, app_version, request).await;
431
432                if let Some((debug_tx, context)) = debug_context {
433                    debug_tx
434                        .unbounded_send(response.as_ref().map_err(|err| err.to_string()).and_then(
435                            |response| {
436                                let Some(request) =
437                                    some_or_debug_panic(response.0.debug_info.clone())
438                                else {
439                                    return Err("Missing debug info".to_string());
440                                };
441                                Ok(PredictionDebugInfo {
442                                    context,
443                                    request,
444                                    retrieval_time,
445                                    buffer: buffer.downgrade(),
446                                    position,
447                                })
448                            },
449                        ))
450                        .ok();
451                }
452
453                let (response, usage) = response?;
454                let edits = edits_from_response(&response.edits, &snapshot);
455
456                anyhow::Ok(Some((response.request_id, edits, usage)))
457            }
458        });
459
460        let buffer = buffer.clone();
461
462        cx.spawn(async move |this, cx| {
463            match request_task.await {
464                Ok(Some((id, edits, usage))) => {
465                    if let Some(usage) = usage {
466                        this.update(cx, |this, cx| {
467                            this.user_store.update(cx, |user_store, cx| {
468                                user_store.update_edit_prediction_usage(usage, cx);
469                            });
470                        })
471                        .ok();
472                    }
473
474                    // TODO telemetry: duration, etc
475                    let Some((edits, snapshot, edit_preview_task)) =
476                        buffer.read_with(cx, |buffer, cx| {
477                            let new_snapshot = buffer.snapshot();
478                            let edits: Arc<[_]> =
479                                interpolate_edits(&snapshot, &new_snapshot, edits)?.into();
480                            Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
481                        })?
482                    else {
483                        return Ok(None);
484                    };
485
486                    Ok(Some(EditPrediction {
487                        id: id.into(),
488                        edits,
489                        snapshot,
490                        edit_preview: edit_preview_task.await,
491                    }))
492                }
493                Ok(None) => Ok(None),
494                Err(err) => {
495                    if err.is::<ZedUpdateRequiredError>() {
496                        cx.update(|cx| {
497                            this.update(cx, |this, _cx| {
498                                this.update_required = true;
499                            })
500                            .ok();
501
502                            let error_message: SharedString = err.to_string().into();
503                            show_app_notification(
504                                NotificationId::unique::<ZedUpdateRequiredError>(),
505                                cx,
506                                move |cx| {
507                                    cx.new(|cx| {
508                                        ErrorMessagePrompt::new(error_message.clone(), cx)
509                                            .with_link_button(
510                                                "Update Zed",
511                                                "https://zed.dev/releases",
512                                            )
513                                    })
514                                },
515                            );
516                        })
517                        .ok();
518                    }
519
520                    Err(err)
521                }
522            }
523        })
524    }
525
526    async fn perform_request(
527        client: Arc<Client>,
528        llm_token: LlmApiToken,
529        app_version: SemanticVersion,
530        request: predict_edits_v3::PredictEditsRequest,
531    ) -> Result<(
532        predict_edits_v3::PredictEditsResponse,
533        Option<EditPredictionUsage>,
534    )> {
535        let http_client = client.http_client();
536        let mut token = llm_token.acquire(&client).await?;
537        let mut did_retry = false;
538
539        loop {
540            let request_builder = http_client::Request::builder().method(Method::POST);
541            let request_builder =
542                if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
543                    request_builder.uri(predict_edits_url)
544                } else {
545                    request_builder.uri(
546                        http_client
547                            .build_zed_llm_url("/predict_edits/v3", &[])?
548                            .as_ref(),
549                    )
550                };
551            let request = request_builder
552                .header("Content-Type", "application/json")
553                .header("Authorization", format!("Bearer {}", token))
554                .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
555                .body(serde_json::to_string(&request)?.into())?;
556
557            let mut response = http_client.send(request).await?;
558
559            if let Some(minimum_required_version) = response
560                .headers()
561                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
562                .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
563            {
564                anyhow::ensure!(
565                    app_version >= minimum_required_version,
566                    ZedUpdateRequiredError {
567                        minimum_version: minimum_required_version
568                    }
569                );
570            }
571
572            if response.status().is_success() {
573                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
574
575                let mut body = Vec::new();
576                response.body_mut().read_to_end(&mut body).await?;
577                return Ok((serde_json::from_slice(&body)?, usage));
578            } else if !did_retry
579                && response
580                    .headers()
581                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
582                    .is_some()
583            {
584                did_retry = true;
585                token = llm_token.refresh(&client).await?;
586            } else {
587                let mut body = String::new();
588                response.body_mut().read_to_string(&mut body).await?;
589                anyhow::bail!(
590                    "error predicting edits.\nStatus: {:?}\nBody: {}",
591                    response.status(),
592                    body
593                );
594            }
595        }
596    }
597
598    fn gather_nearby_diagnostics(
599        cursor_offset: usize,
600        diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
601        snapshot: &BufferSnapshot,
602        max_diagnostics_bytes: usize,
603    ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
604        // TODO: Could make this more efficient
605        let mut diagnostic_groups = Vec::new();
606        for (language_server_id, diagnostics) in diagnostic_sets {
607            let mut groups = Vec::new();
608            diagnostics.groups(*language_server_id, &mut groups, &snapshot);
609            diagnostic_groups.extend(
610                groups
611                    .into_iter()
612                    .map(|(_, group)| group.resolve::<usize>(&snapshot)),
613            );
614        }
615
616        // sort by proximity to cursor
617        diagnostic_groups.sort_by_key(|group| {
618            let range = &group.entries[group.primary_ix].range;
619            if range.start >= cursor_offset {
620                range.start - cursor_offset
621            } else if cursor_offset >= range.end {
622                cursor_offset - range.end
623            } else {
624                (cursor_offset - range.start).min(range.end - cursor_offset)
625            }
626        });
627
628        let mut results = Vec::new();
629        let mut diagnostic_groups_truncated = false;
630        let mut diagnostics_byte_count = 0;
631        for group in diagnostic_groups {
632            let raw_value = serde_json::value::to_raw_value(&group).unwrap();
633            diagnostics_byte_count += raw_value.get().len();
634            if diagnostics_byte_count > max_diagnostics_bytes {
635                diagnostic_groups_truncated = true;
636                break;
637            }
638            results.push(predict_edits_v3::DiagnosticGroup(raw_value));
639        }
640
641        (results, diagnostic_groups_truncated)
642    }
643
644    // TODO: Dedupe with similar code in request_prediction?
645    pub fn cloud_request_for_zeta_cli(
646        &mut self,
647        project: &Entity<Project>,
648        buffer: &Entity<Buffer>,
649        position: language::Anchor,
650        cx: &mut Context<Self>,
651    ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
652        let project_state = self.projects.get(&project.entity_id());
653
654        let index_state = project_state.map(|state| {
655            state
656                .syntax_index
657                .read_with(cx, |index, _cx| index.state().clone())
658        });
659        let options = self.options.clone();
660        let snapshot = buffer.read(cx).snapshot();
661        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
662            return Task::ready(Err(anyhow!("No file path for excerpt")));
663        };
664        let worktree_snapshots = project
665            .read(cx)
666            .worktrees(cx)
667            .map(|worktree| worktree.read(cx).snapshot())
668            .collect::<Vec<_>>();
669
670        cx.background_spawn(async move {
671            let index_state = if let Some(index_state) = index_state {
672                Some(index_state.lock_owned().await)
673            } else {
674                None
675            };
676
677            let cursor_point = position.to_point(&snapshot);
678
679            let debug_info = true;
680            EditPredictionContext::gather_context(
681                cursor_point,
682                &snapshot,
683                &options.excerpt,
684                index_state.as_deref(),
685            )
686            .context("Failed to select excerpt")
687            .map(|context| {
688                make_cloud_request(
689                    excerpt_path.clone(),
690                    context,
691                    // TODO pass everything
692                    Vec::new(),
693                    false,
694                    Vec::new(),
695                    false,
696                    None,
697                    debug_info,
698                    &worktree_snapshots,
699                    index_state.as_deref(),
700                    Some(options.max_prompt_bytes),
701                )
702            })
703        })
704    }
705}
706
707#[derive(Error, Debug)]
708#[error(
709    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
710)]
711pub struct ZedUpdateRequiredError {
712    minimum_version: SemanticVersion,
713}
714
715fn make_cloud_request(
716    excerpt_path: PathBuf,
717    context: EditPredictionContext,
718    events: Vec<predict_edits_v3::Event>,
719    can_collect_data: bool,
720    diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
721    diagnostic_groups_truncated: bool,
722    git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
723    debug_info: bool,
724    worktrees: &Vec<worktree::Snapshot>,
725    index_state: Option<&SyntaxIndexState>,
726    prompt_max_bytes: Option<usize>,
727) -> predict_edits_v3::PredictEditsRequest {
728    let mut signatures = Vec::new();
729    let mut declaration_to_signature_index = HashMap::default();
730    let mut referenced_declarations = Vec::new();
731
732    for snippet in context.snippets {
733        let project_entry_id = snippet.declaration.project_entry_id();
734        let Some(path) = worktrees.iter().find_map(|worktree| {
735            worktree.entry_for_id(project_entry_id).map(|entry| {
736                let mut full_path = RelPathBuf::new();
737                full_path.push(worktree.root_name());
738                full_path.push(&entry.path);
739                full_path
740            })
741        }) else {
742            continue;
743        };
744
745        let parent_index = index_state.and_then(|index_state| {
746            snippet.declaration.parent().and_then(|parent| {
747                add_signature(
748                    parent,
749                    &mut declaration_to_signature_index,
750                    &mut signatures,
751                    index_state,
752                )
753            })
754        });
755
756        let (text, text_is_truncated) = snippet.declaration.item_text();
757        referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
758            path: path.as_std_path().to_path_buf(),
759            text: text.into(),
760            range: snippet.declaration.item_range(),
761            text_is_truncated,
762            signature_range: snippet.declaration.signature_range_in_item_text(),
763            parent_index,
764            score_components: snippet.score_components,
765            signature_score: snippet.scores.signature,
766            declaration_score: snippet.scores.declaration,
767        });
768    }
769
770    let excerpt_parent = index_state.and_then(|index_state| {
771        context
772            .excerpt
773            .parent_declarations
774            .last()
775            .and_then(|(parent, _)| {
776                add_signature(
777                    *parent,
778                    &mut declaration_to_signature_index,
779                    &mut signatures,
780                    index_state,
781                )
782            })
783    });
784
785    predict_edits_v3::PredictEditsRequest {
786        excerpt_path,
787        excerpt: context.excerpt_text.body,
788        excerpt_range: context.excerpt.range,
789        cursor_offset: context.cursor_offset_in_excerpt,
790        referenced_declarations,
791        signatures,
792        excerpt_parent,
793        events,
794        can_collect_data,
795        diagnostic_groups,
796        diagnostic_groups_truncated,
797        git_info,
798        debug_info,
799        prompt_max_bytes,
800    }
801}
802
803fn add_signature(
804    declaration_id: DeclarationId,
805    declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
806    signatures: &mut Vec<Signature>,
807    index: &SyntaxIndexState,
808) -> Option<usize> {
809    if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
810        return Some(*signature_index);
811    }
812    let Some(parent_declaration) = index.declaration(declaration_id) else {
813        log::error!("bug: missing parent declaration");
814        return None;
815    };
816    let parent_index = parent_declaration.parent().and_then(|parent| {
817        add_signature(parent, declaration_to_signature_index, signatures, index)
818    });
819    let (text, text_is_truncated) = parent_declaration.signature_text();
820    let signature_index = signatures.len();
821    signatures.push(Signature {
822        text: text.into(),
823        text_is_truncated,
824        parent_index,
825        range: parent_declaration.signature_range(),
826    });
827    declaration_to_signature_index.insert(declaration_id, signature_index);
828    Some(signature_index)
829}