zeta2.rs

  1use anyhow::{Context as _, Result, anyhow};
  2use arrayvec::ArrayVec;
  3use client::{Client, EditPredictionUsage, UserStore};
  4use cloud_llm_client::predict_edits_v3::{self, Signature};
  5use cloud_llm_client::{
  6    EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, ZED_VERSION_HEADER_NAME,
  7};
  8use edit_prediction::{DataCollectionState, Direction, EditPredictionProvider};
  9use edit_prediction_context::{
 10    DeclarationId, EditPredictionContext, EditPredictionExcerptOptions, SyntaxIndex,
 11    SyntaxIndexState,
 12};
 13use futures::AsyncReadExt as _;
 14use gpui::http_client::Method;
 15use gpui::{
 16    App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, http_client,
 17    prelude::*,
 18};
 19use language::{Anchor, Buffer, OffsetRangeExt as _, ToPoint};
 20use language::{BufferSnapshot, EditPreview};
 21use language_model::{LlmApiToken, RefreshLlmTokenListener};
 22use project::Project;
 23use release_channel::AppVersion;
 24use std::cmp;
 25use std::collections::HashMap;
 26use std::path::PathBuf;
 27use std::str::FromStr as _;
 28use std::time::{Duration, Instant};
 29use std::{ops::Range, sync::Arc};
 30use thiserror::Error;
 31use util::ResultExt as _;
 32use uuid::Uuid;
 33use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
 34
 35#[derive(Clone)]
 36struct ZetaGlobal(Entity<Zeta>);
 37
 38impl Global for ZetaGlobal {}
 39
 40pub struct Zeta {
 41    client: Arc<Client>,
 42    user_store: Entity<UserStore>,
 43    llm_token: LlmApiToken,
 44    _llm_token_subscription: Subscription,
 45    projects: HashMap<EntityId, RegisteredProject>,
 46    excerpt_options: EditPredictionExcerptOptions,
 47    update_required: bool,
 48}
 49
 50struct RegisteredProject {
 51    syntax_index: Entity<SyntaxIndex>,
 52}
 53
 54impl Zeta {
 55    pub fn global(
 56        client: &Arc<Client>,
 57        user_store: &Entity<UserStore>,
 58        cx: &mut App,
 59    ) -> Entity<Self> {
 60        cx.try_global::<ZetaGlobal>()
 61            .map(|global| global.0.clone())
 62            .unwrap_or_else(|| {
 63                let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 64                cx.set_global(ZetaGlobal(zeta.clone()));
 65                zeta
 66            })
 67    }
 68
 69    fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 70        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 71
 72        Self {
 73            projects: HashMap::new(),
 74            client,
 75            user_store,
 76            excerpt_options: EditPredictionExcerptOptions {
 77                max_bytes: 512,
 78                min_bytes: 128,
 79                target_before_cursor_over_total_bytes: 0.5,
 80            },
 81            llm_token: LlmApiToken::default(),
 82            _llm_token_subscription: cx.subscribe(
 83                &refresh_llm_token_listener,
 84                |this, _listener, _event, cx| {
 85                    let client = this.client.clone();
 86                    let llm_token = this.llm_token.clone();
 87                    cx.spawn(async move |_this, _cx| {
 88                        llm_token.refresh(&client).await?;
 89                        anyhow::Ok(())
 90                    })
 91                    .detach_and_log_err(cx);
 92                },
 93            ),
 94            update_required: false,
 95        }
 96    }
 97
 98    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 99        self.user_store.read(cx).edit_prediction_usage()
100    }
101
102    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
103        self.projects
104            .entry(project.entity_id())
105            .or_insert_with(|| RegisteredProject {
106                syntax_index: cx.new(|cx| SyntaxIndex::new(project, cx)),
107            });
108    }
109
110    pub fn request_prediction(
111        &mut self,
112        project: &Entity<Project>,
113        buffer: &Entity<Buffer>,
114        position: language::Anchor,
115        cx: &mut Context<Self>,
116    ) -> Task<Result<Option<EditPrediction>>> {
117        let project_state = self.projects.get(&project.entity_id());
118
119        let index_state = project_state.map(|state| {
120            state
121                .syntax_index
122                .read_with(cx, |index, _cx| index.state().clone())
123        });
124        let excerpt_options = self.excerpt_options.clone();
125        let snapshot = buffer.read(cx).snapshot();
126        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
127            return Task::ready(Err(anyhow!("No file path for excerpt")));
128        };
129        let client = self.client.clone();
130        let llm_token = self.llm_token.clone();
131        let app_version = AppVersion::global(cx);
132        let worktree_snapshots = project
133            .read(cx)
134            .worktrees(cx)
135            .map(|worktree| worktree.read(cx).snapshot())
136            .collect::<Vec<_>>();
137
138        let request_task = cx.background_spawn({
139            let snapshot = snapshot.clone();
140            async move {
141                let index_state = if let Some(index_state) = index_state {
142                    Some(index_state.lock_owned().await)
143                } else {
144                    None
145                };
146
147                let cursor_point = position.to_point(&snapshot);
148
149                // TODO: make this only true if debug view is open
150                let debug_info = true;
151
152                let Some(request) = EditPredictionContext::gather_context(
153                    cursor_point,
154                    &snapshot,
155                    &excerpt_options,
156                    index_state.as_deref(),
157                )
158                .map(|context| {
159                    make_cloud_request(
160                        excerpt_path.clone(),
161                        context,
162                        // TODO pass everything
163                        Vec::new(),
164                        false,
165                        Vec::new(),
166                        None,
167                        debug_info,
168                        &worktree_snapshots,
169                        index_state.as_deref(),
170                    )
171                }) else {
172                    return Ok(None);
173                };
174
175                anyhow::Ok(Some(
176                    Self::perform_request(client, llm_token, app_version, request).await?,
177                ))
178            }
179        });
180
181        let buffer = buffer.clone();
182
183        cx.spawn(async move |this, cx| {
184            match request_task.await {
185                Ok(Some((response, usage))) => {
186                    log::debug!("predicted edits: {:?}", &response.edits);
187
188                    if let Some(usage) = usage {
189                        this.update(cx, |this, cx| {
190                            this.user_store.update(cx, |user_store, cx| {
191                                user_store.update_edit_prediction_usage(usage, cx);
192                            });
193                        })
194                        .ok();
195                    }
196
197                    // TODO telemetry: duration, etc
198
199                    // TODO produce smaller edits by diffing against snapshot first
200                    //
201                    // Cloud returns entire snippets/excerpts ranges as they were included
202                    // in the request, but we should display smaller edits to the user.
203                    //
204                    // We can do this by computing a diff of each one against the snapshot.
205                    // Similar to zeta::Zeta::compute_edits, but per edit.
206                    let edits = response
207                        .edits
208                        .into_iter()
209                        .map(|edit| {
210                            // TODO edits to different files
211                            (
212                                snapshot.anchor_before(edit.range.start)
213                                    ..snapshot.anchor_before(edit.range.end),
214                                edit.content,
215                            )
216                        })
217                        .collect::<Vec<_>>()
218                        .into();
219
220                    let Some((edits, snapshot, edit_preview_task)) =
221                        buffer.read_with(cx, |buffer, cx| {
222                            let new_snapshot = buffer.snapshot();
223                            let edits: Arc<[_]> =
224                                interpolate(&snapshot, &new_snapshot, edits)?.into();
225                            Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
226                        })?
227                    else {
228                        return Ok(None);
229                    };
230
231                    Ok(Some(EditPrediction {
232                        id: EditPredictionId(response.request_id),
233                        edits,
234                        snapshot,
235                        edit_preview: edit_preview_task.await,
236                    }))
237                }
238                Ok(None) => Ok(None),
239                Err(err) => {
240                    if err.is::<ZedUpdateRequiredError>() {
241                        cx.update(|cx| {
242                            this.update(cx, |this, _cx| {
243                                this.update_required = true;
244                            })
245                            .ok();
246
247                            let error_message: SharedString = err.to_string().into();
248                            show_app_notification(
249                                NotificationId::unique::<ZedUpdateRequiredError>(),
250                                cx,
251                                move |cx| {
252                                    cx.new(|cx| {
253                                        ErrorMessagePrompt::new(error_message.clone(), cx)
254                                            .with_link_button(
255                                                "Update Zed",
256                                                "https://zed.dev/releases",
257                                            )
258                                    })
259                                },
260                            );
261                        })
262                        .ok();
263                    }
264
265                    Err(err)
266                }
267            }
268        })
269    }
270
271    async fn perform_request(
272        client: Arc<Client>,
273        llm_token: LlmApiToken,
274        app_version: SemanticVersion,
275        request: predict_edits_v3::PredictEditsRequest,
276    ) -> Result<(
277        predict_edits_v3::PredictEditsResponse,
278        Option<EditPredictionUsage>,
279    )> {
280        let http_client = client.http_client();
281        let mut token = llm_token.acquire(&client).await?;
282        let mut did_retry = false;
283
284        loop {
285            let request_builder = http_client::Request::builder().method(Method::POST);
286            let request_builder =
287                if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
288                    request_builder.uri(predict_edits_url)
289                } else {
290                    request_builder.uri(
291                        http_client
292                            .build_zed_llm_url("/predict_edits/v3", &[])?
293                            .as_ref(),
294                    )
295                };
296            let request = request_builder
297                .header("Content-Type", "application/json")
298                .header("Authorization", format!("Bearer {}", token))
299                .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
300                .body(serde_json::to_string(&request)?.into())?;
301
302            let mut response = http_client.send(request).await?;
303
304            if let Some(minimum_required_version) = response
305                .headers()
306                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
307                .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
308            {
309                anyhow::ensure!(
310                    app_version >= minimum_required_version,
311                    ZedUpdateRequiredError {
312                        minimum_version: minimum_required_version
313                    }
314                );
315            }
316
317            if response.status().is_success() {
318                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
319
320                let mut body = Vec::new();
321                response.body_mut().read_to_end(&mut body).await?;
322                return Ok((serde_json::from_slice(&body)?, usage));
323            } else if !did_retry
324                && response
325                    .headers()
326                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
327                    .is_some()
328            {
329                did_retry = true;
330                token = llm_token.refresh(&client).await?;
331            } else {
332                let mut body = String::new();
333                response.body_mut().read_to_string(&mut body).await?;
334                anyhow::bail!(
335                    "error predicting edits.\nStatus: {:?}\nBody: {}",
336                    response.status(),
337                    body
338                );
339            }
340        }
341    }
342}
343
344#[derive(Error, Debug)]
345#[error(
346    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
347)]
348pub struct ZedUpdateRequiredError {
349    minimum_version: SemanticVersion,
350}
351
352pub struct ZetaEditPredictionProvider {
353    zeta: Entity<Zeta>,
354    current_prediction: Option<CurrentEditPrediction>,
355    next_pending_prediction_id: usize,
356    pending_predictions: ArrayVec<PendingPrediction, 2>,
357    last_request_timestamp: Instant,
358}
359
360impl ZetaEditPredictionProvider {
361    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
362
363    pub fn new(
364        project: Option<&Entity<Project>>,
365        client: &Arc<Client>,
366        user_store: &Entity<UserStore>,
367        cx: &mut App,
368    ) -> Self {
369        let zeta = Zeta::global(client, user_store, cx);
370        if let Some(project) = project {
371            zeta.update(cx, |zeta, cx| {
372                zeta.register_project(project, cx);
373            });
374        }
375
376        Self {
377            zeta,
378            current_prediction: None,
379            next_pending_prediction_id: 0,
380            pending_predictions: ArrayVec::new(),
381            last_request_timestamp: Instant::now(),
382        }
383    }
384}
385
386#[derive(Clone)]
387struct CurrentEditPrediction {
388    buffer_id: EntityId,
389    prediction: EditPrediction,
390}
391
392impl CurrentEditPrediction {
393    fn should_replace_prediction(&self, old_prediction: &Self, snapshot: &BufferSnapshot) -> bool {
394        if self.buffer_id != old_prediction.buffer_id {
395            return true;
396        }
397
398        let Some(old_edits) = old_prediction.prediction.interpolate(snapshot) else {
399            return true;
400        };
401        let Some(new_edits) = self.prediction.interpolate(snapshot) else {
402            return false;
403        };
404
405        if old_edits.len() == 1 && new_edits.len() == 1 {
406            let (old_range, old_text) = &old_edits[0];
407            let (new_range, new_text) = &new_edits[0];
408            new_range == old_range && new_text.starts_with(old_text)
409        } else {
410            true
411        }
412    }
413}
414
415#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
416pub struct EditPredictionId(Uuid);
417
418impl From<EditPredictionId> for gpui::ElementId {
419    fn from(value: EditPredictionId) -> Self {
420        gpui::ElementId::Uuid(value.0)
421    }
422}
423
424impl std::fmt::Display for EditPredictionId {
425    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
426        write!(f, "{}", self.0)
427    }
428}
429
430#[derive(Clone)]
431pub struct EditPrediction {
432    id: EditPredictionId,
433    edits: Arc<[(Range<Anchor>, String)]>,
434    snapshot: BufferSnapshot,
435    edit_preview: EditPreview,
436}
437
438impl EditPrediction {
439    fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
440        interpolate(&self.snapshot, new_snapshot, self.edits.clone())
441    }
442}
443
444struct PendingPrediction {
445    id: usize,
446    _task: Task<()>,
447}
448
449impl EditPredictionProvider for ZetaEditPredictionProvider {
450    fn name() -> &'static str {
451        // TODO [zeta2]
452        "zed-predict2"
453    }
454
455    fn display_name() -> &'static str {
456        "Zed's Edit Predictions 2"
457    }
458
459    fn show_completions_in_menu() -> bool {
460        true
461    }
462
463    fn show_tab_accept_marker() -> bool {
464        true
465    }
466
467    fn data_collection_state(&self, _cx: &App) -> DataCollectionState {
468        // TODO [zeta2]
469        DataCollectionState::Unsupported
470    }
471
472    fn toggle_data_collection(&mut self, _cx: &mut App) {
473        // TODO [zeta2]
474    }
475
476    fn usage(&self, cx: &App) -> Option<client::EditPredictionUsage> {
477        self.zeta.read(cx).usage(cx)
478    }
479
480    fn is_enabled(
481        &self,
482        _buffer: &Entity<language::Buffer>,
483        _cursor_position: language::Anchor,
484        _cx: &App,
485    ) -> bool {
486        true
487    }
488
489    fn is_refreshing(&self) -> bool {
490        !self.pending_predictions.is_empty()
491    }
492
493    fn refresh(
494        &mut self,
495        project: Option<Entity<project::Project>>,
496        buffer: Entity<language::Buffer>,
497        cursor_position: language::Anchor,
498        _debounce: bool,
499        cx: &mut Context<Self>,
500    ) {
501        let Some(project) = project else {
502            return;
503        };
504
505        if self
506            .zeta
507            .read(cx)
508            .user_store
509            .read_with(cx, |user_store, _cx| {
510                user_store.account_too_young() || user_store.has_overdue_invoices()
511            })
512        {
513            return;
514        }
515
516        if let Some(current_prediction) = self.current_prediction.as_ref() {
517            let snapshot = buffer.read(cx).snapshot();
518            if current_prediction
519                .prediction
520                .interpolate(&snapshot)
521                .is_some()
522            {
523                return;
524            }
525        }
526
527        let pending_prediction_id = self.next_pending_prediction_id;
528        self.next_pending_prediction_id += 1;
529        let last_request_timestamp = self.last_request_timestamp;
530
531        let task = cx.spawn(async move |this, cx| {
532            if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
533                .checked_duration_since(Instant::now())
534            {
535                cx.background_executor().timer(timeout).await;
536            }
537
538            let prediction_request = this.update(cx, |this, cx| {
539                this.last_request_timestamp = Instant::now();
540                this.zeta.update(cx, |zeta, cx| {
541                    zeta.request_prediction(&project, &buffer, cursor_position, cx)
542                })
543            });
544
545            let prediction = match prediction_request {
546                Ok(prediction_request) => {
547                    let prediction_request = prediction_request.await;
548                    prediction_request.map(|c| {
549                        c.map(|prediction| CurrentEditPrediction {
550                            buffer_id: buffer.entity_id(),
551                            prediction,
552                        })
553                    })
554                }
555                Err(error) => Err(error),
556            };
557
558            this.update(cx, |this, cx| {
559                if this.pending_predictions[0].id == pending_prediction_id {
560                    this.pending_predictions.remove(0);
561                } else {
562                    this.pending_predictions.clear();
563                }
564
565                let Some(new_prediction) = prediction
566                    .context("edit prediction failed")
567                    .log_err()
568                    .flatten()
569                else {
570                    cx.notify();
571                    return;
572                };
573
574                if let Some(old_prediction) = this.current_prediction.as_ref() {
575                    let snapshot = buffer.read(cx).snapshot();
576                    if new_prediction.should_replace_prediction(old_prediction, &snapshot) {
577                        this.current_prediction = Some(new_prediction);
578                    }
579                } else {
580                    this.current_prediction = Some(new_prediction);
581                }
582
583                cx.notify();
584            })
585            .ok();
586        });
587
588        // We always maintain at most two pending predictions. When we already
589        // have two, we replace the newest one.
590        if self.pending_predictions.len() <= 1 {
591            self.pending_predictions.push(PendingPrediction {
592                id: pending_prediction_id,
593                _task: task,
594            });
595        } else if self.pending_predictions.len() == 2 {
596            self.pending_predictions.pop();
597            self.pending_predictions.push(PendingPrediction {
598                id: pending_prediction_id,
599                _task: task,
600            });
601        }
602
603        cx.notify();
604    }
605
606    fn cycle(
607        &mut self,
608        _buffer: Entity<language::Buffer>,
609        _cursor_position: language::Anchor,
610        _direction: Direction,
611        _cx: &mut Context<Self>,
612    ) {
613    }
614
615    fn accept(&mut self, _cx: &mut Context<Self>) {
616        // TODO [zeta2] report accept
617        self.current_prediction.take();
618        self.pending_predictions.clear();
619    }
620
621    fn discard(&mut self, _cx: &mut Context<Self>) {
622        self.pending_predictions.clear();
623        self.current_prediction.take();
624    }
625
626    fn suggest(
627        &mut self,
628        buffer: &Entity<language::Buffer>,
629        cursor_position: language::Anchor,
630        cx: &mut Context<Self>,
631    ) -> Option<edit_prediction::EditPrediction> {
632        let CurrentEditPrediction {
633            buffer_id,
634            prediction,
635            ..
636        } = self.current_prediction.as_mut()?;
637
638        // Invalidate previous prediction if it was generated for a different buffer.
639        if *buffer_id != buffer.entity_id() {
640            self.current_prediction.take();
641            return None;
642        }
643
644        let buffer = buffer.read(cx);
645        let Some(edits) = prediction.interpolate(&buffer.snapshot()) else {
646            self.current_prediction.take();
647            return None;
648        };
649
650        let cursor_row = cursor_position.to_point(buffer).row;
651        let (closest_edit_ix, (closest_edit_range, _)) =
652            edits.iter().enumerate().min_by_key(|(_, (range, _))| {
653                let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
654                let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
655                cmp::min(distance_from_start, distance_from_end)
656            })?;
657
658        let mut edit_start_ix = closest_edit_ix;
659        for (range, _) in edits[..edit_start_ix].iter().rev() {
660            let distance_from_closest_edit =
661                closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
662            if distance_from_closest_edit <= 1 {
663                edit_start_ix -= 1;
664            } else {
665                break;
666            }
667        }
668
669        let mut edit_end_ix = closest_edit_ix + 1;
670        for (range, _) in &edits[edit_end_ix..] {
671            let distance_from_closest_edit =
672                range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
673            if distance_from_closest_edit <= 1 {
674                edit_end_ix += 1;
675            } else {
676                break;
677            }
678        }
679
680        Some(edit_prediction::EditPrediction {
681            id: Some(prediction.id.to_string().into()),
682            edits: edits[edit_start_ix..edit_end_ix].to_vec(),
683            edit_preview: Some(prediction.edit_preview.clone()),
684        })
685    }
686}
687
688fn make_cloud_request(
689    excerpt_path: PathBuf,
690    context: EditPredictionContext,
691    events: Vec<predict_edits_v3::Event>,
692    can_collect_data: bool,
693    diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
694    git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
695    debug_info: bool,
696    worktrees: &Vec<worktree::Snapshot>,
697    index_state: Option<&SyntaxIndexState>,
698) -> predict_edits_v3::PredictEditsRequest {
699    let mut signatures = Vec::new();
700    let mut declaration_to_signature_index = HashMap::default();
701    let mut referenced_declarations = Vec::new();
702
703    for snippet in context.snippets {
704        let project_entry_id = snippet.declaration.project_entry_id();
705        // TODO: Use full paths (worktree rooted) - need to move full_path method to the snapshot.
706        // Note that currently full_path is currently being used for excerpt_path.
707        let Some(path) = worktrees.iter().find_map(|worktree| {
708            let abs_path = worktree.abs_path();
709            worktree
710                .entry_for_id(project_entry_id)
711                .map(|e| abs_path.join(&e.path))
712        }) else {
713            continue;
714        };
715
716        let parent_index = index_state.and_then(|index_state| {
717            snippet.declaration.parent().and_then(|parent| {
718                add_signature(
719                    parent,
720                    &mut declaration_to_signature_index,
721                    &mut signatures,
722                    index_state,
723                )
724            })
725        });
726
727        let (text, text_is_truncated) = snippet.declaration.item_text();
728        referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
729            path,
730            text: text.into(),
731            range: snippet.declaration.item_range(),
732            text_is_truncated,
733            signature_range: snippet.declaration.signature_range_in_item_text(),
734            parent_index,
735            score_components: snippet.score_components,
736            signature_score: snippet.scores.signature,
737            declaration_score: snippet.scores.declaration,
738        });
739    }
740
741    let excerpt_parent = index_state.and_then(|index_state| {
742        context
743            .excerpt
744            .parent_declarations
745            .last()
746            .and_then(|(parent, _)| {
747                add_signature(
748                    *parent,
749                    &mut declaration_to_signature_index,
750                    &mut signatures,
751                    index_state,
752                )
753            })
754    });
755
756    predict_edits_v3::PredictEditsRequest {
757        excerpt_path,
758        excerpt: context.excerpt_text.body,
759        excerpt_range: context.excerpt.range,
760        cursor_offset: context.cursor_offset_in_excerpt,
761        referenced_declarations,
762        signatures,
763        excerpt_parent,
764        // todo!
765        events,
766        can_collect_data,
767        diagnostic_groups,
768        git_info,
769        debug_info,
770    }
771}
772
773fn add_signature(
774    declaration_id: DeclarationId,
775    declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
776    signatures: &mut Vec<Signature>,
777    index: &SyntaxIndexState,
778) -> Option<usize> {
779    if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
780        return Some(*signature_index);
781    }
782    let Some(parent_declaration) = index.declaration(declaration_id) else {
783        log::error!("bug: missing parent declaration");
784        return None;
785    };
786    let parent_index = parent_declaration.parent().and_then(|parent| {
787        add_signature(parent, declaration_to_signature_index, signatures, index)
788    });
789    let (text, text_is_truncated) = parent_declaration.signature_text();
790    let signature_index = signatures.len();
791    signatures.push(Signature {
792        text: text.into(),
793        text_is_truncated,
794        parent_index,
795    });
796    declaration_to_signature_index.insert(declaration_id, signature_index);
797    Some(signature_index)
798}
799
800fn interpolate(
801    old_snapshot: &BufferSnapshot,
802    new_snapshot: &BufferSnapshot,
803    current_edits: Arc<[(Range<Anchor>, String)]>,
804) -> Option<Vec<(Range<Anchor>, String)>> {
805    let mut edits = Vec::new();
806
807    let mut model_edits = current_edits.iter().peekable();
808    for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
809        while let Some((model_old_range, _)) = model_edits.peek() {
810            let model_old_range = model_old_range.to_offset(old_snapshot);
811            if model_old_range.end < user_edit.old.start {
812                let (model_old_range, model_new_text) = model_edits.next().unwrap();
813                edits.push((model_old_range.clone(), model_new_text.clone()));
814            } else {
815                break;
816            }
817        }
818
819        if let Some((model_old_range, model_new_text)) = model_edits.peek() {
820            let model_old_offset_range = model_old_range.to_offset(old_snapshot);
821            if user_edit.old == model_old_offset_range {
822                let user_new_text = new_snapshot
823                    .text_for_range(user_edit.new.clone())
824                    .collect::<String>();
825
826                if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
827                    if !model_suffix.is_empty() {
828                        let anchor = old_snapshot.anchor_after(user_edit.old.end);
829                        edits.push((anchor..anchor, model_suffix.to_string()));
830                    }
831
832                    model_edits.next();
833                    continue;
834                }
835            }
836        }
837
838        return None;
839    }
840
841    edits.extend(model_edits.cloned());
842
843    if edits.is_empty() { None } else { Some(edits) }
844}
845
846#[cfg(test)]
847mod tests {
848    use super::*;
849    use gpui::TestAppContext;
850    use language::ToOffset as _;
851
852    #[gpui::test]
853    async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
854        let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
855        let edits: Arc<[(Range<Anchor>, String)]> = cx.update(|cx| {
856            to_prediction_edits(
857                [(2..5, "REM".to_string()), (9..11, "".to_string())],
858                &buffer,
859                cx,
860            )
861            .into()
862        });
863
864        let edit_preview = cx
865            .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
866            .await;
867
868        let prediction = EditPrediction {
869            id: EditPredictionId(Uuid::new_v4()),
870            edits,
871            snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
872            edit_preview,
873        };
874
875        cx.update(|cx| {
876            assert_eq!(
877                from_prediction_edits(
878                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
879                    &buffer,
880                    cx
881                ),
882                vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
883            );
884
885            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
886            assert_eq!(
887                from_prediction_edits(
888                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
889                    &buffer,
890                    cx
891                ),
892                vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
893            );
894
895            buffer.update(cx, |buffer, cx| buffer.undo(cx));
896            assert_eq!(
897                from_prediction_edits(
898                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
899                    &buffer,
900                    cx
901                ),
902                vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
903            );
904
905            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
906            assert_eq!(
907                from_prediction_edits(
908                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
909                    &buffer,
910                    cx
911                ),
912                vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
913            );
914
915            buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
916            assert_eq!(
917                from_prediction_edits(
918                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
919                    &buffer,
920                    cx
921                ),
922                vec![(4..4, "M".to_string()), (8..10, "".to_string())]
923            );
924
925            buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
926            assert_eq!(
927                from_prediction_edits(
928                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
929                    &buffer,
930                    cx
931                ),
932                vec![(9..11, "".to_string())]
933            );
934
935            buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
936            assert_eq!(
937                from_prediction_edits(
938                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
939                    &buffer,
940                    cx
941                ),
942                vec![(4..4, "M".to_string()), (8..10, "".to_string())]
943            );
944
945            buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
946            assert_eq!(
947                from_prediction_edits(
948                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
949                    &buffer,
950                    cx
951                ),
952                vec![(4..4, "M".to_string())]
953            );
954
955            buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
956            assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
957        })
958    }
959
960    fn to_prediction_edits(
961        iterator: impl IntoIterator<Item = (Range<usize>, String)>,
962        buffer: &Entity<Buffer>,
963        cx: &App,
964    ) -> Vec<(Range<Anchor>, String)> {
965        let buffer = buffer.read(cx);
966        iterator
967            .into_iter()
968            .map(|(range, text)| {
969                (
970                    buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
971                    text,
972                )
973            })
974            .collect()
975    }
976
977    fn from_prediction_edits(
978        editor_edits: &[(Range<Anchor>, String)],
979        buffer: &Entity<Buffer>,
980        cx: &App,
981    ) -> Vec<(Range<usize>, String)> {
982        let buffer = buffer.read(cx);
983        editor_edits
984            .iter()
985            .map(|(range, text)| {
986                (
987                    range.start.to_offset(buffer)..range.end.to_offset(buffer),
988                    text.clone(),
989                )
990            })
991            .collect()
992    }
993}