zeta2.rs

  1use anyhow::{Context as _, Result, anyhow};
  2use chrono::TimeDelta;
  3use client::{Client, EditPredictionUsage, UserStore};
  4use cloud_llm_client::predict_edits_v3::{self, Signature};
  5use cloud_llm_client::{
  6    EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, ZED_VERSION_HEADER_NAME,
  7};
  8use cloud_zeta2_prompt::DEFAULT_MAX_PROMPT_BYTES;
  9use edit_prediction_context::{
 10    DeclarationId, EditPredictionContext, EditPredictionExcerptOptions, SyntaxIndex,
 11    SyntaxIndexState,
 12};
 13use futures::AsyncReadExt as _;
 14use futures::channel::mpsc;
 15use gpui::http_client::Method;
 16use gpui::{
 17    App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
 18    http_client, prelude::*,
 19};
 20use language::BufferSnapshot;
 21use language::{Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
 22use language_model::{LlmApiToken, RefreshLlmTokenListener};
 23use project::Project;
 24use release_channel::AppVersion;
 25use std::collections::{HashMap, VecDeque, hash_map};
 26use std::path::PathBuf;
 27use std::str::FromStr as _;
 28use std::sync::Arc;
 29use std::time::{Duration, Instant};
 30use thiserror::Error;
 31use util::some_or_debug_panic;
 32use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
 33
 34mod prediction;
 35mod provider;
 36
 37use crate::prediction::{EditPrediction, edits_from_response, interpolate_edits};
 38pub use provider::ZetaEditPredictionProvider;
 39
 40const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
 41
 42/// Maximum number of events to track.
 43const MAX_EVENT_COUNT: usize = 16;
 44
 45pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
 46    max_bytes: 512,
 47    min_bytes: 128,
 48    target_before_cursor_over_total_bytes: 0.5,
 49};
 50
 51pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
 52    excerpt: DEFAULT_EXCERPT_OPTIONS,
 53    max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
 54    max_diagnostic_bytes: 2048,
 55};
 56
 57#[derive(Clone)]
 58struct ZetaGlobal(Entity<Zeta>);
 59
 60impl Global for ZetaGlobal {}
 61
 62pub struct Zeta {
 63    client: Arc<Client>,
 64    user_store: Entity<UserStore>,
 65    llm_token: LlmApiToken,
 66    _llm_token_subscription: Subscription,
 67    projects: HashMap<EntityId, ZetaProject>,
 68    options: ZetaOptions,
 69    update_required: bool,
 70    debug_tx: Option<mpsc::UnboundedSender<Result<PredictionDebugInfo, String>>>,
 71}
 72
 73#[derive(Debug, Clone, PartialEq)]
 74pub struct ZetaOptions {
 75    pub excerpt: EditPredictionExcerptOptions,
 76    pub max_prompt_bytes: usize,
 77    pub max_diagnostic_bytes: usize,
 78}
 79
 80pub struct PredictionDebugInfo {
 81    pub context: EditPredictionContext,
 82    pub retrieval_time: TimeDelta,
 83    pub request: RequestDebugInfo,
 84    pub buffer: WeakEntity<Buffer>,
 85    pub position: language::Anchor,
 86}
 87
 88pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
 89
 90struct ZetaProject {
 91    syntax_index: Entity<SyntaxIndex>,
 92    events: VecDeque<Event>,
 93    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 94}
 95
 96struct RegisteredBuffer {
 97    snapshot: BufferSnapshot,
 98    _subscriptions: [gpui::Subscription; 2],
 99}
100
101#[derive(Clone)]
102pub enum Event {
103    BufferChange {
104        old_snapshot: BufferSnapshot,
105        new_snapshot: BufferSnapshot,
106        timestamp: Instant,
107    },
108}
109
110impl Zeta {
111    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
112        cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
113    }
114
115    pub fn global(
116        client: &Arc<Client>,
117        user_store: &Entity<UserStore>,
118        cx: &mut App,
119    ) -> Entity<Self> {
120        cx.try_global::<ZetaGlobal>()
121            .map(|global| global.0.clone())
122            .unwrap_or_else(|| {
123                let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
124                cx.set_global(ZetaGlobal(zeta.clone()));
125                zeta
126            })
127    }
128
129    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
130        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
131
132        Self {
133            projects: HashMap::new(),
134            client,
135            user_store,
136            options: DEFAULT_OPTIONS,
137            llm_token: LlmApiToken::default(),
138            _llm_token_subscription: cx.subscribe(
139                &refresh_llm_token_listener,
140                |this, _listener, _event, cx| {
141                    let client = this.client.clone();
142                    let llm_token = this.llm_token.clone();
143                    cx.spawn(async move |_this, _cx| {
144                        llm_token.refresh(&client).await?;
145                        anyhow::Ok(())
146                    })
147                    .detach_and_log_err(cx);
148                },
149            ),
150            update_required: false,
151            debug_tx: None,
152        }
153    }
154
155    pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<Result<PredictionDebugInfo, String>> {
156        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
157        self.debug_tx = Some(debug_watch_tx);
158        debug_watch_rx
159    }
160
161    pub fn options(&self) -> &ZetaOptions {
162        &self.options
163    }
164
165    pub fn set_options(&mut self, options: ZetaOptions) {
166        self.options = options;
167    }
168
169    pub fn clear_history(&mut self) {
170        for zeta_project in self.projects.values_mut() {
171            zeta_project.events.clear();
172        }
173    }
174
175    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
176        self.user_store.read(cx).edit_prediction_usage()
177    }
178
179    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
180        self.get_or_init_zeta_project(project, cx);
181    }
182
183    pub fn register_buffer(
184        &mut self,
185        buffer: &Entity<Buffer>,
186        project: &Entity<Project>,
187        cx: &mut Context<Self>,
188    ) {
189        let zeta_project = self.get_or_init_zeta_project(project, cx);
190        Self::register_buffer_impl(zeta_project, buffer, project, cx);
191    }
192
193    fn get_or_init_zeta_project(
194        &mut self,
195        project: &Entity<Project>,
196        cx: &mut App,
197    ) -> &mut ZetaProject {
198        self.projects
199            .entry(project.entity_id())
200            .or_insert_with(|| ZetaProject {
201                syntax_index: cx.new(|cx| SyntaxIndex::new(project, cx)),
202                events: VecDeque::new(),
203                registered_buffers: HashMap::new(),
204            })
205    }
206
207    fn register_buffer_impl<'a>(
208        zeta_project: &'a mut ZetaProject,
209        buffer: &Entity<Buffer>,
210        project: &Entity<Project>,
211        cx: &mut Context<Self>,
212    ) -> &'a mut RegisteredBuffer {
213        let buffer_id = buffer.entity_id();
214        match zeta_project.registered_buffers.entry(buffer_id) {
215            hash_map::Entry::Occupied(entry) => entry.into_mut(),
216            hash_map::Entry::Vacant(entry) => {
217                let snapshot = buffer.read(cx).snapshot();
218                let project_entity_id = project.entity_id();
219                entry.insert(RegisteredBuffer {
220                    snapshot,
221                    _subscriptions: [
222                        cx.subscribe(buffer, {
223                            let project = project.downgrade();
224                            move |this, buffer, event, cx| {
225                                if let language::BufferEvent::Edited = event
226                                    && let Some(project) = project.upgrade()
227                                {
228                                    this.report_changes_for_buffer(&buffer, &project, cx);
229                                }
230                            }
231                        }),
232                        cx.observe_release(buffer, move |this, _buffer, _cx| {
233                            let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
234                            else {
235                                return;
236                            };
237                            zeta_project.registered_buffers.remove(&buffer_id);
238                        }),
239                    ],
240                })
241            }
242        }
243    }
244
245    fn report_changes_for_buffer(
246        &mut self,
247        buffer: &Entity<Buffer>,
248        project: &Entity<Project>,
249        cx: &mut Context<Self>,
250    ) -> BufferSnapshot {
251        let zeta_project = self.get_or_init_zeta_project(project, cx);
252        let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
253
254        let new_snapshot = buffer.read(cx).snapshot();
255        if new_snapshot.version != registered_buffer.snapshot.version {
256            let old_snapshot =
257                std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
258            Self::push_event(
259                zeta_project,
260                Event::BufferChange {
261                    old_snapshot,
262                    new_snapshot: new_snapshot.clone(),
263                    timestamp: Instant::now(),
264                },
265            );
266        }
267
268        new_snapshot
269    }
270
271    fn push_event(zeta_project: &mut ZetaProject, event: Event) {
272        let events = &mut zeta_project.events;
273
274        if let Some(Event::BufferChange {
275            new_snapshot: last_new_snapshot,
276            timestamp: last_timestamp,
277            ..
278        }) = events.back_mut()
279        {
280            // Coalesce edits for the same buffer when they happen one after the other.
281            let Event::BufferChange {
282                old_snapshot,
283                new_snapshot,
284                timestamp,
285            } = &event;
286
287            if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
288                && old_snapshot.remote_id() == last_new_snapshot.remote_id()
289                && old_snapshot.version == last_new_snapshot.version
290            {
291                *last_new_snapshot = new_snapshot.clone();
292                *last_timestamp = *timestamp;
293                return;
294            }
295        }
296
297        if events.len() >= MAX_EVENT_COUNT {
298            // These are halved instead of popping to improve prompt caching.
299            events.drain(..MAX_EVENT_COUNT / 2);
300        }
301
302        events.push_back(event);
303    }
304
305    pub fn request_prediction(
306        &mut self,
307        project: &Entity<Project>,
308        buffer: &Entity<Buffer>,
309        position: language::Anchor,
310        cx: &mut Context<Self>,
311    ) -> Task<Result<Option<EditPrediction>>> {
312        let project_state = self.projects.get(&project.entity_id());
313
314        let index_state = project_state.map(|state| {
315            state
316                .syntax_index
317                .read_with(cx, |index, _cx| index.state().clone())
318        });
319        let options = self.options.clone();
320        let snapshot = buffer.read(cx).snapshot();
321        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
322            return Task::ready(Err(anyhow!("No file path for excerpt")));
323        };
324        let client = self.client.clone();
325        let llm_token = self.llm_token.clone();
326        let app_version = AppVersion::global(cx);
327        let worktree_snapshots = project
328            .read(cx)
329            .worktrees(cx)
330            .map(|worktree| worktree.read(cx).snapshot())
331            .collect::<Vec<_>>();
332        let debug_tx = self.debug_tx.clone();
333
334        let events = project_state
335            .map(|state| {
336                state
337                    .events
338                    .iter()
339                    .map(|event| match event {
340                        Event::BufferChange {
341                            old_snapshot,
342                            new_snapshot,
343                            ..
344                        } => {
345                            let path = new_snapshot.file().map(|f| f.path().to_path_buf());
346
347                            let old_path = old_snapshot.file().and_then(|f| {
348                                let old_path = f.path().as_ref();
349                                if Some(old_path) != path.as_deref() {
350                                    Some(old_path.to_path_buf())
351                                } else {
352                                    None
353                                }
354                            });
355
356                            predict_edits_v3::Event::BufferChange {
357                                old_path,
358                                path,
359                                diff: language::unified_diff(
360                                    &old_snapshot.text(),
361                                    &new_snapshot.text(),
362                                ),
363                                //todo: Actually detect if this edit was predicted or not
364                                predicted: false,
365                            }
366                        }
367                    })
368                    .collect::<Vec<_>>()
369            })
370            .unwrap_or_default();
371
372        let diagnostics = snapshot.diagnostic_sets().clone();
373
374        let request_task = cx.background_spawn({
375            let snapshot = snapshot.clone();
376            let buffer = buffer.clone();
377            async move {
378                let index_state = if let Some(index_state) = index_state {
379                    Some(index_state.lock_owned().await)
380                } else {
381                    None
382                };
383
384                let cursor_offset = position.to_offset(&snapshot);
385                let cursor_point = cursor_offset.to_point(&snapshot);
386
387                let before_retrieval = chrono::Utc::now();
388
389                let Some(context) = EditPredictionContext::gather_context(
390                    cursor_point,
391                    &snapshot,
392                    &options.excerpt,
393                    index_state.as_deref(),
394                ) else {
395                    return Ok(None);
396                };
397
398                let debug_context = if let Some(debug_tx) = debug_tx {
399                    Some((debug_tx, context.clone()))
400                } else {
401                    None
402                };
403
404                let (diagnostic_groups, diagnostic_groups_truncated) =
405                    Self::gather_nearby_diagnostics(
406                        cursor_offset,
407                        &diagnostics,
408                        &snapshot,
409                        options.max_diagnostic_bytes,
410                    );
411
412                let request = make_cloud_request(
413                    excerpt_path.clone(),
414                    context,
415                    events,
416                    // TODO data collection
417                    false,
418                    diagnostic_groups,
419                    diagnostic_groups_truncated,
420                    None,
421                    debug_context.is_some(),
422                    &worktree_snapshots,
423                    index_state.as_deref(),
424                    Some(options.max_prompt_bytes),
425                );
426
427                let retrieval_time = chrono::Utc::now() - before_retrieval;
428                let response = Self::perform_request(client, llm_token, app_version, request).await;
429
430                if let Some((debug_tx, context)) = debug_context {
431                    debug_tx
432                        .unbounded_send(response.as_ref().map_err(|err| err.to_string()).and_then(
433                            |response| {
434                                let Some(request) =
435                                    some_or_debug_panic(response.0.debug_info.clone())
436                                else {
437                                    return Err("Missing debug info".to_string());
438                                };
439                                Ok(PredictionDebugInfo {
440                                    context,
441                                    request,
442                                    retrieval_time,
443                                    buffer: buffer.downgrade(),
444                                    position,
445                                })
446                            },
447                        ))
448                        .ok();
449                }
450
451                let (response, usage) = response?;
452                let edits = edits_from_response(&response.edits, &snapshot);
453
454                anyhow::Ok(Some((response.request_id, edits, usage)))
455            }
456        });
457
458        let buffer = buffer.clone();
459
460        cx.spawn(async move |this, cx| {
461            match request_task.await {
462                Ok(Some((id, edits, usage))) => {
463                    if let Some(usage) = usage {
464                        this.update(cx, |this, cx| {
465                            this.user_store.update(cx, |user_store, cx| {
466                                user_store.update_edit_prediction_usage(usage, cx);
467                            });
468                        })
469                        .ok();
470                    }
471
472                    // TODO telemetry: duration, etc
473                    let Some((edits, snapshot, edit_preview_task)) =
474                        buffer.read_with(cx, |buffer, cx| {
475                            let new_snapshot = buffer.snapshot();
476                            let edits: Arc<[_]> =
477                                interpolate_edits(&snapshot, &new_snapshot, edits)?.into();
478                            Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
479                        })?
480                    else {
481                        return Ok(None);
482                    };
483
484                    Ok(Some(EditPrediction {
485                        id: id.into(),
486                        edits,
487                        snapshot,
488                        edit_preview: edit_preview_task.await,
489                    }))
490                }
491                Ok(None) => Ok(None),
492                Err(err) => {
493                    if err.is::<ZedUpdateRequiredError>() {
494                        cx.update(|cx| {
495                            this.update(cx, |this, _cx| {
496                                this.update_required = true;
497                            })
498                            .ok();
499
500                            let error_message: SharedString = err.to_string().into();
501                            show_app_notification(
502                                NotificationId::unique::<ZedUpdateRequiredError>(),
503                                cx,
504                                move |cx| {
505                                    cx.new(|cx| {
506                                        ErrorMessagePrompt::new(error_message.clone(), cx)
507                                            .with_link_button(
508                                                "Update Zed",
509                                                "https://zed.dev/releases",
510                                            )
511                                    })
512                                },
513                            );
514                        })
515                        .ok();
516                    }
517
518                    Err(err)
519                }
520            }
521        })
522    }
523
524    async fn perform_request(
525        client: Arc<Client>,
526        llm_token: LlmApiToken,
527        app_version: SemanticVersion,
528        request: predict_edits_v3::PredictEditsRequest,
529    ) -> Result<(
530        predict_edits_v3::PredictEditsResponse,
531        Option<EditPredictionUsage>,
532    )> {
533        let http_client = client.http_client();
534        let mut token = llm_token.acquire(&client).await?;
535        let mut did_retry = false;
536
537        loop {
538            let request_builder = http_client::Request::builder().method(Method::POST);
539            let request_builder =
540                if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
541                    request_builder.uri(predict_edits_url)
542                } else {
543                    request_builder.uri(
544                        http_client
545                            .build_zed_llm_url("/predict_edits/v3", &[])?
546                            .as_ref(),
547                    )
548                };
549            let request = request_builder
550                .header("Content-Type", "application/json")
551                .header("Authorization", format!("Bearer {}", token))
552                .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
553                .body(serde_json::to_string(&request)?.into())?;
554
555            let mut response = http_client.send(request).await?;
556
557            if let Some(minimum_required_version) = response
558                .headers()
559                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
560                .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
561            {
562                anyhow::ensure!(
563                    app_version >= minimum_required_version,
564                    ZedUpdateRequiredError {
565                        minimum_version: minimum_required_version
566                    }
567                );
568            }
569
570            if response.status().is_success() {
571                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
572
573                let mut body = Vec::new();
574                response.body_mut().read_to_end(&mut body).await?;
575                return Ok((serde_json::from_slice(&body)?, usage));
576            } else if !did_retry
577                && response
578                    .headers()
579                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
580                    .is_some()
581            {
582                did_retry = true;
583                token = llm_token.refresh(&client).await?;
584            } else {
585                let mut body = String::new();
586                response.body_mut().read_to_string(&mut body).await?;
587                anyhow::bail!(
588                    "error predicting edits.\nStatus: {:?}\nBody: {}",
589                    response.status(),
590                    body
591                );
592            }
593        }
594    }
595
596    fn gather_nearby_diagnostics(
597        cursor_offset: usize,
598        diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
599        snapshot: &BufferSnapshot,
600        max_diagnostics_bytes: usize,
601    ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
602        // TODO: Could make this more efficient
603        let mut diagnostic_groups = Vec::new();
604        for (language_server_id, diagnostics) in diagnostic_sets {
605            let mut groups = Vec::new();
606            diagnostics.groups(*language_server_id, &mut groups, &snapshot);
607            diagnostic_groups.extend(
608                groups
609                    .into_iter()
610                    .map(|(_, group)| group.resolve::<usize>(&snapshot)),
611            );
612        }
613
614        // sort by proximity to cursor
615        diagnostic_groups.sort_by_key(|group| {
616            let range = &group.entries[group.primary_ix].range;
617            if range.start >= cursor_offset {
618                range.start - cursor_offset
619            } else if cursor_offset >= range.end {
620                cursor_offset - range.end
621            } else {
622                (cursor_offset - range.start).min(range.end - cursor_offset)
623            }
624        });
625
626        let mut results = Vec::new();
627        let mut diagnostic_groups_truncated = false;
628        let mut diagnostics_byte_count = 0;
629        for group in diagnostic_groups {
630            let raw_value = serde_json::value::to_raw_value(&group).unwrap();
631            diagnostics_byte_count += raw_value.get().len();
632            if diagnostics_byte_count > max_diagnostics_bytes {
633                diagnostic_groups_truncated = true;
634                break;
635            }
636            results.push(predict_edits_v3::DiagnosticGroup(raw_value));
637        }
638
639        (results, diagnostic_groups_truncated)
640    }
641
642    // TODO: Dedupe with similar code in request_prediction?
643    pub fn cloud_request_for_zeta_cli(
644        &mut self,
645        project: &Entity<Project>,
646        buffer: &Entity<Buffer>,
647        position: language::Anchor,
648        cx: &mut Context<Self>,
649    ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
650        let project_state = self.projects.get(&project.entity_id());
651
652        let index_state = project_state.map(|state| {
653            state
654                .syntax_index
655                .read_with(cx, |index, _cx| index.state().clone())
656        });
657        let options = self.options.clone();
658        let snapshot = buffer.read(cx).snapshot();
659        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
660            return Task::ready(Err(anyhow!("No file path for excerpt")));
661        };
662        let worktree_snapshots = project
663            .read(cx)
664            .worktrees(cx)
665            .map(|worktree| worktree.read(cx).snapshot())
666            .collect::<Vec<_>>();
667
668        cx.background_spawn(async move {
669            let index_state = if let Some(index_state) = index_state {
670                Some(index_state.lock_owned().await)
671            } else {
672                None
673            };
674
675            let cursor_point = position.to_point(&snapshot);
676
677            let debug_info = true;
678            EditPredictionContext::gather_context(
679                cursor_point,
680                &snapshot,
681                &options.excerpt,
682                index_state.as_deref(),
683            )
684            .context("Failed to select excerpt")
685            .map(|context| {
686                make_cloud_request(
687                    excerpt_path.clone(),
688                    context,
689                    // TODO pass everything
690                    Vec::new(),
691                    false,
692                    Vec::new(),
693                    false,
694                    None,
695                    debug_info,
696                    &worktree_snapshots,
697                    index_state.as_deref(),
698                    Some(options.max_prompt_bytes),
699                )
700            })
701        })
702    }
703}
704
705#[derive(Error, Debug)]
706#[error(
707    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
708)]
709pub struct ZedUpdateRequiredError {
710    minimum_version: SemanticVersion,
711}
712
713fn make_cloud_request(
714    excerpt_path: PathBuf,
715    context: EditPredictionContext,
716    events: Vec<predict_edits_v3::Event>,
717    can_collect_data: bool,
718    diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
719    diagnostic_groups_truncated: bool,
720    git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
721    debug_info: bool,
722    worktrees: &Vec<worktree::Snapshot>,
723    index_state: Option<&SyntaxIndexState>,
724    prompt_max_bytes: Option<usize>,
725) -> predict_edits_v3::PredictEditsRequest {
726    let mut signatures = Vec::new();
727    let mut declaration_to_signature_index = HashMap::default();
728    let mut referenced_declarations = Vec::new();
729
730    for snippet in context.snippets {
731        let project_entry_id = snippet.declaration.project_entry_id();
732        let Some(path) = worktrees.iter().find_map(|worktree| {
733            worktree.entry_for_id(project_entry_id).map(|entry| {
734                let mut full_path = PathBuf::new();
735                full_path.push(worktree.root_name());
736                full_path.push(&entry.path);
737                full_path
738            })
739        }) else {
740            continue;
741        };
742
743        let parent_index = index_state.and_then(|index_state| {
744            snippet.declaration.parent().and_then(|parent| {
745                add_signature(
746                    parent,
747                    &mut declaration_to_signature_index,
748                    &mut signatures,
749                    index_state,
750                )
751            })
752        });
753
754        let (text, text_is_truncated) = snippet.declaration.item_text();
755        referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
756            path,
757            text: text.into(),
758            range: snippet.declaration.item_range(),
759            text_is_truncated,
760            signature_range: snippet.declaration.signature_range_in_item_text(),
761            parent_index,
762            score_components: snippet.score_components,
763            signature_score: snippet.scores.signature,
764            declaration_score: snippet.scores.declaration,
765        });
766    }
767
768    let excerpt_parent = index_state.and_then(|index_state| {
769        context
770            .excerpt
771            .parent_declarations
772            .last()
773            .and_then(|(parent, _)| {
774                add_signature(
775                    *parent,
776                    &mut declaration_to_signature_index,
777                    &mut signatures,
778                    index_state,
779                )
780            })
781    });
782
783    predict_edits_v3::PredictEditsRequest {
784        excerpt_path,
785        excerpt: context.excerpt_text.body,
786        excerpt_range: context.excerpt.range,
787        cursor_offset: context.cursor_offset_in_excerpt,
788        referenced_declarations,
789        signatures,
790        excerpt_parent,
791        events,
792        can_collect_data,
793        diagnostic_groups,
794        diagnostic_groups_truncated,
795        git_info,
796        debug_info,
797        prompt_max_bytes,
798    }
799}
800
801fn add_signature(
802    declaration_id: DeclarationId,
803    declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
804    signatures: &mut Vec<Signature>,
805    index: &SyntaxIndexState,
806) -> Option<usize> {
807    if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
808        return Some(*signature_index);
809    }
810    let Some(parent_declaration) = index.declaration(declaration_id) else {
811        log::error!("bug: missing parent declaration");
812        return None;
813    };
814    let parent_index = parent_declaration.parent().and_then(|parent| {
815        add_signature(parent, declaration_to_signature_index, signatures, index)
816    });
817    let (text, text_is_truncated) = parent_declaration.signature_text();
818    let signature_index = signatures.len();
819    signatures.push(Signature {
820        text: text.into(),
821        text_is_truncated,
822        parent_index,
823        range: parent_declaration.signature_range(),
824    });
825    declaration_to_signature_index.insert(declaration_id, signature_index);
826    Some(signature_index)
827}