zeta2.rs

  1use anyhow::{Context as _, Result, anyhow};
  2use chrono::TimeDelta;
  3use client::{Client, EditPredictionUsage, UserStore};
  4use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
  5use cloud_llm_client::{
  6    EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, ZED_VERSION_HEADER_NAME,
  7};
  8use cloud_zeta2_prompt::DEFAULT_MAX_PROMPT_BYTES;
  9use edit_prediction_context::{
 10    DeclarationId, EditPredictionContext, EditPredictionExcerptOptions, SyntaxIndex,
 11    SyntaxIndexState,
 12};
 13use futures::AsyncReadExt as _;
 14use futures::channel::mpsc;
 15use gpui::http_client::Method;
 16use gpui::{
 17    App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
 18    http_client, prelude::*,
 19};
 20use language::BufferSnapshot;
 21use language::{Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
 22use language_model::{LlmApiToken, RefreshLlmTokenListener};
 23use project::Project;
 24use release_channel::AppVersion;
 25use std::collections::{HashMap, VecDeque, hash_map};
 26use std::path::Path;
 27use std::str::FromStr as _;
 28use std::sync::Arc;
 29use std::time::{Duration, Instant};
 30use thiserror::Error;
 31use util::rel_path::RelPathBuf;
 32use util::some_or_debug_panic;
 33use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
 34
 35mod prediction;
 36mod provider;
 37
 38use crate::prediction::{EditPrediction, edits_from_response, interpolate_edits};
 39pub use provider::ZetaEditPredictionProvider;
 40
 41const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
 42
 43/// Maximum number of events to track.
 44const MAX_EVENT_COUNT: usize = 16;
 45
 46pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
 47    max_bytes: 512,
 48    min_bytes: 128,
 49    target_before_cursor_over_total_bytes: 0.5,
 50};
 51
 52pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
 53    excerpt: DEFAULT_EXCERPT_OPTIONS,
 54    max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
 55    max_diagnostic_bytes: 2048,
 56    prompt_format: PromptFormat::MarkedExcerpt,
 57};
 58
 59#[derive(Clone)]
 60struct ZetaGlobal(Entity<Zeta>);
 61
 62impl Global for ZetaGlobal {}
 63
 64pub struct Zeta {
 65    client: Arc<Client>,
 66    user_store: Entity<UserStore>,
 67    llm_token: LlmApiToken,
 68    _llm_token_subscription: Subscription,
 69    projects: HashMap<EntityId, ZetaProject>,
 70    options: ZetaOptions,
 71    update_required: bool,
 72    debug_tx: Option<mpsc::UnboundedSender<Result<PredictionDebugInfo, String>>>,
 73}
 74
 75#[derive(Debug, Clone, PartialEq)]
 76pub struct ZetaOptions {
 77    pub excerpt: EditPredictionExcerptOptions,
 78    pub max_prompt_bytes: usize,
 79    pub max_diagnostic_bytes: usize,
 80    pub prompt_format: predict_edits_v3::PromptFormat,
 81}
 82
 83pub struct PredictionDebugInfo {
 84    pub context: EditPredictionContext,
 85    pub retrieval_time: TimeDelta,
 86    pub request: RequestDebugInfo,
 87    pub buffer: WeakEntity<Buffer>,
 88    pub position: language::Anchor,
 89}
 90
 91pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
 92
 93struct ZetaProject {
 94    syntax_index: Entity<SyntaxIndex>,
 95    events: VecDeque<Event>,
 96    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 97}
 98
 99struct RegisteredBuffer {
100    snapshot: BufferSnapshot,
101    _subscriptions: [gpui::Subscription; 2],
102}
103
104#[derive(Clone)]
105pub enum Event {
106    BufferChange {
107        old_snapshot: BufferSnapshot,
108        new_snapshot: BufferSnapshot,
109        timestamp: Instant,
110    },
111}
112
113impl Zeta {
114    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
115        cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
116    }
117
118    pub fn global(
119        client: &Arc<Client>,
120        user_store: &Entity<UserStore>,
121        cx: &mut App,
122    ) -> Entity<Self> {
123        cx.try_global::<ZetaGlobal>()
124            .map(|global| global.0.clone())
125            .unwrap_or_else(|| {
126                let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
127                cx.set_global(ZetaGlobal(zeta.clone()));
128                zeta
129            })
130    }
131
132    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
133        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
134
135        Self {
136            projects: HashMap::new(),
137            client,
138            user_store,
139            options: DEFAULT_OPTIONS,
140            llm_token: LlmApiToken::default(),
141            _llm_token_subscription: cx.subscribe(
142                &refresh_llm_token_listener,
143                |this, _listener, _event, cx| {
144                    let client = this.client.clone();
145                    let llm_token = this.llm_token.clone();
146                    cx.spawn(async move |_this, _cx| {
147                        llm_token.refresh(&client).await?;
148                        anyhow::Ok(())
149                    })
150                    .detach_and_log_err(cx);
151                },
152            ),
153            update_required: false,
154            debug_tx: None,
155        }
156    }
157
158    pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<Result<PredictionDebugInfo, String>> {
159        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
160        self.debug_tx = Some(debug_watch_tx);
161        debug_watch_rx
162    }
163
164    pub fn options(&self) -> &ZetaOptions {
165        &self.options
166    }
167
168    pub fn set_options(&mut self, options: ZetaOptions) {
169        self.options = options;
170    }
171
172    pub fn clear_history(&mut self) {
173        for zeta_project in self.projects.values_mut() {
174            zeta_project.events.clear();
175        }
176    }
177
178    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
179        self.user_store.read(cx).edit_prediction_usage()
180    }
181
182    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
183        self.get_or_init_zeta_project(project, cx);
184    }
185
186    pub fn register_buffer(
187        &mut self,
188        buffer: &Entity<Buffer>,
189        project: &Entity<Project>,
190        cx: &mut Context<Self>,
191    ) {
192        let zeta_project = self.get_or_init_zeta_project(project, cx);
193        Self::register_buffer_impl(zeta_project, buffer, project, cx);
194    }
195
196    fn get_or_init_zeta_project(
197        &mut self,
198        project: &Entity<Project>,
199        cx: &mut App,
200    ) -> &mut ZetaProject {
201        self.projects
202            .entry(project.entity_id())
203            .or_insert_with(|| ZetaProject {
204                syntax_index: cx.new(|cx| SyntaxIndex::new(project, cx)),
205                events: VecDeque::new(),
206                registered_buffers: HashMap::new(),
207            })
208    }
209
210    fn register_buffer_impl<'a>(
211        zeta_project: &'a mut ZetaProject,
212        buffer: &Entity<Buffer>,
213        project: &Entity<Project>,
214        cx: &mut Context<Self>,
215    ) -> &'a mut RegisteredBuffer {
216        let buffer_id = buffer.entity_id();
217        match zeta_project.registered_buffers.entry(buffer_id) {
218            hash_map::Entry::Occupied(entry) => entry.into_mut(),
219            hash_map::Entry::Vacant(entry) => {
220                let snapshot = buffer.read(cx).snapshot();
221                let project_entity_id = project.entity_id();
222                entry.insert(RegisteredBuffer {
223                    snapshot,
224                    _subscriptions: [
225                        cx.subscribe(buffer, {
226                            let project = project.downgrade();
227                            move |this, buffer, event, cx| {
228                                if let language::BufferEvent::Edited = event
229                                    && let Some(project) = project.upgrade()
230                                {
231                                    this.report_changes_for_buffer(&buffer, &project, cx);
232                                }
233                            }
234                        }),
235                        cx.observe_release(buffer, move |this, _buffer, _cx| {
236                            let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
237                            else {
238                                return;
239                            };
240                            zeta_project.registered_buffers.remove(&buffer_id);
241                        }),
242                    ],
243                })
244            }
245        }
246    }
247
248    fn report_changes_for_buffer(
249        &mut self,
250        buffer: &Entity<Buffer>,
251        project: &Entity<Project>,
252        cx: &mut Context<Self>,
253    ) -> BufferSnapshot {
254        let zeta_project = self.get_or_init_zeta_project(project, cx);
255        let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
256
257        let new_snapshot = buffer.read(cx).snapshot();
258        if new_snapshot.version != registered_buffer.snapshot.version {
259            let old_snapshot =
260                std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
261            Self::push_event(
262                zeta_project,
263                Event::BufferChange {
264                    old_snapshot,
265                    new_snapshot: new_snapshot.clone(),
266                    timestamp: Instant::now(),
267                },
268            );
269        }
270
271        new_snapshot
272    }
273
274    fn push_event(zeta_project: &mut ZetaProject, event: Event) {
275        let events = &mut zeta_project.events;
276
277        if let Some(Event::BufferChange {
278            new_snapshot: last_new_snapshot,
279            timestamp: last_timestamp,
280            ..
281        }) = events.back_mut()
282        {
283            // Coalesce edits for the same buffer when they happen one after the other.
284            let Event::BufferChange {
285                old_snapshot,
286                new_snapshot,
287                timestamp,
288            } = &event;
289
290            if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
291                && old_snapshot.remote_id() == last_new_snapshot.remote_id()
292                && old_snapshot.version == last_new_snapshot.version
293            {
294                *last_new_snapshot = new_snapshot.clone();
295                *last_timestamp = *timestamp;
296                return;
297            }
298        }
299
300        if events.len() >= MAX_EVENT_COUNT {
301            // These are halved instead of popping to improve prompt caching.
302            events.drain(..MAX_EVENT_COUNT / 2);
303        }
304
305        events.push_back(event);
306    }
307
308    pub fn request_prediction(
309        &mut self,
310        project: &Entity<Project>,
311        buffer: &Entity<Buffer>,
312        position: language::Anchor,
313        cx: &mut Context<Self>,
314    ) -> Task<Result<Option<EditPrediction>>> {
315        let project_state = self.projects.get(&project.entity_id());
316
317        let index_state = project_state.map(|state| {
318            state
319                .syntax_index
320                .read_with(cx, |index, _cx| index.state().clone())
321        });
322        let options = self.options.clone();
323        let snapshot = buffer.read(cx).snapshot();
324        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx).into()) else {
325            return Task::ready(Err(anyhow!("No file path for excerpt")));
326        };
327        let client = self.client.clone();
328        let llm_token = self.llm_token.clone();
329        let app_version = AppVersion::global(cx);
330        let worktree_snapshots = project
331            .read(cx)
332            .worktrees(cx)
333            .map(|worktree| worktree.read(cx).snapshot())
334            .collect::<Vec<_>>();
335        let debug_tx = self.debug_tx.clone();
336
337        let events = project_state
338            .map(|state| {
339                state
340                    .events
341                    .iter()
342                    .map(|event| match event {
343                        Event::BufferChange {
344                            old_snapshot,
345                            new_snapshot,
346                            ..
347                        } => {
348                            let path = new_snapshot.file().map(|f| f.path().clone());
349
350                            let old_path = old_snapshot.file().and_then(|f| {
351                                let old_path = f.path();
352                                if Some(old_path) != path.as_ref() {
353                                    Some(old_path.clone())
354                                } else {
355                                    None
356                                }
357                            });
358
359                            predict_edits_v3::Event::BufferChange {
360                                old_path: old_path
361                                    .map(|old_path| old_path.as_std_path().to_path_buf()),
362                                path: path.map(|path| path.as_std_path().to_path_buf()),
363                                diff: language::unified_diff(
364                                    &old_snapshot.text(),
365                                    &new_snapshot.text(),
366                                ),
367                                //todo: Actually detect if this edit was predicted or not
368                                predicted: false,
369                            }
370                        }
371                    })
372                    .collect::<Vec<_>>()
373            })
374            .unwrap_or_default();
375
376        let diagnostics = snapshot.diagnostic_sets().clone();
377
378        let request_task = cx.background_spawn({
379            let snapshot = snapshot.clone();
380            let buffer = buffer.clone();
381            async move {
382                let index_state = if let Some(index_state) = index_state {
383                    Some(index_state.lock_owned().await)
384                } else {
385                    None
386                };
387
388                let cursor_offset = position.to_offset(&snapshot);
389                let cursor_point = cursor_offset.to_point(&snapshot);
390
391                let before_retrieval = chrono::Utc::now();
392
393                let Some(context) = EditPredictionContext::gather_context(
394                    cursor_point,
395                    &snapshot,
396                    &options.excerpt,
397                    index_state.as_deref(),
398                ) else {
399                    return Ok(None);
400                };
401
402                let debug_context = if let Some(debug_tx) = debug_tx {
403                    Some((debug_tx, context.clone()))
404                } else {
405                    None
406                };
407
408                let (diagnostic_groups, diagnostic_groups_truncated) =
409                    Self::gather_nearby_diagnostics(
410                        cursor_offset,
411                        &diagnostics,
412                        &snapshot,
413                        options.max_diagnostic_bytes,
414                    );
415
416                let request = make_cloud_request(
417                    excerpt_path,
418                    context,
419                    events,
420                    // TODO data collection
421                    false,
422                    diagnostic_groups,
423                    diagnostic_groups_truncated,
424                    None,
425                    debug_context.is_some(),
426                    &worktree_snapshots,
427                    index_state.as_deref(),
428                    Some(options.max_prompt_bytes),
429                    options.prompt_format,
430                );
431
432                let retrieval_time = chrono::Utc::now() - before_retrieval;
433                let response = Self::perform_request(client, llm_token, app_version, request).await;
434
435                if let Some((debug_tx, context)) = debug_context {
436                    debug_tx
437                        .unbounded_send(response.as_ref().map_err(|err| err.to_string()).and_then(
438                            |response| {
439                                let Some(request) =
440                                    some_or_debug_panic(response.0.debug_info.clone())
441                                else {
442                                    return Err("Missing debug info".to_string());
443                                };
444                                Ok(PredictionDebugInfo {
445                                    context,
446                                    request,
447                                    retrieval_time,
448                                    buffer: buffer.downgrade(),
449                                    position,
450                                })
451                            },
452                        ))
453                        .ok();
454                }
455
456                let (response, usage) = response?;
457                let edits = edits_from_response(&response.edits, &snapshot);
458
459                anyhow::Ok(Some((response.request_id, edits, usage)))
460            }
461        });
462
463        let buffer = buffer.clone();
464
465        cx.spawn(async move |this, cx| {
466            match request_task.await {
467                Ok(Some((id, edits, usage))) => {
468                    if let Some(usage) = usage {
469                        this.update(cx, |this, cx| {
470                            this.user_store.update(cx, |user_store, cx| {
471                                user_store.update_edit_prediction_usage(usage, cx);
472                            });
473                        })
474                        .ok();
475                    }
476
477                    // TODO telemetry: duration, etc
478                    let Some((edits, snapshot, edit_preview_task)) =
479                        buffer.read_with(cx, |buffer, cx| {
480                            let new_snapshot = buffer.snapshot();
481                            let edits: Arc<[_]> =
482                                interpolate_edits(&snapshot, &new_snapshot, edits)?.into();
483                            Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
484                        })?
485                    else {
486                        return Ok(None);
487                    };
488
489                    Ok(Some(EditPrediction {
490                        id: id.into(),
491                        edits,
492                        snapshot,
493                        edit_preview: edit_preview_task.await,
494                    }))
495                }
496                Ok(None) => Ok(None),
497                Err(err) => {
498                    if err.is::<ZedUpdateRequiredError>() {
499                        cx.update(|cx| {
500                            this.update(cx, |this, _cx| {
501                                this.update_required = true;
502                            })
503                            .ok();
504
505                            let error_message: SharedString = err.to_string().into();
506                            show_app_notification(
507                                NotificationId::unique::<ZedUpdateRequiredError>(),
508                                cx,
509                                move |cx| {
510                                    cx.new(|cx| {
511                                        ErrorMessagePrompt::new(error_message.clone(), cx)
512                                            .with_link_button(
513                                                "Update Zed",
514                                                "https://zed.dev/releases",
515                                            )
516                                    })
517                                },
518                            );
519                        })
520                        .ok();
521                    }
522
523                    Err(err)
524                }
525            }
526        })
527    }
528
529    async fn perform_request(
530        client: Arc<Client>,
531        llm_token: LlmApiToken,
532        app_version: SemanticVersion,
533        request: predict_edits_v3::PredictEditsRequest,
534    ) -> Result<(
535        predict_edits_v3::PredictEditsResponse,
536        Option<EditPredictionUsage>,
537    )> {
538        let http_client = client.http_client();
539        let mut token = llm_token.acquire(&client).await?;
540        let mut did_retry = false;
541
542        loop {
543            let request_builder = http_client::Request::builder().method(Method::POST);
544            let request_builder =
545                if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
546                    request_builder.uri(predict_edits_url)
547                } else {
548                    request_builder.uri(
549                        http_client
550                            .build_zed_llm_url("/predict_edits/v3", &[])?
551                            .as_ref(),
552                    )
553                };
554            let request = request_builder
555                .header("Content-Type", "application/json")
556                .header("Authorization", format!("Bearer {}", token))
557                .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
558                .body(serde_json::to_string(&request)?.into())?;
559
560            let mut response = http_client.send(request).await?;
561
562            if let Some(minimum_required_version) = response
563                .headers()
564                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
565                .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
566            {
567                anyhow::ensure!(
568                    app_version >= minimum_required_version,
569                    ZedUpdateRequiredError {
570                        minimum_version: minimum_required_version
571                    }
572                );
573            }
574
575            if response.status().is_success() {
576                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
577
578                let mut body = Vec::new();
579                response.body_mut().read_to_end(&mut body).await?;
580                return Ok((serde_json::from_slice(&body)?, usage));
581            } else if !did_retry
582                && response
583                    .headers()
584                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
585                    .is_some()
586            {
587                did_retry = true;
588                token = llm_token.refresh(&client).await?;
589            } else {
590                let mut body = String::new();
591                response.body_mut().read_to_string(&mut body).await?;
592                anyhow::bail!(
593                    "error predicting edits.\nStatus: {:?}\nBody: {}",
594                    response.status(),
595                    body
596                );
597            }
598        }
599    }
600
601    fn gather_nearby_diagnostics(
602        cursor_offset: usize,
603        diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
604        snapshot: &BufferSnapshot,
605        max_diagnostics_bytes: usize,
606    ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
607        // TODO: Could make this more efficient
608        let mut diagnostic_groups = Vec::new();
609        for (language_server_id, diagnostics) in diagnostic_sets {
610            let mut groups = Vec::new();
611            diagnostics.groups(*language_server_id, &mut groups, &snapshot);
612            diagnostic_groups.extend(
613                groups
614                    .into_iter()
615                    .map(|(_, group)| group.resolve::<usize>(&snapshot)),
616            );
617        }
618
619        // sort by proximity to cursor
620        diagnostic_groups.sort_by_key(|group| {
621            let range = &group.entries[group.primary_ix].range;
622            if range.start >= cursor_offset {
623                range.start - cursor_offset
624            } else if cursor_offset >= range.end {
625                cursor_offset - range.end
626            } else {
627                (cursor_offset - range.start).min(range.end - cursor_offset)
628            }
629        });
630
631        let mut results = Vec::new();
632        let mut diagnostic_groups_truncated = false;
633        let mut diagnostics_byte_count = 0;
634        for group in diagnostic_groups {
635            let raw_value = serde_json::value::to_raw_value(&group).unwrap();
636            diagnostics_byte_count += raw_value.get().len();
637            if diagnostics_byte_count > max_diagnostics_bytes {
638                diagnostic_groups_truncated = true;
639                break;
640            }
641            results.push(predict_edits_v3::DiagnosticGroup(raw_value));
642        }
643
644        (results, diagnostic_groups_truncated)
645    }
646
647    // TODO: Dedupe with similar code in request_prediction?
648    pub fn cloud_request_for_zeta_cli(
649        &mut self,
650        project: &Entity<Project>,
651        buffer: &Entity<Buffer>,
652        position: language::Anchor,
653        cx: &mut Context<Self>,
654    ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
655        let project_state = self.projects.get(&project.entity_id());
656
657        let index_state = project_state.map(|state| {
658            state
659                .syntax_index
660                .read_with(cx, |index, _cx| index.state().clone())
661        });
662        let options = self.options.clone();
663        let snapshot = buffer.read(cx).snapshot();
664        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
665            return Task::ready(Err(anyhow!("No file path for excerpt")));
666        };
667        let worktree_snapshots = project
668            .read(cx)
669            .worktrees(cx)
670            .map(|worktree| worktree.read(cx).snapshot())
671            .collect::<Vec<_>>();
672
673        cx.background_spawn(async move {
674            let index_state = if let Some(index_state) = index_state {
675                Some(index_state.lock_owned().await)
676            } else {
677                None
678            };
679
680            let cursor_point = position.to_point(&snapshot);
681
682            let debug_info = true;
683            EditPredictionContext::gather_context(
684                cursor_point,
685                &snapshot,
686                &options.excerpt,
687                index_state.as_deref(),
688            )
689            .context("Failed to select excerpt")
690            .map(|context| {
691                make_cloud_request(
692                    excerpt_path.into(),
693                    context,
694                    // TODO pass everything
695                    Vec::new(),
696                    false,
697                    Vec::new(),
698                    false,
699                    None,
700                    debug_info,
701                    &worktree_snapshots,
702                    index_state.as_deref(),
703                    Some(options.max_prompt_bytes),
704                    options.prompt_format,
705                )
706            })
707        })
708    }
709}
710
711#[derive(Error, Debug)]
712#[error(
713    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
714)]
715pub struct ZedUpdateRequiredError {
716    minimum_version: SemanticVersion,
717}
718
719fn make_cloud_request(
720    excerpt_path: Arc<Path>,
721    context: EditPredictionContext,
722    events: Vec<predict_edits_v3::Event>,
723    can_collect_data: bool,
724    diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
725    diagnostic_groups_truncated: bool,
726    git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
727    debug_info: bool,
728    worktrees: &Vec<worktree::Snapshot>,
729    index_state: Option<&SyntaxIndexState>,
730    prompt_max_bytes: Option<usize>,
731    prompt_format: PromptFormat,
732) -> predict_edits_v3::PredictEditsRequest {
733    let mut signatures = Vec::new();
734    let mut declaration_to_signature_index = HashMap::default();
735    let mut referenced_declarations = Vec::new();
736
737    for snippet in context.snippets {
738        let project_entry_id = snippet.declaration.project_entry_id();
739        let Some(path) = worktrees.iter().find_map(|worktree| {
740            worktree.entry_for_id(project_entry_id).map(|entry| {
741                let mut full_path = RelPathBuf::new();
742                full_path.push(worktree.root_name());
743                full_path.push(&entry.path);
744                full_path
745            })
746        }) else {
747            continue;
748        };
749
750        let parent_index = index_state.and_then(|index_state| {
751            snippet.declaration.parent().and_then(|parent| {
752                add_signature(
753                    parent,
754                    &mut declaration_to_signature_index,
755                    &mut signatures,
756                    index_state,
757                )
758            })
759        });
760
761        let (text, text_is_truncated) = snippet.declaration.item_text();
762        referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
763            path: path.as_std_path().into(),
764            text: text.into(),
765            range: snippet.declaration.item_range(),
766            text_is_truncated,
767            signature_range: snippet.declaration.signature_range_in_item_text(),
768            parent_index,
769            score_components: snippet.score_components,
770            signature_score: snippet.scores.signature,
771            declaration_score: snippet.scores.declaration,
772        });
773    }
774
775    let excerpt_parent = index_state.and_then(|index_state| {
776        context
777            .excerpt
778            .parent_declarations
779            .last()
780            .and_then(|(parent, _)| {
781                add_signature(
782                    *parent,
783                    &mut declaration_to_signature_index,
784                    &mut signatures,
785                    index_state,
786                )
787            })
788    });
789
790    predict_edits_v3::PredictEditsRequest {
791        excerpt_path,
792        excerpt: context.excerpt_text.body,
793        excerpt_range: context.excerpt.range,
794        cursor_offset: context.cursor_offset_in_excerpt,
795        referenced_declarations,
796        signatures,
797        excerpt_parent,
798        events,
799        can_collect_data,
800        diagnostic_groups,
801        diagnostic_groups_truncated,
802        git_info,
803        debug_info,
804        prompt_max_bytes,
805        prompt_format,
806    }
807}
808
809fn add_signature(
810    declaration_id: DeclarationId,
811    declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
812    signatures: &mut Vec<Signature>,
813    index: &SyntaxIndexState,
814) -> Option<usize> {
815    if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
816        return Some(*signature_index);
817    }
818    let Some(parent_declaration) = index.declaration(declaration_id) else {
819        log::error!("bug: missing parent declaration");
820        return None;
821    };
822    let parent_index = parent_declaration.parent().and_then(|parent| {
823        add_signature(parent, declaration_to_signature_index, signatures, index)
824    });
825    let (text, text_is_truncated) = parent_declaration.signature_text();
826    let signature_index = signatures.len();
827    signatures.push(Signature {
828        text: text.into(),
829        text_is_truncated,
830        parent_index,
831        range: parent_declaration.signature_range(),
832    });
833    declaration_to_signature_index.insert(declaration_id, signature_index);
834    Some(signature_index)
835}