zeta2.rs

   1use anyhow::{Context as _, Result, anyhow};
   2use arrayvec::ArrayVec;
   3use chrono::TimeDelta;
   4use client::{Client, EditPredictionUsage, UserStore};
   5use cloud_llm_client::predict_edits_v3::{self, Signature};
   6use cloud_llm_client::{
   7    EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, ZED_VERSION_HEADER_NAME,
   8};
   9use edit_prediction::{DataCollectionState, Direction, EditPredictionProvider};
  10use edit_prediction_context::{
  11    DeclarationId, EditPredictionContext, EditPredictionExcerptOptions, SyntaxIndex,
  12    SyntaxIndexState,
  13};
  14use futures::AsyncReadExt as _;
  15use futures::channel::mpsc;
  16use gpui::http_client::Method;
  17use gpui::{
  18    App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, http_client,
  19    prelude::*,
  20};
  21use language::{Anchor, Buffer, OffsetRangeExt as _, ToPoint};
  22use language::{BufferSnapshot, EditPreview};
  23use language_model::{LlmApiToken, RefreshLlmTokenListener};
  24use project::Project;
  25use release_channel::AppVersion;
  26use std::cmp;
  27use std::collections::{HashMap, VecDeque, hash_map};
  28use std::path::PathBuf;
  29use std::str::FromStr as _;
  30use std::time::{Duration, Instant};
  31use std::{ops::Range, sync::Arc};
  32use thiserror::Error;
  33use util::{ResultExt as _, some_or_debug_panic};
  34use uuid::Uuid;
  35use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  36
  37const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
  38
  39/// Maximum number of events to track.
  40const MAX_EVENT_COUNT: usize = 16;
  41
  42#[derive(Clone)]
  43struct ZetaGlobal(Entity<Zeta>);
  44
  45impl Global for ZetaGlobal {}
  46
  47pub struct Zeta {
  48    client: Arc<Client>,
  49    user_store: Entity<UserStore>,
  50    llm_token: LlmApiToken,
  51    _llm_token_subscription: Subscription,
  52    projects: HashMap<EntityId, ZetaProject>,
  53    excerpt_options: EditPredictionExcerptOptions,
  54    update_required: bool,
  55    debug_tx: Option<mpsc::UnboundedSender<Result<PredictionDebugInfo, String>>>,
  56}
  57
  58pub struct PredictionDebugInfo {
  59    pub context: EditPredictionContext,
  60    pub retrieval_time: TimeDelta,
  61    pub request: RequestDebugInfo,
  62}
  63
  64pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
  65
  66struct ZetaProject {
  67    syntax_index: Entity<SyntaxIndex>,
  68    events: VecDeque<Event>,
  69    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
  70}
  71
  72struct RegisteredBuffer {
  73    snapshot: BufferSnapshot,
  74    _subscriptions: [gpui::Subscription; 2],
  75}
  76
  77#[derive(Clone)]
  78pub enum Event {
  79    BufferChange {
  80        old_snapshot: BufferSnapshot,
  81        new_snapshot: BufferSnapshot,
  82        timestamp: Instant,
  83    },
  84}
  85
  86impl Event {
  87    //TODO: Actually use the events this in the prompt
  88    // fn to_prompt(&self) -> String {
  89    //     match self {
  90    //         Event::BufferChange {
  91    //             old_snapshot,
  92    //             new_snapshot,
  93    //             ..
  94    //         } => {
  95    //             let mut prompt = String::new();
  96
  97    //             let old_path = old_snapshot
  98    //                 .file()
  99    //                 .map(|f| f.path().as_ref())
 100    //                 .unwrap_or(Path::new("untitled"));
 101    //             let new_path = new_snapshot
 102    //                 .file()
 103    //                 .map(|f| f.path().as_ref())
 104    //                 .unwrap_or(Path::new("untitled"));
 105    //             if old_path != new_path {
 106    //                 writeln!(prompt, "User renamed {:?} to {:?}\n", old_path, new_path).unwrap();
 107    //             }
 108
 109    //             let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
 110    //             if !diff.is_empty() {
 111    //                 write!(
 112    //                     prompt,
 113    //                     "User edited {:?}:\n```diff\n{}\n```",
 114    //                     new_path, diff
 115    //                 )
 116    //                 .unwrap();
 117    //             }
 118
 119    //             prompt
 120    //         }
 121    //     }
 122    // }
 123}
 124
 125impl Zeta {
 126    pub fn global(
 127        client: &Arc<Client>,
 128        user_store: &Entity<UserStore>,
 129        cx: &mut App,
 130    ) -> Entity<Self> {
 131        cx.try_global::<ZetaGlobal>()
 132            .map(|global| global.0.clone())
 133            .unwrap_or_else(|| {
 134                let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 135                cx.set_global(ZetaGlobal(zeta.clone()));
 136                zeta
 137            })
 138    }
 139
 140    fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 141        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 142
 143        Self {
 144            projects: HashMap::new(),
 145            client,
 146            user_store,
 147            excerpt_options: EditPredictionExcerptOptions {
 148                max_bytes: 512,
 149                min_bytes: 128,
 150                target_before_cursor_over_total_bytes: 0.5,
 151            },
 152            llm_token: LlmApiToken::default(),
 153            _llm_token_subscription: cx.subscribe(
 154                &refresh_llm_token_listener,
 155                |this, _listener, _event, cx| {
 156                    let client = this.client.clone();
 157                    let llm_token = this.llm_token.clone();
 158                    cx.spawn(async move |_this, _cx| {
 159                        llm_token.refresh(&client).await?;
 160                        anyhow::Ok(())
 161                    })
 162                    .detach_and_log_err(cx);
 163                },
 164            ),
 165            update_required: false,
 166            debug_tx: None,
 167        }
 168    }
 169
 170    pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<Result<PredictionDebugInfo, String>> {
 171        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
 172        self.debug_tx = Some(debug_watch_tx);
 173        debug_watch_rx
 174    }
 175
 176    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 177        self.user_store.read(cx).edit_prediction_usage()
 178    }
 179
 180    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
 181        self.get_or_init_zeta_project(project, cx);
 182    }
 183
 184    pub fn register_buffer(
 185        &mut self,
 186        buffer: &Entity<Buffer>,
 187        project: &Entity<Project>,
 188        cx: &mut Context<Self>,
 189    ) {
 190        let zeta_project = self.get_or_init_zeta_project(project, cx);
 191        Self::register_buffer_impl(zeta_project, buffer, project, cx);
 192    }
 193
 194    fn get_or_init_zeta_project(
 195        &mut self,
 196        project: &Entity<Project>,
 197        cx: &mut App,
 198    ) -> &mut ZetaProject {
 199        self.projects
 200            .entry(project.entity_id())
 201            .or_insert_with(|| ZetaProject {
 202                syntax_index: cx.new(|cx| SyntaxIndex::new(project, cx)),
 203                events: VecDeque::new(),
 204                registered_buffers: HashMap::new(),
 205            })
 206    }
 207
 208    fn register_buffer_impl<'a>(
 209        zeta_project: &'a mut ZetaProject,
 210        buffer: &Entity<Buffer>,
 211        project: &Entity<Project>,
 212        cx: &mut Context<Self>,
 213    ) -> &'a mut RegisteredBuffer {
 214        let buffer_id = buffer.entity_id();
 215        match zeta_project.registered_buffers.entry(buffer_id) {
 216            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 217            hash_map::Entry::Vacant(entry) => {
 218                let snapshot = buffer.read(cx).snapshot();
 219                let project_entity_id = project.entity_id();
 220                entry.insert(RegisteredBuffer {
 221                    snapshot,
 222                    _subscriptions: [
 223                        cx.subscribe(buffer, {
 224                            let project = project.downgrade();
 225                            move |this, buffer, event, cx| {
 226                                if let language::BufferEvent::Edited = event
 227                                    && let Some(project) = project.upgrade()
 228                                {
 229                                    this.report_changes_for_buffer(&buffer, &project, cx);
 230                                }
 231                            }
 232                        }),
 233                        cx.observe_release(buffer, move |this, _buffer, _cx| {
 234                            let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
 235                            else {
 236                                return;
 237                            };
 238                            zeta_project.registered_buffers.remove(&buffer_id);
 239                        }),
 240                    ],
 241                })
 242            }
 243        }
 244    }
 245
 246    fn report_changes_for_buffer(
 247        &mut self,
 248        buffer: &Entity<Buffer>,
 249        project: &Entity<Project>,
 250        cx: &mut Context<Self>,
 251    ) -> BufferSnapshot {
 252        let zeta_project = self.get_or_init_zeta_project(project, cx);
 253        let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
 254
 255        let new_snapshot = buffer.read(cx).snapshot();
 256        if new_snapshot.version != registered_buffer.snapshot.version {
 257            let old_snapshot =
 258                std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
 259            Self::push_event(
 260                zeta_project,
 261                Event::BufferChange {
 262                    old_snapshot,
 263                    new_snapshot: new_snapshot.clone(),
 264                    timestamp: Instant::now(),
 265                },
 266            );
 267        }
 268
 269        new_snapshot
 270    }
 271
 272    fn push_event(zeta_project: &mut ZetaProject, event: Event) {
 273        let events = &mut zeta_project.events;
 274
 275        if let Some(Event::BufferChange {
 276            new_snapshot: last_new_snapshot,
 277            timestamp: last_timestamp,
 278            ..
 279        }) = events.back_mut()
 280        {
 281            // Coalesce edits for the same buffer when they happen one after the other.
 282            let Event::BufferChange {
 283                old_snapshot,
 284                new_snapshot,
 285                timestamp,
 286            } = &event;
 287
 288            if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
 289                && old_snapshot.remote_id() == last_new_snapshot.remote_id()
 290                && old_snapshot.version == last_new_snapshot.version
 291            {
 292                *last_new_snapshot = new_snapshot.clone();
 293                *last_timestamp = *timestamp;
 294                return;
 295            }
 296        }
 297
 298        if events.len() >= MAX_EVENT_COUNT {
 299            // These are halved instead of popping to improve prompt caching.
 300            events.drain(..MAX_EVENT_COUNT / 2);
 301        }
 302
 303        events.push_back(event);
 304    }
 305
 306    pub fn request_prediction(
 307        &mut self,
 308        project: &Entity<Project>,
 309        buffer: &Entity<Buffer>,
 310        position: language::Anchor,
 311        cx: &mut Context<Self>,
 312    ) -> Task<Result<Option<EditPrediction>>> {
 313        let project_state = self.projects.get(&project.entity_id());
 314
 315        let index_state = project_state.map(|state| {
 316            state
 317                .syntax_index
 318                .read_with(cx, |index, _cx| index.state().clone())
 319        });
 320        let excerpt_options = self.excerpt_options.clone();
 321        let snapshot = buffer.read(cx).snapshot();
 322        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
 323            return Task::ready(Err(anyhow!("No file path for excerpt")));
 324        };
 325        let client = self.client.clone();
 326        let llm_token = self.llm_token.clone();
 327        let app_version = AppVersion::global(cx);
 328        let worktree_snapshots = project
 329            .read(cx)
 330            .worktrees(cx)
 331            .map(|worktree| worktree.read(cx).snapshot())
 332            .collect::<Vec<_>>();
 333        let debug_tx = self.debug_tx.clone();
 334
 335        let request_task = cx.background_spawn({
 336            let snapshot = snapshot.clone();
 337            async move {
 338                let index_state = if let Some(index_state) = index_state {
 339                    Some(index_state.lock_owned().await)
 340                } else {
 341                    None
 342                };
 343
 344                let cursor_point = position.to_point(&snapshot);
 345
 346                let before_retrieval = chrono::Utc::now();
 347
 348                let Some(context) = EditPredictionContext::gather_context(
 349                    cursor_point,
 350                    &snapshot,
 351                    &excerpt_options,
 352                    index_state.as_deref(),
 353                ) else {
 354                    return Ok(None);
 355                };
 356
 357                let debug_context = if let Some(debug_tx) = debug_tx {
 358                    Some((debug_tx, context.clone()))
 359                } else {
 360                    None
 361                };
 362
 363                let request = make_cloud_request(
 364                    excerpt_path.clone(),
 365                    context,
 366                    // TODO pass everything
 367                    Vec::new(),
 368                    false,
 369                    Vec::new(),
 370                    None,
 371                    debug_context.is_some(),
 372                    &worktree_snapshots,
 373                    index_state.as_deref(),
 374                );
 375
 376                let retrieval_time = chrono::Utc::now() - before_retrieval;
 377                let response = Self::perform_request(client, llm_token, app_version, request).await;
 378
 379                if let Some((debug_tx, context)) = debug_context {
 380                    debug_tx
 381                        .unbounded_send(response.as_ref().map_err(|err| err.to_string()).and_then(
 382                            |response| {
 383                                let Some(request) =
 384                                    some_or_debug_panic(response.0.debug_info.clone())
 385                                else {
 386                                    return Err("Missing debug info".to_string());
 387                                };
 388                                Ok(PredictionDebugInfo {
 389                                    context,
 390                                    request,
 391                                    retrieval_time,
 392                                })
 393                            },
 394                        ))
 395                        .ok();
 396                }
 397
 398                anyhow::Ok(Some(response?))
 399            }
 400        });
 401
 402        let buffer = buffer.clone();
 403
 404        cx.spawn(async move |this, cx| {
 405            match request_task.await {
 406                Ok(Some((response, usage))) => {
 407                    log::debug!("predicted edits: {:?}", &response.edits);
 408
 409                    if let Some(usage) = usage {
 410                        this.update(cx, |this, cx| {
 411                            this.user_store.update(cx, |user_store, cx| {
 412                                user_store.update_edit_prediction_usage(usage, cx);
 413                            });
 414                        })
 415                        .ok();
 416                    }
 417
 418                    // TODO telemetry: duration, etc
 419
 420                    // TODO produce smaller edits by diffing against snapshot first
 421                    //
 422                    // Cloud returns entire snippets/excerpts ranges as they were included
 423                    // in the request, but we should display smaller edits to the user.
 424                    //
 425                    // We can do this by computing a diff of each one against the snapshot.
 426                    // Similar to zeta::Zeta::compute_edits, but per edit.
 427                    let edits = response
 428                        .edits
 429                        .into_iter()
 430                        .map(|edit| {
 431                            // TODO edits to different files
 432                            (
 433                                snapshot.anchor_before(edit.range.start)
 434                                    ..snapshot.anchor_before(edit.range.end),
 435                                edit.content,
 436                            )
 437                        })
 438                        .collect::<Vec<_>>()
 439                        .into();
 440
 441                    let Some((edits, snapshot, edit_preview_task)) =
 442                        buffer.read_with(cx, |buffer, cx| {
 443                            let new_snapshot = buffer.snapshot();
 444                            let edits: Arc<[_]> =
 445                                interpolate(&snapshot, &new_snapshot, edits)?.into();
 446                            Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
 447                        })?
 448                    else {
 449                        return Ok(None);
 450                    };
 451
 452                    Ok(Some(EditPrediction {
 453                        id: EditPredictionId(response.request_id),
 454                        edits,
 455                        snapshot,
 456                        edit_preview: edit_preview_task.await,
 457                    }))
 458                }
 459                Ok(None) => Ok(None),
 460                Err(err) => {
 461                    if err.is::<ZedUpdateRequiredError>() {
 462                        cx.update(|cx| {
 463                            this.update(cx, |this, _cx| {
 464                                this.update_required = true;
 465                            })
 466                            .ok();
 467
 468                            let error_message: SharedString = err.to_string().into();
 469                            show_app_notification(
 470                                NotificationId::unique::<ZedUpdateRequiredError>(),
 471                                cx,
 472                                move |cx| {
 473                                    cx.new(|cx| {
 474                                        ErrorMessagePrompt::new(error_message.clone(), cx)
 475                                            .with_link_button(
 476                                                "Update Zed",
 477                                                "https://zed.dev/releases",
 478                                            )
 479                                    })
 480                                },
 481                            );
 482                        })
 483                        .ok();
 484                    }
 485
 486                    Err(err)
 487                }
 488            }
 489        })
 490    }
 491
 492    async fn perform_request(
 493        client: Arc<Client>,
 494        llm_token: LlmApiToken,
 495        app_version: SemanticVersion,
 496        request: predict_edits_v3::PredictEditsRequest,
 497    ) -> Result<(
 498        predict_edits_v3::PredictEditsResponse,
 499        Option<EditPredictionUsage>,
 500    )> {
 501        let http_client = client.http_client();
 502        let mut token = llm_token.acquire(&client).await?;
 503        let mut did_retry = false;
 504
 505        loop {
 506            let request_builder = http_client::Request::builder().method(Method::POST);
 507            let request_builder =
 508                if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
 509                    request_builder.uri(predict_edits_url)
 510                } else {
 511                    request_builder.uri(
 512                        http_client
 513                            .build_zed_llm_url("/predict_edits/v3", &[])?
 514                            .as_ref(),
 515                    )
 516                };
 517            let request = request_builder
 518                .header("Content-Type", "application/json")
 519                .header("Authorization", format!("Bearer {}", token))
 520                .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 521                .body(serde_json::to_string(&request)?.into())?;
 522
 523            let mut response = http_client.send(request).await?;
 524
 525            if let Some(minimum_required_version) = response
 526                .headers()
 527                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
 528                .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
 529            {
 530                anyhow::ensure!(
 531                    app_version >= minimum_required_version,
 532                    ZedUpdateRequiredError {
 533                        minimum_version: minimum_required_version
 534                    }
 535                );
 536            }
 537
 538            if response.status().is_success() {
 539                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
 540
 541                let mut body = Vec::new();
 542                response.body_mut().read_to_end(&mut body).await?;
 543                return Ok((serde_json::from_slice(&body)?, usage));
 544            } else if !did_retry
 545                && response
 546                    .headers()
 547                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
 548                    .is_some()
 549            {
 550                did_retry = true;
 551                token = llm_token.refresh(&client).await?;
 552            } else {
 553                let mut body = String::new();
 554                response.body_mut().read_to_string(&mut body).await?;
 555                anyhow::bail!(
 556                    "error predicting edits.\nStatus: {:?}\nBody: {}",
 557                    response.status(),
 558                    body
 559                );
 560            }
 561        }
 562    }
 563}
 564
 565#[derive(Error, Debug)]
 566#[error(
 567    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
 568)]
 569pub struct ZedUpdateRequiredError {
 570    minimum_version: SemanticVersion,
 571}
 572
 573pub struct ZetaEditPredictionProvider {
 574    zeta: Entity<Zeta>,
 575    current_prediction: Option<CurrentEditPrediction>,
 576    next_pending_prediction_id: usize,
 577    pending_predictions: ArrayVec<PendingPrediction, 2>,
 578    last_request_timestamp: Instant,
 579}
 580
 581impl ZetaEditPredictionProvider {
 582    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
 583
 584    pub fn new(
 585        project: Option<&Entity<Project>>,
 586        client: &Arc<Client>,
 587        user_store: &Entity<UserStore>,
 588        cx: &mut App,
 589    ) -> Self {
 590        let zeta = Zeta::global(client, user_store, cx);
 591        if let Some(project) = project {
 592            zeta.update(cx, |zeta, cx| {
 593                zeta.register_project(project, cx);
 594            });
 595        }
 596
 597        Self {
 598            zeta,
 599            current_prediction: None,
 600            next_pending_prediction_id: 0,
 601            pending_predictions: ArrayVec::new(),
 602            last_request_timestamp: Instant::now(),
 603        }
 604    }
 605}
 606
 607#[derive(Clone)]
 608struct CurrentEditPrediction {
 609    buffer_id: EntityId,
 610    prediction: EditPrediction,
 611}
 612
 613impl CurrentEditPrediction {
 614    fn should_replace_prediction(&self, old_prediction: &Self, snapshot: &BufferSnapshot) -> bool {
 615        if self.buffer_id != old_prediction.buffer_id {
 616            return true;
 617        }
 618
 619        let Some(old_edits) = old_prediction.prediction.interpolate(snapshot) else {
 620            return true;
 621        };
 622        let Some(new_edits) = self.prediction.interpolate(snapshot) else {
 623            return false;
 624        };
 625
 626        if old_edits.len() == 1 && new_edits.len() == 1 {
 627            let (old_range, old_text) = &old_edits[0];
 628            let (new_range, new_text) = &new_edits[0];
 629            new_range == old_range && new_text.starts_with(old_text)
 630        } else {
 631            true
 632        }
 633    }
 634}
 635
 636#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
 637pub struct EditPredictionId(Uuid);
 638
 639impl From<EditPredictionId> for gpui::ElementId {
 640    fn from(value: EditPredictionId) -> Self {
 641        gpui::ElementId::Uuid(value.0)
 642    }
 643}
 644
 645impl std::fmt::Display for EditPredictionId {
 646    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 647        write!(f, "{}", self.0)
 648    }
 649}
 650
 651#[derive(Clone)]
 652pub struct EditPrediction {
 653    id: EditPredictionId,
 654    edits: Arc<[(Range<Anchor>, String)]>,
 655    snapshot: BufferSnapshot,
 656    edit_preview: EditPreview,
 657}
 658
 659impl EditPrediction {
 660    fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
 661        interpolate(&self.snapshot, new_snapshot, self.edits.clone())
 662    }
 663}
 664
 665struct PendingPrediction {
 666    id: usize,
 667    _task: Task<()>,
 668}
 669
 670impl EditPredictionProvider for ZetaEditPredictionProvider {
 671    fn name() -> &'static str {
 672        "zed-predict2"
 673    }
 674
 675    fn display_name() -> &'static str {
 676        "Zed's Edit Predictions 2"
 677    }
 678
 679    fn show_completions_in_menu() -> bool {
 680        true
 681    }
 682
 683    fn show_tab_accept_marker() -> bool {
 684        true
 685    }
 686
 687    fn data_collection_state(&self, _cx: &App) -> DataCollectionState {
 688        // TODO [zeta2]
 689        DataCollectionState::Unsupported
 690    }
 691
 692    fn toggle_data_collection(&mut self, _cx: &mut App) {
 693        // TODO [zeta2]
 694    }
 695
 696    fn usage(&self, cx: &App) -> Option<client::EditPredictionUsage> {
 697        self.zeta.read(cx).usage(cx)
 698    }
 699
 700    fn is_enabled(
 701        &self,
 702        _buffer: &Entity<language::Buffer>,
 703        _cursor_position: language::Anchor,
 704        _cx: &App,
 705    ) -> bool {
 706        true
 707    }
 708
 709    fn is_refreshing(&self) -> bool {
 710        !self.pending_predictions.is_empty()
 711    }
 712
 713    fn refresh(
 714        &mut self,
 715        project: Option<Entity<project::Project>>,
 716        buffer: Entity<language::Buffer>,
 717        cursor_position: language::Anchor,
 718        _debounce: bool,
 719        cx: &mut Context<Self>,
 720    ) {
 721        let Some(project) = project else {
 722            return;
 723        };
 724
 725        if self
 726            .zeta
 727            .read(cx)
 728            .user_store
 729            .read_with(cx, |user_store, _cx| {
 730                user_store.account_too_young() || user_store.has_overdue_invoices()
 731            })
 732        {
 733            return;
 734        }
 735
 736        if let Some(current_prediction) = self.current_prediction.as_ref() {
 737            let snapshot = buffer.read(cx).snapshot();
 738            if current_prediction
 739                .prediction
 740                .interpolate(&snapshot)
 741                .is_some()
 742            {
 743                return;
 744            }
 745        }
 746
 747        let pending_prediction_id = self.next_pending_prediction_id;
 748        self.next_pending_prediction_id += 1;
 749        let last_request_timestamp = self.last_request_timestamp;
 750
 751        let task = cx.spawn(async move |this, cx| {
 752            if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
 753                .checked_duration_since(Instant::now())
 754            {
 755                cx.background_executor().timer(timeout).await;
 756            }
 757
 758            let prediction_request = this.update(cx, |this, cx| {
 759                this.last_request_timestamp = Instant::now();
 760                this.zeta.update(cx, |zeta, cx| {
 761                    zeta.request_prediction(&project, &buffer, cursor_position, cx)
 762                })
 763            });
 764
 765            let prediction = match prediction_request {
 766                Ok(prediction_request) => {
 767                    let prediction_request = prediction_request.await;
 768                    prediction_request.map(|c| {
 769                        c.map(|prediction| CurrentEditPrediction {
 770                            buffer_id: buffer.entity_id(),
 771                            prediction,
 772                        })
 773                    })
 774                }
 775                Err(error) => Err(error),
 776            };
 777
 778            this.update(cx, |this, cx| {
 779                if this.pending_predictions[0].id == pending_prediction_id {
 780                    this.pending_predictions.remove(0);
 781                } else {
 782                    this.pending_predictions.clear();
 783                }
 784
 785                let Some(new_prediction) = prediction
 786                    .context("edit prediction failed")
 787                    .log_err()
 788                    .flatten()
 789                else {
 790                    cx.notify();
 791                    return;
 792                };
 793
 794                if let Some(old_prediction) = this.current_prediction.as_ref() {
 795                    let snapshot = buffer.read(cx).snapshot();
 796                    if new_prediction.should_replace_prediction(old_prediction, &snapshot) {
 797                        this.current_prediction = Some(new_prediction);
 798                    }
 799                } else {
 800                    this.current_prediction = Some(new_prediction);
 801                }
 802
 803                cx.notify();
 804            })
 805            .ok();
 806        });
 807
 808        // We always maintain at most two pending predictions. When we already
 809        // have two, we replace the newest one.
 810        if self.pending_predictions.len() <= 1 {
 811            self.pending_predictions.push(PendingPrediction {
 812                id: pending_prediction_id,
 813                _task: task,
 814            });
 815        } else if self.pending_predictions.len() == 2 {
 816            self.pending_predictions.pop();
 817            self.pending_predictions.push(PendingPrediction {
 818                id: pending_prediction_id,
 819                _task: task,
 820            });
 821        }
 822
 823        cx.notify();
 824    }
 825
 826    fn cycle(
 827        &mut self,
 828        _buffer: Entity<language::Buffer>,
 829        _cursor_position: language::Anchor,
 830        _direction: Direction,
 831        _cx: &mut Context<Self>,
 832    ) {
 833    }
 834
 835    fn accept(&mut self, _cx: &mut Context<Self>) {
 836        // TODO [zeta2] report accept
 837        self.current_prediction.take();
 838        self.pending_predictions.clear();
 839    }
 840
 841    fn discard(&mut self, _cx: &mut Context<Self>) {
 842        self.pending_predictions.clear();
 843        self.current_prediction.take();
 844    }
 845
 846    fn suggest(
 847        &mut self,
 848        buffer: &Entity<language::Buffer>,
 849        cursor_position: language::Anchor,
 850        cx: &mut Context<Self>,
 851    ) -> Option<edit_prediction::EditPrediction> {
 852        let CurrentEditPrediction {
 853            buffer_id,
 854            prediction,
 855            ..
 856        } = self.current_prediction.as_mut()?;
 857
 858        // Invalidate previous prediction if it was generated for a different buffer.
 859        if *buffer_id != buffer.entity_id() {
 860            self.current_prediction.take();
 861            return None;
 862        }
 863
 864        let buffer = buffer.read(cx);
 865        let Some(edits) = prediction.interpolate(&buffer.snapshot()) else {
 866            self.current_prediction.take();
 867            return None;
 868        };
 869
 870        let cursor_row = cursor_position.to_point(buffer).row;
 871        let (closest_edit_ix, (closest_edit_range, _)) =
 872            edits.iter().enumerate().min_by_key(|(_, (range, _))| {
 873                let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
 874                let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
 875                cmp::min(distance_from_start, distance_from_end)
 876            })?;
 877
 878        let mut edit_start_ix = closest_edit_ix;
 879        for (range, _) in edits[..edit_start_ix].iter().rev() {
 880            let distance_from_closest_edit =
 881                closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
 882            if distance_from_closest_edit <= 1 {
 883                edit_start_ix -= 1;
 884            } else {
 885                break;
 886            }
 887        }
 888
 889        let mut edit_end_ix = closest_edit_ix + 1;
 890        for (range, _) in &edits[edit_end_ix..] {
 891            let distance_from_closest_edit =
 892                range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
 893            if distance_from_closest_edit <= 1 {
 894                edit_end_ix += 1;
 895            } else {
 896                break;
 897            }
 898        }
 899
 900        Some(edit_prediction::EditPrediction {
 901            id: Some(prediction.id.to_string().into()),
 902            edits: edits[edit_start_ix..edit_end_ix].to_vec(),
 903            edit_preview: Some(prediction.edit_preview.clone()),
 904        })
 905    }
 906}
 907
 908fn make_cloud_request(
 909    excerpt_path: PathBuf,
 910    context: EditPredictionContext,
 911    events: Vec<predict_edits_v3::Event>,
 912    can_collect_data: bool,
 913    diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
 914    git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
 915    debug_info: bool,
 916    worktrees: &Vec<worktree::Snapshot>,
 917    index_state: Option<&SyntaxIndexState>,
 918) -> predict_edits_v3::PredictEditsRequest {
 919    let mut signatures = Vec::new();
 920    let mut declaration_to_signature_index = HashMap::default();
 921    let mut referenced_declarations = Vec::new();
 922
 923    for snippet in context.snippets {
 924        let project_entry_id = snippet.declaration.project_entry_id();
 925        // TODO: Use full paths (worktree rooted) - need to move full_path method to the snapshot.
 926        // Note that currently full_path is currently being used for excerpt_path.
 927        let Some(path) = worktrees.iter().find_map(|worktree| {
 928            let abs_path = worktree.abs_path();
 929            worktree
 930                .entry_for_id(project_entry_id)
 931                .map(|e| abs_path.join(&e.path))
 932        }) else {
 933            continue;
 934        };
 935
 936        let parent_index = index_state.and_then(|index_state| {
 937            snippet.declaration.parent().and_then(|parent| {
 938                add_signature(
 939                    parent,
 940                    &mut declaration_to_signature_index,
 941                    &mut signatures,
 942                    index_state,
 943                )
 944            })
 945        });
 946
 947        let (text, text_is_truncated) = snippet.declaration.item_text();
 948        referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
 949            path,
 950            text: text.into(),
 951            range: snippet.declaration.item_range(),
 952            text_is_truncated,
 953            signature_range: snippet.declaration.signature_range_in_item_text(),
 954            parent_index,
 955            score_components: snippet.score_components,
 956            signature_score: snippet.scores.signature,
 957            declaration_score: snippet.scores.declaration,
 958        });
 959    }
 960
 961    let excerpt_parent = index_state.and_then(|index_state| {
 962        context
 963            .excerpt
 964            .parent_declarations
 965            .last()
 966            .and_then(|(parent, _)| {
 967                add_signature(
 968                    *parent,
 969                    &mut declaration_to_signature_index,
 970                    &mut signatures,
 971                    index_state,
 972                )
 973            })
 974    });
 975
 976    predict_edits_v3::PredictEditsRequest {
 977        excerpt_path,
 978        excerpt: context.excerpt_text.body,
 979        excerpt_range: context.excerpt.range,
 980        cursor_offset: context.cursor_offset_in_excerpt,
 981        referenced_declarations,
 982        signatures,
 983        excerpt_parent,
 984        // todo!
 985        events,
 986        can_collect_data,
 987        diagnostic_groups,
 988        git_info,
 989        debug_info,
 990    }
 991}
 992
 993fn add_signature(
 994    declaration_id: DeclarationId,
 995    declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
 996    signatures: &mut Vec<Signature>,
 997    index: &SyntaxIndexState,
 998) -> Option<usize> {
 999    if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
1000        return Some(*signature_index);
1001    }
1002    let Some(parent_declaration) = index.declaration(declaration_id) else {
1003        log::error!("bug: missing parent declaration");
1004        return None;
1005    };
1006    let parent_index = parent_declaration.parent().and_then(|parent| {
1007        add_signature(parent, declaration_to_signature_index, signatures, index)
1008    });
1009    let (text, text_is_truncated) = parent_declaration.signature_text();
1010    let signature_index = signatures.len();
1011    signatures.push(Signature {
1012        text: text.into(),
1013        text_is_truncated,
1014        parent_index,
1015    });
1016    declaration_to_signature_index.insert(declaration_id, signature_index);
1017    Some(signature_index)
1018}
1019
1020fn interpolate(
1021    old_snapshot: &BufferSnapshot,
1022    new_snapshot: &BufferSnapshot,
1023    current_edits: Arc<[(Range<Anchor>, String)]>,
1024) -> Option<Vec<(Range<Anchor>, String)>> {
1025    let mut edits = Vec::new();
1026
1027    let mut model_edits = current_edits.iter().peekable();
1028    for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
1029        while let Some((model_old_range, _)) = model_edits.peek() {
1030            let model_old_range = model_old_range.to_offset(old_snapshot);
1031            if model_old_range.end < user_edit.old.start {
1032                let (model_old_range, model_new_text) = model_edits.next().unwrap();
1033                edits.push((model_old_range.clone(), model_new_text.clone()));
1034            } else {
1035                break;
1036            }
1037        }
1038
1039        if let Some((model_old_range, model_new_text)) = model_edits.peek() {
1040            let model_old_offset_range = model_old_range.to_offset(old_snapshot);
1041            if user_edit.old == model_old_offset_range {
1042                let user_new_text = new_snapshot
1043                    .text_for_range(user_edit.new.clone())
1044                    .collect::<String>();
1045
1046                if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
1047                    if !model_suffix.is_empty() {
1048                        let anchor = old_snapshot.anchor_after(user_edit.old.end);
1049                        edits.push((anchor..anchor, model_suffix.to_string()));
1050                    }
1051
1052                    model_edits.next();
1053                    continue;
1054                }
1055            }
1056        }
1057
1058        return None;
1059    }
1060
1061    edits.extend(model_edits.cloned());
1062
1063    if edits.is_empty() { None } else { Some(edits) }
1064}
1065
1066#[cfg(test)]
1067mod tests {
1068    use super::*;
1069    use gpui::TestAppContext;
1070    use language::ToOffset as _;
1071
1072    #[gpui::test]
1073    async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1074        let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1075        let edits: Arc<[(Range<Anchor>, String)]> = cx.update(|cx| {
1076            to_prediction_edits(
1077                [(2..5, "REM".to_string()), (9..11, "".to_string())],
1078                &buffer,
1079                cx,
1080            )
1081            .into()
1082        });
1083
1084        let edit_preview = cx
1085            .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1086            .await;
1087
1088        let prediction = EditPrediction {
1089            id: EditPredictionId(Uuid::new_v4()),
1090            edits,
1091            snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1092            edit_preview,
1093        };
1094
1095        cx.update(|cx| {
1096            assert_eq!(
1097                from_prediction_edits(
1098                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1099                    &buffer,
1100                    cx
1101                ),
1102                vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1103            );
1104
1105            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1106            assert_eq!(
1107                from_prediction_edits(
1108                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1109                    &buffer,
1110                    cx
1111                ),
1112                vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
1113            );
1114
1115            buffer.update(cx, |buffer, cx| buffer.undo(cx));
1116            assert_eq!(
1117                from_prediction_edits(
1118                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1119                    &buffer,
1120                    cx
1121                ),
1122                vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1123            );
1124
1125            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1126            assert_eq!(
1127                from_prediction_edits(
1128                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1129                    &buffer,
1130                    cx
1131                ),
1132                vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
1133            );
1134
1135            buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1136            assert_eq!(
1137                from_prediction_edits(
1138                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1139                    &buffer,
1140                    cx
1141                ),
1142                vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1143            );
1144
1145            buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1146            assert_eq!(
1147                from_prediction_edits(
1148                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1149                    &buffer,
1150                    cx
1151                ),
1152                vec![(9..11, "".to_string())]
1153            );
1154
1155            buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1156            assert_eq!(
1157                from_prediction_edits(
1158                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1159                    &buffer,
1160                    cx
1161                ),
1162                vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1163            );
1164
1165            buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1166            assert_eq!(
1167                from_prediction_edits(
1168                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1169                    &buffer,
1170                    cx
1171                ),
1172                vec![(4..4, "M".to_string())]
1173            );
1174
1175            buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1176            assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
1177        })
1178    }
1179
1180    fn to_prediction_edits(
1181        iterator: impl IntoIterator<Item = (Range<usize>, String)>,
1182        buffer: &Entity<Buffer>,
1183        cx: &App,
1184    ) -> Vec<(Range<Anchor>, String)> {
1185        let buffer = buffer.read(cx);
1186        iterator
1187            .into_iter()
1188            .map(|(range, text)| {
1189                (
1190                    buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
1191                    text,
1192                )
1193            })
1194            .collect()
1195    }
1196
1197    fn from_prediction_edits(
1198        editor_edits: &[(Range<Anchor>, String)],
1199        buffer: &Entity<Buffer>,
1200        cx: &App,
1201    ) -> Vec<(Range<usize>, String)> {
1202        let buffer = buffer.read(cx);
1203        editor_edits
1204            .iter()
1205            .map(|(range, text)| {
1206                (
1207                    range.start.to_offset(buffer)..range.end.to_offset(buffer),
1208                    text.clone(),
1209                )
1210            })
1211            .collect()
1212    }
1213}