zeta2.rs

  1use anyhow::{Context as _, Result, anyhow};
  2use chrono::TimeDelta;
  3use client::{Client, EditPredictionUsage, UserStore};
  4use cloud_llm_client::predict_edits_v3::{self, Signature};
  5use cloud_llm_client::{
  6    EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, ZED_VERSION_HEADER_NAME,
  7};
  8use cloud_zeta2_prompt::DEFAULT_MAX_PROMPT_BYTES;
  9use edit_prediction_context::{
 10    DeclarationId, EditPredictionContext, EditPredictionExcerptOptions, SyntaxIndex,
 11    SyntaxIndexState,
 12};
 13use futures::AsyncReadExt as _;
 14use futures::channel::mpsc;
 15use gpui::http_client::Method;
 16use gpui::{
 17    App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
 18    http_client, prelude::*,
 19};
 20use language::BufferSnapshot;
 21use language::{Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
 22use language_model::{LlmApiToken, RefreshLlmTokenListener};
 23use project::Project;
 24use release_channel::AppVersion;
 25use std::collections::{HashMap, VecDeque, hash_map};
 26use std::path::PathBuf;
 27use std::str::FromStr as _;
 28use std::sync::Arc;
 29use std::time::{Duration, Instant};
 30use thiserror::Error;
 31use util::some_or_debug_panic;
 32use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
 33
 34mod prediction;
 35mod provider;
 36
 37use crate::prediction::{EditPrediction, edits_from_response, interpolate_edits};
 38pub use provider::ZetaEditPredictionProvider;
 39
 40const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
 41
 42/// Maximum number of events to track.
 43const MAX_EVENT_COUNT: usize = 16;
 44
 45pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
 46    max_bytes: 512,
 47    min_bytes: 128,
 48    target_before_cursor_over_total_bytes: 0.5,
 49};
 50
 51pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
 52    excerpt: DEFAULT_EXCERPT_OPTIONS,
 53    max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
 54    max_diagnostic_bytes: 2048,
 55};
 56
 57#[derive(Clone)]
 58struct ZetaGlobal(Entity<Zeta>);
 59
 60impl Global for ZetaGlobal {}
 61
 62pub struct Zeta {
 63    client: Arc<Client>,
 64    user_store: Entity<UserStore>,
 65    llm_token: LlmApiToken,
 66    _llm_token_subscription: Subscription,
 67    projects: HashMap<EntityId, ZetaProject>,
 68    options: ZetaOptions,
 69    update_required: bool,
 70    debug_tx: Option<mpsc::UnboundedSender<Result<PredictionDebugInfo, String>>>,
 71}
 72
 73#[derive(Debug, Clone, PartialEq)]
 74pub struct ZetaOptions {
 75    pub excerpt: EditPredictionExcerptOptions,
 76    pub max_prompt_bytes: usize,
 77    pub max_diagnostic_bytes: usize,
 78}
 79
 80pub struct PredictionDebugInfo {
 81    pub context: EditPredictionContext,
 82    pub retrieval_time: TimeDelta,
 83    pub request: RequestDebugInfo,
 84    pub buffer: WeakEntity<Buffer>,
 85    pub position: language::Anchor,
 86}
 87
 88pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
 89
 90struct ZetaProject {
 91    syntax_index: Entity<SyntaxIndex>,
 92    events: VecDeque<Event>,
 93    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 94}
 95
 96struct RegisteredBuffer {
 97    snapshot: BufferSnapshot,
 98    _subscriptions: [gpui::Subscription; 2],
 99}
100
101#[derive(Clone)]
102pub enum Event {
103    BufferChange {
104        old_snapshot: BufferSnapshot,
105        new_snapshot: BufferSnapshot,
106        timestamp: Instant,
107    },
108}
109
110impl Zeta {
111    pub fn global(
112        client: &Arc<Client>,
113        user_store: &Entity<UserStore>,
114        cx: &mut App,
115    ) -> Entity<Self> {
116        cx.try_global::<ZetaGlobal>()
117            .map(|global| global.0.clone())
118            .unwrap_or_else(|| {
119                let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
120                cx.set_global(ZetaGlobal(zeta.clone()));
121                zeta
122            })
123    }
124
125    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
126        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
127
128        Self {
129            projects: HashMap::new(),
130            client,
131            user_store,
132            options: DEFAULT_OPTIONS,
133            llm_token: LlmApiToken::default(),
134            _llm_token_subscription: cx.subscribe(
135                &refresh_llm_token_listener,
136                |this, _listener, _event, cx| {
137                    let client = this.client.clone();
138                    let llm_token = this.llm_token.clone();
139                    cx.spawn(async move |_this, _cx| {
140                        llm_token.refresh(&client).await?;
141                        anyhow::Ok(())
142                    })
143                    .detach_and_log_err(cx);
144                },
145            ),
146            update_required: false,
147            debug_tx: None,
148        }
149    }
150
151    pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<Result<PredictionDebugInfo, String>> {
152        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
153        self.debug_tx = Some(debug_watch_tx);
154        debug_watch_rx
155    }
156
157    pub fn options(&self) -> &ZetaOptions {
158        &self.options
159    }
160
161    pub fn set_options(&mut self, options: ZetaOptions) {
162        self.options = options;
163    }
164
165    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
166        self.user_store.read(cx).edit_prediction_usage()
167    }
168
169    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
170        self.get_or_init_zeta_project(project, cx);
171    }
172
173    pub fn register_buffer(
174        &mut self,
175        buffer: &Entity<Buffer>,
176        project: &Entity<Project>,
177        cx: &mut Context<Self>,
178    ) {
179        let zeta_project = self.get_or_init_zeta_project(project, cx);
180        Self::register_buffer_impl(zeta_project, buffer, project, cx);
181    }
182
183    fn get_or_init_zeta_project(
184        &mut self,
185        project: &Entity<Project>,
186        cx: &mut App,
187    ) -> &mut ZetaProject {
188        self.projects
189            .entry(project.entity_id())
190            .or_insert_with(|| ZetaProject {
191                syntax_index: cx.new(|cx| SyntaxIndex::new(project, cx)),
192                events: VecDeque::new(),
193                registered_buffers: HashMap::new(),
194            })
195    }
196
197    fn register_buffer_impl<'a>(
198        zeta_project: &'a mut ZetaProject,
199        buffer: &Entity<Buffer>,
200        project: &Entity<Project>,
201        cx: &mut Context<Self>,
202    ) -> &'a mut RegisteredBuffer {
203        let buffer_id = buffer.entity_id();
204        match zeta_project.registered_buffers.entry(buffer_id) {
205            hash_map::Entry::Occupied(entry) => entry.into_mut(),
206            hash_map::Entry::Vacant(entry) => {
207                let snapshot = buffer.read(cx).snapshot();
208                let project_entity_id = project.entity_id();
209                entry.insert(RegisteredBuffer {
210                    snapshot,
211                    _subscriptions: [
212                        cx.subscribe(buffer, {
213                            let project = project.downgrade();
214                            move |this, buffer, event, cx| {
215                                if let language::BufferEvent::Edited = event
216                                    && let Some(project) = project.upgrade()
217                                {
218                                    this.report_changes_for_buffer(&buffer, &project, cx);
219                                }
220                            }
221                        }),
222                        cx.observe_release(buffer, move |this, _buffer, _cx| {
223                            let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
224                            else {
225                                return;
226                            };
227                            zeta_project.registered_buffers.remove(&buffer_id);
228                        }),
229                    ],
230                })
231            }
232        }
233    }
234
235    fn report_changes_for_buffer(
236        &mut self,
237        buffer: &Entity<Buffer>,
238        project: &Entity<Project>,
239        cx: &mut Context<Self>,
240    ) -> BufferSnapshot {
241        let zeta_project = self.get_or_init_zeta_project(project, cx);
242        let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
243
244        let new_snapshot = buffer.read(cx).snapshot();
245        if new_snapshot.version != registered_buffer.snapshot.version {
246            let old_snapshot =
247                std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
248            Self::push_event(
249                zeta_project,
250                Event::BufferChange {
251                    old_snapshot,
252                    new_snapshot: new_snapshot.clone(),
253                    timestamp: Instant::now(),
254                },
255            );
256        }
257
258        new_snapshot
259    }
260
261    fn push_event(zeta_project: &mut ZetaProject, event: Event) {
262        let events = &mut zeta_project.events;
263
264        if let Some(Event::BufferChange {
265            new_snapshot: last_new_snapshot,
266            timestamp: last_timestamp,
267            ..
268        }) = events.back_mut()
269        {
270            // Coalesce edits for the same buffer when they happen one after the other.
271            let Event::BufferChange {
272                old_snapshot,
273                new_snapshot,
274                timestamp,
275            } = &event;
276
277            if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
278                && old_snapshot.remote_id() == last_new_snapshot.remote_id()
279                && old_snapshot.version == last_new_snapshot.version
280            {
281                *last_new_snapshot = new_snapshot.clone();
282                *last_timestamp = *timestamp;
283                return;
284            }
285        }
286
287        if events.len() >= MAX_EVENT_COUNT {
288            // These are halved instead of popping to improve prompt caching.
289            events.drain(..MAX_EVENT_COUNT / 2);
290        }
291
292        events.push_back(event);
293    }
294
295    pub fn request_prediction(
296        &mut self,
297        project: &Entity<Project>,
298        buffer: &Entity<Buffer>,
299        position: language::Anchor,
300        cx: &mut Context<Self>,
301    ) -> Task<Result<Option<EditPrediction>>> {
302        let project_state = self.projects.get(&project.entity_id());
303
304        let index_state = project_state.map(|state| {
305            state
306                .syntax_index
307                .read_with(cx, |index, _cx| index.state().clone())
308        });
309        let options = self.options.clone();
310        let snapshot = buffer.read(cx).snapshot();
311        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
312            return Task::ready(Err(anyhow!("No file path for excerpt")));
313        };
314        let client = self.client.clone();
315        let llm_token = self.llm_token.clone();
316        let app_version = AppVersion::global(cx);
317        let worktree_snapshots = project
318            .read(cx)
319            .worktrees(cx)
320            .map(|worktree| worktree.read(cx).snapshot())
321            .collect::<Vec<_>>();
322        let debug_tx = self.debug_tx.clone();
323
324        let events = project_state
325            .map(|state| {
326                state
327                    .events
328                    .iter()
329                    .map(|event| match event {
330                        Event::BufferChange {
331                            old_snapshot,
332                            new_snapshot,
333                            ..
334                        } => {
335                            let path = new_snapshot.file().map(|f| f.path().to_path_buf());
336
337                            let old_path = old_snapshot.file().and_then(|f| {
338                                let old_path = f.path().as_ref();
339                                if Some(old_path) != path.as_deref() {
340                                    Some(old_path.to_path_buf())
341                                } else {
342                                    None
343                                }
344                            });
345
346                            predict_edits_v3::Event::BufferChange {
347                                old_path,
348                                path,
349                                diff: language::unified_diff(
350                                    &old_snapshot.text(),
351                                    &new_snapshot.text(),
352                                ),
353                                //todo: Actually detect if this edit was predicted or not
354                                predicted: false,
355                            }
356                        }
357                    })
358                    .collect::<Vec<_>>()
359            })
360            .unwrap_or_default();
361
362        let diagnostics = snapshot.diagnostic_sets().clone();
363
364        let request_task = cx.background_spawn({
365            let snapshot = snapshot.clone();
366            let buffer = buffer.clone();
367            async move {
368                let index_state = if let Some(index_state) = index_state {
369                    Some(index_state.lock_owned().await)
370                } else {
371                    None
372                };
373
374                let cursor_offset = position.to_offset(&snapshot);
375                let cursor_point = cursor_offset.to_point(&snapshot);
376
377                let before_retrieval = chrono::Utc::now();
378
379                let Some(context) = EditPredictionContext::gather_context(
380                    cursor_point,
381                    &snapshot,
382                    &options.excerpt,
383                    index_state.as_deref(),
384                ) else {
385                    return Ok(None);
386                };
387
388                let debug_context = if let Some(debug_tx) = debug_tx {
389                    Some((debug_tx, context.clone()))
390                } else {
391                    None
392                };
393
394                let (diagnostic_groups, diagnostic_groups_truncated) =
395                    Self::gather_nearby_diagnostics(
396                        cursor_offset,
397                        &diagnostics,
398                        &snapshot,
399                        options.max_diagnostic_bytes,
400                    );
401
402                let request = make_cloud_request(
403                    excerpt_path.clone(),
404                    context,
405                    events,
406                    // TODO data collection
407                    false,
408                    diagnostic_groups,
409                    diagnostic_groups_truncated,
410                    None,
411                    debug_context.is_some(),
412                    &worktree_snapshots,
413                    index_state.as_deref(),
414                    Some(options.max_prompt_bytes),
415                );
416
417                let retrieval_time = chrono::Utc::now() - before_retrieval;
418                let response = Self::perform_request(client, llm_token, app_version, request).await;
419
420                if let Some((debug_tx, context)) = debug_context {
421                    debug_tx
422                        .unbounded_send(response.as_ref().map_err(|err| err.to_string()).and_then(
423                            |response| {
424                                let Some(request) =
425                                    some_or_debug_panic(response.0.debug_info.clone())
426                                else {
427                                    return Err("Missing debug info".to_string());
428                                };
429                                Ok(PredictionDebugInfo {
430                                    context,
431                                    request,
432                                    retrieval_time,
433                                    buffer: buffer.downgrade(),
434                                    position,
435                                })
436                            },
437                        ))
438                        .ok();
439                }
440
441                let (response, usage) = response?;
442                let edits = edits_from_response(&response.edits, &snapshot);
443
444                anyhow::Ok(Some((response.request_id, edits, usage)))
445            }
446        });
447
448        let buffer = buffer.clone();
449
450        cx.spawn(async move |this, cx| {
451            match request_task.await {
452                Ok(Some((id, edits, usage))) => {
453                    if let Some(usage) = usage {
454                        this.update(cx, |this, cx| {
455                            this.user_store.update(cx, |user_store, cx| {
456                                user_store.update_edit_prediction_usage(usage, cx);
457                            });
458                        })
459                        .ok();
460                    }
461
462                    // TODO telemetry: duration, etc
463                    let Some((edits, snapshot, edit_preview_task)) =
464                        buffer.read_with(cx, |buffer, cx| {
465                            let new_snapshot = buffer.snapshot();
466                            let edits: Arc<[_]> =
467                                interpolate_edits(&snapshot, &new_snapshot, edits)?.into();
468                            Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
469                        })?
470                    else {
471                        return Ok(None);
472                    };
473
474                    Ok(Some(EditPrediction {
475                        id: id.into(),
476                        edits,
477                        snapshot,
478                        edit_preview: edit_preview_task.await,
479                    }))
480                }
481                Ok(None) => Ok(None),
482                Err(err) => {
483                    if err.is::<ZedUpdateRequiredError>() {
484                        cx.update(|cx| {
485                            this.update(cx, |this, _cx| {
486                                this.update_required = true;
487                            })
488                            .ok();
489
490                            let error_message: SharedString = err.to_string().into();
491                            show_app_notification(
492                                NotificationId::unique::<ZedUpdateRequiredError>(),
493                                cx,
494                                move |cx| {
495                                    cx.new(|cx| {
496                                        ErrorMessagePrompt::new(error_message.clone(), cx)
497                                            .with_link_button(
498                                                "Update Zed",
499                                                "https://zed.dev/releases",
500                                            )
501                                    })
502                                },
503                            );
504                        })
505                        .ok();
506                    }
507
508                    Err(err)
509                }
510            }
511        })
512    }
513
514    async fn perform_request(
515        client: Arc<Client>,
516        llm_token: LlmApiToken,
517        app_version: SemanticVersion,
518        request: predict_edits_v3::PredictEditsRequest,
519    ) -> Result<(
520        predict_edits_v3::PredictEditsResponse,
521        Option<EditPredictionUsage>,
522    )> {
523        let http_client = client.http_client();
524        let mut token = llm_token.acquire(&client).await?;
525        let mut did_retry = false;
526
527        loop {
528            let request_builder = http_client::Request::builder().method(Method::POST);
529            let request_builder =
530                if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
531                    request_builder.uri(predict_edits_url)
532                } else {
533                    request_builder.uri(
534                        http_client
535                            .build_zed_llm_url("/predict_edits/v3", &[])?
536                            .as_ref(),
537                    )
538                };
539            let request = request_builder
540                .header("Content-Type", "application/json")
541                .header("Authorization", format!("Bearer {}", token))
542                .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
543                .body(serde_json::to_string(&request)?.into())?;
544
545            let mut response = http_client.send(request).await?;
546
547            if let Some(minimum_required_version) = response
548                .headers()
549                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
550                .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
551            {
552                anyhow::ensure!(
553                    app_version >= minimum_required_version,
554                    ZedUpdateRequiredError {
555                        minimum_version: minimum_required_version
556                    }
557                );
558            }
559
560            if response.status().is_success() {
561                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
562
563                let mut body = Vec::new();
564                response.body_mut().read_to_end(&mut body).await?;
565                return Ok((serde_json::from_slice(&body)?, usage));
566            } else if !did_retry
567                && response
568                    .headers()
569                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
570                    .is_some()
571            {
572                did_retry = true;
573                token = llm_token.refresh(&client).await?;
574            } else {
575                let mut body = String::new();
576                response.body_mut().read_to_string(&mut body).await?;
577                anyhow::bail!(
578                    "error predicting edits.\nStatus: {:?}\nBody: {}",
579                    response.status(),
580                    body
581                );
582            }
583        }
584    }
585
586    fn gather_nearby_diagnostics(
587        cursor_offset: usize,
588        diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
589        snapshot: &BufferSnapshot,
590        max_diagnostics_bytes: usize,
591    ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
592        // TODO: Could make this more efficient
593        let mut diagnostic_groups = Vec::new();
594        for (language_server_id, diagnostics) in diagnostic_sets {
595            let mut groups = Vec::new();
596            diagnostics.groups(*language_server_id, &mut groups, &snapshot);
597            diagnostic_groups.extend(
598                groups
599                    .into_iter()
600                    .map(|(_, group)| group.resolve::<usize>(&snapshot)),
601            );
602        }
603
604        // sort by proximity to cursor
605        diagnostic_groups.sort_by_key(|group| {
606            let range = &group.entries[group.primary_ix].range;
607            if range.start >= cursor_offset {
608                range.start - cursor_offset
609            } else if cursor_offset >= range.end {
610                cursor_offset - range.end
611            } else {
612                (cursor_offset - range.start).min(range.end - cursor_offset)
613            }
614        });
615
616        let mut results = Vec::new();
617        let mut diagnostic_groups_truncated = false;
618        let mut diagnostics_byte_count = 0;
619        for group in diagnostic_groups {
620            let raw_value = serde_json::value::to_raw_value(&group).unwrap();
621            diagnostics_byte_count += raw_value.get().len();
622            if diagnostics_byte_count > max_diagnostics_bytes {
623                diagnostic_groups_truncated = true;
624                break;
625            }
626            results.push(predict_edits_v3::DiagnosticGroup(raw_value));
627        }
628
629        (results, diagnostic_groups_truncated)
630    }
631
632    // TODO: Dedupe with similar code in request_prediction?
633    pub fn cloud_request_for_zeta_cli(
634        &mut self,
635        project: &Entity<Project>,
636        buffer: &Entity<Buffer>,
637        position: language::Anchor,
638        cx: &mut Context<Self>,
639    ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
640        let project_state = self.projects.get(&project.entity_id());
641
642        let index_state = project_state.map(|state| {
643            state
644                .syntax_index
645                .read_with(cx, |index, _cx| index.state().clone())
646        });
647        let options = self.options.clone();
648        let snapshot = buffer.read(cx).snapshot();
649        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
650            return Task::ready(Err(anyhow!("No file path for excerpt")));
651        };
652        let worktree_snapshots = project
653            .read(cx)
654            .worktrees(cx)
655            .map(|worktree| worktree.read(cx).snapshot())
656            .collect::<Vec<_>>();
657
658        cx.background_spawn(async move {
659            let index_state = if let Some(index_state) = index_state {
660                Some(index_state.lock_owned().await)
661            } else {
662                None
663            };
664
665            let cursor_point = position.to_point(&snapshot);
666
667            let debug_info = true;
668            EditPredictionContext::gather_context(
669                cursor_point,
670                &snapshot,
671                &options.excerpt,
672                index_state.as_deref(),
673            )
674            .context("Failed to select excerpt")
675            .map(|context| {
676                make_cloud_request(
677                    excerpt_path.clone(),
678                    context,
679                    // TODO pass everything
680                    Vec::new(),
681                    false,
682                    Vec::new(),
683                    false,
684                    None,
685                    debug_info,
686                    &worktree_snapshots,
687                    index_state.as_deref(),
688                    Some(options.max_prompt_bytes),
689                )
690            })
691        })
692    }
693}
694
695#[derive(Error, Debug)]
696#[error(
697    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
698)]
699pub struct ZedUpdateRequiredError {
700    minimum_version: SemanticVersion,
701}
702
703fn make_cloud_request(
704    excerpt_path: PathBuf,
705    context: EditPredictionContext,
706    events: Vec<predict_edits_v3::Event>,
707    can_collect_data: bool,
708    diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
709    diagnostic_groups_truncated: bool,
710    git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
711    debug_info: bool,
712    worktrees: &Vec<worktree::Snapshot>,
713    index_state: Option<&SyntaxIndexState>,
714    prompt_max_bytes: Option<usize>,
715) -> predict_edits_v3::PredictEditsRequest {
716    let mut signatures = Vec::new();
717    let mut declaration_to_signature_index = HashMap::default();
718    let mut referenced_declarations = Vec::new();
719
720    for snippet in context.snippets {
721        let project_entry_id = snippet.declaration.project_entry_id();
722        let Some(path) = worktrees.iter().find_map(|worktree| {
723            worktree.entry_for_id(project_entry_id).map(|entry| {
724                let mut full_path = PathBuf::new();
725                full_path.push(worktree.root_name());
726                full_path.push(&entry.path);
727                full_path
728            })
729        }) else {
730            continue;
731        };
732
733        let parent_index = index_state.and_then(|index_state| {
734            snippet.declaration.parent().and_then(|parent| {
735                add_signature(
736                    parent,
737                    &mut declaration_to_signature_index,
738                    &mut signatures,
739                    index_state,
740                )
741            })
742        });
743
744        let (text, text_is_truncated) = snippet.declaration.item_text();
745        referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
746            path,
747            text: text.into(),
748            range: snippet.declaration.item_range(),
749            text_is_truncated,
750            signature_range: snippet.declaration.signature_range_in_item_text(),
751            parent_index,
752            score_components: snippet.score_components,
753            signature_score: snippet.scores.signature,
754            declaration_score: snippet.scores.declaration,
755        });
756    }
757
758    let excerpt_parent = index_state.and_then(|index_state| {
759        context
760            .excerpt
761            .parent_declarations
762            .last()
763            .and_then(|(parent, _)| {
764                add_signature(
765                    *parent,
766                    &mut declaration_to_signature_index,
767                    &mut signatures,
768                    index_state,
769                )
770            })
771    });
772
773    predict_edits_v3::PredictEditsRequest {
774        excerpt_path,
775        excerpt: context.excerpt_text.body,
776        excerpt_range: context.excerpt.range,
777        cursor_offset: context.cursor_offset_in_excerpt,
778        referenced_declarations,
779        signatures,
780        excerpt_parent,
781        events,
782        can_collect_data,
783        diagnostic_groups,
784        diagnostic_groups_truncated,
785        git_info,
786        debug_info,
787        prompt_max_bytes,
788    }
789}
790
791fn add_signature(
792    declaration_id: DeclarationId,
793    declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
794    signatures: &mut Vec<Signature>,
795    index: &SyntaxIndexState,
796) -> Option<usize> {
797    if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
798        return Some(*signature_index);
799    }
800    let Some(parent_declaration) = index.declaration(declaration_id) else {
801        log::error!("bug: missing parent declaration");
802        return None;
803    };
804    let parent_index = parent_declaration.parent().and_then(|parent| {
805        add_signature(parent, declaration_to_signature_index, signatures, index)
806    });
807    let (text, text_is_truncated) = parent_declaration.signature_text();
808    let signature_index = signatures.len();
809    signatures.push(Signature {
810        text: text.into(),
811        text_is_truncated,
812        parent_index,
813        range: parent_declaration.signature_range(),
814    });
815    declaration_to_signature_index.insert(declaration_id, signature_index);
816    Some(signature_index)
817}