zeta2.rs

   1use anyhow::{Context as _, Result, anyhow};
   2use chrono::TimeDelta;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
   5use cloud_llm_client::{
   6    AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
   7    ZED_VERSION_HEADER_NAME,
   8};
   9use cloud_zeta2_prompt::{DEFAULT_MAX_PROMPT_BYTES, build_prompt};
  10use collections::HashMap;
  11use edit_prediction_context::{
  12    DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
  13    EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionScoreOptions, Line,
  14    SyntaxIndex, SyntaxIndexState,
  15};
  16use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
  17use futures::AsyncReadExt as _;
  18use futures::channel::{mpsc, oneshot};
  19use gpui::http_client::{AsyncBody, Method};
  20use gpui::{
  21    App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
  22    http_client, prelude::*,
  23};
  24use language::{Anchor, Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
  25use language::{BufferSnapshot, OffsetRangeExt};
  26use language_model::{LlmApiToken, RefreshLlmTokenListener};
  27use project::Project;
  28use release_channel::AppVersion;
  29use serde::de::DeserializeOwned;
  30use std::collections::{VecDeque, hash_map};
  31use std::fmt::Write;
  32use std::ops::Range;
  33use std::path::Path;
  34use std::str::FromStr as _;
  35use std::sync::Arc;
  36use std::time::{Duration, Instant};
  37use thiserror::Error;
  38use util::ResultExt as _;
  39use util::rel_path::RelPathBuf;
  40use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  41
  42pub mod merge_excerpts;
  43mod prediction;
  44mod provider;
  45pub mod related_excerpts;
  46
  47use crate::merge_excerpts::merge_excerpts;
  48use crate::prediction::EditPrediction;
  49use crate::related_excerpts::find_related_excerpts;
  50pub use crate::related_excerpts::{LlmContextOptions, SearchToolQuery};
  51pub use provider::ZetaEditPredictionProvider;
  52
  53const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
  54
  55/// Maximum number of events to track.
  56const MAX_EVENT_COUNT: usize = 16;
  57
  58pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
  59    max_bytes: 512,
  60    min_bytes: 128,
  61    target_before_cursor_over_total_bytes: 0.5,
  62};
  63
  64pub const DEFAULT_CONTEXT_OPTIONS: ContextMode = ContextMode::Llm(DEFAULT_LLM_CONTEXT_OPTIONS);
  65
  66pub const DEFAULT_LLM_CONTEXT_OPTIONS: LlmContextOptions = LlmContextOptions {
  67    excerpt: DEFAULT_EXCERPT_OPTIONS,
  68};
  69
  70pub const DEFAULT_SYNTAX_CONTEXT_OPTIONS: EditPredictionContextOptions =
  71    EditPredictionContextOptions {
  72        use_imports: true,
  73        max_retrieved_declarations: 0,
  74        excerpt: DEFAULT_EXCERPT_OPTIONS,
  75        score: EditPredictionScoreOptions {
  76            omit_excerpt_overlaps: true,
  77        },
  78    };
  79
  80pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
  81    context: DEFAULT_CONTEXT_OPTIONS,
  82    max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
  83    max_diagnostic_bytes: 2048,
  84    prompt_format: PromptFormat::DEFAULT,
  85    file_indexing_parallelism: 1,
  86};
  87
  88pub struct Zeta2FeatureFlag;
  89
  90impl FeatureFlag for Zeta2FeatureFlag {
  91    const NAME: &'static str = "zeta2";
  92
  93    fn enabled_for_staff() -> bool {
  94        false
  95    }
  96}
  97
  98#[derive(Clone)]
  99struct ZetaGlobal(Entity<Zeta>);
 100
 101impl Global for ZetaGlobal {}
 102
 103pub struct Zeta {
 104    client: Arc<Client>,
 105    user_store: Entity<UserStore>,
 106    llm_token: LlmApiToken,
 107    _llm_token_subscription: Subscription,
 108    projects: HashMap<EntityId, ZetaProject>,
 109    options: ZetaOptions,
 110    update_required: bool,
 111    debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
 112}
 113
 114#[derive(Debug, Clone, PartialEq)]
 115pub struct ZetaOptions {
 116    pub context: ContextMode,
 117    pub max_prompt_bytes: usize,
 118    pub max_diagnostic_bytes: usize,
 119    pub prompt_format: predict_edits_v3::PromptFormat,
 120    pub file_indexing_parallelism: usize,
 121}
 122
 123#[derive(Debug, Clone, PartialEq)]
 124pub enum ContextMode {
 125    Llm(LlmContextOptions),
 126    Syntax(EditPredictionContextOptions),
 127}
 128
 129impl ContextMode {
 130    pub fn excerpt(&self) -> &EditPredictionExcerptOptions {
 131        match self {
 132            ContextMode::Llm(options) => &options.excerpt,
 133            ContextMode::Syntax(options) => &options.excerpt,
 134        }
 135    }
 136}
 137
 138pub enum ZetaDebugInfo {
 139    ContextRetrievalStarted(ZetaContextRetrievalStartedDebugInfo),
 140    SearchQueriesGenerated(ZetaSearchQueryDebugInfo),
 141    SearchQueriesExecuted(ZetaContextRetrievalDebugInfo),
 142    SearchResultsFiltered(ZetaContextRetrievalDebugInfo),
 143    ContextRetrievalFinished(ZetaContextRetrievalDebugInfo),
 144    EditPredicted(ZetaEditPredictionDebugInfo),
 145}
 146
 147pub struct ZetaContextRetrievalStartedDebugInfo {
 148    pub project: Entity<Project>,
 149    pub timestamp: Instant,
 150    pub search_prompt: String,
 151}
 152
 153pub struct ZetaContextRetrievalDebugInfo {
 154    pub project: Entity<Project>,
 155    pub timestamp: Instant,
 156}
 157
 158pub struct ZetaEditPredictionDebugInfo {
 159    pub request: predict_edits_v3::PredictEditsRequest,
 160    pub retrieval_time: TimeDelta,
 161    pub buffer: WeakEntity<Buffer>,
 162    pub position: language::Anchor,
 163    pub local_prompt: Result<String, String>,
 164    pub response_rx: oneshot::Receiver<Result<predict_edits_v3::PredictEditsResponse, String>>,
 165}
 166
 167pub struct ZetaSearchQueryDebugInfo {
 168    pub project: Entity<Project>,
 169    pub timestamp: Instant,
 170    pub queries: Vec<SearchToolQuery>,
 171}
 172
 173pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
 174
 175struct ZetaProject {
 176    syntax_index: Entity<SyntaxIndex>,
 177    events: VecDeque<Event>,
 178    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 179    current_prediction: Option<CurrentEditPrediction>,
 180    context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
 181    refresh_context_task: Option<Task<Option<()>>>,
 182    refresh_context_debounce_task: Option<Task<Option<()>>>,
 183    refresh_context_timestamp: Option<Instant>,
 184}
 185
 186#[derive(Debug, Clone)]
 187struct CurrentEditPrediction {
 188    pub requested_by_buffer_id: EntityId,
 189    pub prediction: EditPrediction,
 190}
 191
 192impl CurrentEditPrediction {
 193    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 194        let Some(new_edits) = self
 195            .prediction
 196            .interpolate(&self.prediction.buffer.read(cx))
 197        else {
 198            return false;
 199        };
 200
 201        if self.prediction.buffer != old_prediction.prediction.buffer {
 202            return true;
 203        }
 204
 205        let Some(old_edits) = old_prediction
 206            .prediction
 207            .interpolate(&old_prediction.prediction.buffer.read(cx))
 208        else {
 209            return true;
 210        };
 211
 212        // This reduces the occurrence of UI thrash from replacing edits
 213        //
 214        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 215        if self.requested_by_buffer_id == self.prediction.buffer.entity_id()
 216            && self.requested_by_buffer_id == old_prediction.prediction.buffer.entity_id()
 217            && old_edits.len() == 1
 218            && new_edits.len() == 1
 219        {
 220            let (old_range, old_text) = &old_edits[0];
 221            let (new_range, new_text) = &new_edits[0];
 222            new_range == old_range && new_text.starts_with(old_text)
 223        } else {
 224            true
 225        }
 226    }
 227}
 228
 229/// A prediction from the perspective of a buffer.
 230#[derive(Debug)]
 231enum BufferEditPrediction<'a> {
 232    Local { prediction: &'a EditPrediction },
 233    Jump { prediction: &'a EditPrediction },
 234}
 235
 236struct RegisteredBuffer {
 237    snapshot: BufferSnapshot,
 238    _subscriptions: [gpui::Subscription; 2],
 239}
 240
 241#[derive(Clone)]
 242pub enum Event {
 243    BufferChange {
 244        old_snapshot: BufferSnapshot,
 245        new_snapshot: BufferSnapshot,
 246        timestamp: Instant,
 247    },
 248}
 249
 250impl Event {
 251    pub fn to_request_event(&self, cx: &App) -> Option<predict_edits_v3::Event> {
 252        match self {
 253            Event::BufferChange {
 254                old_snapshot,
 255                new_snapshot,
 256                ..
 257            } => {
 258                let path = new_snapshot.file().map(|f| f.full_path(cx));
 259
 260                let old_path = old_snapshot.file().and_then(|f| {
 261                    let old_path = f.full_path(cx);
 262                    if Some(&old_path) != path.as_ref() {
 263                        Some(old_path)
 264                    } else {
 265                        None
 266                    }
 267                });
 268
 269                // TODO [zeta2] move to bg?
 270                let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
 271
 272                if path == old_path && diff.is_empty() {
 273                    None
 274                } else {
 275                    Some(predict_edits_v3::Event::BufferChange {
 276                        old_path,
 277                        path,
 278                        diff,
 279                        //todo: Actually detect if this edit was predicted or not
 280                        predicted: false,
 281                    })
 282                }
 283            }
 284        }
 285    }
 286}
 287
 288impl Zeta {
 289    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 290        cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
 291    }
 292
 293    pub fn global(
 294        client: &Arc<Client>,
 295        user_store: &Entity<UserStore>,
 296        cx: &mut App,
 297    ) -> Entity<Self> {
 298        cx.try_global::<ZetaGlobal>()
 299            .map(|global| global.0.clone())
 300            .unwrap_or_else(|| {
 301                let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 302                cx.set_global(ZetaGlobal(zeta.clone()));
 303                zeta
 304            })
 305    }
 306
 307    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 308        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 309
 310        Self {
 311            projects: HashMap::default(),
 312            client,
 313            user_store,
 314            options: DEFAULT_OPTIONS,
 315            llm_token: LlmApiToken::default(),
 316            _llm_token_subscription: cx.subscribe(
 317                &refresh_llm_token_listener,
 318                |this, _listener, _event, cx| {
 319                    let client = this.client.clone();
 320                    let llm_token = this.llm_token.clone();
 321                    cx.spawn(async move |_this, _cx| {
 322                        llm_token.refresh(&client).await?;
 323                        anyhow::Ok(())
 324                    })
 325                    .detach_and_log_err(cx);
 326                },
 327            ),
 328            update_required: false,
 329            debug_tx: None,
 330        }
 331    }
 332
 333    pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<ZetaDebugInfo> {
 334        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
 335        self.debug_tx = Some(debug_watch_tx);
 336        debug_watch_rx
 337    }
 338
 339    pub fn options(&self) -> &ZetaOptions {
 340        &self.options
 341    }
 342
 343    pub fn set_options(&mut self, options: ZetaOptions) {
 344        self.options = options;
 345    }
 346
 347    pub fn clear_history(&mut self) {
 348        for zeta_project in self.projects.values_mut() {
 349            zeta_project.events.clear();
 350        }
 351    }
 352
 353    pub fn history_for_project(&self, project: &Entity<Project>) -> impl Iterator<Item = &Event> {
 354        self.projects
 355            .get(&project.entity_id())
 356            .map(|project| project.events.iter())
 357            .into_iter()
 358            .flatten()
 359    }
 360
 361    pub fn context_for_project(
 362        &self,
 363        project: &Entity<Project>,
 364    ) -> impl Iterator<Item = (Entity<Buffer>, &[Range<Anchor>])> {
 365        self.projects
 366            .get(&project.entity_id())
 367            .and_then(|project| {
 368                Some(
 369                    project
 370                        .context
 371                        .as_ref()?
 372                        .iter()
 373                        .map(|(buffer, ranges)| (buffer.clone(), ranges.as_slice())),
 374                )
 375            })
 376            .into_iter()
 377            .flatten()
 378    }
 379
 380    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 381        self.user_store.read(cx).edit_prediction_usage()
 382    }
 383
 384    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
 385        self.get_or_init_zeta_project(project, cx);
 386    }
 387
 388    pub fn register_buffer(
 389        &mut self,
 390        buffer: &Entity<Buffer>,
 391        project: &Entity<Project>,
 392        cx: &mut Context<Self>,
 393    ) {
 394        let zeta_project = self.get_or_init_zeta_project(project, cx);
 395        Self::register_buffer_impl(zeta_project, buffer, project, cx);
 396    }
 397
 398    fn get_or_init_zeta_project(
 399        &mut self,
 400        project: &Entity<Project>,
 401        cx: &mut App,
 402    ) -> &mut ZetaProject {
 403        self.projects
 404            .entry(project.entity_id())
 405            .or_insert_with(|| ZetaProject {
 406                syntax_index: cx.new(|cx| {
 407                    SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
 408                }),
 409                events: VecDeque::new(),
 410                registered_buffers: HashMap::default(),
 411                current_prediction: None,
 412                context: None,
 413                refresh_context_task: None,
 414                refresh_context_debounce_task: None,
 415                refresh_context_timestamp: None,
 416            })
 417    }
 418
 419    fn register_buffer_impl<'a>(
 420        zeta_project: &'a mut ZetaProject,
 421        buffer: &Entity<Buffer>,
 422        project: &Entity<Project>,
 423        cx: &mut Context<Self>,
 424    ) -> &'a mut RegisteredBuffer {
 425        let buffer_id = buffer.entity_id();
 426        match zeta_project.registered_buffers.entry(buffer_id) {
 427            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 428            hash_map::Entry::Vacant(entry) => {
 429                let snapshot = buffer.read(cx).snapshot();
 430                let project_entity_id = project.entity_id();
 431                entry.insert(RegisteredBuffer {
 432                    snapshot,
 433                    _subscriptions: [
 434                        cx.subscribe(buffer, {
 435                            let project = project.downgrade();
 436                            move |this, buffer, event, cx| {
 437                                if let language::BufferEvent::Edited = event
 438                                    && let Some(project) = project.upgrade()
 439                                {
 440                                    this.report_changes_for_buffer(&buffer, &project, cx);
 441                                }
 442                            }
 443                        }),
 444                        cx.observe_release(buffer, move |this, _buffer, _cx| {
 445                            let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
 446                            else {
 447                                return;
 448                            };
 449                            zeta_project.registered_buffers.remove(&buffer_id);
 450                        }),
 451                    ],
 452                })
 453            }
 454        }
 455    }
 456
 457    fn report_changes_for_buffer(
 458        &mut self,
 459        buffer: &Entity<Buffer>,
 460        project: &Entity<Project>,
 461        cx: &mut Context<Self>,
 462    ) -> BufferSnapshot {
 463        let zeta_project = self.get_or_init_zeta_project(project, cx);
 464        let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
 465
 466        let new_snapshot = buffer.read(cx).snapshot();
 467        if new_snapshot.version != registered_buffer.snapshot.version {
 468            let old_snapshot =
 469                std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
 470            Self::push_event(
 471                zeta_project,
 472                Event::BufferChange {
 473                    old_snapshot,
 474                    new_snapshot: new_snapshot.clone(),
 475                    timestamp: Instant::now(),
 476                },
 477            );
 478        }
 479
 480        new_snapshot
 481    }
 482
 483    fn push_event(zeta_project: &mut ZetaProject, event: Event) {
 484        let events = &mut zeta_project.events;
 485
 486        if let Some(Event::BufferChange {
 487            new_snapshot: last_new_snapshot,
 488            timestamp: last_timestamp,
 489            ..
 490        }) = events.back_mut()
 491        {
 492            // Coalesce edits for the same buffer when they happen one after the other.
 493            let Event::BufferChange {
 494                old_snapshot,
 495                new_snapshot,
 496                timestamp,
 497            } = &event;
 498
 499            if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
 500                && old_snapshot.remote_id() == last_new_snapshot.remote_id()
 501                && old_snapshot.version == last_new_snapshot.version
 502            {
 503                *last_new_snapshot = new_snapshot.clone();
 504                *last_timestamp = *timestamp;
 505                return;
 506            }
 507        }
 508
 509        if events.len() >= MAX_EVENT_COUNT {
 510            // These are halved instead of popping to improve prompt caching.
 511            events.drain(..MAX_EVENT_COUNT / 2);
 512        }
 513
 514        events.push_back(event);
 515    }
 516
 517    fn current_prediction_for_buffer(
 518        &self,
 519        buffer: &Entity<Buffer>,
 520        project: &Entity<Project>,
 521        cx: &App,
 522    ) -> Option<BufferEditPrediction<'_>> {
 523        let project_state = self.projects.get(&project.entity_id())?;
 524
 525        let CurrentEditPrediction {
 526            requested_by_buffer_id,
 527            prediction,
 528        } = project_state.current_prediction.as_ref()?;
 529
 530        if prediction.targets_buffer(buffer.read(cx), cx) {
 531            Some(BufferEditPrediction::Local { prediction })
 532        } else if *requested_by_buffer_id == buffer.entity_id() {
 533            Some(BufferEditPrediction::Jump { prediction })
 534        } else {
 535            None
 536        }
 537    }
 538
 539    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 540        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
 541            return;
 542        };
 543
 544        let Some(prediction) = project_state.current_prediction.take() else {
 545            return;
 546        };
 547        let request_id = prediction.prediction.id.into();
 548
 549        let client = self.client.clone();
 550        let llm_token = self.llm_token.clone();
 551        let app_version = AppVersion::global(cx);
 552        cx.spawn(async move |this, cx| {
 553            let url = if let Ok(predict_edits_url) = std::env::var("ZED_ACCEPT_PREDICTION_URL") {
 554                http_client::Url::parse(&predict_edits_url)?
 555            } else {
 556                client
 557                    .http_client()
 558                    .build_zed_llm_url("/predict_edits/accept", &[])?
 559            };
 560
 561            let response = cx
 562                .background_spawn(Self::send_api_request::<()>(
 563                    move |builder| {
 564                        let req = builder.uri(url.as_ref()).body(
 565                            serde_json::to_string(&AcceptEditPredictionBody { request_id })?.into(),
 566                        );
 567                        Ok(req?)
 568                    },
 569                    client,
 570                    llm_token,
 571                    app_version,
 572                ))
 573                .await;
 574
 575            Self::handle_api_response(&this, response, cx)?;
 576            anyhow::Ok(())
 577        })
 578        .detach_and_log_err(cx);
 579    }
 580
 581    fn discard_current_prediction(&mut self, project: &Entity<Project>) {
 582        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 583            project_state.current_prediction.take();
 584        };
 585    }
 586
 587    pub fn refresh_prediction(
 588        &mut self,
 589        project: &Entity<Project>,
 590        buffer: &Entity<Buffer>,
 591        position: language::Anchor,
 592        cx: &mut Context<Self>,
 593    ) -> Task<Result<()>> {
 594        let request_task = self.request_prediction(project, buffer, position, cx);
 595        let buffer = buffer.clone();
 596        let project = project.clone();
 597
 598        cx.spawn(async move |this, cx| {
 599            if let Some(prediction) = request_task.await? {
 600                this.update(cx, |this, cx| {
 601                    let project_state = this
 602                        .projects
 603                        .get_mut(&project.entity_id())
 604                        .context("Project not found")?;
 605
 606                    let new_prediction = CurrentEditPrediction {
 607                        requested_by_buffer_id: buffer.entity_id(),
 608                        prediction: prediction,
 609                    };
 610
 611                    if project_state
 612                        .current_prediction
 613                        .as_ref()
 614                        .is_none_or(|old_prediction| {
 615                            new_prediction.should_replace_prediction(&old_prediction, cx)
 616                        })
 617                    {
 618                        project_state.current_prediction = Some(new_prediction);
 619                    }
 620                    anyhow::Ok(())
 621                })??;
 622            }
 623            Ok(())
 624        })
 625    }
 626
 627    pub fn request_prediction(
 628        &mut self,
 629        project: &Entity<Project>,
 630        buffer: &Entity<Buffer>,
 631        position: language::Anchor,
 632        cx: &mut Context<Self>,
 633    ) -> Task<Result<Option<EditPrediction>>> {
 634        let project_state = self.projects.get(&project.entity_id());
 635
 636        let index_state = project_state.map(|state| {
 637            state
 638                .syntax_index
 639                .read_with(cx, |index, _cx| index.state().clone())
 640        });
 641        let options = self.options.clone();
 642        let snapshot = buffer.read(cx).snapshot();
 643        let Some(excerpt_path) = snapshot
 644            .file()
 645            .map(|path| -> Arc<Path> { path.full_path(cx).into() })
 646        else {
 647            return Task::ready(Err(anyhow!("No file path for excerpt")));
 648        };
 649        let client = self.client.clone();
 650        let llm_token = self.llm_token.clone();
 651        let app_version = AppVersion::global(cx);
 652        let worktree_snapshots = project
 653            .read(cx)
 654            .worktrees(cx)
 655            .map(|worktree| worktree.read(cx).snapshot())
 656            .collect::<Vec<_>>();
 657        let debug_tx = self.debug_tx.clone();
 658
 659        let events = project_state
 660            .map(|state| {
 661                state
 662                    .events
 663                    .iter()
 664                    .filter_map(|event| event.to_request_event(cx))
 665                    .collect::<Vec<_>>()
 666            })
 667            .unwrap_or_default();
 668
 669        let diagnostics = snapshot.diagnostic_sets().clone();
 670
 671        let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
 672            let mut path = f.worktree.read(cx).absolutize(&f.path);
 673            if path.pop() { Some(path) } else { None }
 674        });
 675
 676        // TODO data collection
 677        let can_collect_data = cx.is_staff();
 678
 679        let mut included_files = project_state
 680            .and_then(|project_state| project_state.context.as_ref())
 681            .unwrap_or(&HashMap::default())
 682            .iter()
 683            .filter_map(|(buffer, ranges)| {
 684                let buffer = buffer.read(cx);
 685                Some((
 686                    buffer.snapshot(),
 687                    buffer.file()?.full_path(cx).into(),
 688                    ranges.clone(),
 689                ))
 690            })
 691            .collect::<Vec<_>>();
 692
 693        let request_task = cx.background_spawn({
 694            let snapshot = snapshot.clone();
 695            let buffer = buffer.clone();
 696            async move {
 697                let index_state = if let Some(index_state) = index_state {
 698                    Some(index_state.lock_owned().await)
 699                } else {
 700                    None
 701                };
 702
 703                let cursor_offset = position.to_offset(&snapshot);
 704                let cursor_point = cursor_offset.to_point(&snapshot);
 705
 706                let before_retrieval = chrono::Utc::now();
 707
 708                let (diagnostic_groups, diagnostic_groups_truncated) =
 709                    Self::gather_nearby_diagnostics(
 710                        cursor_offset,
 711                        &diagnostics,
 712                        &snapshot,
 713                        options.max_diagnostic_bytes,
 714                    );
 715
 716                let request = match options.context {
 717                    ContextMode::Llm(context_options) => {
 718                        let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
 719                            cursor_point,
 720                            &snapshot,
 721                            &context_options.excerpt,
 722                            index_state.as_deref(),
 723                        ) else {
 724                            return Ok((None, None));
 725                        };
 726
 727                        let excerpt_anchor_range = snapshot.anchor_after(excerpt.range.start)
 728                            ..snapshot.anchor_before(excerpt.range.end);
 729
 730                        if let Some(buffer_ix) = included_files
 731                            .iter()
 732                            .position(|(buffer, _, _)| buffer.remote_id() == snapshot.remote_id())
 733                        {
 734                            let (buffer, _, ranges) = &mut included_files[buffer_ix];
 735                            let range_ix = ranges
 736                                .binary_search_by(|probe| {
 737                                    probe
 738                                        .start
 739                                        .cmp(&excerpt_anchor_range.start, buffer)
 740                                        .then(excerpt_anchor_range.end.cmp(&probe.end, buffer))
 741                                })
 742                                .unwrap_or_else(|ix| ix);
 743
 744                            ranges.insert(range_ix, excerpt_anchor_range);
 745                            let last_ix = included_files.len() - 1;
 746                            included_files.swap(buffer_ix, last_ix);
 747                        } else {
 748                            included_files.push((
 749                                snapshot,
 750                                excerpt_path.clone(),
 751                                vec![excerpt_anchor_range],
 752                            ));
 753                        }
 754
 755                        let included_files = included_files
 756                            .into_iter()
 757                            .map(|(buffer, path, ranges)| {
 758                                let excerpts = merge_excerpts(
 759                                    &buffer,
 760                                    ranges.iter().map(|range| {
 761                                        let point_range = range.to_point(&buffer);
 762                                        Line(point_range.start.row)..Line(point_range.end.row)
 763                                    }),
 764                                );
 765                                predict_edits_v3::IncludedFile {
 766                                    path,
 767                                    max_row: Line(buffer.max_point().row),
 768                                    excerpts,
 769                                }
 770                            })
 771                            .collect::<Vec<_>>();
 772
 773                        predict_edits_v3::PredictEditsRequest {
 774                            excerpt_path,
 775                            excerpt: String::new(),
 776                            excerpt_line_range: Line(0)..Line(0),
 777                            excerpt_range: 0..0,
 778                            cursor_point: predict_edits_v3::Point {
 779                                line: predict_edits_v3::Line(cursor_point.row),
 780                                column: cursor_point.column,
 781                            },
 782                            included_files,
 783                            referenced_declarations: vec![],
 784                            events,
 785                            can_collect_data,
 786                            diagnostic_groups,
 787                            diagnostic_groups_truncated,
 788                            debug_info: debug_tx.is_some(),
 789                            prompt_max_bytes: Some(options.max_prompt_bytes),
 790                            prompt_format: options.prompt_format,
 791                            // TODO [zeta2]
 792                            signatures: vec![],
 793                            excerpt_parent: None,
 794                            git_info: None,
 795                        }
 796                    }
 797                    ContextMode::Syntax(context_options) => {
 798                        let Some(context) = EditPredictionContext::gather_context(
 799                            cursor_point,
 800                            &snapshot,
 801                            parent_abs_path.as_deref(),
 802                            &context_options,
 803                            index_state.as_deref(),
 804                        ) else {
 805                            return Ok((None, None));
 806                        };
 807
 808                        make_syntax_context_cloud_request(
 809                            excerpt_path,
 810                            context,
 811                            events,
 812                            can_collect_data,
 813                            diagnostic_groups,
 814                            diagnostic_groups_truncated,
 815                            None,
 816                            debug_tx.is_some(),
 817                            &worktree_snapshots,
 818                            index_state.as_deref(),
 819                            Some(options.max_prompt_bytes),
 820                            options.prompt_format,
 821                        )
 822                    }
 823                };
 824
 825                let retrieval_time = chrono::Utc::now() - before_retrieval;
 826
 827                let debug_response_tx = if let Some(debug_tx) = &debug_tx {
 828                    let (response_tx, response_rx) = oneshot::channel();
 829
 830                    let local_prompt = build_prompt(&request)
 831                        .map(|(prompt, _)| prompt)
 832                        .map_err(|err| err.to_string());
 833
 834                    debug_tx
 835                        .unbounded_send(ZetaDebugInfo::EditPredicted(ZetaEditPredictionDebugInfo {
 836                            request: request.clone(),
 837                            retrieval_time,
 838                            buffer: buffer.downgrade(),
 839                            local_prompt,
 840                            position,
 841                            response_rx,
 842                        }))
 843                        .ok();
 844                    Some(response_tx)
 845                } else {
 846                    None
 847                };
 848
 849                if cfg!(debug_assertions) && std::env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
 850                    if let Some(debug_response_tx) = debug_response_tx {
 851                        debug_response_tx
 852                            .send(Err("Request skipped".to_string()))
 853                            .ok();
 854                    }
 855                    anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
 856                }
 857
 858                let response =
 859                    Self::send_prediction_request(client, llm_token, app_version, request).await;
 860
 861                if let Some(debug_response_tx) = debug_response_tx {
 862                    debug_response_tx
 863                        .send(
 864                            response
 865                                .as_ref()
 866                                .map_err(|err| err.to_string())
 867                                .map(|response| response.0.clone()),
 868                        )
 869                        .ok();
 870                }
 871
 872                response.map(|(res, usage)| (Some(res), usage))
 873            }
 874        });
 875
 876        let buffer = buffer.clone();
 877
 878        cx.spawn({
 879            let project = project.clone();
 880            async move |this, cx| {
 881                let Some(response) = Self::handle_api_response(&this, request_task.await, cx)?
 882                else {
 883                    return Ok(None);
 884                };
 885
 886                // TODO telemetry: duration, etc
 887                Ok(EditPrediction::from_response(response, &snapshot, &buffer, &project, cx).await)
 888            }
 889        })
 890    }
 891
 892    async fn send_prediction_request(
 893        client: Arc<Client>,
 894        llm_token: LlmApiToken,
 895        app_version: SemanticVersion,
 896        request: predict_edits_v3::PredictEditsRequest,
 897    ) -> Result<(
 898        predict_edits_v3::PredictEditsResponse,
 899        Option<EditPredictionUsage>,
 900    )> {
 901        let url = if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
 902            http_client::Url::parse(&predict_edits_url)?
 903        } else {
 904            client
 905                .http_client()
 906                .build_zed_llm_url("/predict_edits/v3", &[])?
 907        };
 908
 909        Self::send_api_request(
 910            |builder| {
 911                let req = builder
 912                    .uri(url.as_ref())
 913                    .body(serde_json::to_string(&request)?.into());
 914                Ok(req?)
 915            },
 916            client,
 917            llm_token,
 918            app_version,
 919        )
 920        .await
 921    }
 922
 923    fn handle_api_response<T>(
 924        this: &WeakEntity<Self>,
 925        response: Result<(T, Option<EditPredictionUsage>)>,
 926        cx: &mut gpui::AsyncApp,
 927    ) -> Result<T> {
 928        match response {
 929            Ok((data, usage)) => {
 930                if let Some(usage) = usage {
 931                    this.update(cx, |this, cx| {
 932                        this.user_store.update(cx, |user_store, cx| {
 933                            user_store.update_edit_prediction_usage(usage, cx);
 934                        });
 935                    })
 936                    .ok();
 937                }
 938                Ok(data)
 939            }
 940            Err(err) => {
 941                if err.is::<ZedUpdateRequiredError>() {
 942                    cx.update(|cx| {
 943                        this.update(cx, |this, _cx| {
 944                            this.update_required = true;
 945                        })
 946                        .ok();
 947
 948                        let error_message: SharedString = err.to_string().into();
 949                        show_app_notification(
 950                            NotificationId::unique::<ZedUpdateRequiredError>(),
 951                            cx,
 952                            move |cx| {
 953                                cx.new(|cx| {
 954                                    ErrorMessagePrompt::new(error_message.clone(), cx)
 955                                        .with_link_button("Update Zed", "https://zed.dev/releases")
 956                                })
 957                            },
 958                        );
 959                    })
 960                    .ok();
 961                }
 962                Err(err)
 963            }
 964        }
 965    }
 966
 967    async fn send_api_request<Res>(
 968        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
 969        client: Arc<Client>,
 970        llm_token: LlmApiToken,
 971        app_version: SemanticVersion,
 972    ) -> Result<(Res, Option<EditPredictionUsage>)>
 973    where
 974        Res: DeserializeOwned,
 975    {
 976        let http_client = client.http_client();
 977        let mut token = llm_token.acquire(&client).await?;
 978        let mut did_retry = false;
 979
 980        loop {
 981            let request_builder = http_client::Request::builder().method(Method::POST);
 982
 983            let request = build(
 984                request_builder
 985                    .header("Content-Type", "application/json")
 986                    .header("Authorization", format!("Bearer {}", token))
 987                    .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
 988            )?;
 989
 990            let mut response = http_client.send(request).await?;
 991
 992            if let Some(minimum_required_version) = response
 993                .headers()
 994                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
 995                .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
 996            {
 997                anyhow::ensure!(
 998                    app_version >= minimum_required_version,
 999                    ZedUpdateRequiredError {
1000                        minimum_version: minimum_required_version
1001                    }
1002                );
1003            }
1004
1005            if response.status().is_success() {
1006                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
1007
1008                let mut body = Vec::new();
1009                response.body_mut().read_to_end(&mut body).await?;
1010                return Ok((serde_json::from_slice(&body)?, usage));
1011            } else if !did_retry
1012                && response
1013                    .headers()
1014                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1015                    .is_some()
1016            {
1017                did_retry = true;
1018                token = llm_token.refresh(&client).await?;
1019            } else {
1020                let mut body = String::new();
1021                response.body_mut().read_to_string(&mut body).await?;
1022                anyhow::bail!(
1023                    "Request failed with status: {:?}\nBody: {}",
1024                    response.status(),
1025                    body
1026                );
1027            }
1028        }
1029    }
1030
1031    pub const CONTEXT_RETRIEVAL_IDLE_DURATION: Duration = Duration::from_secs(10);
1032    pub const CONTEXT_RETRIEVAL_DEBOUNCE_DURATION: Duration = Duration::from_secs(3);
1033
1034    // Refresh the related excerpts when the user just beguns editing after
1035    // an idle period, and after they pause editing.
1036    fn refresh_context_if_needed(
1037        &mut self,
1038        project: &Entity<Project>,
1039        buffer: &Entity<language::Buffer>,
1040        cursor_position: language::Anchor,
1041        cx: &mut Context<Self>,
1042    ) {
1043        if !matches!(&self.options().context, ContextMode::Llm { .. }) {
1044            return;
1045        }
1046
1047        let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1048            return;
1049        };
1050
1051        let now = Instant::now();
1052        let was_idle = zeta_project
1053            .refresh_context_timestamp
1054            .map_or(true, |timestamp| {
1055                now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION
1056            });
1057        zeta_project.refresh_context_timestamp = Some(now);
1058        zeta_project.refresh_context_debounce_task = Some(cx.spawn({
1059            let buffer = buffer.clone();
1060            let project = project.clone();
1061            async move |this, cx| {
1062                if was_idle {
1063                    log::debug!("refetching edit prediction context after idle");
1064                } else {
1065                    cx.background_executor()
1066                        .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION)
1067                        .await;
1068                    log::debug!("refetching edit prediction context after pause");
1069                }
1070                this.update(cx, |this, cx| {
1071                    this.refresh_context(project, buffer, cursor_position, cx);
1072                })
1073                .ok()
1074            }
1075        }));
1076    }
1077
1078    // Refresh the related excerpts asynchronously. Ensure the task runs to completion,
1079    // and avoid spawning more than one concurrent task.
1080    fn refresh_context(
1081        &mut self,
1082        project: Entity<Project>,
1083        buffer: Entity<language::Buffer>,
1084        cursor_position: language::Anchor,
1085        cx: &mut Context<Self>,
1086    ) {
1087        let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1088            return;
1089        };
1090
1091        let debug_tx = self.debug_tx.clone();
1092
1093        zeta_project
1094            .refresh_context_task
1095            .get_or_insert(cx.spawn(async move |this, cx| {
1096                let related_excerpts = this
1097                    .update(cx, |this, cx| {
1098                        let Some(zeta_project) = this.projects.get(&project.entity_id()) else {
1099                            return Task::ready(anyhow::Ok(HashMap::default()));
1100                        };
1101
1102                        let ContextMode::Llm(options) = &this.options().context else {
1103                            return Task::ready(anyhow::Ok(HashMap::default()));
1104                        };
1105
1106                        let mut edit_history_unified_diff = String::new();
1107
1108                        for event in zeta_project.events.iter() {
1109                            if let Some(event) = event.to_request_event(cx) {
1110                                writeln!(&mut edit_history_unified_diff, "{event}").ok();
1111                            }
1112                        }
1113
1114                        find_related_excerpts(
1115                            buffer.clone(),
1116                            cursor_position,
1117                            &project,
1118                            edit_history_unified_diff,
1119                            options,
1120                            debug_tx,
1121                            cx,
1122                        )
1123                    })
1124                    .ok()?
1125                    .await
1126                    .log_err()
1127                    .unwrap_or_default();
1128                this.update(cx, |this, _cx| {
1129                    let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
1130                        return;
1131                    };
1132                    zeta_project.context = Some(related_excerpts);
1133                    zeta_project.refresh_context_task.take();
1134                    if let Some(debug_tx) = &this.debug_tx {
1135                        debug_tx
1136                            .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished(
1137                                ZetaContextRetrievalDebugInfo {
1138                                    project,
1139                                    timestamp: Instant::now(),
1140                                },
1141                            ))
1142                            .ok();
1143                    }
1144                })
1145                .ok()
1146            }));
1147    }
1148
1149    fn gather_nearby_diagnostics(
1150        cursor_offset: usize,
1151        diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
1152        snapshot: &BufferSnapshot,
1153        max_diagnostics_bytes: usize,
1154    ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
1155        // TODO: Could make this more efficient
1156        let mut diagnostic_groups = Vec::new();
1157        for (language_server_id, diagnostics) in diagnostic_sets {
1158            let mut groups = Vec::new();
1159            diagnostics.groups(*language_server_id, &mut groups, &snapshot);
1160            diagnostic_groups.extend(
1161                groups
1162                    .into_iter()
1163                    .map(|(_, group)| group.resolve::<usize>(&snapshot)),
1164            );
1165        }
1166
1167        // sort by proximity to cursor
1168        diagnostic_groups.sort_by_key(|group| {
1169            let range = &group.entries[group.primary_ix].range;
1170            if range.start >= cursor_offset {
1171                range.start - cursor_offset
1172            } else if cursor_offset >= range.end {
1173                cursor_offset - range.end
1174            } else {
1175                (cursor_offset - range.start).min(range.end - cursor_offset)
1176            }
1177        });
1178
1179        let mut results = Vec::new();
1180        let mut diagnostic_groups_truncated = false;
1181        let mut diagnostics_byte_count = 0;
1182        for group in diagnostic_groups {
1183            let raw_value = serde_json::value::to_raw_value(&group).unwrap();
1184            diagnostics_byte_count += raw_value.get().len();
1185            if diagnostics_byte_count > max_diagnostics_bytes {
1186                diagnostic_groups_truncated = true;
1187                break;
1188            }
1189            results.push(predict_edits_v3::DiagnosticGroup(raw_value));
1190        }
1191
1192        (results, diagnostic_groups_truncated)
1193    }
1194
1195    // TODO: Dedupe with similar code in request_prediction?
1196    pub fn cloud_request_for_zeta_cli(
1197        &mut self,
1198        project: &Entity<Project>,
1199        buffer: &Entity<Buffer>,
1200        position: language::Anchor,
1201        cx: &mut Context<Self>,
1202    ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
1203        let project_state = self.projects.get(&project.entity_id());
1204
1205        let index_state = project_state.map(|state| {
1206            state
1207                .syntax_index
1208                .read_with(cx, |index, _cx| index.state().clone())
1209        });
1210        let options = self.options.clone();
1211        let snapshot = buffer.read(cx).snapshot();
1212        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
1213            return Task::ready(Err(anyhow!("No file path for excerpt")));
1214        };
1215        let worktree_snapshots = project
1216            .read(cx)
1217            .worktrees(cx)
1218            .map(|worktree| worktree.read(cx).snapshot())
1219            .collect::<Vec<_>>();
1220
1221        let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
1222            let mut path = f.worktree.read(cx).absolutize(&f.path);
1223            if path.pop() { Some(path) } else { None }
1224        });
1225
1226        cx.background_spawn(async move {
1227            let index_state = if let Some(index_state) = index_state {
1228                Some(index_state.lock_owned().await)
1229            } else {
1230                None
1231            };
1232
1233            let cursor_point = position.to_point(&snapshot);
1234
1235            let debug_info = true;
1236            EditPredictionContext::gather_context(
1237                cursor_point,
1238                &snapshot,
1239                parent_abs_path.as_deref(),
1240                match &options.context {
1241                    ContextMode::Llm(_) => {
1242                        // TODO
1243                        panic!("Llm mode not supported in zeta cli yet");
1244                    }
1245                    ContextMode::Syntax(edit_prediction_context_options) => {
1246                        edit_prediction_context_options
1247                    }
1248                },
1249                index_state.as_deref(),
1250            )
1251            .context("Failed to select excerpt")
1252            .map(|context| {
1253                make_syntax_context_cloud_request(
1254                    excerpt_path.into(),
1255                    context,
1256                    // TODO pass everything
1257                    Vec::new(),
1258                    false,
1259                    Vec::new(),
1260                    false,
1261                    None,
1262                    debug_info,
1263                    &worktree_snapshots,
1264                    index_state.as_deref(),
1265                    Some(options.max_prompt_bytes),
1266                    options.prompt_format,
1267                )
1268            })
1269        })
1270    }
1271
1272    pub fn wait_for_initial_indexing(
1273        &mut self,
1274        project: &Entity<Project>,
1275        cx: &mut App,
1276    ) -> Task<Result<()>> {
1277        let zeta_project = self.get_or_init_zeta_project(project, cx);
1278        zeta_project
1279            .syntax_index
1280            .read(cx)
1281            .wait_for_initial_file_indexing(cx)
1282    }
1283}
1284
1285#[derive(Error, Debug)]
1286#[error(
1287    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1288)]
1289pub struct ZedUpdateRequiredError {
1290    minimum_version: SemanticVersion,
1291}
1292
1293fn make_syntax_context_cloud_request(
1294    excerpt_path: Arc<Path>,
1295    context: EditPredictionContext,
1296    events: Vec<predict_edits_v3::Event>,
1297    can_collect_data: bool,
1298    diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
1299    diagnostic_groups_truncated: bool,
1300    git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
1301    debug_info: bool,
1302    worktrees: &Vec<worktree::Snapshot>,
1303    index_state: Option<&SyntaxIndexState>,
1304    prompt_max_bytes: Option<usize>,
1305    prompt_format: PromptFormat,
1306) -> predict_edits_v3::PredictEditsRequest {
1307    let mut signatures = Vec::new();
1308    let mut declaration_to_signature_index = HashMap::default();
1309    let mut referenced_declarations = Vec::new();
1310
1311    for snippet in context.declarations {
1312        let project_entry_id = snippet.declaration.project_entry_id();
1313        let Some(path) = worktrees.iter().find_map(|worktree| {
1314            worktree.entry_for_id(project_entry_id).map(|entry| {
1315                let mut full_path = RelPathBuf::new();
1316                full_path.push(worktree.root_name());
1317                full_path.push(&entry.path);
1318                full_path
1319            })
1320        }) else {
1321            continue;
1322        };
1323
1324        let parent_index = index_state.and_then(|index_state| {
1325            snippet.declaration.parent().and_then(|parent| {
1326                add_signature(
1327                    parent,
1328                    &mut declaration_to_signature_index,
1329                    &mut signatures,
1330                    index_state,
1331                )
1332            })
1333        });
1334
1335        let (text, text_is_truncated) = snippet.declaration.item_text();
1336        referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
1337            path: path.as_std_path().into(),
1338            text: text.into(),
1339            range: snippet.declaration.item_line_range(),
1340            text_is_truncated,
1341            signature_range: snippet.declaration.signature_range_in_item_text(),
1342            parent_index,
1343            signature_score: snippet.score(DeclarationStyle::Signature),
1344            declaration_score: snippet.score(DeclarationStyle::Declaration),
1345            score_components: snippet.components,
1346        });
1347    }
1348
1349    let excerpt_parent = index_state.and_then(|index_state| {
1350        context
1351            .excerpt
1352            .parent_declarations
1353            .last()
1354            .and_then(|(parent, _)| {
1355                add_signature(
1356                    *parent,
1357                    &mut declaration_to_signature_index,
1358                    &mut signatures,
1359                    index_state,
1360                )
1361            })
1362    });
1363
1364    predict_edits_v3::PredictEditsRequest {
1365        excerpt_path,
1366        excerpt: context.excerpt_text.body,
1367        excerpt_line_range: context.excerpt.line_range,
1368        excerpt_range: context.excerpt.range,
1369        cursor_point: predict_edits_v3::Point {
1370            line: predict_edits_v3::Line(context.cursor_point.row),
1371            column: context.cursor_point.column,
1372        },
1373        referenced_declarations,
1374        included_files: vec![],
1375        signatures,
1376        excerpt_parent,
1377        events,
1378        can_collect_data,
1379        diagnostic_groups,
1380        diagnostic_groups_truncated,
1381        git_info,
1382        debug_info,
1383        prompt_max_bytes,
1384        prompt_format,
1385    }
1386}
1387
1388fn add_signature(
1389    declaration_id: DeclarationId,
1390    declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
1391    signatures: &mut Vec<Signature>,
1392    index: &SyntaxIndexState,
1393) -> Option<usize> {
1394    if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
1395        return Some(*signature_index);
1396    }
1397    let Some(parent_declaration) = index.declaration(declaration_id) else {
1398        log::error!("bug: missing parent declaration");
1399        return None;
1400    };
1401    let parent_index = parent_declaration.parent().and_then(|parent| {
1402        add_signature(parent, declaration_to_signature_index, signatures, index)
1403    });
1404    let (text, text_is_truncated) = parent_declaration.signature_text();
1405    let signature_index = signatures.len();
1406    signatures.push(Signature {
1407        text: text.into(),
1408        text_is_truncated,
1409        parent_index,
1410        range: parent_declaration.signature_line_range(),
1411    });
1412    declaration_to_signature_index.insert(declaration_id, signature_index);
1413    Some(signature_index)
1414}
1415
1416#[cfg(test)]
1417mod tests {
1418    use std::{
1419        path::{Path, PathBuf},
1420        sync::Arc,
1421    };
1422
1423    use client::UserStore;
1424    use clock::FakeSystemClock;
1425    use cloud_llm_client::predict_edits_v3::{self, Point};
1426    use edit_prediction_context::Line;
1427    use futures::{
1428        AsyncReadExt, StreamExt,
1429        channel::{mpsc, oneshot},
1430    };
1431    use gpui::{
1432        Entity, TestAppContext,
1433        http_client::{FakeHttpClient, Response},
1434        prelude::*,
1435    };
1436    use indoc::indoc;
1437    use language::{LanguageServerId, OffsetRangeExt as _};
1438    use pretty_assertions::{assert_eq, assert_matches};
1439    use project::{FakeFs, Project};
1440    use serde_json::json;
1441    use settings::SettingsStore;
1442    use util::path;
1443    use uuid::Uuid;
1444
1445    use crate::{BufferEditPrediction, Zeta};
1446
1447    #[gpui::test]
1448    async fn test_current_state(cx: &mut TestAppContext) {
1449        let (zeta, mut req_rx) = init_test(cx);
1450        let fs = FakeFs::new(cx.executor());
1451        fs.insert_tree(
1452            "/root",
1453            json!({
1454                "1.txt": "Hello!\nHow\nBye",
1455                "2.txt": "Hola!\nComo\nAdios"
1456            }),
1457        )
1458        .await;
1459        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1460
1461        zeta.update(cx, |zeta, cx| {
1462            zeta.register_project(&project, cx);
1463        });
1464
1465        let buffer1 = project
1466            .update(cx, |project, cx| {
1467                let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
1468                project.open_buffer(path, cx)
1469            })
1470            .await
1471            .unwrap();
1472        let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
1473        let position = snapshot1.anchor_before(language::Point::new(1, 3));
1474
1475        // Prediction for current file
1476
1477        let prediction_task = zeta.update(cx, |zeta, cx| {
1478            zeta.refresh_prediction(&project, &buffer1, position, cx)
1479        });
1480        let (_request, respond_tx) = req_rx.next().await.unwrap();
1481        respond_tx
1482            .send(predict_edits_v3::PredictEditsResponse {
1483                request_id: Uuid::new_v4(),
1484                edits: vec![predict_edits_v3::Edit {
1485                    path: Path::new(path!("root/1.txt")).into(),
1486                    range: Line(0)..Line(snapshot1.max_point().row + 1),
1487                    content: "Hello!\nHow are you?\nBye".into(),
1488                }],
1489                debug_info: None,
1490            })
1491            .unwrap();
1492        prediction_task.await.unwrap();
1493
1494        zeta.read_with(cx, |zeta, cx| {
1495            let prediction = zeta
1496                .current_prediction_for_buffer(&buffer1, &project, cx)
1497                .unwrap();
1498            assert_matches!(prediction, BufferEditPrediction::Local { .. });
1499        });
1500
1501        // Prediction for another file
1502        let prediction_task = zeta.update(cx, |zeta, cx| {
1503            zeta.refresh_prediction(&project, &buffer1, position, cx)
1504        });
1505        let (_request, respond_tx) = req_rx.next().await.unwrap();
1506        respond_tx
1507            .send(predict_edits_v3::PredictEditsResponse {
1508                request_id: Uuid::new_v4(),
1509                edits: vec![predict_edits_v3::Edit {
1510                    path: Path::new(path!("root/2.txt")).into(),
1511                    range: Line(0)..Line(snapshot1.max_point().row + 1),
1512                    content: "Hola!\nComo estas?\nAdios".into(),
1513                }],
1514                debug_info: None,
1515            })
1516            .unwrap();
1517        prediction_task.await.unwrap();
1518        zeta.read_with(cx, |zeta, cx| {
1519            let prediction = zeta
1520                .current_prediction_for_buffer(&buffer1, &project, cx)
1521                .unwrap();
1522            assert_matches!(
1523                prediction,
1524                BufferEditPrediction::Jump { prediction } if prediction.path.as_ref() == Path::new(path!("root/2.txt"))
1525            );
1526        });
1527
1528        let buffer2 = project
1529            .update(cx, |project, cx| {
1530                let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
1531                project.open_buffer(path, cx)
1532            })
1533            .await
1534            .unwrap();
1535
1536        zeta.read_with(cx, |zeta, cx| {
1537            let prediction = zeta
1538                .current_prediction_for_buffer(&buffer2, &project, cx)
1539                .unwrap();
1540            assert_matches!(prediction, BufferEditPrediction::Local { .. });
1541        });
1542    }
1543
1544    #[gpui::test]
1545    async fn test_simple_request(cx: &mut TestAppContext) {
1546        let (zeta, mut req_rx) = init_test(cx);
1547        let fs = FakeFs::new(cx.executor());
1548        fs.insert_tree(
1549            "/root",
1550            json!({
1551                "foo.md":  "Hello!\nHow\nBye"
1552            }),
1553        )
1554        .await;
1555        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1556
1557        let buffer = project
1558            .update(cx, |project, cx| {
1559                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1560                project.open_buffer(path, cx)
1561            })
1562            .await
1563            .unwrap();
1564        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1565        let position = snapshot.anchor_before(language::Point::new(1, 3));
1566
1567        let prediction_task = zeta.update(cx, |zeta, cx| {
1568            zeta.request_prediction(&project, &buffer, position, cx)
1569        });
1570
1571        let (request, respond_tx) = req_rx.next().await.unwrap();
1572        assert_eq!(
1573            request.excerpt_path.as_ref(),
1574            Path::new(path!("root/foo.md"))
1575        );
1576        assert_eq!(
1577            request.cursor_point,
1578            Point {
1579                line: Line(1),
1580                column: 3
1581            }
1582        );
1583
1584        respond_tx
1585            .send(predict_edits_v3::PredictEditsResponse {
1586                request_id: Uuid::new_v4(),
1587                edits: vec![predict_edits_v3::Edit {
1588                    path: Path::new(path!("root/foo.md")).into(),
1589                    range: Line(0)..Line(snapshot.max_point().row + 1),
1590                    content: "Hello!\nHow are you?\nBye".into(),
1591                }],
1592                debug_info: None,
1593            })
1594            .unwrap();
1595
1596        let prediction = prediction_task.await.unwrap().unwrap();
1597
1598        assert_eq!(prediction.edits.len(), 1);
1599        assert_eq!(
1600            prediction.edits[0].0.to_point(&snapshot).start,
1601            language::Point::new(1, 3)
1602        );
1603        assert_eq!(prediction.edits[0].1, " are you?");
1604    }
1605
1606    #[gpui::test]
1607    async fn test_request_events(cx: &mut TestAppContext) {
1608        let (zeta, mut req_rx) = init_test(cx);
1609        let fs = FakeFs::new(cx.executor());
1610        fs.insert_tree(
1611            "/root",
1612            json!({
1613                "foo.md": "Hello!\n\nBye"
1614            }),
1615        )
1616        .await;
1617        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1618
1619        let buffer = project
1620            .update(cx, |project, cx| {
1621                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1622                project.open_buffer(path, cx)
1623            })
1624            .await
1625            .unwrap();
1626
1627        zeta.update(cx, |zeta, cx| {
1628            zeta.register_buffer(&buffer, &project, cx);
1629        });
1630
1631        buffer.update(cx, |buffer, cx| {
1632            buffer.edit(vec![(7..7, "How")], None, cx);
1633        });
1634
1635        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1636        let position = snapshot.anchor_before(language::Point::new(1, 3));
1637
1638        let prediction_task = zeta.update(cx, |zeta, cx| {
1639            zeta.request_prediction(&project, &buffer, position, cx)
1640        });
1641
1642        let (request, respond_tx) = req_rx.next().await.unwrap();
1643
1644        assert_eq!(request.events.len(), 1);
1645        assert_eq!(
1646            request.events[0],
1647            predict_edits_v3::Event::BufferChange {
1648                path: Some(PathBuf::from(path!("root/foo.md"))),
1649                old_path: None,
1650                diff: indoc! {"
1651                        @@ -1,3 +1,3 @@
1652                         Hello!
1653                        -
1654                        +How
1655                         Bye
1656                    "}
1657                .to_string(),
1658                predicted: false
1659            }
1660        );
1661
1662        respond_tx
1663            .send(predict_edits_v3::PredictEditsResponse {
1664                request_id: Uuid::new_v4(),
1665                edits: vec![predict_edits_v3::Edit {
1666                    path: Path::new(path!("root/foo.md")).into(),
1667                    range: Line(0)..Line(snapshot.max_point().row + 1),
1668                    content: "Hello!\nHow are you?\nBye".into(),
1669                }],
1670                debug_info: None,
1671            })
1672            .unwrap();
1673
1674        let prediction = prediction_task.await.unwrap().unwrap();
1675
1676        assert_eq!(prediction.edits.len(), 1);
1677        assert_eq!(
1678            prediction.edits[0].0.to_point(&snapshot).start,
1679            language::Point::new(1, 3)
1680        );
1681        assert_eq!(prediction.edits[0].1, " are you?");
1682    }
1683
1684    #[gpui::test]
1685    async fn test_request_diagnostics(cx: &mut TestAppContext) {
1686        let (zeta, mut req_rx) = init_test(cx);
1687        let fs = FakeFs::new(cx.executor());
1688        fs.insert_tree(
1689            "/root",
1690            json!({
1691                "foo.md": "Hello!\nBye"
1692            }),
1693        )
1694        .await;
1695        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1696
1697        let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1698        let diagnostic = lsp::Diagnostic {
1699            range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1700            severity: Some(lsp::DiagnosticSeverity::ERROR),
1701            message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1702            ..Default::default()
1703        };
1704
1705        project.update(cx, |project, cx| {
1706            project.lsp_store().update(cx, |lsp_store, cx| {
1707                // Create some diagnostics
1708                lsp_store
1709                    .update_diagnostics(
1710                        LanguageServerId(0),
1711                        lsp::PublishDiagnosticsParams {
1712                            uri: path_to_buffer_uri.clone(),
1713                            diagnostics: vec![diagnostic],
1714                            version: None,
1715                        },
1716                        None,
1717                        language::DiagnosticSourceKind::Pushed,
1718                        &[],
1719                        cx,
1720                    )
1721                    .unwrap();
1722            });
1723        });
1724
1725        let buffer = project
1726            .update(cx, |project, cx| {
1727                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1728                project.open_buffer(path, cx)
1729            })
1730            .await
1731            .unwrap();
1732
1733        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1734        let position = snapshot.anchor_before(language::Point::new(0, 0));
1735
1736        let _prediction_task = zeta.update(cx, |zeta, cx| {
1737            zeta.request_prediction(&project, &buffer, position, cx)
1738        });
1739
1740        let (request, _respond_tx) = req_rx.next().await.unwrap();
1741
1742        assert_eq!(request.diagnostic_groups.len(), 1);
1743        let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1744            .unwrap();
1745        // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1746        assert_eq!(
1747            value,
1748            json!({
1749                "entries": [{
1750                    "range": {
1751                        "start": 8,
1752                        "end": 10
1753                    },
1754                    "diagnostic": {
1755                        "source": null,
1756                        "code": null,
1757                        "code_description": null,
1758                        "severity": 1,
1759                        "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1760                        "markdown": null,
1761                        "group_id": 0,
1762                        "is_primary": true,
1763                        "is_disk_based": false,
1764                        "is_unnecessary": false,
1765                        "source_kind": "Pushed",
1766                        "data": null,
1767                        "underline": true
1768                    }
1769                }],
1770                "primary_ix": 0
1771            })
1772        );
1773    }
1774
1775    fn init_test(
1776        cx: &mut TestAppContext,
1777    ) -> (
1778        Entity<Zeta>,
1779        mpsc::UnboundedReceiver<(
1780            predict_edits_v3::PredictEditsRequest,
1781            oneshot::Sender<predict_edits_v3::PredictEditsResponse>,
1782        )>,
1783    ) {
1784        cx.update(move |cx| {
1785            let settings_store = SettingsStore::test(cx);
1786            cx.set_global(settings_store);
1787            language::init(cx);
1788            Project::init_settings(cx);
1789
1790            let (req_tx, req_rx) = mpsc::unbounded();
1791
1792            let http_client = FakeHttpClient::create({
1793                move |req| {
1794                    let uri = req.uri().path().to_string();
1795                    let mut body = req.into_body();
1796                    let req_tx = req_tx.clone();
1797                    async move {
1798                        let resp = match uri.as_str() {
1799                            "/client/llm_tokens" => serde_json::to_string(&json!({
1800                                "token": "test"
1801                            }))
1802                            .unwrap(),
1803                            "/predict_edits/v3" => {
1804                                let mut buf = Vec::new();
1805                                body.read_to_end(&mut buf).await.ok();
1806                                let req = serde_json::from_slice(&buf).unwrap();
1807
1808                                let (res_tx, res_rx) = oneshot::channel();
1809                                req_tx.unbounded_send((req, res_tx)).unwrap();
1810                                serde_json::to_string(&res_rx.await?).unwrap()
1811                            }
1812                            _ => {
1813                                panic!("Unexpected path: {}", uri)
1814                            }
1815                        };
1816
1817                        Ok(Response::builder().body(resp.into()).unwrap())
1818                    }
1819                }
1820            });
1821
1822            let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1823            client.cloud_client().set_credentials(1, "test".into());
1824
1825            language_model::init(client.clone(), cx);
1826
1827            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1828            let zeta = Zeta::global(&client, &user_store, cx);
1829
1830            (zeta, req_rx)
1831        })
1832    }
1833}