zeta2.rs

   1use anyhow::{Context as _, Result, anyhow};
   2use chrono::TimeDelta;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
   5use cloud_llm_client::{
   6    AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
   7    ZED_VERSION_HEADER_NAME,
   8};
   9use cloud_zeta2_prompt::{DEFAULT_MAX_PROMPT_BYTES, build_prompt};
  10use collections::HashMap;
  11use edit_prediction_context::{
  12    DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
  13    EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionScoreOptions, Line,
  14    SyntaxIndex, SyntaxIndexState,
  15};
  16use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
  17use futures::AsyncReadExt as _;
  18use futures::channel::{mpsc, oneshot};
  19use gpui::http_client::{AsyncBody, Method};
  20use gpui::{
  21    App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
  22    http_client, prelude::*,
  23};
  24use language::{Anchor, Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
  25use language::{BufferSnapshot, OffsetRangeExt};
  26use language_model::{LlmApiToken, RefreshLlmTokenListener};
  27use project::Project;
  28use release_channel::AppVersion;
  29use serde::de::DeserializeOwned;
  30use std::collections::{VecDeque, hash_map};
  31use std::ops::Range;
  32use std::path::Path;
  33use std::str::FromStr as _;
  34use std::sync::Arc;
  35use std::time::{Duration, Instant};
  36use thiserror::Error;
  37use util::ResultExt as _;
  38use util::rel_path::RelPathBuf;
  39use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  40
  41mod merge_excerpts;
  42mod prediction;
  43mod provider;
  44mod related_excerpts;
  45
  46use crate::merge_excerpts::merge_excerpts;
  47use crate::prediction::EditPrediction;
  48use crate::related_excerpts::find_related_excerpts;
  49pub use crate::related_excerpts::{LlmContextOptions, SearchToolQuery};
  50pub use provider::ZetaEditPredictionProvider;
  51
  52const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
  53
  54/// Maximum number of events to track.
  55const MAX_EVENT_COUNT: usize = 16;
  56
  57pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
  58    max_bytes: 512,
  59    min_bytes: 128,
  60    target_before_cursor_over_total_bytes: 0.5,
  61};
  62
  63pub const DEFAULT_CONTEXT_OPTIONS: ContextMode = ContextMode::Llm(DEFAULT_LLM_CONTEXT_OPTIONS);
  64
  65pub const DEFAULT_LLM_CONTEXT_OPTIONS: LlmContextOptions = LlmContextOptions {
  66    excerpt: DEFAULT_EXCERPT_OPTIONS,
  67};
  68
  69pub const DEFAULT_SYNTAX_CONTEXT_OPTIONS: EditPredictionContextOptions =
  70    EditPredictionContextOptions {
  71        use_imports: true,
  72        max_retrieved_declarations: 0,
  73        excerpt: DEFAULT_EXCERPT_OPTIONS,
  74        score: EditPredictionScoreOptions {
  75            omit_excerpt_overlaps: true,
  76        },
  77    };
  78
  79pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
  80    context: DEFAULT_CONTEXT_OPTIONS,
  81    max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
  82    max_diagnostic_bytes: 2048,
  83    prompt_format: PromptFormat::DEFAULT,
  84    file_indexing_parallelism: 1,
  85};
  86
  87pub struct Zeta2FeatureFlag;
  88
  89impl FeatureFlag for Zeta2FeatureFlag {
  90    const NAME: &'static str = "zeta2";
  91
  92    fn enabled_for_staff() -> bool {
  93        false
  94    }
  95}
  96
  97#[derive(Clone)]
  98struct ZetaGlobal(Entity<Zeta>);
  99
 100impl Global for ZetaGlobal {}
 101
 102pub struct Zeta {
 103    client: Arc<Client>,
 104    user_store: Entity<UserStore>,
 105    llm_token: LlmApiToken,
 106    _llm_token_subscription: Subscription,
 107    projects: HashMap<EntityId, ZetaProject>,
 108    options: ZetaOptions,
 109    update_required: bool,
 110    debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
 111}
 112
 113#[derive(Debug, Clone, PartialEq)]
 114pub struct ZetaOptions {
 115    pub context: ContextMode,
 116    pub max_prompt_bytes: usize,
 117    pub max_diagnostic_bytes: usize,
 118    pub prompt_format: predict_edits_v3::PromptFormat,
 119    pub file_indexing_parallelism: usize,
 120}
 121
 122#[derive(Debug, Clone, PartialEq)]
 123pub enum ContextMode {
 124    Llm(LlmContextOptions),
 125    Syntax(EditPredictionContextOptions),
 126}
 127
 128impl ContextMode {
 129    pub fn excerpt(&self) -> &EditPredictionExcerptOptions {
 130        match self {
 131            ContextMode::Llm(options) => &options.excerpt,
 132            ContextMode::Syntax(options) => &options.excerpt,
 133        }
 134    }
 135}
 136
 137pub enum ZetaDebugInfo {
 138    ContextRetrievalStarted(ZetaContextRetrievalDebugInfo),
 139    SearchQueriesGenerated(ZetaSearchQueryDebugInfo),
 140    SearchQueriesExecuted(ZetaContextRetrievalDebugInfo),
 141    SearchResultsFiltered(ZetaContextRetrievalDebugInfo),
 142    ContextRetrievalFinished(ZetaContextRetrievalDebugInfo),
 143    EditPredicted(ZetaEditPredictionDebugInfo),
 144}
 145
 146pub struct ZetaContextRetrievalDebugInfo {
 147    pub project: Entity<Project>,
 148    pub timestamp: Instant,
 149}
 150
 151pub struct ZetaEditPredictionDebugInfo {
 152    pub request: predict_edits_v3::PredictEditsRequest,
 153    pub retrieval_time: TimeDelta,
 154    pub buffer: WeakEntity<Buffer>,
 155    pub position: language::Anchor,
 156    pub local_prompt: Result<String, String>,
 157    pub response_rx: oneshot::Receiver<Result<predict_edits_v3::PredictEditsResponse, String>>,
 158}
 159
 160pub struct ZetaSearchQueryDebugInfo {
 161    pub project: Entity<Project>,
 162    pub timestamp: Instant,
 163    pub queries: Vec<SearchToolQuery>,
 164}
 165
 166pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
 167
 168struct ZetaProject {
 169    syntax_index: Entity<SyntaxIndex>,
 170    events: VecDeque<Event>,
 171    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 172    current_prediction: Option<CurrentEditPrediction>,
 173    context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
 174    refresh_context_task: Option<Task<Option<()>>>,
 175    refresh_context_debounce_task: Option<Task<Option<()>>>,
 176    refresh_context_timestamp: Option<Instant>,
 177}
 178
 179#[derive(Debug, Clone)]
 180struct CurrentEditPrediction {
 181    pub requested_by_buffer_id: EntityId,
 182    pub prediction: EditPrediction,
 183}
 184
 185impl CurrentEditPrediction {
 186    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 187        let Some(new_edits) = self
 188            .prediction
 189            .interpolate(&self.prediction.buffer.read(cx))
 190        else {
 191            return false;
 192        };
 193
 194        if self.prediction.buffer != old_prediction.prediction.buffer {
 195            return true;
 196        }
 197
 198        let Some(old_edits) = old_prediction
 199            .prediction
 200            .interpolate(&old_prediction.prediction.buffer.read(cx))
 201        else {
 202            return true;
 203        };
 204
 205        // This reduces the occurrence of UI thrash from replacing edits
 206        //
 207        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 208        if self.requested_by_buffer_id == self.prediction.buffer.entity_id()
 209            && self.requested_by_buffer_id == old_prediction.prediction.buffer.entity_id()
 210            && old_edits.len() == 1
 211            && new_edits.len() == 1
 212        {
 213            let (old_range, old_text) = &old_edits[0];
 214            let (new_range, new_text) = &new_edits[0];
 215            new_range == old_range && new_text.starts_with(old_text)
 216        } else {
 217            true
 218        }
 219    }
 220}
 221
 222/// A prediction from the perspective of a buffer.
 223#[derive(Debug)]
 224enum BufferEditPrediction<'a> {
 225    Local { prediction: &'a EditPrediction },
 226    Jump { prediction: &'a EditPrediction },
 227}
 228
 229struct RegisteredBuffer {
 230    snapshot: BufferSnapshot,
 231    _subscriptions: [gpui::Subscription; 2],
 232}
 233
 234#[derive(Clone)]
 235pub enum Event {
 236    BufferChange {
 237        old_snapshot: BufferSnapshot,
 238        new_snapshot: BufferSnapshot,
 239        timestamp: Instant,
 240    },
 241}
 242
 243impl Event {
 244    pub fn to_request_event(&self, cx: &App) -> Option<predict_edits_v3::Event> {
 245        match self {
 246            Event::BufferChange {
 247                old_snapshot,
 248                new_snapshot,
 249                ..
 250            } => {
 251                let path = new_snapshot.file().map(|f| f.full_path(cx));
 252
 253                let old_path = old_snapshot.file().and_then(|f| {
 254                    let old_path = f.full_path(cx);
 255                    if Some(&old_path) != path.as_ref() {
 256                        Some(old_path)
 257                    } else {
 258                        None
 259                    }
 260                });
 261
 262                // TODO [zeta2] move to bg?
 263                let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
 264
 265                if path == old_path && diff.is_empty() {
 266                    None
 267                } else {
 268                    Some(predict_edits_v3::Event::BufferChange {
 269                        old_path,
 270                        path,
 271                        diff,
 272                        //todo: Actually detect if this edit was predicted or not
 273                        predicted: false,
 274                    })
 275                }
 276            }
 277        }
 278    }
 279}
 280
 281impl Zeta {
 282    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 283        cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
 284    }
 285
 286    pub fn global(
 287        client: &Arc<Client>,
 288        user_store: &Entity<UserStore>,
 289        cx: &mut App,
 290    ) -> Entity<Self> {
 291        cx.try_global::<ZetaGlobal>()
 292            .map(|global| global.0.clone())
 293            .unwrap_or_else(|| {
 294                let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 295                cx.set_global(ZetaGlobal(zeta.clone()));
 296                zeta
 297            })
 298    }
 299
 300    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 301        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 302
 303        Self {
 304            projects: HashMap::default(),
 305            client,
 306            user_store,
 307            options: DEFAULT_OPTIONS,
 308            llm_token: LlmApiToken::default(),
 309            _llm_token_subscription: cx.subscribe(
 310                &refresh_llm_token_listener,
 311                |this, _listener, _event, cx| {
 312                    let client = this.client.clone();
 313                    let llm_token = this.llm_token.clone();
 314                    cx.spawn(async move |_this, _cx| {
 315                        llm_token.refresh(&client).await?;
 316                        anyhow::Ok(())
 317                    })
 318                    .detach_and_log_err(cx);
 319                },
 320            ),
 321            update_required: false,
 322            debug_tx: None,
 323        }
 324    }
 325
 326    pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<ZetaDebugInfo> {
 327        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
 328        self.debug_tx = Some(debug_watch_tx);
 329        debug_watch_rx
 330    }
 331
 332    pub fn options(&self) -> &ZetaOptions {
 333        &self.options
 334    }
 335
 336    pub fn set_options(&mut self, options: ZetaOptions) {
 337        self.options = options;
 338    }
 339
 340    pub fn clear_history(&mut self) {
 341        for zeta_project in self.projects.values_mut() {
 342            zeta_project.events.clear();
 343        }
 344    }
 345
 346    pub fn history_for_project(&self, project: &Entity<Project>) -> impl Iterator<Item = &Event> {
 347        self.projects
 348            .get(&project.entity_id())
 349            .map(|project| project.events.iter())
 350            .into_iter()
 351            .flatten()
 352    }
 353
 354    pub fn context_for_project(
 355        &self,
 356        project: &Entity<Project>,
 357    ) -> impl Iterator<Item = (Entity<Buffer>, &[Range<Anchor>])> {
 358        self.projects
 359            .get(&project.entity_id())
 360            .and_then(|project| {
 361                Some(
 362                    project
 363                        .context
 364                        .as_ref()?
 365                        .iter()
 366                        .map(|(buffer, ranges)| (buffer.clone(), ranges.as_slice())),
 367                )
 368            })
 369            .into_iter()
 370            .flatten()
 371    }
 372
 373    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 374        self.user_store.read(cx).edit_prediction_usage()
 375    }
 376
 377    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
 378        self.get_or_init_zeta_project(project, cx);
 379    }
 380
 381    pub fn register_buffer(
 382        &mut self,
 383        buffer: &Entity<Buffer>,
 384        project: &Entity<Project>,
 385        cx: &mut Context<Self>,
 386    ) {
 387        let zeta_project = self.get_or_init_zeta_project(project, cx);
 388        Self::register_buffer_impl(zeta_project, buffer, project, cx);
 389    }
 390
 391    fn get_or_init_zeta_project(
 392        &mut self,
 393        project: &Entity<Project>,
 394        cx: &mut App,
 395    ) -> &mut ZetaProject {
 396        self.projects
 397            .entry(project.entity_id())
 398            .or_insert_with(|| ZetaProject {
 399                syntax_index: cx.new(|cx| {
 400                    SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
 401                }),
 402                events: VecDeque::new(),
 403                registered_buffers: HashMap::default(),
 404                current_prediction: None,
 405                context: None,
 406                refresh_context_task: None,
 407                refresh_context_debounce_task: None,
 408                refresh_context_timestamp: None,
 409            })
 410    }
 411
 412    fn register_buffer_impl<'a>(
 413        zeta_project: &'a mut ZetaProject,
 414        buffer: &Entity<Buffer>,
 415        project: &Entity<Project>,
 416        cx: &mut Context<Self>,
 417    ) -> &'a mut RegisteredBuffer {
 418        let buffer_id = buffer.entity_id();
 419        match zeta_project.registered_buffers.entry(buffer_id) {
 420            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 421            hash_map::Entry::Vacant(entry) => {
 422                let snapshot = buffer.read(cx).snapshot();
 423                let project_entity_id = project.entity_id();
 424                entry.insert(RegisteredBuffer {
 425                    snapshot,
 426                    _subscriptions: [
 427                        cx.subscribe(buffer, {
 428                            let project = project.downgrade();
 429                            move |this, buffer, event, cx| {
 430                                if let language::BufferEvent::Edited = event
 431                                    && let Some(project) = project.upgrade()
 432                                {
 433                                    this.report_changes_for_buffer(&buffer, &project, cx);
 434                                }
 435                            }
 436                        }),
 437                        cx.observe_release(buffer, move |this, _buffer, _cx| {
 438                            let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
 439                            else {
 440                                return;
 441                            };
 442                            zeta_project.registered_buffers.remove(&buffer_id);
 443                        }),
 444                    ],
 445                })
 446            }
 447        }
 448    }
 449
 450    fn report_changes_for_buffer(
 451        &mut self,
 452        buffer: &Entity<Buffer>,
 453        project: &Entity<Project>,
 454        cx: &mut Context<Self>,
 455    ) -> BufferSnapshot {
 456        let zeta_project = self.get_or_init_zeta_project(project, cx);
 457        let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
 458
 459        let new_snapshot = buffer.read(cx).snapshot();
 460        if new_snapshot.version != registered_buffer.snapshot.version {
 461            let old_snapshot =
 462                std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
 463            Self::push_event(
 464                zeta_project,
 465                Event::BufferChange {
 466                    old_snapshot,
 467                    new_snapshot: new_snapshot.clone(),
 468                    timestamp: Instant::now(),
 469                },
 470            );
 471        }
 472
 473        new_snapshot
 474    }
 475
 476    fn push_event(zeta_project: &mut ZetaProject, event: Event) {
 477        let events = &mut zeta_project.events;
 478
 479        if let Some(Event::BufferChange {
 480            new_snapshot: last_new_snapshot,
 481            timestamp: last_timestamp,
 482            ..
 483        }) = events.back_mut()
 484        {
 485            // Coalesce edits for the same buffer when they happen one after the other.
 486            let Event::BufferChange {
 487                old_snapshot,
 488                new_snapshot,
 489                timestamp,
 490            } = &event;
 491
 492            if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
 493                && old_snapshot.remote_id() == last_new_snapshot.remote_id()
 494                && old_snapshot.version == last_new_snapshot.version
 495            {
 496                *last_new_snapshot = new_snapshot.clone();
 497                *last_timestamp = *timestamp;
 498                return;
 499            }
 500        }
 501
 502        if events.len() >= MAX_EVENT_COUNT {
 503            // These are halved instead of popping to improve prompt caching.
 504            events.drain(..MAX_EVENT_COUNT / 2);
 505        }
 506
 507        events.push_back(event);
 508    }
 509
 510    fn current_prediction_for_buffer(
 511        &self,
 512        buffer: &Entity<Buffer>,
 513        project: &Entity<Project>,
 514        cx: &App,
 515    ) -> Option<BufferEditPrediction<'_>> {
 516        let project_state = self.projects.get(&project.entity_id())?;
 517
 518        let CurrentEditPrediction {
 519            requested_by_buffer_id,
 520            prediction,
 521        } = project_state.current_prediction.as_ref()?;
 522
 523        if prediction.targets_buffer(buffer.read(cx), cx) {
 524            Some(BufferEditPrediction::Local { prediction })
 525        } else if *requested_by_buffer_id == buffer.entity_id() {
 526            Some(BufferEditPrediction::Jump { prediction })
 527        } else {
 528            None
 529        }
 530    }
 531
 532    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 533        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
 534            return;
 535        };
 536
 537        let Some(prediction) = project_state.current_prediction.take() else {
 538            return;
 539        };
 540        let request_id = prediction.prediction.id.into();
 541
 542        let client = self.client.clone();
 543        let llm_token = self.llm_token.clone();
 544        let app_version = AppVersion::global(cx);
 545        cx.spawn(async move |this, cx| {
 546            let url = if let Ok(predict_edits_url) = std::env::var("ZED_ACCEPT_PREDICTION_URL") {
 547                http_client::Url::parse(&predict_edits_url)?
 548            } else {
 549                client
 550                    .http_client()
 551                    .build_zed_llm_url("/predict_edits/accept", &[])?
 552            };
 553
 554            let response = cx
 555                .background_spawn(Self::send_api_request::<()>(
 556                    move |builder| {
 557                        let req = builder.uri(url.as_ref()).body(
 558                            serde_json::to_string(&AcceptEditPredictionBody { request_id })?.into(),
 559                        );
 560                        Ok(req?)
 561                    },
 562                    client,
 563                    llm_token,
 564                    app_version,
 565                ))
 566                .await;
 567
 568            Self::handle_api_response(&this, response, cx)?;
 569            anyhow::Ok(())
 570        })
 571        .detach_and_log_err(cx);
 572    }
 573
 574    fn discard_current_prediction(&mut self, project: &Entity<Project>) {
 575        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 576            project_state.current_prediction.take();
 577        };
 578    }
 579
 580    pub fn refresh_prediction(
 581        &mut self,
 582        project: &Entity<Project>,
 583        buffer: &Entity<Buffer>,
 584        position: language::Anchor,
 585        cx: &mut Context<Self>,
 586    ) -> Task<Result<()>> {
 587        let request_task = self.request_prediction(project, buffer, position, cx);
 588        let buffer = buffer.clone();
 589        let project = project.clone();
 590
 591        cx.spawn(async move |this, cx| {
 592            if let Some(prediction) = request_task.await? {
 593                this.update(cx, |this, cx| {
 594                    let project_state = this
 595                        .projects
 596                        .get_mut(&project.entity_id())
 597                        .context("Project not found")?;
 598
 599                    let new_prediction = CurrentEditPrediction {
 600                        requested_by_buffer_id: buffer.entity_id(),
 601                        prediction: prediction,
 602                    };
 603
 604                    if project_state
 605                        .current_prediction
 606                        .as_ref()
 607                        .is_none_or(|old_prediction| {
 608                            new_prediction.should_replace_prediction(&old_prediction, cx)
 609                        })
 610                    {
 611                        project_state.current_prediction = Some(new_prediction);
 612                    }
 613                    anyhow::Ok(())
 614                })??;
 615            }
 616            Ok(())
 617        })
 618    }
 619
 620    fn request_prediction(
 621        &mut self,
 622        project: &Entity<Project>,
 623        buffer: &Entity<Buffer>,
 624        position: language::Anchor,
 625        cx: &mut Context<Self>,
 626    ) -> Task<Result<Option<EditPrediction>>> {
 627        let project_state = self.projects.get(&project.entity_id());
 628
 629        let index_state = project_state.map(|state| {
 630            state
 631                .syntax_index
 632                .read_with(cx, |index, _cx| index.state().clone())
 633        });
 634        let options = self.options.clone();
 635        let snapshot = buffer.read(cx).snapshot();
 636        let Some(excerpt_path) = snapshot
 637            .file()
 638            .map(|path| -> Arc<Path> { path.full_path(cx).into() })
 639        else {
 640            return Task::ready(Err(anyhow!("No file path for excerpt")));
 641        };
 642        let client = self.client.clone();
 643        let llm_token = self.llm_token.clone();
 644        let app_version = AppVersion::global(cx);
 645        let worktree_snapshots = project
 646            .read(cx)
 647            .worktrees(cx)
 648            .map(|worktree| worktree.read(cx).snapshot())
 649            .collect::<Vec<_>>();
 650        let debug_tx = self.debug_tx.clone();
 651
 652        let events = project_state
 653            .map(|state| {
 654                state
 655                    .events
 656                    .iter()
 657                    .filter_map(|event| event.to_request_event(cx))
 658                    .collect::<Vec<_>>()
 659            })
 660            .unwrap_or_default();
 661
 662        let diagnostics = snapshot.diagnostic_sets().clone();
 663
 664        let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
 665            let mut path = f.worktree.read(cx).absolutize(&f.path);
 666            if path.pop() { Some(path) } else { None }
 667        });
 668
 669        // TODO data collection
 670        let can_collect_data = cx.is_staff();
 671
 672        let mut included_files = project_state
 673            .and_then(|project_state| project_state.context.as_ref())
 674            .unwrap_or(&HashMap::default())
 675            .iter()
 676            .filter_map(|(buffer, ranges)| {
 677                let buffer = buffer.read(cx);
 678                Some((
 679                    buffer.snapshot(),
 680                    buffer.file()?.full_path(cx).into(),
 681                    ranges.clone(),
 682                ))
 683            })
 684            .collect::<Vec<_>>();
 685
 686        let request_task = cx.background_spawn({
 687            let snapshot = snapshot.clone();
 688            let buffer = buffer.clone();
 689            async move {
 690                let index_state = if let Some(index_state) = index_state {
 691                    Some(index_state.lock_owned().await)
 692                } else {
 693                    None
 694                };
 695
 696                let cursor_offset = position.to_offset(&snapshot);
 697                let cursor_point = cursor_offset.to_point(&snapshot);
 698
 699                let before_retrieval = chrono::Utc::now();
 700
 701                let (diagnostic_groups, diagnostic_groups_truncated) =
 702                    Self::gather_nearby_diagnostics(
 703                        cursor_offset,
 704                        &diagnostics,
 705                        &snapshot,
 706                        options.max_diagnostic_bytes,
 707                    );
 708
 709                let request = match options.context {
 710                    ContextMode::Llm(context_options) => {
 711                        let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
 712                            cursor_point,
 713                            &snapshot,
 714                            &context_options.excerpt,
 715                            index_state.as_deref(),
 716                        ) else {
 717                            return Ok((None, None));
 718                        };
 719
 720                        let excerpt_anchor_range = snapshot.anchor_after(excerpt.range.start)
 721                            ..snapshot.anchor_before(excerpt.range.end);
 722
 723                        if let Some(buffer_ix) = included_files
 724                            .iter()
 725                            .position(|(buffer, _, _)| buffer.remote_id() == snapshot.remote_id())
 726                        {
 727                            let (buffer, _, ranges) = &mut included_files[buffer_ix];
 728                            let range_ix = ranges
 729                                .binary_search_by(|probe| {
 730                                    probe
 731                                        .start
 732                                        .cmp(&excerpt_anchor_range.start, buffer)
 733                                        .then(excerpt_anchor_range.end.cmp(&probe.end, buffer))
 734                                })
 735                                .unwrap_or_else(|ix| ix);
 736
 737                            ranges.insert(range_ix, excerpt_anchor_range);
 738                            let last_ix = included_files.len() - 1;
 739                            included_files.swap(buffer_ix, last_ix);
 740                        } else {
 741                            included_files.push((
 742                                snapshot,
 743                                excerpt_path.clone(),
 744                                vec![excerpt_anchor_range],
 745                            ));
 746                        }
 747
 748                        let included_files = included_files
 749                            .into_iter()
 750                            .map(|(buffer, path, ranges)| {
 751                                let excerpts = merge_excerpts(
 752                                    &buffer,
 753                                    ranges.iter().map(|range| {
 754                                        let point_range = range.to_point(&buffer);
 755                                        Line(point_range.start.row)..Line(point_range.end.row)
 756                                    }),
 757                                );
 758                                predict_edits_v3::IncludedFile {
 759                                    path,
 760                                    max_row: Line(buffer.max_point().row),
 761                                    excerpts,
 762                                }
 763                            })
 764                            .collect::<Vec<_>>();
 765
 766                        predict_edits_v3::PredictEditsRequest {
 767                            excerpt_path,
 768                            excerpt: String::new(),
 769                            excerpt_line_range: Line(0)..Line(0),
 770                            excerpt_range: 0..0,
 771                            cursor_point: predict_edits_v3::Point {
 772                                line: predict_edits_v3::Line(cursor_point.row),
 773                                column: cursor_point.column,
 774                            },
 775                            included_files,
 776                            referenced_declarations: vec![],
 777                            events,
 778                            can_collect_data,
 779                            diagnostic_groups,
 780                            diagnostic_groups_truncated,
 781                            debug_info: debug_tx.is_some(),
 782                            prompt_max_bytes: Some(options.max_prompt_bytes),
 783                            prompt_format: options.prompt_format,
 784                            // TODO [zeta2]
 785                            signatures: vec![],
 786                            excerpt_parent: None,
 787                            git_info: None,
 788                        }
 789                    }
 790                    ContextMode::Syntax(context_options) => {
 791                        let Some(context) = EditPredictionContext::gather_context(
 792                            cursor_point,
 793                            &snapshot,
 794                            parent_abs_path.as_deref(),
 795                            &context_options,
 796                            index_state.as_deref(),
 797                        ) else {
 798                            return Ok((None, None));
 799                        };
 800
 801                        make_syntax_context_cloud_request(
 802                            excerpt_path,
 803                            context,
 804                            events,
 805                            can_collect_data,
 806                            diagnostic_groups,
 807                            diagnostic_groups_truncated,
 808                            None,
 809                            debug_tx.is_some(),
 810                            &worktree_snapshots,
 811                            index_state.as_deref(),
 812                            Some(options.max_prompt_bytes),
 813                            options.prompt_format,
 814                        )
 815                    }
 816                };
 817
 818                let retrieval_time = chrono::Utc::now() - before_retrieval;
 819
 820                let debug_response_tx = if let Some(debug_tx) = &debug_tx {
 821                    let (response_tx, response_rx) = oneshot::channel();
 822
 823                    let local_prompt = build_prompt(&request)
 824                        .map(|(prompt, _)| prompt)
 825                        .map_err(|err| err.to_string());
 826
 827                    debug_tx
 828                        .unbounded_send(ZetaDebugInfo::EditPredicted(ZetaEditPredictionDebugInfo {
 829                            request: request.clone(),
 830                            retrieval_time,
 831                            buffer: buffer.downgrade(),
 832                            local_prompt,
 833                            position,
 834                            response_rx,
 835                        }))
 836                        .ok();
 837                    Some(response_tx)
 838                } else {
 839                    None
 840                };
 841
 842                if cfg!(debug_assertions) && std::env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
 843                    if let Some(debug_response_tx) = debug_response_tx {
 844                        debug_response_tx
 845                            .send(Err("Request skipped".to_string()))
 846                            .ok();
 847                    }
 848                    anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
 849                }
 850
 851                let response =
 852                    Self::send_prediction_request(client, llm_token, app_version, request).await;
 853
 854                if let Some(debug_response_tx) = debug_response_tx {
 855                    debug_response_tx
 856                        .send(
 857                            response
 858                                .as_ref()
 859                                .map_err(|err| err.to_string())
 860                                .map(|response| response.0.clone()),
 861                        )
 862                        .ok();
 863                }
 864
 865                response.map(|(res, usage)| (Some(res), usage))
 866            }
 867        });
 868
 869        let buffer = buffer.clone();
 870
 871        cx.spawn({
 872            let project = project.clone();
 873            async move |this, cx| {
 874                let Some(response) = Self::handle_api_response(&this, request_task.await, cx)?
 875                else {
 876                    return Ok(None);
 877                };
 878
 879                // TODO telemetry: duration, etc
 880                Ok(EditPrediction::from_response(response, &snapshot, &buffer, &project, cx).await)
 881            }
 882        })
 883    }
 884
 885    async fn send_prediction_request(
 886        client: Arc<Client>,
 887        llm_token: LlmApiToken,
 888        app_version: SemanticVersion,
 889        request: predict_edits_v3::PredictEditsRequest,
 890    ) -> Result<(
 891        predict_edits_v3::PredictEditsResponse,
 892        Option<EditPredictionUsage>,
 893    )> {
 894        let url = if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
 895            http_client::Url::parse(&predict_edits_url)?
 896        } else {
 897            client
 898                .http_client()
 899                .build_zed_llm_url("/predict_edits/v3", &[])?
 900        };
 901
 902        Self::send_api_request(
 903            |builder| {
 904                let req = builder
 905                    .uri(url.as_ref())
 906                    .body(serde_json::to_string(&request)?.into());
 907                Ok(req?)
 908            },
 909            client,
 910            llm_token,
 911            app_version,
 912        )
 913        .await
 914    }
 915
 916    fn handle_api_response<T>(
 917        this: &WeakEntity<Self>,
 918        response: Result<(T, Option<EditPredictionUsage>)>,
 919        cx: &mut gpui::AsyncApp,
 920    ) -> Result<T> {
 921        match response {
 922            Ok((data, usage)) => {
 923                if let Some(usage) = usage {
 924                    this.update(cx, |this, cx| {
 925                        this.user_store.update(cx, |user_store, cx| {
 926                            user_store.update_edit_prediction_usage(usage, cx);
 927                        });
 928                    })
 929                    .ok();
 930                }
 931                Ok(data)
 932            }
 933            Err(err) => {
 934                if err.is::<ZedUpdateRequiredError>() {
 935                    cx.update(|cx| {
 936                        this.update(cx, |this, _cx| {
 937                            this.update_required = true;
 938                        })
 939                        .ok();
 940
 941                        let error_message: SharedString = err.to_string().into();
 942                        show_app_notification(
 943                            NotificationId::unique::<ZedUpdateRequiredError>(),
 944                            cx,
 945                            move |cx| {
 946                                cx.new(|cx| {
 947                                    ErrorMessagePrompt::new(error_message.clone(), cx)
 948                                        .with_link_button("Update Zed", "https://zed.dev/releases")
 949                                })
 950                            },
 951                        );
 952                    })
 953                    .ok();
 954                }
 955                Err(err)
 956            }
 957        }
 958    }
 959
 960    async fn send_api_request<Res>(
 961        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
 962        client: Arc<Client>,
 963        llm_token: LlmApiToken,
 964        app_version: SemanticVersion,
 965    ) -> Result<(Res, Option<EditPredictionUsage>)>
 966    where
 967        Res: DeserializeOwned,
 968    {
 969        let http_client = client.http_client();
 970        let mut token = llm_token.acquire(&client).await?;
 971        let mut did_retry = false;
 972
 973        loop {
 974            let request_builder = http_client::Request::builder().method(Method::POST);
 975
 976            let request = build(
 977                request_builder
 978                    .header("Content-Type", "application/json")
 979                    .header("Authorization", format!("Bearer {}", token))
 980                    .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
 981            )?;
 982
 983            let mut response = http_client.send(request).await?;
 984
 985            if let Some(minimum_required_version) = response
 986                .headers()
 987                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
 988                .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
 989            {
 990                anyhow::ensure!(
 991                    app_version >= minimum_required_version,
 992                    ZedUpdateRequiredError {
 993                        minimum_version: minimum_required_version
 994                    }
 995                );
 996            }
 997
 998            if response.status().is_success() {
 999                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
1000
1001                let mut body = Vec::new();
1002                response.body_mut().read_to_end(&mut body).await?;
1003                return Ok((serde_json::from_slice(&body)?, usage));
1004            } else if !did_retry
1005                && response
1006                    .headers()
1007                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1008                    .is_some()
1009            {
1010                did_retry = true;
1011                token = llm_token.refresh(&client).await?;
1012            } else {
1013                let mut body = String::new();
1014                response.body_mut().read_to_string(&mut body).await?;
1015                anyhow::bail!(
1016                    "Request failed with status: {:?}\nBody: {}",
1017                    response.status(),
1018                    body
1019                );
1020            }
1021        }
1022    }
1023
1024    pub const CONTEXT_RETRIEVAL_IDLE_DURATION: Duration = Duration::from_secs(10);
1025    pub const CONTEXT_RETRIEVAL_DEBOUNCE_DURATION: Duration = Duration::from_secs(3);
1026
1027    // Refresh the related excerpts when the user just beguns editing after
1028    // an idle period, and after they pause editing.
1029    fn refresh_context_if_needed(
1030        &mut self,
1031        project: &Entity<Project>,
1032        buffer: &Entity<language::Buffer>,
1033        cursor_position: language::Anchor,
1034        cx: &mut Context<Self>,
1035    ) {
1036        if !matches!(&self.options().context, ContextMode::Llm { .. }) {
1037            return;
1038        }
1039
1040        let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1041            return;
1042        };
1043
1044        let now = Instant::now();
1045        let was_idle = zeta_project
1046            .refresh_context_timestamp
1047            .map_or(true, |timestamp| {
1048                now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION
1049            });
1050        zeta_project.refresh_context_timestamp = Some(now);
1051        zeta_project.refresh_context_debounce_task = Some(cx.spawn({
1052            let buffer = buffer.clone();
1053            let project = project.clone();
1054            async move |this, cx| {
1055                if was_idle {
1056                    log::debug!("refetching edit prediction context after idle");
1057                } else {
1058                    cx.background_executor()
1059                        .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION)
1060                        .await;
1061                    log::debug!("refetching edit prediction context after pause");
1062                }
1063                this.update(cx, |this, cx| {
1064                    this.refresh_context(project, buffer, cursor_position, cx);
1065                })
1066                .ok()
1067            }
1068        }));
1069    }
1070
1071    // Refresh the related excerpts asynchronously. Ensure the task runs to completion,
1072    // and avoid spawning more than one concurrent task.
1073    fn refresh_context(
1074        &mut self,
1075        project: Entity<Project>,
1076        buffer: Entity<language::Buffer>,
1077        cursor_position: language::Anchor,
1078        cx: &mut Context<Self>,
1079    ) {
1080        let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1081            return;
1082        };
1083
1084        let debug_tx = self.debug_tx.clone();
1085
1086        zeta_project
1087            .refresh_context_task
1088            .get_or_insert(cx.spawn(async move |this, cx| {
1089                if let Some(debug_tx) = &debug_tx {
1090                    debug_tx
1091                        .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted(
1092                            ZetaContextRetrievalDebugInfo {
1093                                project: project.clone(),
1094                                timestamp: Instant::now(),
1095                            },
1096                        ))
1097                        .ok();
1098                }
1099
1100                let related_excerpts = this
1101                    .update(cx, |this, cx| {
1102                        let Some(zeta_project) = this.projects.get(&project.entity_id()) else {
1103                            return Task::ready(anyhow::Ok(HashMap::default()));
1104                        };
1105
1106                        let ContextMode::Llm(options) = &this.options().context else {
1107                            return Task::ready(anyhow::Ok(HashMap::default()));
1108                        };
1109
1110                        find_related_excerpts(
1111                            buffer.clone(),
1112                            cursor_position,
1113                            &project,
1114                            zeta_project.events.iter(),
1115                            options,
1116                            debug_tx,
1117                            cx,
1118                        )
1119                    })
1120                    .ok()?
1121                    .await
1122                    .log_err()
1123                    .unwrap_or_default();
1124                this.update(cx, |this, _cx| {
1125                    let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
1126                        return;
1127                    };
1128                    zeta_project.context = Some(related_excerpts);
1129                    zeta_project.refresh_context_task.take();
1130                    if let Some(debug_tx) = &this.debug_tx {
1131                        debug_tx
1132                            .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished(
1133                                ZetaContextRetrievalDebugInfo {
1134                                    project,
1135                                    timestamp: Instant::now(),
1136                                },
1137                            ))
1138                            .ok();
1139                    }
1140                })
1141                .ok()
1142            }));
1143    }
1144
1145    fn gather_nearby_diagnostics(
1146        cursor_offset: usize,
1147        diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
1148        snapshot: &BufferSnapshot,
1149        max_diagnostics_bytes: usize,
1150    ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
1151        // TODO: Could make this more efficient
1152        let mut diagnostic_groups = Vec::new();
1153        for (language_server_id, diagnostics) in diagnostic_sets {
1154            let mut groups = Vec::new();
1155            diagnostics.groups(*language_server_id, &mut groups, &snapshot);
1156            diagnostic_groups.extend(
1157                groups
1158                    .into_iter()
1159                    .map(|(_, group)| group.resolve::<usize>(&snapshot)),
1160            );
1161        }
1162
1163        // sort by proximity to cursor
1164        diagnostic_groups.sort_by_key(|group| {
1165            let range = &group.entries[group.primary_ix].range;
1166            if range.start >= cursor_offset {
1167                range.start - cursor_offset
1168            } else if cursor_offset >= range.end {
1169                cursor_offset - range.end
1170            } else {
1171                (cursor_offset - range.start).min(range.end - cursor_offset)
1172            }
1173        });
1174
1175        let mut results = Vec::new();
1176        let mut diagnostic_groups_truncated = false;
1177        let mut diagnostics_byte_count = 0;
1178        for group in diagnostic_groups {
1179            let raw_value = serde_json::value::to_raw_value(&group).unwrap();
1180            diagnostics_byte_count += raw_value.get().len();
1181            if diagnostics_byte_count > max_diagnostics_bytes {
1182                diagnostic_groups_truncated = true;
1183                break;
1184            }
1185            results.push(predict_edits_v3::DiagnosticGroup(raw_value));
1186        }
1187
1188        (results, diagnostic_groups_truncated)
1189    }
1190
1191    // TODO: Dedupe with similar code in request_prediction?
1192    pub fn cloud_request_for_zeta_cli(
1193        &mut self,
1194        project: &Entity<Project>,
1195        buffer: &Entity<Buffer>,
1196        position: language::Anchor,
1197        cx: &mut Context<Self>,
1198    ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
1199        let project_state = self.projects.get(&project.entity_id());
1200
1201        let index_state = project_state.map(|state| {
1202            state
1203                .syntax_index
1204                .read_with(cx, |index, _cx| index.state().clone())
1205        });
1206        let options = self.options.clone();
1207        let snapshot = buffer.read(cx).snapshot();
1208        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
1209            return Task::ready(Err(anyhow!("No file path for excerpt")));
1210        };
1211        let worktree_snapshots = project
1212            .read(cx)
1213            .worktrees(cx)
1214            .map(|worktree| worktree.read(cx).snapshot())
1215            .collect::<Vec<_>>();
1216
1217        let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
1218            let mut path = f.worktree.read(cx).absolutize(&f.path);
1219            if path.pop() { Some(path) } else { None }
1220        });
1221
1222        cx.background_spawn(async move {
1223            let index_state = if let Some(index_state) = index_state {
1224                Some(index_state.lock_owned().await)
1225            } else {
1226                None
1227            };
1228
1229            let cursor_point = position.to_point(&snapshot);
1230
1231            let debug_info = true;
1232            EditPredictionContext::gather_context(
1233                cursor_point,
1234                &snapshot,
1235                parent_abs_path.as_deref(),
1236                match &options.context {
1237                    ContextMode::Llm(_) => {
1238                        // TODO
1239                        panic!("Llm mode not supported in zeta cli yet");
1240                    }
1241                    ContextMode::Syntax(edit_prediction_context_options) => {
1242                        edit_prediction_context_options
1243                    }
1244                },
1245                index_state.as_deref(),
1246            )
1247            .context("Failed to select excerpt")
1248            .map(|context| {
1249                make_syntax_context_cloud_request(
1250                    excerpt_path.into(),
1251                    context,
1252                    // TODO pass everything
1253                    Vec::new(),
1254                    false,
1255                    Vec::new(),
1256                    false,
1257                    None,
1258                    debug_info,
1259                    &worktree_snapshots,
1260                    index_state.as_deref(),
1261                    Some(options.max_prompt_bytes),
1262                    options.prompt_format,
1263                )
1264            })
1265        })
1266    }
1267
1268    pub fn wait_for_initial_indexing(
1269        &mut self,
1270        project: &Entity<Project>,
1271        cx: &mut App,
1272    ) -> Task<Result<()>> {
1273        let zeta_project = self.get_or_init_zeta_project(project, cx);
1274        zeta_project
1275            .syntax_index
1276            .read(cx)
1277            .wait_for_initial_file_indexing(cx)
1278    }
1279}
1280
1281#[derive(Error, Debug)]
1282#[error(
1283    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1284)]
1285pub struct ZedUpdateRequiredError {
1286    minimum_version: SemanticVersion,
1287}
1288
1289fn make_syntax_context_cloud_request(
1290    excerpt_path: Arc<Path>,
1291    context: EditPredictionContext,
1292    events: Vec<predict_edits_v3::Event>,
1293    can_collect_data: bool,
1294    diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
1295    diagnostic_groups_truncated: bool,
1296    git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
1297    debug_info: bool,
1298    worktrees: &Vec<worktree::Snapshot>,
1299    index_state: Option<&SyntaxIndexState>,
1300    prompt_max_bytes: Option<usize>,
1301    prompt_format: PromptFormat,
1302) -> predict_edits_v3::PredictEditsRequest {
1303    let mut signatures = Vec::new();
1304    let mut declaration_to_signature_index = HashMap::default();
1305    let mut referenced_declarations = Vec::new();
1306
1307    for snippet in context.declarations {
1308        let project_entry_id = snippet.declaration.project_entry_id();
1309        let Some(path) = worktrees.iter().find_map(|worktree| {
1310            worktree.entry_for_id(project_entry_id).map(|entry| {
1311                let mut full_path = RelPathBuf::new();
1312                full_path.push(worktree.root_name());
1313                full_path.push(&entry.path);
1314                full_path
1315            })
1316        }) else {
1317            continue;
1318        };
1319
1320        let parent_index = index_state.and_then(|index_state| {
1321            snippet.declaration.parent().and_then(|parent| {
1322                add_signature(
1323                    parent,
1324                    &mut declaration_to_signature_index,
1325                    &mut signatures,
1326                    index_state,
1327                )
1328            })
1329        });
1330
1331        let (text, text_is_truncated) = snippet.declaration.item_text();
1332        referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
1333            path: path.as_std_path().into(),
1334            text: text.into(),
1335            range: snippet.declaration.item_line_range(),
1336            text_is_truncated,
1337            signature_range: snippet.declaration.signature_range_in_item_text(),
1338            parent_index,
1339            signature_score: snippet.score(DeclarationStyle::Signature),
1340            declaration_score: snippet.score(DeclarationStyle::Declaration),
1341            score_components: snippet.components,
1342        });
1343    }
1344
1345    let excerpt_parent = index_state.and_then(|index_state| {
1346        context
1347            .excerpt
1348            .parent_declarations
1349            .last()
1350            .and_then(|(parent, _)| {
1351                add_signature(
1352                    *parent,
1353                    &mut declaration_to_signature_index,
1354                    &mut signatures,
1355                    index_state,
1356                )
1357            })
1358    });
1359
1360    predict_edits_v3::PredictEditsRequest {
1361        excerpt_path,
1362        excerpt: context.excerpt_text.body,
1363        excerpt_line_range: context.excerpt.line_range,
1364        excerpt_range: context.excerpt.range,
1365        cursor_point: predict_edits_v3::Point {
1366            line: predict_edits_v3::Line(context.cursor_point.row),
1367            column: context.cursor_point.column,
1368        },
1369        referenced_declarations,
1370        included_files: vec![],
1371        signatures,
1372        excerpt_parent,
1373        events,
1374        can_collect_data,
1375        diagnostic_groups,
1376        diagnostic_groups_truncated,
1377        git_info,
1378        debug_info,
1379        prompt_max_bytes,
1380        prompt_format,
1381    }
1382}
1383
1384fn add_signature(
1385    declaration_id: DeclarationId,
1386    declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
1387    signatures: &mut Vec<Signature>,
1388    index: &SyntaxIndexState,
1389) -> Option<usize> {
1390    if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
1391        return Some(*signature_index);
1392    }
1393    let Some(parent_declaration) = index.declaration(declaration_id) else {
1394        log::error!("bug: missing parent declaration");
1395        return None;
1396    };
1397    let parent_index = parent_declaration.parent().and_then(|parent| {
1398        add_signature(parent, declaration_to_signature_index, signatures, index)
1399    });
1400    let (text, text_is_truncated) = parent_declaration.signature_text();
1401    let signature_index = signatures.len();
1402    signatures.push(Signature {
1403        text: text.into(),
1404        text_is_truncated,
1405        parent_index,
1406        range: parent_declaration.signature_line_range(),
1407    });
1408    declaration_to_signature_index.insert(declaration_id, signature_index);
1409    Some(signature_index)
1410}
1411
1412#[cfg(test)]
1413mod tests {
1414    use std::{
1415        path::{Path, PathBuf},
1416        sync::Arc,
1417    };
1418
1419    use client::UserStore;
1420    use clock::FakeSystemClock;
1421    use cloud_llm_client::predict_edits_v3::{self, Point};
1422    use edit_prediction_context::Line;
1423    use futures::{
1424        AsyncReadExt, StreamExt,
1425        channel::{mpsc, oneshot},
1426    };
1427    use gpui::{
1428        Entity, TestAppContext,
1429        http_client::{FakeHttpClient, Response},
1430        prelude::*,
1431    };
1432    use indoc::indoc;
1433    use language::{LanguageServerId, OffsetRangeExt as _};
1434    use pretty_assertions::{assert_eq, assert_matches};
1435    use project::{FakeFs, Project};
1436    use serde_json::json;
1437    use settings::SettingsStore;
1438    use util::path;
1439    use uuid::Uuid;
1440
1441    use crate::{BufferEditPrediction, Zeta};
1442
1443    #[gpui::test]
1444    async fn test_current_state(cx: &mut TestAppContext) {
1445        let (zeta, mut req_rx) = init_test(cx);
1446        let fs = FakeFs::new(cx.executor());
1447        fs.insert_tree(
1448            "/root",
1449            json!({
1450                "1.txt": "Hello!\nHow\nBye",
1451                "2.txt": "Hola!\nComo\nAdios"
1452            }),
1453        )
1454        .await;
1455        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1456
1457        zeta.update(cx, |zeta, cx| {
1458            zeta.register_project(&project, cx);
1459        });
1460
1461        let buffer1 = project
1462            .update(cx, |project, cx| {
1463                let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
1464                project.open_buffer(path, cx)
1465            })
1466            .await
1467            .unwrap();
1468        let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
1469        let position = snapshot1.anchor_before(language::Point::new(1, 3));
1470
1471        // Prediction for current file
1472
1473        let prediction_task = zeta.update(cx, |zeta, cx| {
1474            zeta.refresh_prediction(&project, &buffer1, position, cx)
1475        });
1476        let (_request, respond_tx) = req_rx.next().await.unwrap();
1477        respond_tx
1478            .send(predict_edits_v3::PredictEditsResponse {
1479                request_id: Uuid::new_v4(),
1480                edits: vec![predict_edits_v3::Edit {
1481                    path: Path::new(path!("root/1.txt")).into(),
1482                    range: Line(0)..Line(snapshot1.max_point().row + 1),
1483                    content: "Hello!\nHow are you?\nBye".into(),
1484                }],
1485                debug_info: None,
1486            })
1487            .unwrap();
1488        prediction_task.await.unwrap();
1489
1490        zeta.read_with(cx, |zeta, cx| {
1491            let prediction = zeta
1492                .current_prediction_for_buffer(&buffer1, &project, cx)
1493                .unwrap();
1494            assert_matches!(prediction, BufferEditPrediction::Local { .. });
1495        });
1496
1497        // Prediction for another file
1498        let prediction_task = zeta.update(cx, |zeta, cx| {
1499            zeta.refresh_prediction(&project, &buffer1, position, cx)
1500        });
1501        let (_request, respond_tx) = req_rx.next().await.unwrap();
1502        respond_tx
1503            .send(predict_edits_v3::PredictEditsResponse {
1504                request_id: Uuid::new_v4(),
1505                edits: vec![predict_edits_v3::Edit {
1506                    path: Path::new(path!("root/2.txt")).into(),
1507                    range: Line(0)..Line(snapshot1.max_point().row + 1),
1508                    content: "Hola!\nComo estas?\nAdios".into(),
1509                }],
1510                debug_info: None,
1511            })
1512            .unwrap();
1513        prediction_task.await.unwrap();
1514        zeta.read_with(cx, |zeta, cx| {
1515            let prediction = zeta
1516                .current_prediction_for_buffer(&buffer1, &project, cx)
1517                .unwrap();
1518            assert_matches!(
1519                prediction,
1520                BufferEditPrediction::Jump { prediction } if prediction.path.as_ref() == Path::new(path!("root/2.txt"))
1521            );
1522        });
1523
1524        let buffer2 = project
1525            .update(cx, |project, cx| {
1526                let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
1527                project.open_buffer(path, cx)
1528            })
1529            .await
1530            .unwrap();
1531
1532        zeta.read_with(cx, |zeta, cx| {
1533            let prediction = zeta
1534                .current_prediction_for_buffer(&buffer2, &project, cx)
1535                .unwrap();
1536            assert_matches!(prediction, BufferEditPrediction::Local { .. });
1537        });
1538    }
1539
1540    #[gpui::test]
1541    async fn test_simple_request(cx: &mut TestAppContext) {
1542        let (zeta, mut req_rx) = init_test(cx);
1543        let fs = FakeFs::new(cx.executor());
1544        fs.insert_tree(
1545            "/root",
1546            json!({
1547                "foo.md":  "Hello!\nHow\nBye"
1548            }),
1549        )
1550        .await;
1551        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1552
1553        let buffer = project
1554            .update(cx, |project, cx| {
1555                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1556                project.open_buffer(path, cx)
1557            })
1558            .await
1559            .unwrap();
1560        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1561        let position = snapshot.anchor_before(language::Point::new(1, 3));
1562
1563        let prediction_task = zeta.update(cx, |zeta, cx| {
1564            zeta.request_prediction(&project, &buffer, position, cx)
1565        });
1566
1567        let (request, respond_tx) = req_rx.next().await.unwrap();
1568        assert_eq!(
1569            request.excerpt_path.as_ref(),
1570            Path::new(path!("root/foo.md"))
1571        );
1572        assert_eq!(
1573            request.cursor_point,
1574            Point {
1575                line: Line(1),
1576                column: 3
1577            }
1578        );
1579
1580        respond_tx
1581            .send(predict_edits_v3::PredictEditsResponse {
1582                request_id: Uuid::new_v4(),
1583                edits: vec![predict_edits_v3::Edit {
1584                    path: Path::new(path!("root/foo.md")).into(),
1585                    range: Line(0)..Line(snapshot.max_point().row + 1),
1586                    content: "Hello!\nHow are you?\nBye".into(),
1587                }],
1588                debug_info: None,
1589            })
1590            .unwrap();
1591
1592        let prediction = prediction_task.await.unwrap().unwrap();
1593
1594        assert_eq!(prediction.edits.len(), 1);
1595        assert_eq!(
1596            prediction.edits[0].0.to_point(&snapshot).start,
1597            language::Point::new(1, 3)
1598        );
1599        assert_eq!(prediction.edits[0].1, " are you?");
1600    }
1601
1602    #[gpui::test]
1603    async fn test_request_events(cx: &mut TestAppContext) {
1604        let (zeta, mut req_rx) = init_test(cx);
1605        let fs = FakeFs::new(cx.executor());
1606        fs.insert_tree(
1607            "/root",
1608            json!({
1609                "foo.md": "Hello!\n\nBye"
1610            }),
1611        )
1612        .await;
1613        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1614
1615        let buffer = project
1616            .update(cx, |project, cx| {
1617                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1618                project.open_buffer(path, cx)
1619            })
1620            .await
1621            .unwrap();
1622
1623        zeta.update(cx, |zeta, cx| {
1624            zeta.register_buffer(&buffer, &project, cx);
1625        });
1626
1627        buffer.update(cx, |buffer, cx| {
1628            buffer.edit(vec![(7..7, "How")], None, cx);
1629        });
1630
1631        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1632        let position = snapshot.anchor_before(language::Point::new(1, 3));
1633
1634        let prediction_task = zeta.update(cx, |zeta, cx| {
1635            zeta.request_prediction(&project, &buffer, position, cx)
1636        });
1637
1638        let (request, respond_tx) = req_rx.next().await.unwrap();
1639
1640        assert_eq!(request.events.len(), 1);
1641        assert_eq!(
1642            request.events[0],
1643            predict_edits_v3::Event::BufferChange {
1644                path: Some(PathBuf::from(path!("root/foo.md"))),
1645                old_path: None,
1646                diff: indoc! {"
1647                        @@ -1,3 +1,3 @@
1648                         Hello!
1649                        -
1650                        +How
1651                         Bye
1652                    "}
1653                .to_string(),
1654                predicted: false
1655            }
1656        );
1657
1658        respond_tx
1659            .send(predict_edits_v3::PredictEditsResponse {
1660                request_id: Uuid::new_v4(),
1661                edits: vec![predict_edits_v3::Edit {
1662                    path: Path::new(path!("root/foo.md")).into(),
1663                    range: Line(0)..Line(snapshot.max_point().row + 1),
1664                    content: "Hello!\nHow are you?\nBye".into(),
1665                }],
1666                debug_info: None,
1667            })
1668            .unwrap();
1669
1670        let prediction = prediction_task.await.unwrap().unwrap();
1671
1672        assert_eq!(prediction.edits.len(), 1);
1673        assert_eq!(
1674            prediction.edits[0].0.to_point(&snapshot).start,
1675            language::Point::new(1, 3)
1676        );
1677        assert_eq!(prediction.edits[0].1, " are you?");
1678    }
1679
1680    #[gpui::test]
1681    async fn test_request_diagnostics(cx: &mut TestAppContext) {
1682        let (zeta, mut req_rx) = init_test(cx);
1683        let fs = FakeFs::new(cx.executor());
1684        fs.insert_tree(
1685            "/root",
1686            json!({
1687                "foo.md": "Hello!\nBye"
1688            }),
1689        )
1690        .await;
1691        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1692
1693        let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1694        let diagnostic = lsp::Diagnostic {
1695            range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1696            severity: Some(lsp::DiagnosticSeverity::ERROR),
1697            message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1698            ..Default::default()
1699        };
1700
1701        project.update(cx, |project, cx| {
1702            project.lsp_store().update(cx, |lsp_store, cx| {
1703                // Create some diagnostics
1704                lsp_store
1705                    .update_diagnostics(
1706                        LanguageServerId(0),
1707                        lsp::PublishDiagnosticsParams {
1708                            uri: path_to_buffer_uri.clone(),
1709                            diagnostics: vec![diagnostic],
1710                            version: None,
1711                        },
1712                        None,
1713                        language::DiagnosticSourceKind::Pushed,
1714                        &[],
1715                        cx,
1716                    )
1717                    .unwrap();
1718            });
1719        });
1720
1721        let buffer = project
1722            .update(cx, |project, cx| {
1723                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1724                project.open_buffer(path, cx)
1725            })
1726            .await
1727            .unwrap();
1728
1729        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1730        let position = snapshot.anchor_before(language::Point::new(0, 0));
1731
1732        let _prediction_task = zeta.update(cx, |zeta, cx| {
1733            zeta.request_prediction(&project, &buffer, position, cx)
1734        });
1735
1736        let (request, _respond_tx) = req_rx.next().await.unwrap();
1737
1738        assert_eq!(request.diagnostic_groups.len(), 1);
1739        let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1740            .unwrap();
1741        // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1742        assert_eq!(
1743            value,
1744            json!({
1745                "entries": [{
1746                    "range": {
1747                        "start": 8,
1748                        "end": 10
1749                    },
1750                    "diagnostic": {
1751                        "source": null,
1752                        "code": null,
1753                        "code_description": null,
1754                        "severity": 1,
1755                        "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1756                        "markdown": null,
1757                        "group_id": 0,
1758                        "is_primary": true,
1759                        "is_disk_based": false,
1760                        "is_unnecessary": false,
1761                        "source_kind": "Pushed",
1762                        "data": null,
1763                        "underline": true
1764                    }
1765                }],
1766                "primary_ix": 0
1767            })
1768        );
1769    }
1770
1771    fn init_test(
1772        cx: &mut TestAppContext,
1773    ) -> (
1774        Entity<Zeta>,
1775        mpsc::UnboundedReceiver<(
1776            predict_edits_v3::PredictEditsRequest,
1777            oneshot::Sender<predict_edits_v3::PredictEditsResponse>,
1778        )>,
1779    ) {
1780        cx.update(move |cx| {
1781            let settings_store = SettingsStore::test(cx);
1782            cx.set_global(settings_store);
1783            language::init(cx);
1784            Project::init_settings(cx);
1785
1786            let (req_tx, req_rx) = mpsc::unbounded();
1787
1788            let http_client = FakeHttpClient::create({
1789                move |req| {
1790                    let uri = req.uri().path().to_string();
1791                    let mut body = req.into_body();
1792                    let req_tx = req_tx.clone();
1793                    async move {
1794                        let resp = match uri.as_str() {
1795                            "/client/llm_tokens" => serde_json::to_string(&json!({
1796                                "token": "test"
1797                            }))
1798                            .unwrap(),
1799                            "/predict_edits/v3" => {
1800                                let mut buf = Vec::new();
1801                                body.read_to_end(&mut buf).await.ok();
1802                                let req = serde_json::from_slice(&buf).unwrap();
1803
1804                                let (res_tx, res_rx) = oneshot::channel();
1805                                req_tx.unbounded_send((req, res_tx)).unwrap();
1806                                serde_json::to_string(&res_rx.await?).unwrap()
1807                            }
1808                            _ => {
1809                                panic!("Unexpected path: {}", uri)
1810                            }
1811                        };
1812
1813                        Ok(Response::builder().body(resp.into()).unwrap())
1814                    }
1815                }
1816            });
1817
1818            let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1819            client.cloud_client().set_credentials(1, "test".into());
1820
1821            language_model::init(client.clone(), cx);
1822
1823            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1824            let zeta = Zeta::global(&client, &user_store, cx);
1825
1826            (zeta, req_rx)
1827        })
1828    }
1829}