zeta2.rs

   1use anyhow::{Context as _, Result, anyhow};
   2use chrono::TimeDelta;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
   5use cloud_llm_client::{
   6    AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
   7    ZED_VERSION_HEADER_NAME,
   8};
   9use cloud_zeta2_prompt::{DEFAULT_MAX_PROMPT_BYTES, build_prompt};
  10use collections::HashMap;
  11use edit_prediction_context::{
  12    DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
  13    EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionScoreOptions, Line,
  14    SyntaxIndex, SyntaxIndexState,
  15};
  16use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
  17use futures::AsyncReadExt as _;
  18use futures::channel::{mpsc, oneshot};
  19use gpui::http_client::{AsyncBody, Method};
  20use gpui::{
  21    App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
  22    http_client, prelude::*,
  23};
  24use language::{Anchor, Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
  25use language::{BufferSnapshot, OffsetRangeExt};
  26use language_model::{LlmApiToken, RefreshLlmTokenListener};
  27use project::Project;
  28use release_channel::AppVersion;
  29use serde::de::DeserializeOwned;
  30use std::collections::{VecDeque, hash_map};
  31use std::fmt::Write;
  32use std::ops::Range;
  33use std::path::Path;
  34use std::str::FromStr as _;
  35use std::sync::Arc;
  36use std::time::{Duration, Instant};
  37use thiserror::Error;
  38use util::ResultExt as _;
  39use util::rel_path::RelPathBuf;
  40use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  41
  42pub mod merge_excerpts;
  43mod prediction;
  44mod provider;
  45pub mod related_excerpts;
  46
  47use crate::merge_excerpts::merge_excerpts;
  48use crate::prediction::EditPrediction;
  49use crate::related_excerpts::find_related_excerpts;
  50pub use crate::related_excerpts::{LlmContextOptions, SearchToolQuery};
  51pub use provider::ZetaEditPredictionProvider;
  52
  53/// Maximum number of events to track.
  54const MAX_EVENT_COUNT: usize = 16;
  55
  56pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
  57    max_bytes: 512,
  58    min_bytes: 128,
  59    target_before_cursor_over_total_bytes: 0.5,
  60};
  61
  62pub const DEFAULT_CONTEXT_OPTIONS: ContextMode = ContextMode::Llm(DEFAULT_LLM_CONTEXT_OPTIONS);
  63
  64pub const DEFAULT_LLM_CONTEXT_OPTIONS: LlmContextOptions = LlmContextOptions {
  65    excerpt: DEFAULT_EXCERPT_OPTIONS,
  66};
  67
  68pub const DEFAULT_SYNTAX_CONTEXT_OPTIONS: EditPredictionContextOptions =
  69    EditPredictionContextOptions {
  70        use_imports: true,
  71        max_retrieved_declarations: 0,
  72        excerpt: DEFAULT_EXCERPT_OPTIONS,
  73        score: EditPredictionScoreOptions {
  74            omit_excerpt_overlaps: true,
  75        },
  76    };
  77
  78pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
  79    context: DEFAULT_CONTEXT_OPTIONS,
  80    max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
  81    max_diagnostic_bytes: 2048,
  82    prompt_format: PromptFormat::DEFAULT,
  83    file_indexing_parallelism: 1,
  84    buffer_change_grouping_interval: Duration::from_secs(1),
  85};
  86
  87pub struct Zeta2FeatureFlag;
  88
  89impl FeatureFlag for Zeta2FeatureFlag {
  90    const NAME: &'static str = "zeta2";
  91
  92    fn enabled_for_staff() -> bool {
  93        false
  94    }
  95}
  96
  97#[derive(Clone)]
  98struct ZetaGlobal(Entity<Zeta>);
  99
 100impl Global for ZetaGlobal {}
 101
 102pub struct Zeta {
 103    client: Arc<Client>,
 104    user_store: Entity<UserStore>,
 105    llm_token: LlmApiToken,
 106    _llm_token_subscription: Subscription,
 107    projects: HashMap<EntityId, ZetaProject>,
 108    options: ZetaOptions,
 109    update_required: bool,
 110    debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
 111}
 112
 113#[derive(Debug, Clone, PartialEq)]
 114pub struct ZetaOptions {
 115    pub context: ContextMode,
 116    pub max_prompt_bytes: usize,
 117    pub max_diagnostic_bytes: usize,
 118    pub prompt_format: predict_edits_v3::PromptFormat,
 119    pub file_indexing_parallelism: usize,
 120    pub buffer_change_grouping_interval: Duration,
 121}
 122
 123#[derive(Debug, Clone, PartialEq)]
 124pub enum ContextMode {
 125    Llm(LlmContextOptions),
 126    Syntax(EditPredictionContextOptions),
 127}
 128
 129impl ContextMode {
 130    pub fn excerpt(&self) -> &EditPredictionExcerptOptions {
 131        match self {
 132            ContextMode::Llm(options) => &options.excerpt,
 133            ContextMode::Syntax(options) => &options.excerpt,
 134        }
 135    }
 136}
 137
 138pub enum ZetaDebugInfo {
 139    ContextRetrievalStarted(ZetaContextRetrievalStartedDebugInfo),
 140    SearchQueriesGenerated(ZetaSearchQueryDebugInfo),
 141    SearchQueriesExecuted(ZetaContextRetrievalDebugInfo),
 142    SearchResultsFiltered(ZetaContextRetrievalDebugInfo),
 143    ContextRetrievalFinished(ZetaContextRetrievalDebugInfo),
 144    EditPredicted(ZetaEditPredictionDebugInfo),
 145}
 146
 147pub struct ZetaContextRetrievalStartedDebugInfo {
 148    pub project: Entity<Project>,
 149    pub timestamp: Instant,
 150    pub search_prompt: String,
 151}
 152
 153pub struct ZetaContextRetrievalDebugInfo {
 154    pub project: Entity<Project>,
 155    pub timestamp: Instant,
 156}
 157
 158pub struct ZetaEditPredictionDebugInfo {
 159    pub request: predict_edits_v3::PredictEditsRequest,
 160    pub retrieval_time: TimeDelta,
 161    pub buffer: WeakEntity<Buffer>,
 162    pub position: language::Anchor,
 163    pub local_prompt: Result<String, String>,
 164    pub response_rx: oneshot::Receiver<Result<predict_edits_v3::PredictEditsResponse, String>>,
 165}
 166
 167pub struct ZetaSearchQueryDebugInfo {
 168    pub project: Entity<Project>,
 169    pub timestamp: Instant,
 170    pub queries: Vec<SearchToolQuery>,
 171}
 172
 173pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
 174
 175struct ZetaProject {
 176    syntax_index: Entity<SyntaxIndex>,
 177    events: VecDeque<Event>,
 178    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 179    current_prediction: Option<CurrentEditPrediction>,
 180    context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
 181    refresh_context_task: Option<Task<Option<()>>>,
 182    refresh_context_debounce_task: Option<Task<Option<()>>>,
 183    refresh_context_timestamp: Option<Instant>,
 184}
 185
 186#[derive(Debug, Clone)]
 187struct CurrentEditPrediction {
 188    pub requested_by_buffer_id: EntityId,
 189    pub prediction: EditPrediction,
 190}
 191
 192impl CurrentEditPrediction {
 193    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 194        let Some(new_edits) = self
 195            .prediction
 196            .interpolate(&self.prediction.buffer.read(cx))
 197        else {
 198            return false;
 199        };
 200
 201        if self.prediction.buffer != old_prediction.prediction.buffer {
 202            return true;
 203        }
 204
 205        let Some(old_edits) = old_prediction
 206            .prediction
 207            .interpolate(&old_prediction.prediction.buffer.read(cx))
 208        else {
 209            return true;
 210        };
 211
 212        // This reduces the occurrence of UI thrash from replacing edits
 213        //
 214        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 215        if self.requested_by_buffer_id == self.prediction.buffer.entity_id()
 216            && self.requested_by_buffer_id == old_prediction.prediction.buffer.entity_id()
 217            && old_edits.len() == 1
 218            && new_edits.len() == 1
 219        {
 220            let (old_range, old_text) = &old_edits[0];
 221            let (new_range, new_text) = &new_edits[0];
 222            new_range == old_range && new_text.starts_with(old_text)
 223        } else {
 224            true
 225        }
 226    }
 227}
 228
 229/// A prediction from the perspective of a buffer.
 230#[derive(Debug)]
 231enum BufferEditPrediction<'a> {
 232    Local { prediction: &'a EditPrediction },
 233    Jump { prediction: &'a EditPrediction },
 234}
 235
 236struct RegisteredBuffer {
 237    snapshot: BufferSnapshot,
 238    _subscriptions: [gpui::Subscription; 2],
 239}
 240
 241#[derive(Clone)]
 242pub enum Event {
 243    BufferChange {
 244        old_snapshot: BufferSnapshot,
 245        new_snapshot: BufferSnapshot,
 246        timestamp: Instant,
 247    },
 248}
 249
 250impl Event {
 251    pub fn to_request_event(&self, cx: &App) -> Option<predict_edits_v3::Event> {
 252        match self {
 253            Event::BufferChange {
 254                old_snapshot,
 255                new_snapshot,
 256                ..
 257            } => {
 258                let path = new_snapshot.file().map(|f| f.full_path(cx));
 259
 260                let old_path = old_snapshot.file().and_then(|f| {
 261                    let old_path = f.full_path(cx);
 262                    if Some(&old_path) != path.as_ref() {
 263                        Some(old_path)
 264                    } else {
 265                        None
 266                    }
 267                });
 268
 269                // TODO [zeta2] move to bg?
 270                let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
 271
 272                if path == old_path && diff.is_empty() {
 273                    None
 274                } else {
 275                    Some(predict_edits_v3::Event::BufferChange {
 276                        old_path,
 277                        path,
 278                        diff,
 279                        //todo: Actually detect if this edit was predicted or not
 280                        predicted: false,
 281                    })
 282                }
 283            }
 284        }
 285    }
 286}
 287
 288impl Zeta {
 289    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 290        cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
 291    }
 292
 293    pub fn global(
 294        client: &Arc<Client>,
 295        user_store: &Entity<UserStore>,
 296        cx: &mut App,
 297    ) -> Entity<Self> {
 298        cx.try_global::<ZetaGlobal>()
 299            .map(|global| global.0.clone())
 300            .unwrap_or_else(|| {
 301                let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 302                cx.set_global(ZetaGlobal(zeta.clone()));
 303                zeta
 304            })
 305    }
 306
 307    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 308        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 309
 310        Self {
 311            projects: HashMap::default(),
 312            client,
 313            user_store,
 314            options: DEFAULT_OPTIONS,
 315            llm_token: LlmApiToken::default(),
 316            _llm_token_subscription: cx.subscribe(
 317                &refresh_llm_token_listener,
 318                |this, _listener, _event, cx| {
 319                    let client = this.client.clone();
 320                    let llm_token = this.llm_token.clone();
 321                    cx.spawn(async move |_this, _cx| {
 322                        llm_token.refresh(&client).await?;
 323                        anyhow::Ok(())
 324                    })
 325                    .detach_and_log_err(cx);
 326                },
 327            ),
 328            update_required: false,
 329            debug_tx: None,
 330        }
 331    }
 332
 333    pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<ZetaDebugInfo> {
 334        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
 335        self.debug_tx = Some(debug_watch_tx);
 336        debug_watch_rx
 337    }
 338
 339    pub fn options(&self) -> &ZetaOptions {
 340        &self.options
 341    }
 342
 343    pub fn set_options(&mut self, options: ZetaOptions) {
 344        self.options = options;
 345    }
 346
 347    pub fn clear_history(&mut self) {
 348        for zeta_project in self.projects.values_mut() {
 349            zeta_project.events.clear();
 350        }
 351    }
 352
 353    pub fn history_for_project(&self, project: &Entity<Project>) -> impl Iterator<Item = &Event> {
 354        self.projects
 355            .get(&project.entity_id())
 356            .map(|project| project.events.iter())
 357            .into_iter()
 358            .flatten()
 359    }
 360
 361    pub fn context_for_project(
 362        &self,
 363        project: &Entity<Project>,
 364    ) -> impl Iterator<Item = (Entity<Buffer>, &[Range<Anchor>])> {
 365        self.projects
 366            .get(&project.entity_id())
 367            .and_then(|project| {
 368                Some(
 369                    project
 370                        .context
 371                        .as_ref()?
 372                        .iter()
 373                        .map(|(buffer, ranges)| (buffer.clone(), ranges.as_slice())),
 374                )
 375            })
 376            .into_iter()
 377            .flatten()
 378    }
 379
 380    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 381        self.user_store.read(cx).edit_prediction_usage()
 382    }
 383
 384    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
 385        self.get_or_init_zeta_project(project, cx);
 386    }
 387
 388    pub fn register_buffer(
 389        &mut self,
 390        buffer: &Entity<Buffer>,
 391        project: &Entity<Project>,
 392        cx: &mut Context<Self>,
 393    ) {
 394        let zeta_project = self.get_or_init_zeta_project(project, cx);
 395        Self::register_buffer_impl(zeta_project, buffer, project, cx);
 396    }
 397
 398    fn get_or_init_zeta_project(
 399        &mut self,
 400        project: &Entity<Project>,
 401        cx: &mut App,
 402    ) -> &mut ZetaProject {
 403        self.projects
 404            .entry(project.entity_id())
 405            .or_insert_with(|| ZetaProject {
 406                syntax_index: cx.new(|cx| {
 407                    SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
 408                }),
 409                events: VecDeque::new(),
 410                registered_buffers: HashMap::default(),
 411                current_prediction: None,
 412                context: None,
 413                refresh_context_task: None,
 414                refresh_context_debounce_task: None,
 415                refresh_context_timestamp: None,
 416            })
 417    }
 418
 419    fn register_buffer_impl<'a>(
 420        zeta_project: &'a mut ZetaProject,
 421        buffer: &Entity<Buffer>,
 422        project: &Entity<Project>,
 423        cx: &mut Context<Self>,
 424    ) -> &'a mut RegisteredBuffer {
 425        let buffer_id = buffer.entity_id();
 426        match zeta_project.registered_buffers.entry(buffer_id) {
 427            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 428            hash_map::Entry::Vacant(entry) => {
 429                let snapshot = buffer.read(cx).snapshot();
 430                let project_entity_id = project.entity_id();
 431                entry.insert(RegisteredBuffer {
 432                    snapshot,
 433                    _subscriptions: [
 434                        cx.subscribe(buffer, {
 435                            let project = project.downgrade();
 436                            move |this, buffer, event, cx| {
 437                                if let language::BufferEvent::Edited = event
 438                                    && let Some(project) = project.upgrade()
 439                                {
 440                                    this.report_changes_for_buffer(&buffer, &project, cx);
 441                                }
 442                            }
 443                        }),
 444                        cx.observe_release(buffer, move |this, _buffer, _cx| {
 445                            let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
 446                            else {
 447                                return;
 448                            };
 449                            zeta_project.registered_buffers.remove(&buffer_id);
 450                        }),
 451                    ],
 452                })
 453            }
 454        }
 455    }
 456
 457    fn report_changes_for_buffer(
 458        &mut self,
 459        buffer: &Entity<Buffer>,
 460        project: &Entity<Project>,
 461        cx: &mut Context<Self>,
 462    ) -> BufferSnapshot {
 463        let buffer_change_grouping_interval = self.options.buffer_change_grouping_interval;
 464        let zeta_project = self.get_or_init_zeta_project(project, cx);
 465        let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
 466
 467        let new_snapshot = buffer.read(cx).snapshot();
 468        if new_snapshot.version != registered_buffer.snapshot.version {
 469            let old_snapshot =
 470                std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
 471            Self::push_event(
 472                zeta_project,
 473                buffer_change_grouping_interval,
 474                Event::BufferChange {
 475                    old_snapshot,
 476                    new_snapshot: new_snapshot.clone(),
 477                    timestamp: Instant::now(),
 478                },
 479            );
 480        }
 481
 482        new_snapshot
 483    }
 484
 485    fn push_event(
 486        zeta_project: &mut ZetaProject,
 487        buffer_change_grouping_interval: Duration,
 488        event: Event,
 489    ) {
 490        let events = &mut zeta_project.events;
 491
 492        if buffer_change_grouping_interval > Duration::ZERO
 493            && let Some(Event::BufferChange {
 494                new_snapshot: last_new_snapshot,
 495                timestamp: last_timestamp,
 496                ..
 497            }) = events.back_mut()
 498        {
 499            // Coalesce edits for the same buffer when they happen one after the other.
 500            let Event::BufferChange {
 501                old_snapshot,
 502                new_snapshot,
 503                timestamp,
 504            } = &event;
 505
 506            if timestamp.duration_since(*last_timestamp) <= buffer_change_grouping_interval
 507                && old_snapshot.remote_id() == last_new_snapshot.remote_id()
 508                && old_snapshot.version == last_new_snapshot.version
 509            {
 510                *last_new_snapshot = new_snapshot.clone();
 511                *last_timestamp = *timestamp;
 512                return;
 513            }
 514        }
 515
 516        if events.len() >= MAX_EVENT_COUNT {
 517            // These are halved instead of popping to improve prompt caching.
 518            events.drain(..MAX_EVENT_COUNT / 2);
 519        }
 520
 521        events.push_back(event);
 522    }
 523
 524    fn current_prediction_for_buffer(
 525        &self,
 526        buffer: &Entity<Buffer>,
 527        project: &Entity<Project>,
 528        cx: &App,
 529    ) -> Option<BufferEditPrediction<'_>> {
 530        let project_state = self.projects.get(&project.entity_id())?;
 531
 532        let CurrentEditPrediction {
 533            requested_by_buffer_id,
 534            prediction,
 535        } = project_state.current_prediction.as_ref()?;
 536
 537        if prediction.targets_buffer(buffer.read(cx), cx) {
 538            Some(BufferEditPrediction::Local { prediction })
 539        } else if *requested_by_buffer_id == buffer.entity_id() {
 540            Some(BufferEditPrediction::Jump { prediction })
 541        } else {
 542            None
 543        }
 544    }
 545
 546    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 547        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
 548            return;
 549        };
 550
 551        let Some(prediction) = project_state.current_prediction.take() else {
 552            return;
 553        };
 554        let request_id = prediction.prediction.id.into();
 555
 556        let client = self.client.clone();
 557        let llm_token = self.llm_token.clone();
 558        let app_version = AppVersion::global(cx);
 559        cx.spawn(async move |this, cx| {
 560            let url = if let Ok(predict_edits_url) = std::env::var("ZED_ACCEPT_PREDICTION_URL") {
 561                http_client::Url::parse(&predict_edits_url)?
 562            } else {
 563                client
 564                    .http_client()
 565                    .build_zed_llm_url("/predict_edits/accept", &[])?
 566            };
 567
 568            let response = cx
 569                .background_spawn(Self::send_api_request::<()>(
 570                    move |builder| {
 571                        let req = builder.uri(url.as_ref()).body(
 572                            serde_json::to_string(&AcceptEditPredictionBody { request_id })?.into(),
 573                        );
 574                        Ok(req?)
 575                    },
 576                    client,
 577                    llm_token,
 578                    app_version,
 579                ))
 580                .await;
 581
 582            Self::handle_api_response(&this, response, cx)?;
 583            anyhow::Ok(())
 584        })
 585        .detach_and_log_err(cx);
 586    }
 587
 588    fn discard_current_prediction(&mut self, project: &Entity<Project>) {
 589        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 590            project_state.current_prediction.take();
 591        };
 592    }
 593
 594    pub fn refresh_prediction(
 595        &mut self,
 596        project: &Entity<Project>,
 597        buffer: &Entity<Buffer>,
 598        position: language::Anchor,
 599        cx: &mut Context<Self>,
 600    ) -> Task<Result<()>> {
 601        let request_task = self.request_prediction(project, buffer, position, cx);
 602        let buffer = buffer.clone();
 603        let project = project.clone();
 604
 605        cx.spawn(async move |this, cx| {
 606            if let Some(prediction) = request_task.await? {
 607                this.update(cx, |this, cx| {
 608                    let project_state = this
 609                        .projects
 610                        .get_mut(&project.entity_id())
 611                        .context("Project not found")?;
 612
 613                    let new_prediction = CurrentEditPrediction {
 614                        requested_by_buffer_id: buffer.entity_id(),
 615                        prediction: prediction,
 616                    };
 617
 618                    if project_state
 619                        .current_prediction
 620                        .as_ref()
 621                        .is_none_or(|old_prediction| {
 622                            new_prediction.should_replace_prediction(&old_prediction, cx)
 623                        })
 624                    {
 625                        project_state.current_prediction = Some(new_prediction);
 626                    }
 627                    anyhow::Ok(())
 628                })??;
 629            }
 630            Ok(())
 631        })
 632    }
 633
 634    pub fn request_prediction(
 635        &mut self,
 636        project: &Entity<Project>,
 637        buffer: &Entity<Buffer>,
 638        position: language::Anchor,
 639        cx: &mut Context<Self>,
 640    ) -> Task<Result<Option<EditPrediction>>> {
 641        let project_state = self.projects.get(&project.entity_id());
 642
 643        let index_state = project_state.map(|state| {
 644            state
 645                .syntax_index
 646                .read_with(cx, |index, _cx| index.state().clone())
 647        });
 648        let options = self.options.clone();
 649        let snapshot = buffer.read(cx).snapshot();
 650        let Some(excerpt_path) = snapshot
 651            .file()
 652            .map(|path| -> Arc<Path> { path.full_path(cx).into() })
 653        else {
 654            return Task::ready(Err(anyhow!("No file path for excerpt")));
 655        };
 656        let client = self.client.clone();
 657        let llm_token = self.llm_token.clone();
 658        let app_version = AppVersion::global(cx);
 659        let worktree_snapshots = project
 660            .read(cx)
 661            .worktrees(cx)
 662            .map(|worktree| worktree.read(cx).snapshot())
 663            .collect::<Vec<_>>();
 664        let debug_tx = self.debug_tx.clone();
 665
 666        let events = project_state
 667            .map(|state| {
 668                state
 669                    .events
 670                    .iter()
 671                    .filter_map(|event| event.to_request_event(cx))
 672                    .collect::<Vec<_>>()
 673            })
 674            .unwrap_or_default();
 675
 676        let diagnostics = snapshot.diagnostic_sets().clone();
 677
 678        let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
 679            let mut path = f.worktree.read(cx).absolutize(&f.path);
 680            if path.pop() { Some(path) } else { None }
 681        });
 682
 683        // TODO data collection
 684        let can_collect_data = cx.is_staff();
 685
 686        let mut included_files = project_state
 687            .and_then(|project_state| project_state.context.as_ref())
 688            .unwrap_or(&HashMap::default())
 689            .iter()
 690            .filter_map(|(buffer, ranges)| {
 691                let buffer = buffer.read(cx);
 692                Some((
 693                    buffer.snapshot(),
 694                    buffer.file()?.full_path(cx).into(),
 695                    ranges.clone(),
 696                ))
 697            })
 698            .collect::<Vec<_>>();
 699
 700        let request_task = cx.background_spawn({
 701            let snapshot = snapshot.clone();
 702            let buffer = buffer.clone();
 703            async move {
 704                let index_state = if let Some(index_state) = index_state {
 705                    Some(index_state.lock_owned().await)
 706                } else {
 707                    None
 708                };
 709
 710                let cursor_offset = position.to_offset(&snapshot);
 711                let cursor_point = cursor_offset.to_point(&snapshot);
 712
 713                let before_retrieval = chrono::Utc::now();
 714
 715                let (diagnostic_groups, diagnostic_groups_truncated) =
 716                    Self::gather_nearby_diagnostics(
 717                        cursor_offset,
 718                        &diagnostics,
 719                        &snapshot,
 720                        options.max_diagnostic_bytes,
 721                    );
 722
 723                let request = match options.context {
 724                    ContextMode::Llm(context_options) => {
 725                        let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
 726                            cursor_point,
 727                            &snapshot,
 728                            &context_options.excerpt,
 729                            index_state.as_deref(),
 730                        ) else {
 731                            return Ok((None, None));
 732                        };
 733
 734                        let excerpt_anchor_range = snapshot.anchor_after(excerpt.range.start)
 735                            ..snapshot.anchor_before(excerpt.range.end);
 736
 737                        if let Some(buffer_ix) = included_files
 738                            .iter()
 739                            .position(|(buffer, _, _)| buffer.remote_id() == snapshot.remote_id())
 740                        {
 741                            let (buffer, _, ranges) = &mut included_files[buffer_ix];
 742                            let range_ix = ranges
 743                                .binary_search_by(|probe| {
 744                                    probe
 745                                        .start
 746                                        .cmp(&excerpt_anchor_range.start, buffer)
 747                                        .then(excerpt_anchor_range.end.cmp(&probe.end, buffer))
 748                                })
 749                                .unwrap_or_else(|ix| ix);
 750
 751                            ranges.insert(range_ix, excerpt_anchor_range);
 752                            let last_ix = included_files.len() - 1;
 753                            included_files.swap(buffer_ix, last_ix);
 754                        } else {
 755                            included_files.push((
 756                                snapshot,
 757                                excerpt_path.clone(),
 758                                vec![excerpt_anchor_range],
 759                            ));
 760                        }
 761
 762                        let included_files = included_files
 763                            .into_iter()
 764                            .map(|(buffer, path, ranges)| {
 765                                let excerpts = merge_excerpts(
 766                                    &buffer,
 767                                    ranges.iter().map(|range| {
 768                                        let point_range = range.to_point(&buffer);
 769                                        Line(point_range.start.row)..Line(point_range.end.row)
 770                                    }),
 771                                );
 772                                predict_edits_v3::IncludedFile {
 773                                    path,
 774                                    max_row: Line(buffer.max_point().row),
 775                                    excerpts,
 776                                }
 777                            })
 778                            .collect::<Vec<_>>();
 779
 780                        predict_edits_v3::PredictEditsRequest {
 781                            excerpt_path,
 782                            excerpt: String::new(),
 783                            excerpt_line_range: Line(0)..Line(0),
 784                            excerpt_range: 0..0,
 785                            cursor_point: predict_edits_v3::Point {
 786                                line: predict_edits_v3::Line(cursor_point.row),
 787                                column: cursor_point.column,
 788                            },
 789                            included_files,
 790                            referenced_declarations: vec![],
 791                            events,
 792                            can_collect_data,
 793                            diagnostic_groups,
 794                            diagnostic_groups_truncated,
 795                            debug_info: debug_tx.is_some(),
 796                            prompt_max_bytes: Some(options.max_prompt_bytes),
 797                            prompt_format: options.prompt_format,
 798                            // TODO [zeta2]
 799                            signatures: vec![],
 800                            excerpt_parent: None,
 801                            git_info: None,
 802                        }
 803                    }
 804                    ContextMode::Syntax(context_options) => {
 805                        let Some(context) = EditPredictionContext::gather_context(
 806                            cursor_point,
 807                            &snapshot,
 808                            parent_abs_path.as_deref(),
 809                            &context_options,
 810                            index_state.as_deref(),
 811                        ) else {
 812                            return Ok((None, None));
 813                        };
 814
 815                        make_syntax_context_cloud_request(
 816                            excerpt_path,
 817                            context,
 818                            events,
 819                            can_collect_data,
 820                            diagnostic_groups,
 821                            diagnostic_groups_truncated,
 822                            None,
 823                            debug_tx.is_some(),
 824                            &worktree_snapshots,
 825                            index_state.as_deref(),
 826                            Some(options.max_prompt_bytes),
 827                            options.prompt_format,
 828                        )
 829                    }
 830                };
 831
 832                let retrieval_time = chrono::Utc::now() - before_retrieval;
 833
 834                let debug_response_tx = if let Some(debug_tx) = &debug_tx {
 835                    let (response_tx, response_rx) = oneshot::channel();
 836
 837                    let local_prompt = build_prompt(&request)
 838                        .map(|(prompt, _)| prompt)
 839                        .map_err(|err| err.to_string());
 840
 841                    debug_tx
 842                        .unbounded_send(ZetaDebugInfo::EditPredicted(ZetaEditPredictionDebugInfo {
 843                            request: request.clone(),
 844                            retrieval_time,
 845                            buffer: buffer.downgrade(),
 846                            local_prompt,
 847                            position,
 848                            response_rx,
 849                        }))
 850                        .ok();
 851                    Some(response_tx)
 852                } else {
 853                    None
 854                };
 855
 856                if cfg!(debug_assertions) && std::env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
 857                    if let Some(debug_response_tx) = debug_response_tx {
 858                        debug_response_tx
 859                            .send(Err("Request skipped".to_string()))
 860                            .ok();
 861                    }
 862                    anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
 863                }
 864
 865                let response =
 866                    Self::send_prediction_request(client, llm_token, app_version, request).await;
 867
 868                if let Some(debug_response_tx) = debug_response_tx {
 869                    debug_response_tx
 870                        .send(
 871                            response
 872                                .as_ref()
 873                                .map_err(|err| err.to_string())
 874                                .map(|response| response.0.clone()),
 875                        )
 876                        .ok();
 877                }
 878
 879                response.map(|(res, usage)| (Some(res), usage))
 880            }
 881        });
 882
 883        let buffer = buffer.clone();
 884
 885        cx.spawn({
 886            let project = project.clone();
 887            async move |this, cx| {
 888                let Some(response) = Self::handle_api_response(&this, request_task.await, cx)?
 889                else {
 890                    return Ok(None);
 891                };
 892
 893                // TODO telemetry: duration, etc
 894                Ok(EditPrediction::from_response(response, &snapshot, &buffer, &project, cx).await)
 895            }
 896        })
 897    }
 898
 899    async fn send_prediction_request(
 900        client: Arc<Client>,
 901        llm_token: LlmApiToken,
 902        app_version: SemanticVersion,
 903        request: predict_edits_v3::PredictEditsRequest,
 904    ) -> Result<(
 905        predict_edits_v3::PredictEditsResponse,
 906        Option<EditPredictionUsage>,
 907    )> {
 908        let url = if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
 909            http_client::Url::parse(&predict_edits_url)?
 910        } else {
 911            client
 912                .http_client()
 913                .build_zed_llm_url("/predict_edits/v3", &[])?
 914        };
 915
 916        Self::send_api_request(
 917            |builder| {
 918                let req = builder
 919                    .uri(url.as_ref())
 920                    .body(serde_json::to_string(&request)?.into());
 921                Ok(req?)
 922            },
 923            client,
 924            llm_token,
 925            app_version,
 926        )
 927        .await
 928    }
 929
 930    fn handle_api_response<T>(
 931        this: &WeakEntity<Self>,
 932        response: Result<(T, Option<EditPredictionUsage>)>,
 933        cx: &mut gpui::AsyncApp,
 934    ) -> Result<T> {
 935        match response {
 936            Ok((data, usage)) => {
 937                if let Some(usage) = usage {
 938                    this.update(cx, |this, cx| {
 939                        this.user_store.update(cx, |user_store, cx| {
 940                            user_store.update_edit_prediction_usage(usage, cx);
 941                        });
 942                    })
 943                    .ok();
 944                }
 945                Ok(data)
 946            }
 947            Err(err) => {
 948                if err.is::<ZedUpdateRequiredError>() {
 949                    cx.update(|cx| {
 950                        this.update(cx, |this, _cx| {
 951                            this.update_required = true;
 952                        })
 953                        .ok();
 954
 955                        let error_message: SharedString = err.to_string().into();
 956                        show_app_notification(
 957                            NotificationId::unique::<ZedUpdateRequiredError>(),
 958                            cx,
 959                            move |cx| {
 960                                cx.new(|cx| {
 961                                    ErrorMessagePrompt::new(error_message.clone(), cx)
 962                                        .with_link_button("Update Zed", "https://zed.dev/releases")
 963                                })
 964                            },
 965                        );
 966                    })
 967                    .ok();
 968                }
 969                Err(err)
 970            }
 971        }
 972    }
 973
 974    async fn send_api_request<Res>(
 975        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
 976        client: Arc<Client>,
 977        llm_token: LlmApiToken,
 978        app_version: SemanticVersion,
 979    ) -> Result<(Res, Option<EditPredictionUsage>)>
 980    where
 981        Res: DeserializeOwned,
 982    {
 983        let http_client = client.http_client();
 984        let mut token = llm_token.acquire(&client).await?;
 985        let mut did_retry = false;
 986
 987        loop {
 988            let request_builder = http_client::Request::builder().method(Method::POST);
 989
 990            let request = build(
 991                request_builder
 992                    .header("Content-Type", "application/json")
 993                    .header("Authorization", format!("Bearer {}", token))
 994                    .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
 995            )?;
 996
 997            let mut response = http_client.send(request).await?;
 998
 999            if let Some(minimum_required_version) = response
1000                .headers()
1001                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
1002                .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
1003            {
1004                anyhow::ensure!(
1005                    app_version >= minimum_required_version,
1006                    ZedUpdateRequiredError {
1007                        minimum_version: minimum_required_version
1008                    }
1009                );
1010            }
1011
1012            if response.status().is_success() {
1013                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
1014
1015                let mut body = Vec::new();
1016                response.body_mut().read_to_end(&mut body).await?;
1017                return Ok((serde_json::from_slice(&body)?, usage));
1018            } else if !did_retry
1019                && response
1020                    .headers()
1021                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1022                    .is_some()
1023            {
1024                did_retry = true;
1025                token = llm_token.refresh(&client).await?;
1026            } else {
1027                let mut body = String::new();
1028                response.body_mut().read_to_string(&mut body).await?;
1029                anyhow::bail!(
1030                    "Request failed with status: {:?}\nBody: {}",
1031                    response.status(),
1032                    body
1033                );
1034            }
1035        }
1036    }
1037
1038    pub const CONTEXT_RETRIEVAL_IDLE_DURATION: Duration = Duration::from_secs(10);
1039    pub const CONTEXT_RETRIEVAL_DEBOUNCE_DURATION: Duration = Duration::from_secs(3);
1040
1041    // Refresh the related excerpts when the user just beguns editing after
1042    // an idle period, and after they pause editing.
1043    fn refresh_context_if_needed(
1044        &mut self,
1045        project: &Entity<Project>,
1046        buffer: &Entity<language::Buffer>,
1047        cursor_position: language::Anchor,
1048        cx: &mut Context<Self>,
1049    ) {
1050        if !matches!(&self.options().context, ContextMode::Llm { .. }) {
1051            return;
1052        }
1053
1054        let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1055            return;
1056        };
1057
1058        let now = Instant::now();
1059        let was_idle = zeta_project
1060            .refresh_context_timestamp
1061            .map_or(true, |timestamp| {
1062                now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION
1063            });
1064        zeta_project.refresh_context_timestamp = Some(now);
1065        zeta_project.refresh_context_debounce_task = Some(cx.spawn({
1066            let buffer = buffer.clone();
1067            let project = project.clone();
1068            async move |this, cx| {
1069                if was_idle {
1070                    log::debug!("refetching edit prediction context after idle");
1071                } else {
1072                    cx.background_executor()
1073                        .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION)
1074                        .await;
1075                    log::debug!("refetching edit prediction context after pause");
1076                }
1077                this.update(cx, |this, cx| {
1078                    this.refresh_context(project, buffer, cursor_position, cx);
1079                })
1080                .ok()
1081            }
1082        }));
1083    }
1084
1085    // Refresh the related excerpts asynchronously. Ensure the task runs to completion,
1086    // and avoid spawning more than one concurrent task.
1087    fn refresh_context(
1088        &mut self,
1089        project: Entity<Project>,
1090        buffer: Entity<language::Buffer>,
1091        cursor_position: language::Anchor,
1092        cx: &mut Context<Self>,
1093    ) {
1094        let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1095            return;
1096        };
1097
1098        let debug_tx = self.debug_tx.clone();
1099
1100        zeta_project
1101            .refresh_context_task
1102            .get_or_insert(cx.spawn(async move |this, cx| {
1103                let related_excerpts = this
1104                    .update(cx, |this, cx| {
1105                        let Some(zeta_project) = this.projects.get(&project.entity_id()) else {
1106                            return Task::ready(anyhow::Ok(HashMap::default()));
1107                        };
1108
1109                        let ContextMode::Llm(options) = &this.options().context else {
1110                            return Task::ready(anyhow::Ok(HashMap::default()));
1111                        };
1112
1113                        let mut edit_history_unified_diff = String::new();
1114
1115                        for event in zeta_project.events.iter() {
1116                            if let Some(event) = event.to_request_event(cx) {
1117                                writeln!(&mut edit_history_unified_diff, "{event}").ok();
1118                            }
1119                        }
1120
1121                        find_related_excerpts(
1122                            buffer.clone(),
1123                            cursor_position,
1124                            &project,
1125                            edit_history_unified_diff,
1126                            options,
1127                            debug_tx,
1128                            cx,
1129                        )
1130                    })
1131                    .ok()?
1132                    .await
1133                    .log_err()
1134                    .unwrap_or_default();
1135                this.update(cx, |this, _cx| {
1136                    let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
1137                        return;
1138                    };
1139                    zeta_project.context = Some(related_excerpts);
1140                    zeta_project.refresh_context_task.take();
1141                    if let Some(debug_tx) = &this.debug_tx {
1142                        debug_tx
1143                            .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished(
1144                                ZetaContextRetrievalDebugInfo {
1145                                    project,
1146                                    timestamp: Instant::now(),
1147                                },
1148                            ))
1149                            .ok();
1150                    }
1151                })
1152                .ok()
1153            }));
1154    }
1155
1156    fn gather_nearby_diagnostics(
1157        cursor_offset: usize,
1158        diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
1159        snapshot: &BufferSnapshot,
1160        max_diagnostics_bytes: usize,
1161    ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
1162        // TODO: Could make this more efficient
1163        let mut diagnostic_groups = Vec::new();
1164        for (language_server_id, diagnostics) in diagnostic_sets {
1165            let mut groups = Vec::new();
1166            diagnostics.groups(*language_server_id, &mut groups, &snapshot);
1167            diagnostic_groups.extend(
1168                groups
1169                    .into_iter()
1170                    .map(|(_, group)| group.resolve::<usize>(&snapshot)),
1171            );
1172        }
1173
1174        // sort by proximity to cursor
1175        diagnostic_groups.sort_by_key(|group| {
1176            let range = &group.entries[group.primary_ix].range;
1177            if range.start >= cursor_offset {
1178                range.start - cursor_offset
1179            } else if cursor_offset >= range.end {
1180                cursor_offset - range.end
1181            } else {
1182                (cursor_offset - range.start).min(range.end - cursor_offset)
1183            }
1184        });
1185
1186        let mut results = Vec::new();
1187        let mut diagnostic_groups_truncated = false;
1188        let mut diagnostics_byte_count = 0;
1189        for group in diagnostic_groups {
1190            let raw_value = serde_json::value::to_raw_value(&group).unwrap();
1191            diagnostics_byte_count += raw_value.get().len();
1192            if diagnostics_byte_count > max_diagnostics_bytes {
1193                diagnostic_groups_truncated = true;
1194                break;
1195            }
1196            results.push(predict_edits_v3::DiagnosticGroup(raw_value));
1197        }
1198
1199        (results, diagnostic_groups_truncated)
1200    }
1201
1202    // TODO: Dedupe with similar code in request_prediction?
1203    pub fn cloud_request_for_zeta_cli(
1204        &mut self,
1205        project: &Entity<Project>,
1206        buffer: &Entity<Buffer>,
1207        position: language::Anchor,
1208        cx: &mut Context<Self>,
1209    ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
1210        let project_state = self.projects.get(&project.entity_id());
1211
1212        let index_state = project_state.map(|state| {
1213            state
1214                .syntax_index
1215                .read_with(cx, |index, _cx| index.state().clone())
1216        });
1217        let options = self.options.clone();
1218        let snapshot = buffer.read(cx).snapshot();
1219        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
1220            return Task::ready(Err(anyhow!("No file path for excerpt")));
1221        };
1222        let worktree_snapshots = project
1223            .read(cx)
1224            .worktrees(cx)
1225            .map(|worktree| worktree.read(cx).snapshot())
1226            .collect::<Vec<_>>();
1227
1228        let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
1229            let mut path = f.worktree.read(cx).absolutize(&f.path);
1230            if path.pop() { Some(path) } else { None }
1231        });
1232
1233        cx.background_spawn(async move {
1234            let index_state = if let Some(index_state) = index_state {
1235                Some(index_state.lock_owned().await)
1236            } else {
1237                None
1238            };
1239
1240            let cursor_point = position.to_point(&snapshot);
1241
1242            let debug_info = true;
1243            EditPredictionContext::gather_context(
1244                cursor_point,
1245                &snapshot,
1246                parent_abs_path.as_deref(),
1247                match &options.context {
1248                    ContextMode::Llm(_) => {
1249                        // TODO
1250                        panic!("Llm mode not supported in zeta cli yet");
1251                    }
1252                    ContextMode::Syntax(edit_prediction_context_options) => {
1253                        edit_prediction_context_options
1254                    }
1255                },
1256                index_state.as_deref(),
1257            )
1258            .context("Failed to select excerpt")
1259            .map(|context| {
1260                make_syntax_context_cloud_request(
1261                    excerpt_path.into(),
1262                    context,
1263                    // TODO pass everything
1264                    Vec::new(),
1265                    false,
1266                    Vec::new(),
1267                    false,
1268                    None,
1269                    debug_info,
1270                    &worktree_snapshots,
1271                    index_state.as_deref(),
1272                    Some(options.max_prompt_bytes),
1273                    options.prompt_format,
1274                )
1275            })
1276        })
1277    }
1278
1279    pub fn wait_for_initial_indexing(
1280        &mut self,
1281        project: &Entity<Project>,
1282        cx: &mut App,
1283    ) -> Task<Result<()>> {
1284        let zeta_project = self.get_or_init_zeta_project(project, cx);
1285        zeta_project
1286            .syntax_index
1287            .read(cx)
1288            .wait_for_initial_file_indexing(cx)
1289    }
1290}
1291
1292#[derive(Error, Debug)]
1293#[error(
1294    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1295)]
1296pub struct ZedUpdateRequiredError {
1297    minimum_version: SemanticVersion,
1298}
1299
1300fn make_syntax_context_cloud_request(
1301    excerpt_path: Arc<Path>,
1302    context: EditPredictionContext,
1303    events: Vec<predict_edits_v3::Event>,
1304    can_collect_data: bool,
1305    diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
1306    diagnostic_groups_truncated: bool,
1307    git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
1308    debug_info: bool,
1309    worktrees: &Vec<worktree::Snapshot>,
1310    index_state: Option<&SyntaxIndexState>,
1311    prompt_max_bytes: Option<usize>,
1312    prompt_format: PromptFormat,
1313) -> predict_edits_v3::PredictEditsRequest {
1314    let mut signatures = Vec::new();
1315    let mut declaration_to_signature_index = HashMap::default();
1316    let mut referenced_declarations = Vec::new();
1317
1318    for snippet in context.declarations {
1319        let project_entry_id = snippet.declaration.project_entry_id();
1320        let Some(path) = worktrees.iter().find_map(|worktree| {
1321            worktree.entry_for_id(project_entry_id).map(|entry| {
1322                let mut full_path = RelPathBuf::new();
1323                full_path.push(worktree.root_name());
1324                full_path.push(&entry.path);
1325                full_path
1326            })
1327        }) else {
1328            continue;
1329        };
1330
1331        let parent_index = index_state.and_then(|index_state| {
1332            snippet.declaration.parent().and_then(|parent| {
1333                add_signature(
1334                    parent,
1335                    &mut declaration_to_signature_index,
1336                    &mut signatures,
1337                    index_state,
1338                )
1339            })
1340        });
1341
1342        let (text, text_is_truncated) = snippet.declaration.item_text();
1343        referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
1344            path: path.as_std_path().into(),
1345            text: text.into(),
1346            range: snippet.declaration.item_line_range(),
1347            text_is_truncated,
1348            signature_range: snippet.declaration.signature_range_in_item_text(),
1349            parent_index,
1350            signature_score: snippet.score(DeclarationStyle::Signature),
1351            declaration_score: snippet.score(DeclarationStyle::Declaration),
1352            score_components: snippet.components,
1353        });
1354    }
1355
1356    let excerpt_parent = index_state.and_then(|index_state| {
1357        context
1358            .excerpt
1359            .parent_declarations
1360            .last()
1361            .and_then(|(parent, _)| {
1362                add_signature(
1363                    *parent,
1364                    &mut declaration_to_signature_index,
1365                    &mut signatures,
1366                    index_state,
1367                )
1368            })
1369    });
1370
1371    predict_edits_v3::PredictEditsRequest {
1372        excerpt_path,
1373        excerpt: context.excerpt_text.body,
1374        excerpt_line_range: context.excerpt.line_range,
1375        excerpt_range: context.excerpt.range,
1376        cursor_point: predict_edits_v3::Point {
1377            line: predict_edits_v3::Line(context.cursor_point.row),
1378            column: context.cursor_point.column,
1379        },
1380        referenced_declarations,
1381        included_files: vec![],
1382        signatures,
1383        excerpt_parent,
1384        events,
1385        can_collect_data,
1386        diagnostic_groups,
1387        diagnostic_groups_truncated,
1388        git_info,
1389        debug_info,
1390        prompt_max_bytes,
1391        prompt_format,
1392    }
1393}
1394
1395fn add_signature(
1396    declaration_id: DeclarationId,
1397    declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
1398    signatures: &mut Vec<Signature>,
1399    index: &SyntaxIndexState,
1400) -> Option<usize> {
1401    if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
1402        return Some(*signature_index);
1403    }
1404    let Some(parent_declaration) = index.declaration(declaration_id) else {
1405        log::error!("bug: missing parent declaration");
1406        return None;
1407    };
1408    let parent_index = parent_declaration.parent().and_then(|parent| {
1409        add_signature(parent, declaration_to_signature_index, signatures, index)
1410    });
1411    let (text, text_is_truncated) = parent_declaration.signature_text();
1412    let signature_index = signatures.len();
1413    signatures.push(Signature {
1414        text: text.into(),
1415        text_is_truncated,
1416        parent_index,
1417        range: parent_declaration.signature_line_range(),
1418    });
1419    declaration_to_signature_index.insert(declaration_id, signature_index);
1420    Some(signature_index)
1421}
1422
1423#[cfg(test)]
1424mod tests {
1425    use std::{
1426        path::{Path, PathBuf},
1427        sync::Arc,
1428    };
1429
1430    use client::UserStore;
1431    use clock::FakeSystemClock;
1432    use cloud_llm_client::predict_edits_v3::{self, Point};
1433    use edit_prediction_context::Line;
1434    use futures::{
1435        AsyncReadExt, StreamExt,
1436        channel::{mpsc, oneshot},
1437    };
1438    use gpui::{
1439        Entity, TestAppContext,
1440        http_client::{FakeHttpClient, Response},
1441        prelude::*,
1442    };
1443    use indoc::indoc;
1444    use language::{LanguageServerId, OffsetRangeExt as _};
1445    use pretty_assertions::{assert_eq, assert_matches};
1446    use project::{FakeFs, Project};
1447    use serde_json::json;
1448    use settings::SettingsStore;
1449    use util::path;
1450    use uuid::Uuid;
1451
1452    use crate::{BufferEditPrediction, Zeta};
1453
1454    #[gpui::test]
1455    async fn test_current_state(cx: &mut TestAppContext) {
1456        let (zeta, mut req_rx) = init_test(cx);
1457        let fs = FakeFs::new(cx.executor());
1458        fs.insert_tree(
1459            "/root",
1460            json!({
1461                "1.txt": "Hello!\nHow\nBye",
1462                "2.txt": "Hola!\nComo\nAdios"
1463            }),
1464        )
1465        .await;
1466        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1467
1468        zeta.update(cx, |zeta, cx| {
1469            zeta.register_project(&project, cx);
1470        });
1471
1472        let buffer1 = project
1473            .update(cx, |project, cx| {
1474                let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
1475                project.open_buffer(path, cx)
1476            })
1477            .await
1478            .unwrap();
1479        let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
1480        let position = snapshot1.anchor_before(language::Point::new(1, 3));
1481
1482        // Prediction for current file
1483
1484        let prediction_task = zeta.update(cx, |zeta, cx| {
1485            zeta.refresh_prediction(&project, &buffer1, position, cx)
1486        });
1487        let (_request, respond_tx) = req_rx.next().await.unwrap();
1488        respond_tx
1489            .send(predict_edits_v3::PredictEditsResponse {
1490                request_id: Uuid::new_v4(),
1491                edits: vec![predict_edits_v3::Edit {
1492                    path: Path::new(path!("root/1.txt")).into(),
1493                    range: Line(0)..Line(snapshot1.max_point().row + 1),
1494                    content: "Hello!\nHow are you?\nBye".into(),
1495                }],
1496                debug_info: None,
1497            })
1498            .unwrap();
1499        prediction_task.await.unwrap();
1500
1501        zeta.read_with(cx, |zeta, cx| {
1502            let prediction = zeta
1503                .current_prediction_for_buffer(&buffer1, &project, cx)
1504                .unwrap();
1505            assert_matches!(prediction, BufferEditPrediction::Local { .. });
1506        });
1507
1508        // Prediction for another file
1509        let prediction_task = zeta.update(cx, |zeta, cx| {
1510            zeta.refresh_prediction(&project, &buffer1, position, cx)
1511        });
1512        let (_request, respond_tx) = req_rx.next().await.unwrap();
1513        respond_tx
1514            .send(predict_edits_v3::PredictEditsResponse {
1515                request_id: Uuid::new_v4(),
1516                edits: vec![predict_edits_v3::Edit {
1517                    path: Path::new(path!("root/2.txt")).into(),
1518                    range: Line(0)..Line(snapshot1.max_point().row + 1),
1519                    content: "Hola!\nComo estas?\nAdios".into(),
1520                }],
1521                debug_info: None,
1522            })
1523            .unwrap();
1524        prediction_task.await.unwrap();
1525        zeta.read_with(cx, |zeta, cx| {
1526            let prediction = zeta
1527                .current_prediction_for_buffer(&buffer1, &project, cx)
1528                .unwrap();
1529            assert_matches!(
1530                prediction,
1531                BufferEditPrediction::Jump { prediction } if prediction.path.as_ref() == Path::new(path!("root/2.txt"))
1532            );
1533        });
1534
1535        let buffer2 = project
1536            .update(cx, |project, cx| {
1537                let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
1538                project.open_buffer(path, cx)
1539            })
1540            .await
1541            .unwrap();
1542
1543        zeta.read_with(cx, |zeta, cx| {
1544            let prediction = zeta
1545                .current_prediction_for_buffer(&buffer2, &project, cx)
1546                .unwrap();
1547            assert_matches!(prediction, BufferEditPrediction::Local { .. });
1548        });
1549    }
1550
1551    #[gpui::test]
1552    async fn test_simple_request(cx: &mut TestAppContext) {
1553        let (zeta, mut req_rx) = init_test(cx);
1554        let fs = FakeFs::new(cx.executor());
1555        fs.insert_tree(
1556            "/root",
1557            json!({
1558                "foo.md":  "Hello!\nHow\nBye"
1559            }),
1560        )
1561        .await;
1562        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1563
1564        let buffer = project
1565            .update(cx, |project, cx| {
1566                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1567                project.open_buffer(path, cx)
1568            })
1569            .await
1570            .unwrap();
1571        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1572        let position = snapshot.anchor_before(language::Point::new(1, 3));
1573
1574        let prediction_task = zeta.update(cx, |zeta, cx| {
1575            zeta.request_prediction(&project, &buffer, position, cx)
1576        });
1577
1578        let (request, respond_tx) = req_rx.next().await.unwrap();
1579        assert_eq!(
1580            request.excerpt_path.as_ref(),
1581            Path::new(path!("root/foo.md"))
1582        );
1583        assert_eq!(
1584            request.cursor_point,
1585            Point {
1586                line: Line(1),
1587                column: 3
1588            }
1589        );
1590
1591        respond_tx
1592            .send(predict_edits_v3::PredictEditsResponse {
1593                request_id: Uuid::new_v4(),
1594                edits: vec![predict_edits_v3::Edit {
1595                    path: Path::new(path!("root/foo.md")).into(),
1596                    range: Line(0)..Line(snapshot.max_point().row + 1),
1597                    content: "Hello!\nHow are you?\nBye".into(),
1598                }],
1599                debug_info: None,
1600            })
1601            .unwrap();
1602
1603        let prediction = prediction_task.await.unwrap().unwrap();
1604
1605        assert_eq!(prediction.edits.len(), 1);
1606        assert_eq!(
1607            prediction.edits[0].0.to_point(&snapshot).start,
1608            language::Point::new(1, 3)
1609        );
1610        assert_eq!(prediction.edits[0].1, " are you?");
1611    }
1612
1613    #[gpui::test]
1614    async fn test_request_events(cx: &mut TestAppContext) {
1615        let (zeta, mut req_rx) = init_test(cx);
1616        let fs = FakeFs::new(cx.executor());
1617        fs.insert_tree(
1618            "/root",
1619            json!({
1620                "foo.md": "Hello!\n\nBye"
1621            }),
1622        )
1623        .await;
1624        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1625
1626        let buffer = project
1627            .update(cx, |project, cx| {
1628                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1629                project.open_buffer(path, cx)
1630            })
1631            .await
1632            .unwrap();
1633
1634        zeta.update(cx, |zeta, cx| {
1635            zeta.register_buffer(&buffer, &project, cx);
1636        });
1637
1638        buffer.update(cx, |buffer, cx| {
1639            buffer.edit(vec![(7..7, "How")], None, cx);
1640        });
1641
1642        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1643        let position = snapshot.anchor_before(language::Point::new(1, 3));
1644
1645        let prediction_task = zeta.update(cx, |zeta, cx| {
1646            zeta.request_prediction(&project, &buffer, position, cx)
1647        });
1648
1649        let (request, respond_tx) = req_rx.next().await.unwrap();
1650
1651        assert_eq!(request.events.len(), 1);
1652        assert_eq!(
1653            request.events[0],
1654            predict_edits_v3::Event::BufferChange {
1655                path: Some(PathBuf::from(path!("root/foo.md"))),
1656                old_path: None,
1657                diff: indoc! {"
1658                        @@ -1,3 +1,3 @@
1659                         Hello!
1660                        -
1661                        +How
1662                         Bye
1663                    "}
1664                .to_string(),
1665                predicted: false
1666            }
1667        );
1668
1669        respond_tx
1670            .send(predict_edits_v3::PredictEditsResponse {
1671                request_id: Uuid::new_v4(),
1672                edits: vec![predict_edits_v3::Edit {
1673                    path: Path::new(path!("root/foo.md")).into(),
1674                    range: Line(0)..Line(snapshot.max_point().row + 1),
1675                    content: "Hello!\nHow are you?\nBye".into(),
1676                }],
1677                debug_info: None,
1678            })
1679            .unwrap();
1680
1681        let prediction = prediction_task.await.unwrap().unwrap();
1682
1683        assert_eq!(prediction.edits.len(), 1);
1684        assert_eq!(
1685            prediction.edits[0].0.to_point(&snapshot).start,
1686            language::Point::new(1, 3)
1687        );
1688        assert_eq!(prediction.edits[0].1, " are you?");
1689    }
1690
1691    #[gpui::test]
1692    async fn test_request_diagnostics(cx: &mut TestAppContext) {
1693        let (zeta, mut req_rx) = init_test(cx);
1694        let fs = FakeFs::new(cx.executor());
1695        fs.insert_tree(
1696            "/root",
1697            json!({
1698                "foo.md": "Hello!\nBye"
1699            }),
1700        )
1701        .await;
1702        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1703
1704        let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1705        let diagnostic = lsp::Diagnostic {
1706            range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1707            severity: Some(lsp::DiagnosticSeverity::ERROR),
1708            message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1709            ..Default::default()
1710        };
1711
1712        project.update(cx, |project, cx| {
1713            project.lsp_store().update(cx, |lsp_store, cx| {
1714                // Create some diagnostics
1715                lsp_store
1716                    .update_diagnostics(
1717                        LanguageServerId(0),
1718                        lsp::PublishDiagnosticsParams {
1719                            uri: path_to_buffer_uri.clone(),
1720                            diagnostics: vec![diagnostic],
1721                            version: None,
1722                        },
1723                        None,
1724                        language::DiagnosticSourceKind::Pushed,
1725                        &[],
1726                        cx,
1727                    )
1728                    .unwrap();
1729            });
1730        });
1731
1732        let buffer = project
1733            .update(cx, |project, cx| {
1734                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1735                project.open_buffer(path, cx)
1736            })
1737            .await
1738            .unwrap();
1739
1740        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1741        let position = snapshot.anchor_before(language::Point::new(0, 0));
1742
1743        let _prediction_task = zeta.update(cx, |zeta, cx| {
1744            zeta.request_prediction(&project, &buffer, position, cx)
1745        });
1746
1747        let (request, _respond_tx) = req_rx.next().await.unwrap();
1748
1749        assert_eq!(request.diagnostic_groups.len(), 1);
1750        let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1751            .unwrap();
1752        // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1753        assert_eq!(
1754            value,
1755            json!({
1756                "entries": [{
1757                    "range": {
1758                        "start": 8,
1759                        "end": 10
1760                    },
1761                    "diagnostic": {
1762                        "source": null,
1763                        "code": null,
1764                        "code_description": null,
1765                        "severity": 1,
1766                        "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1767                        "markdown": null,
1768                        "group_id": 0,
1769                        "is_primary": true,
1770                        "is_disk_based": false,
1771                        "is_unnecessary": false,
1772                        "source_kind": "Pushed",
1773                        "data": null,
1774                        "underline": true
1775                    }
1776                }],
1777                "primary_ix": 0
1778            })
1779        );
1780    }
1781
1782    fn init_test(
1783        cx: &mut TestAppContext,
1784    ) -> (
1785        Entity<Zeta>,
1786        mpsc::UnboundedReceiver<(
1787            predict_edits_v3::PredictEditsRequest,
1788            oneshot::Sender<predict_edits_v3::PredictEditsResponse>,
1789        )>,
1790    ) {
1791        cx.update(move |cx| {
1792            let settings_store = SettingsStore::test(cx);
1793            cx.set_global(settings_store);
1794            language::init(cx);
1795            Project::init_settings(cx);
1796
1797            let (req_tx, req_rx) = mpsc::unbounded();
1798
1799            let http_client = FakeHttpClient::create({
1800                move |req| {
1801                    let uri = req.uri().path().to_string();
1802                    let mut body = req.into_body();
1803                    let req_tx = req_tx.clone();
1804                    async move {
1805                        let resp = match uri.as_str() {
1806                            "/client/llm_tokens" => serde_json::to_string(&json!({
1807                                "token": "test"
1808                            }))
1809                            .unwrap(),
1810                            "/predict_edits/v3" => {
1811                                let mut buf = Vec::new();
1812                                body.read_to_end(&mut buf).await.ok();
1813                                let req = serde_json::from_slice(&buf).unwrap();
1814
1815                                let (res_tx, res_rx) = oneshot::channel();
1816                                req_tx.unbounded_send((req, res_tx)).unwrap();
1817                                serde_json::to_string(&res_rx.await?).unwrap()
1818                            }
1819                            _ => {
1820                                panic!("Unexpected path: {}", uri)
1821                            }
1822                        };
1823
1824                        Ok(Response::builder().body(resp.into()).unwrap())
1825                    }
1826                }
1827            });
1828
1829            let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1830            client.cloud_client().set_credentials(1, "test".into());
1831
1832            language_model::init(client.clone(), cx);
1833
1834            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1835            let zeta = Zeta::global(&client, &user_store, cx);
1836
1837            (zeta, req_rx)
1838        })
1839    }
1840}