1use anyhow::{Context as _, Result, anyhow};
   2use chrono::TimeDelta;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
   5use cloud_llm_client::{
   6    AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
   7    ZED_VERSION_HEADER_NAME,
   8};
   9use cloud_zeta2_prompt::{DEFAULT_MAX_PROMPT_BYTES, build_prompt};
  10use collections::HashMap;
  11use edit_prediction_context::{
  12    DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
  13    EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionScoreOptions, Line,
  14    SyntaxIndex, SyntaxIndexState,
  15};
  16use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
  17use futures::AsyncReadExt as _;
  18use futures::channel::{mpsc, oneshot};
  19use gpui::http_client::{AsyncBody, Method};
  20use gpui::{
  21    App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
  22    http_client, prelude::*,
  23};
  24use language::{Anchor, Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
  25use language::{BufferSnapshot, OffsetRangeExt};
  26use language_model::{LlmApiToken, RefreshLlmTokenListener};
  27use project::Project;
  28use release_channel::AppVersion;
  29use serde::de::DeserializeOwned;
  30use std::collections::{VecDeque, hash_map};
  31use std::ops::Range;
  32use std::path::Path;
  33use std::str::FromStr as _;
  34use std::sync::Arc;
  35use std::time::{Duration, Instant};
  36use thiserror::Error;
  37use util::ResultExt as _;
  38use util::rel_path::RelPathBuf;
  39use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  40
  41mod merge_excerpts;
  42mod prediction;
  43mod provider;
  44mod related_excerpts;
  45
  46use crate::merge_excerpts::merge_excerpts;
  47use crate::prediction::EditPrediction;
  48use crate::related_excerpts::find_related_excerpts;
  49pub use crate::related_excerpts::{LlmContextOptions, SearchToolQuery};
  50pub use provider::ZetaEditPredictionProvider;
  51
  52const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
  53
  54/// Maximum number of events to track.
  55const MAX_EVENT_COUNT: usize = 16;
  56
  57pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
  58    max_bytes: 512,
  59    min_bytes: 128,
  60    target_before_cursor_over_total_bytes: 0.5,
  61};
  62
  63pub const DEFAULT_CONTEXT_OPTIONS: ContextMode = ContextMode::Llm(DEFAULT_LLM_CONTEXT_OPTIONS);
  64
  65pub const DEFAULT_LLM_CONTEXT_OPTIONS: LlmContextOptions = LlmContextOptions {
  66    excerpt: DEFAULT_EXCERPT_OPTIONS,
  67};
  68
  69pub const DEFAULT_SYNTAX_CONTEXT_OPTIONS: EditPredictionContextOptions =
  70    EditPredictionContextOptions {
  71        use_imports: true,
  72        max_retrieved_declarations: 0,
  73        excerpt: DEFAULT_EXCERPT_OPTIONS,
  74        score: EditPredictionScoreOptions {
  75            omit_excerpt_overlaps: true,
  76        },
  77    };
  78
  79pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
  80    context: DEFAULT_CONTEXT_OPTIONS,
  81    max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
  82    max_diagnostic_bytes: 2048,
  83    prompt_format: PromptFormat::DEFAULT,
  84    file_indexing_parallelism: 1,
  85};
  86
  87pub struct Zeta2FeatureFlag;
  88
  89impl FeatureFlag for Zeta2FeatureFlag {
  90    const NAME: &'static str = "zeta2";
  91
  92    fn enabled_for_staff() -> bool {
  93        false
  94    }
  95}
  96
  97#[derive(Clone)]
  98struct ZetaGlobal(Entity<Zeta>);
  99
 100impl Global for ZetaGlobal {}
 101
 102pub struct Zeta {
 103    client: Arc<Client>,
 104    user_store: Entity<UserStore>,
 105    llm_token: LlmApiToken,
 106    _llm_token_subscription: Subscription,
 107    projects: HashMap<EntityId, ZetaProject>,
 108    options: ZetaOptions,
 109    update_required: bool,
 110    debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
 111}
 112
 113#[derive(Debug, Clone, PartialEq)]
 114pub struct ZetaOptions {
 115    pub context: ContextMode,
 116    pub max_prompt_bytes: usize,
 117    pub max_diagnostic_bytes: usize,
 118    pub prompt_format: predict_edits_v3::PromptFormat,
 119    pub file_indexing_parallelism: usize,
 120}
 121
 122#[derive(Debug, Clone, PartialEq)]
 123pub enum ContextMode {
 124    Llm(LlmContextOptions),
 125    Syntax(EditPredictionContextOptions),
 126}
 127
 128impl ContextMode {
 129    pub fn excerpt(&self) -> &EditPredictionExcerptOptions {
 130        match self {
 131            ContextMode::Llm(options) => &options.excerpt,
 132            ContextMode::Syntax(options) => &options.excerpt,
 133        }
 134    }
 135}
 136
 137pub enum ZetaDebugInfo {
 138    ContextRetrievalStarted(ZetaContextRetrievalDebugInfo),
 139    SearchQueriesGenerated(ZetaSearchQueryDebugInfo),
 140    SearchQueriesExecuted(ZetaContextRetrievalDebugInfo),
 141    ContextRetrievalFinished(ZetaContextRetrievalDebugInfo),
 142    EditPredicted(ZetaEditPredictionDebugInfo),
 143}
 144
 145pub struct ZetaContextRetrievalDebugInfo {
 146    pub project: Entity<Project>,
 147    pub timestamp: Instant,
 148}
 149
 150pub struct ZetaEditPredictionDebugInfo {
 151    pub request: predict_edits_v3::PredictEditsRequest,
 152    pub retrieval_time: TimeDelta,
 153    pub buffer: WeakEntity<Buffer>,
 154    pub position: language::Anchor,
 155    pub local_prompt: Result<String, String>,
 156    pub response_rx: oneshot::Receiver<Result<predict_edits_v3::PredictEditsResponse, String>>,
 157}
 158
 159pub struct ZetaSearchQueryDebugInfo {
 160    pub project: Entity<Project>,
 161    pub timestamp: Instant,
 162    pub queries: Vec<SearchToolQuery>,
 163}
 164
 165pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
 166
 167struct ZetaProject {
 168    syntax_index: Entity<SyntaxIndex>,
 169    events: VecDeque<Event>,
 170    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 171    current_prediction: Option<CurrentEditPrediction>,
 172    context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
 173    refresh_context_task: Option<Task<Option<()>>>,
 174    refresh_context_debounce_task: Option<Task<Option<()>>>,
 175    refresh_context_timestamp: Option<Instant>,
 176}
 177
 178#[derive(Debug, Clone)]
 179struct CurrentEditPrediction {
 180    pub requested_by_buffer_id: EntityId,
 181    pub prediction: EditPrediction,
 182}
 183
 184impl CurrentEditPrediction {
 185    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 186        let Some(new_edits) = self
 187            .prediction
 188            .interpolate(&self.prediction.buffer.read(cx))
 189        else {
 190            return false;
 191        };
 192
 193        if self.prediction.buffer != old_prediction.prediction.buffer {
 194            return true;
 195        }
 196
 197        let Some(old_edits) = old_prediction
 198            .prediction
 199            .interpolate(&old_prediction.prediction.buffer.read(cx))
 200        else {
 201            return true;
 202        };
 203
 204        // This reduces the occurrence of UI thrash from replacing edits
 205        //
 206        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 207        if self.requested_by_buffer_id == self.prediction.buffer.entity_id()
 208            && self.requested_by_buffer_id == old_prediction.prediction.buffer.entity_id()
 209            && old_edits.len() == 1
 210            && new_edits.len() == 1
 211        {
 212            let (old_range, old_text) = &old_edits[0];
 213            let (new_range, new_text) = &new_edits[0];
 214            new_range == old_range && new_text.starts_with(old_text)
 215        } else {
 216            true
 217        }
 218    }
 219}
 220
 221/// A prediction from the perspective of a buffer.
 222#[derive(Debug)]
 223enum BufferEditPrediction<'a> {
 224    Local { prediction: &'a EditPrediction },
 225    Jump { prediction: &'a EditPrediction },
 226}
 227
 228struct RegisteredBuffer {
 229    snapshot: BufferSnapshot,
 230    _subscriptions: [gpui::Subscription; 2],
 231}
 232
 233#[derive(Clone)]
 234pub enum Event {
 235    BufferChange {
 236        old_snapshot: BufferSnapshot,
 237        new_snapshot: BufferSnapshot,
 238        timestamp: Instant,
 239    },
 240}
 241
 242impl Event {
 243    pub fn to_request_event(&self, cx: &App) -> Option<predict_edits_v3::Event> {
 244        match self {
 245            Event::BufferChange {
 246                old_snapshot,
 247                new_snapshot,
 248                ..
 249            } => {
 250                let path = new_snapshot.file().map(|f| f.full_path(cx));
 251
 252                let old_path = old_snapshot.file().and_then(|f| {
 253                    let old_path = f.full_path(cx);
 254                    if Some(&old_path) != path.as_ref() {
 255                        Some(old_path)
 256                    } else {
 257                        None
 258                    }
 259                });
 260
 261                // TODO [zeta2] move to bg?
 262                let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
 263
 264                if path == old_path && diff.is_empty() {
 265                    None
 266                } else {
 267                    Some(predict_edits_v3::Event::BufferChange {
 268                        old_path,
 269                        path,
 270                        diff,
 271                        //todo: Actually detect if this edit was predicted or not
 272                        predicted: false,
 273                    })
 274                }
 275            }
 276        }
 277    }
 278}
 279
 280impl Zeta {
 281    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 282        cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
 283    }
 284
 285    pub fn global(
 286        client: &Arc<Client>,
 287        user_store: &Entity<UserStore>,
 288        cx: &mut App,
 289    ) -> Entity<Self> {
 290        cx.try_global::<ZetaGlobal>()
 291            .map(|global| global.0.clone())
 292            .unwrap_or_else(|| {
 293                let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 294                cx.set_global(ZetaGlobal(zeta.clone()));
 295                zeta
 296            })
 297    }
 298
 299    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 300        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 301
 302        Self {
 303            projects: HashMap::default(),
 304            client,
 305            user_store,
 306            options: DEFAULT_OPTIONS,
 307            llm_token: LlmApiToken::default(),
 308            _llm_token_subscription: cx.subscribe(
 309                &refresh_llm_token_listener,
 310                |this, _listener, _event, cx| {
 311                    let client = this.client.clone();
 312                    let llm_token = this.llm_token.clone();
 313                    cx.spawn(async move |_this, _cx| {
 314                        llm_token.refresh(&client).await?;
 315                        anyhow::Ok(())
 316                    })
 317                    .detach_and_log_err(cx);
 318                },
 319            ),
 320            update_required: false,
 321            debug_tx: None,
 322        }
 323    }
 324
 325    pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<ZetaDebugInfo> {
 326        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
 327        self.debug_tx = Some(debug_watch_tx);
 328        debug_watch_rx
 329    }
 330
 331    pub fn options(&self) -> &ZetaOptions {
 332        &self.options
 333    }
 334
 335    pub fn set_options(&mut self, options: ZetaOptions) {
 336        self.options = options;
 337    }
 338
 339    pub fn clear_history(&mut self) {
 340        for zeta_project in self.projects.values_mut() {
 341            zeta_project.events.clear();
 342        }
 343    }
 344
 345    pub fn history_for_project(&self, project: &Entity<Project>) -> impl Iterator<Item = &Event> {
 346        self.projects
 347            .get(&project.entity_id())
 348            .map(|project| project.events.iter())
 349            .into_iter()
 350            .flatten()
 351    }
 352
 353    pub fn context_for_project(
 354        &self,
 355        project: &Entity<Project>,
 356    ) -> impl Iterator<Item = (Entity<Buffer>, &[Range<Anchor>])> {
 357        self.projects
 358            .get(&project.entity_id())
 359            .and_then(|project| {
 360                Some(
 361                    project
 362                        .context
 363                        .as_ref()?
 364                        .iter()
 365                        .map(|(buffer, ranges)| (buffer.clone(), ranges.as_slice())),
 366                )
 367            })
 368            .into_iter()
 369            .flatten()
 370    }
 371
 372    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 373        self.user_store.read(cx).edit_prediction_usage()
 374    }
 375
 376    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
 377        self.get_or_init_zeta_project(project, cx);
 378    }
 379
 380    pub fn register_buffer(
 381        &mut self,
 382        buffer: &Entity<Buffer>,
 383        project: &Entity<Project>,
 384        cx: &mut Context<Self>,
 385    ) {
 386        let zeta_project = self.get_or_init_zeta_project(project, cx);
 387        Self::register_buffer_impl(zeta_project, buffer, project, cx);
 388    }
 389
 390    fn get_or_init_zeta_project(
 391        &mut self,
 392        project: &Entity<Project>,
 393        cx: &mut App,
 394    ) -> &mut ZetaProject {
 395        self.projects
 396            .entry(project.entity_id())
 397            .or_insert_with(|| ZetaProject {
 398                syntax_index: cx.new(|cx| {
 399                    SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
 400                }),
 401                events: VecDeque::new(),
 402                registered_buffers: HashMap::default(),
 403                current_prediction: None,
 404                context: None,
 405                refresh_context_task: None,
 406                refresh_context_debounce_task: None,
 407                refresh_context_timestamp: None,
 408            })
 409    }
 410
 411    fn register_buffer_impl<'a>(
 412        zeta_project: &'a mut ZetaProject,
 413        buffer: &Entity<Buffer>,
 414        project: &Entity<Project>,
 415        cx: &mut Context<Self>,
 416    ) -> &'a mut RegisteredBuffer {
 417        let buffer_id = buffer.entity_id();
 418        match zeta_project.registered_buffers.entry(buffer_id) {
 419            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 420            hash_map::Entry::Vacant(entry) => {
 421                let snapshot = buffer.read(cx).snapshot();
 422                let project_entity_id = project.entity_id();
 423                entry.insert(RegisteredBuffer {
 424                    snapshot,
 425                    _subscriptions: [
 426                        cx.subscribe(buffer, {
 427                            let project = project.downgrade();
 428                            move |this, buffer, event, cx| {
 429                                if let language::BufferEvent::Edited = event
 430                                    && let Some(project) = project.upgrade()
 431                                {
 432                                    this.report_changes_for_buffer(&buffer, &project, cx);
 433                                }
 434                            }
 435                        }),
 436                        cx.observe_release(buffer, move |this, _buffer, _cx| {
 437                            let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
 438                            else {
 439                                return;
 440                            };
 441                            zeta_project.registered_buffers.remove(&buffer_id);
 442                        }),
 443                    ],
 444                })
 445            }
 446        }
 447    }
 448
 449    fn report_changes_for_buffer(
 450        &mut self,
 451        buffer: &Entity<Buffer>,
 452        project: &Entity<Project>,
 453        cx: &mut Context<Self>,
 454    ) -> BufferSnapshot {
 455        let zeta_project = self.get_or_init_zeta_project(project, cx);
 456        let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
 457
 458        let new_snapshot = buffer.read(cx).snapshot();
 459        if new_snapshot.version != registered_buffer.snapshot.version {
 460            let old_snapshot =
 461                std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
 462            Self::push_event(
 463                zeta_project,
 464                Event::BufferChange {
 465                    old_snapshot,
 466                    new_snapshot: new_snapshot.clone(),
 467                    timestamp: Instant::now(),
 468                },
 469            );
 470        }
 471
 472        new_snapshot
 473    }
 474
 475    fn push_event(zeta_project: &mut ZetaProject, event: Event) {
 476        let events = &mut zeta_project.events;
 477
 478        if let Some(Event::BufferChange {
 479            new_snapshot: last_new_snapshot,
 480            timestamp: last_timestamp,
 481            ..
 482        }) = events.back_mut()
 483        {
 484            // Coalesce edits for the same buffer when they happen one after the other.
 485            let Event::BufferChange {
 486                old_snapshot,
 487                new_snapshot,
 488                timestamp,
 489            } = &event;
 490
 491            if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
 492                && old_snapshot.remote_id() == last_new_snapshot.remote_id()
 493                && old_snapshot.version == last_new_snapshot.version
 494            {
 495                *last_new_snapshot = new_snapshot.clone();
 496                *last_timestamp = *timestamp;
 497                return;
 498            }
 499        }
 500
 501        if events.len() >= MAX_EVENT_COUNT {
 502            // These are halved instead of popping to improve prompt caching.
 503            events.drain(..MAX_EVENT_COUNT / 2);
 504        }
 505
 506        events.push_back(event);
 507    }
 508
 509    fn current_prediction_for_buffer(
 510        &self,
 511        buffer: &Entity<Buffer>,
 512        project: &Entity<Project>,
 513        cx: &App,
 514    ) -> Option<BufferEditPrediction<'_>> {
 515        let project_state = self.projects.get(&project.entity_id())?;
 516
 517        let CurrentEditPrediction {
 518            requested_by_buffer_id,
 519            prediction,
 520        } = project_state.current_prediction.as_ref()?;
 521
 522        if prediction.targets_buffer(buffer.read(cx), cx) {
 523            Some(BufferEditPrediction::Local { prediction })
 524        } else if *requested_by_buffer_id == buffer.entity_id() {
 525            Some(BufferEditPrediction::Jump { prediction })
 526        } else {
 527            None
 528        }
 529    }
 530
 531    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 532        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
 533            return;
 534        };
 535
 536        let Some(prediction) = project_state.current_prediction.take() else {
 537            return;
 538        };
 539        let request_id = prediction.prediction.id.into();
 540
 541        let client = self.client.clone();
 542        let llm_token = self.llm_token.clone();
 543        let app_version = AppVersion::global(cx);
 544        cx.spawn(async move |this, cx| {
 545            let url = if let Ok(predict_edits_url) = std::env::var("ZED_ACCEPT_PREDICTION_URL") {
 546                http_client::Url::parse(&predict_edits_url)?
 547            } else {
 548                client
 549                    .http_client()
 550                    .build_zed_llm_url("/predict_edits/accept", &[])?
 551            };
 552
 553            let response = cx
 554                .background_spawn(Self::send_api_request::<()>(
 555                    move |builder| {
 556                        let req = builder.uri(url.as_ref()).body(
 557                            serde_json::to_string(&AcceptEditPredictionBody { request_id })?.into(),
 558                        );
 559                        Ok(req?)
 560                    },
 561                    client,
 562                    llm_token,
 563                    app_version,
 564                ))
 565                .await;
 566
 567            Self::handle_api_response(&this, response, cx)?;
 568            anyhow::Ok(())
 569        })
 570        .detach_and_log_err(cx);
 571    }
 572
 573    fn discard_current_prediction(&mut self, project: &Entity<Project>) {
 574        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 575            project_state.current_prediction.take();
 576        };
 577    }
 578
 579    pub fn refresh_prediction(
 580        &mut self,
 581        project: &Entity<Project>,
 582        buffer: &Entity<Buffer>,
 583        position: language::Anchor,
 584        cx: &mut Context<Self>,
 585    ) -> Task<Result<()>> {
 586        let request_task = self.request_prediction(project, buffer, position, cx);
 587        let buffer = buffer.clone();
 588        let project = project.clone();
 589
 590        cx.spawn(async move |this, cx| {
 591            if let Some(prediction) = request_task.await? {
 592                this.update(cx, |this, cx| {
 593                    let project_state = this
 594                        .projects
 595                        .get_mut(&project.entity_id())
 596                        .context("Project not found")?;
 597
 598                    let new_prediction = CurrentEditPrediction {
 599                        requested_by_buffer_id: buffer.entity_id(),
 600                        prediction: prediction,
 601                    };
 602
 603                    if project_state
 604                        .current_prediction
 605                        .as_ref()
 606                        .is_none_or(|old_prediction| {
 607                            new_prediction.should_replace_prediction(&old_prediction, cx)
 608                        })
 609                    {
 610                        project_state.current_prediction = Some(new_prediction);
 611                    }
 612                    anyhow::Ok(())
 613                })??;
 614            }
 615            Ok(())
 616        })
 617    }
 618
 619    fn request_prediction(
 620        &mut self,
 621        project: &Entity<Project>,
 622        buffer: &Entity<Buffer>,
 623        position: language::Anchor,
 624        cx: &mut Context<Self>,
 625    ) -> Task<Result<Option<EditPrediction>>> {
 626        let project_state = self.projects.get(&project.entity_id());
 627
 628        let index_state = project_state.map(|state| {
 629            state
 630                .syntax_index
 631                .read_with(cx, |index, _cx| index.state().clone())
 632        });
 633        let options = self.options.clone();
 634        let snapshot = buffer.read(cx).snapshot();
 635        let Some(excerpt_path) = snapshot
 636            .file()
 637            .map(|path| -> Arc<Path> { path.full_path(cx).into() })
 638        else {
 639            return Task::ready(Err(anyhow!("No file path for excerpt")));
 640        };
 641        let client = self.client.clone();
 642        let llm_token = self.llm_token.clone();
 643        let app_version = AppVersion::global(cx);
 644        let worktree_snapshots = project
 645            .read(cx)
 646            .worktrees(cx)
 647            .map(|worktree| worktree.read(cx).snapshot())
 648            .collect::<Vec<_>>();
 649        let debug_tx = self.debug_tx.clone();
 650
 651        let events = project_state
 652            .map(|state| {
 653                state
 654                    .events
 655                    .iter()
 656                    .filter_map(|event| event.to_request_event(cx))
 657                    .collect::<Vec<_>>()
 658            })
 659            .unwrap_or_default();
 660
 661        let diagnostics = snapshot.diagnostic_sets().clone();
 662
 663        let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
 664            let mut path = f.worktree.read(cx).absolutize(&f.path);
 665            if path.pop() { Some(path) } else { None }
 666        });
 667
 668        // TODO data collection
 669        let can_collect_data = cx.is_staff();
 670
 671        let mut included_files = project_state
 672            .and_then(|project_state| project_state.context.as_ref())
 673            .unwrap_or(&HashMap::default())
 674            .iter()
 675            .filter_map(|(buffer, ranges)| {
 676                let buffer = buffer.read(cx);
 677                Some((
 678                    buffer.snapshot(),
 679                    buffer.file()?.full_path(cx).into(),
 680                    ranges.clone(),
 681                ))
 682            })
 683            .collect::<Vec<_>>();
 684
 685        let request_task = cx.background_spawn({
 686            let snapshot = snapshot.clone();
 687            let buffer = buffer.clone();
 688            async move {
 689                let index_state = if let Some(index_state) = index_state {
 690                    Some(index_state.lock_owned().await)
 691                } else {
 692                    None
 693                };
 694
 695                let cursor_offset = position.to_offset(&snapshot);
 696                let cursor_point = cursor_offset.to_point(&snapshot);
 697
 698                let before_retrieval = chrono::Utc::now();
 699
 700                let (diagnostic_groups, diagnostic_groups_truncated) =
 701                    Self::gather_nearby_diagnostics(
 702                        cursor_offset,
 703                        &diagnostics,
 704                        &snapshot,
 705                        options.max_diagnostic_bytes,
 706                    );
 707
 708                let request = match options.context {
 709                    ContextMode::Llm(context_options) => {
 710                        let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
 711                            cursor_point,
 712                            &snapshot,
 713                            &context_options.excerpt,
 714                            index_state.as_deref(),
 715                        ) else {
 716                            return Ok((None, None));
 717                        };
 718
 719                        let excerpt_anchor_range = snapshot.anchor_after(excerpt.range.start)
 720                            ..snapshot.anchor_before(excerpt.range.end);
 721
 722                        if let Some(buffer_ix) = included_files
 723                            .iter()
 724                            .position(|(buffer, _, _)| buffer.remote_id() == snapshot.remote_id())
 725                        {
 726                            let (buffer, _, ranges) = &mut included_files[buffer_ix];
 727                            let range_ix = ranges
 728                                .binary_search_by(|probe| {
 729                                    probe
 730                                        .start
 731                                        .cmp(&excerpt_anchor_range.start, buffer)
 732                                        .then(excerpt_anchor_range.end.cmp(&probe.end, buffer))
 733                                })
 734                                .unwrap_or_else(|ix| ix);
 735
 736                            ranges.insert(range_ix, excerpt_anchor_range);
 737                            let last_ix = included_files.len() - 1;
 738                            included_files.swap(buffer_ix, last_ix);
 739                        } else {
 740                            included_files.push((
 741                                snapshot,
 742                                excerpt_path.clone(),
 743                                vec![excerpt_anchor_range],
 744                            ));
 745                        }
 746
 747                        let included_files = included_files
 748                            .into_iter()
 749                            .map(|(buffer, path, ranges)| {
 750                                let excerpts = merge_excerpts(
 751                                    &buffer,
 752                                    ranges.iter().map(|range| {
 753                                        let point_range = range.to_point(&buffer);
 754                                        Line(point_range.start.row)..Line(point_range.end.row)
 755                                    }),
 756                                );
 757                                predict_edits_v3::IncludedFile {
 758                                    path,
 759                                    max_row: Line(buffer.max_point().row),
 760                                    excerpts,
 761                                }
 762                            })
 763                            .collect::<Vec<_>>();
 764
 765                        predict_edits_v3::PredictEditsRequest {
 766                            excerpt_path,
 767                            excerpt: String::new(),
 768                            excerpt_line_range: Line(0)..Line(0),
 769                            excerpt_range: 0..0,
 770                            cursor_point: predict_edits_v3::Point {
 771                                line: predict_edits_v3::Line(cursor_point.row),
 772                                column: cursor_point.column,
 773                            },
 774                            included_files,
 775                            referenced_declarations: vec![],
 776                            events,
 777                            can_collect_data,
 778                            diagnostic_groups,
 779                            diagnostic_groups_truncated,
 780                            debug_info: debug_tx.is_some(),
 781                            prompt_max_bytes: Some(options.max_prompt_bytes),
 782                            prompt_format: options.prompt_format,
 783                            // TODO [zeta2]
 784                            signatures: vec![],
 785                            excerpt_parent: None,
 786                            git_info: None,
 787                        }
 788                    }
 789                    ContextMode::Syntax(context_options) => {
 790                        let Some(context) = EditPredictionContext::gather_context(
 791                            cursor_point,
 792                            &snapshot,
 793                            parent_abs_path.as_deref(),
 794                            &context_options,
 795                            index_state.as_deref(),
 796                        ) else {
 797                            return Ok((None, None));
 798                        };
 799
 800                        make_syntax_context_cloud_request(
 801                            excerpt_path,
 802                            context,
 803                            events,
 804                            can_collect_data,
 805                            diagnostic_groups,
 806                            diagnostic_groups_truncated,
 807                            None,
 808                            debug_tx.is_some(),
 809                            &worktree_snapshots,
 810                            index_state.as_deref(),
 811                            Some(options.max_prompt_bytes),
 812                            options.prompt_format,
 813                        )
 814                    }
 815                };
 816
 817                let retrieval_time = chrono::Utc::now() - before_retrieval;
 818
 819                let debug_response_tx = if let Some(debug_tx) = &debug_tx {
 820                    let (response_tx, response_rx) = oneshot::channel();
 821
 822                    let local_prompt = build_prompt(&request)
 823                        .map(|(prompt, _)| prompt)
 824                        .map_err(|err| err.to_string());
 825
 826                    debug_tx
 827                        .unbounded_send(ZetaDebugInfo::EditPredicted(ZetaEditPredictionDebugInfo {
 828                            request: request.clone(),
 829                            retrieval_time,
 830                            buffer: buffer.downgrade(),
 831                            local_prompt,
 832                            position,
 833                            response_rx,
 834                        }))
 835                        .ok();
 836                    Some(response_tx)
 837                } else {
 838                    None
 839                };
 840
 841                if cfg!(debug_assertions) && std::env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
 842                    if let Some(debug_response_tx) = debug_response_tx {
 843                        debug_response_tx
 844                            .send(Err("Request skipped".to_string()))
 845                            .ok();
 846                    }
 847                    anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
 848                }
 849
 850                let response =
 851                    Self::send_prediction_request(client, llm_token, app_version, request).await;
 852
 853                if let Some(debug_response_tx) = debug_response_tx {
 854                    debug_response_tx
 855                        .send(
 856                            response
 857                                .as_ref()
 858                                .map_err(|err| err.to_string())
 859                                .map(|response| response.0.clone()),
 860                        )
 861                        .ok();
 862                }
 863
 864                response.map(|(res, usage)| (Some(res), usage))
 865            }
 866        });
 867
 868        let buffer = buffer.clone();
 869
 870        cx.spawn({
 871            let project = project.clone();
 872            async move |this, cx| {
 873                let Some(response) = Self::handle_api_response(&this, request_task.await, cx)?
 874                else {
 875                    return Ok(None);
 876                };
 877
 878                // TODO telemetry: duration, etc
 879                Ok(EditPrediction::from_response(response, &snapshot, &buffer, &project, cx).await)
 880            }
 881        })
 882    }
 883
 884    async fn send_prediction_request(
 885        client: Arc<Client>,
 886        llm_token: LlmApiToken,
 887        app_version: SemanticVersion,
 888        request: predict_edits_v3::PredictEditsRequest,
 889    ) -> Result<(
 890        predict_edits_v3::PredictEditsResponse,
 891        Option<EditPredictionUsage>,
 892    )> {
 893        let url = if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
 894            http_client::Url::parse(&predict_edits_url)?
 895        } else {
 896            client
 897                .http_client()
 898                .build_zed_llm_url("/predict_edits/v3", &[])?
 899        };
 900
 901        Self::send_api_request(
 902            |builder| {
 903                let req = builder
 904                    .uri(url.as_ref())
 905                    .body(serde_json::to_string(&request)?.into());
 906                Ok(req?)
 907            },
 908            client,
 909            llm_token,
 910            app_version,
 911        )
 912        .await
 913    }
 914
 915    fn handle_api_response<T>(
 916        this: &WeakEntity<Self>,
 917        response: Result<(T, Option<EditPredictionUsage>)>,
 918        cx: &mut gpui::AsyncApp,
 919    ) -> Result<T> {
 920        match response {
 921            Ok((data, usage)) => {
 922                if let Some(usage) = usage {
 923                    this.update(cx, |this, cx| {
 924                        this.user_store.update(cx, |user_store, cx| {
 925                            user_store.update_edit_prediction_usage(usage, cx);
 926                        });
 927                    })
 928                    .ok();
 929                }
 930                Ok(data)
 931            }
 932            Err(err) => {
 933                if err.is::<ZedUpdateRequiredError>() {
 934                    cx.update(|cx| {
 935                        this.update(cx, |this, _cx| {
 936                            this.update_required = true;
 937                        })
 938                        .ok();
 939
 940                        let error_message: SharedString = err.to_string().into();
 941                        show_app_notification(
 942                            NotificationId::unique::<ZedUpdateRequiredError>(),
 943                            cx,
 944                            move |cx| {
 945                                cx.new(|cx| {
 946                                    ErrorMessagePrompt::new(error_message.clone(), cx)
 947                                        .with_link_button("Update Zed", "https://zed.dev/releases")
 948                                })
 949                            },
 950                        );
 951                    })
 952                    .ok();
 953                }
 954                Err(err)
 955            }
 956        }
 957    }
 958
 959    async fn send_api_request<Res>(
 960        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
 961        client: Arc<Client>,
 962        llm_token: LlmApiToken,
 963        app_version: SemanticVersion,
 964    ) -> Result<(Res, Option<EditPredictionUsage>)>
 965    where
 966        Res: DeserializeOwned,
 967    {
 968        let http_client = client.http_client();
 969        let mut token = llm_token.acquire(&client).await?;
 970        let mut did_retry = false;
 971
 972        loop {
 973            let request_builder = http_client::Request::builder().method(Method::POST);
 974
 975            let request = build(
 976                request_builder
 977                    .header("Content-Type", "application/json")
 978                    .header("Authorization", format!("Bearer {}", token))
 979                    .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
 980            )?;
 981
 982            let mut response = http_client.send(request).await?;
 983
 984            if let Some(minimum_required_version) = response
 985                .headers()
 986                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
 987                .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
 988            {
 989                anyhow::ensure!(
 990                    app_version >= minimum_required_version,
 991                    ZedUpdateRequiredError {
 992                        minimum_version: minimum_required_version
 993                    }
 994                );
 995            }
 996
 997            if response.status().is_success() {
 998                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
 999
1000                let mut body = Vec::new();
1001                response.body_mut().read_to_end(&mut body).await?;
1002                return Ok((serde_json::from_slice(&body)?, usage));
1003            } else if !did_retry
1004                && response
1005                    .headers()
1006                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1007                    .is_some()
1008            {
1009                did_retry = true;
1010                token = llm_token.refresh(&client).await?;
1011            } else {
1012                let mut body = String::new();
1013                response.body_mut().read_to_string(&mut body).await?;
1014                anyhow::bail!(
1015                    "Request failed with status: {:?}\nBody: {}",
1016                    response.status(),
1017                    body
1018                );
1019            }
1020        }
1021    }
1022
1023    pub const CONTEXT_RETRIEVAL_IDLE_DURATION: Duration = Duration::from_secs(10);
1024    pub const CONTEXT_RETRIEVAL_DEBOUNCE_DURATION: Duration = Duration::from_secs(3);
1025
1026    // Refresh the related excerpts when the user just beguns editing after
1027    // an idle period, and after they pause editing.
1028    fn refresh_context_if_needed(
1029        &mut self,
1030        project: &Entity<Project>,
1031        buffer: &Entity<language::Buffer>,
1032        cursor_position: language::Anchor,
1033        cx: &mut Context<Self>,
1034    ) {
1035        if !matches!(&self.options().context, ContextMode::Llm { .. }) {
1036            return;
1037        }
1038
1039        let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1040            return;
1041        };
1042
1043        let now = Instant::now();
1044        let was_idle = zeta_project
1045            .refresh_context_timestamp
1046            .map_or(true, |timestamp| {
1047                now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION
1048            });
1049        zeta_project.refresh_context_timestamp = Some(now);
1050        zeta_project.refresh_context_debounce_task = Some(cx.spawn({
1051            let buffer = buffer.clone();
1052            let project = project.clone();
1053            async move |this, cx| {
1054                if was_idle {
1055                    log::debug!("refetching edit prediction context after idle");
1056                } else {
1057                    cx.background_executor()
1058                        .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION)
1059                        .await;
1060                    log::debug!("refetching edit prediction context after pause");
1061                }
1062                this.update(cx, |this, cx| {
1063                    this.refresh_context(project, buffer, cursor_position, cx);
1064                })
1065                .ok()
1066            }
1067        }));
1068    }
1069
1070    // Refresh the related excerpts asynchronously. Ensure the task runs to completion,
1071    // and avoid spawning more than one concurrent task.
1072    fn refresh_context(
1073        &mut self,
1074        project: Entity<Project>,
1075        buffer: Entity<language::Buffer>,
1076        cursor_position: language::Anchor,
1077        cx: &mut Context<Self>,
1078    ) {
1079        let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1080            return;
1081        };
1082
1083        let debug_tx = self.debug_tx.clone();
1084
1085        zeta_project
1086            .refresh_context_task
1087            .get_or_insert(cx.spawn(async move |this, cx| {
1088                if let Some(debug_tx) = &debug_tx {
1089                    debug_tx
1090                        .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted(
1091                            ZetaContextRetrievalDebugInfo {
1092                                project: project.clone(),
1093                                timestamp: Instant::now(),
1094                            },
1095                        ))
1096                        .ok();
1097                }
1098
1099                let related_excerpts = this
1100                    .update(cx, |this, cx| {
1101                        let Some(zeta_project) = this.projects.get(&project.entity_id()) else {
1102                            return Task::ready(anyhow::Ok(HashMap::default()));
1103                        };
1104
1105                        let ContextMode::Llm(options) = &this.options().context else {
1106                            return Task::ready(anyhow::Ok(HashMap::default()));
1107                        };
1108
1109                        find_related_excerpts(
1110                            buffer.clone(),
1111                            cursor_position,
1112                            &project,
1113                            zeta_project.events.iter(),
1114                            options,
1115                            debug_tx,
1116                            cx,
1117                        )
1118                    })
1119                    .ok()?
1120                    .await
1121                    .log_err()
1122                    .unwrap_or_default();
1123                this.update(cx, |this, _cx| {
1124                    let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
1125                        return;
1126                    };
1127                    zeta_project.context = Some(related_excerpts);
1128                    zeta_project.refresh_context_task.take();
1129                    if let Some(debug_tx) = &this.debug_tx {
1130                        debug_tx
1131                            .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished(
1132                                ZetaContextRetrievalDebugInfo {
1133                                    project,
1134                                    timestamp: Instant::now(),
1135                                },
1136                            ))
1137                            .ok();
1138                    }
1139                })
1140                .ok()
1141            }));
1142    }
1143
1144    fn gather_nearby_diagnostics(
1145        cursor_offset: usize,
1146        diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
1147        snapshot: &BufferSnapshot,
1148        max_diagnostics_bytes: usize,
1149    ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
1150        // TODO: Could make this more efficient
1151        let mut diagnostic_groups = Vec::new();
1152        for (language_server_id, diagnostics) in diagnostic_sets {
1153            let mut groups = Vec::new();
1154            diagnostics.groups(*language_server_id, &mut groups, &snapshot);
1155            diagnostic_groups.extend(
1156                groups
1157                    .into_iter()
1158                    .map(|(_, group)| group.resolve::<usize>(&snapshot)),
1159            );
1160        }
1161
1162        // sort by proximity to cursor
1163        diagnostic_groups.sort_by_key(|group| {
1164            let range = &group.entries[group.primary_ix].range;
1165            if range.start >= cursor_offset {
1166                range.start - cursor_offset
1167            } else if cursor_offset >= range.end {
1168                cursor_offset - range.end
1169            } else {
1170                (cursor_offset - range.start).min(range.end - cursor_offset)
1171            }
1172        });
1173
1174        let mut results = Vec::new();
1175        let mut diagnostic_groups_truncated = false;
1176        let mut diagnostics_byte_count = 0;
1177        for group in diagnostic_groups {
1178            let raw_value = serde_json::value::to_raw_value(&group).unwrap();
1179            diagnostics_byte_count += raw_value.get().len();
1180            if diagnostics_byte_count > max_diagnostics_bytes {
1181                diagnostic_groups_truncated = true;
1182                break;
1183            }
1184            results.push(predict_edits_v3::DiagnosticGroup(raw_value));
1185        }
1186
1187        (results, diagnostic_groups_truncated)
1188    }
1189
1190    // TODO: Dedupe with similar code in request_prediction?
1191    pub fn cloud_request_for_zeta_cli(
1192        &mut self,
1193        project: &Entity<Project>,
1194        buffer: &Entity<Buffer>,
1195        position: language::Anchor,
1196        cx: &mut Context<Self>,
1197    ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
1198        let project_state = self.projects.get(&project.entity_id());
1199
1200        let index_state = project_state.map(|state| {
1201            state
1202                .syntax_index
1203                .read_with(cx, |index, _cx| index.state().clone())
1204        });
1205        let options = self.options.clone();
1206        let snapshot = buffer.read(cx).snapshot();
1207        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
1208            return Task::ready(Err(anyhow!("No file path for excerpt")));
1209        };
1210        let worktree_snapshots = project
1211            .read(cx)
1212            .worktrees(cx)
1213            .map(|worktree| worktree.read(cx).snapshot())
1214            .collect::<Vec<_>>();
1215
1216        let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
1217            let mut path = f.worktree.read(cx).absolutize(&f.path);
1218            if path.pop() { Some(path) } else { None }
1219        });
1220
1221        cx.background_spawn(async move {
1222            let index_state = if let Some(index_state) = index_state {
1223                Some(index_state.lock_owned().await)
1224            } else {
1225                None
1226            };
1227
1228            let cursor_point = position.to_point(&snapshot);
1229
1230            let debug_info = true;
1231            EditPredictionContext::gather_context(
1232                cursor_point,
1233                &snapshot,
1234                parent_abs_path.as_deref(),
1235                match &options.context {
1236                    ContextMode::Llm(_) => {
1237                        // TODO
1238                        panic!("Llm mode not supported in zeta cli yet");
1239                    }
1240                    ContextMode::Syntax(edit_prediction_context_options) => {
1241                        edit_prediction_context_options
1242                    }
1243                },
1244                index_state.as_deref(),
1245            )
1246            .context("Failed to select excerpt")
1247            .map(|context| {
1248                make_syntax_context_cloud_request(
1249                    excerpt_path.into(),
1250                    context,
1251                    // TODO pass everything
1252                    Vec::new(),
1253                    false,
1254                    Vec::new(),
1255                    false,
1256                    None,
1257                    debug_info,
1258                    &worktree_snapshots,
1259                    index_state.as_deref(),
1260                    Some(options.max_prompt_bytes),
1261                    options.prompt_format,
1262                )
1263            })
1264        })
1265    }
1266
1267    pub fn wait_for_initial_indexing(
1268        &mut self,
1269        project: &Entity<Project>,
1270        cx: &mut App,
1271    ) -> Task<Result<()>> {
1272        let zeta_project = self.get_or_init_zeta_project(project, cx);
1273        zeta_project
1274            .syntax_index
1275            .read(cx)
1276            .wait_for_initial_file_indexing(cx)
1277    }
1278}
1279
1280#[derive(Error, Debug)]
1281#[error(
1282    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1283)]
1284pub struct ZedUpdateRequiredError {
1285    minimum_version: SemanticVersion,
1286}
1287
1288fn make_syntax_context_cloud_request(
1289    excerpt_path: Arc<Path>,
1290    context: EditPredictionContext,
1291    events: Vec<predict_edits_v3::Event>,
1292    can_collect_data: bool,
1293    diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
1294    diagnostic_groups_truncated: bool,
1295    git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
1296    debug_info: bool,
1297    worktrees: &Vec<worktree::Snapshot>,
1298    index_state: Option<&SyntaxIndexState>,
1299    prompt_max_bytes: Option<usize>,
1300    prompt_format: PromptFormat,
1301) -> predict_edits_v3::PredictEditsRequest {
1302    let mut signatures = Vec::new();
1303    let mut declaration_to_signature_index = HashMap::default();
1304    let mut referenced_declarations = Vec::new();
1305
1306    for snippet in context.declarations {
1307        let project_entry_id = snippet.declaration.project_entry_id();
1308        let Some(path) = worktrees.iter().find_map(|worktree| {
1309            worktree.entry_for_id(project_entry_id).map(|entry| {
1310                let mut full_path = RelPathBuf::new();
1311                full_path.push(worktree.root_name());
1312                full_path.push(&entry.path);
1313                full_path
1314            })
1315        }) else {
1316            continue;
1317        };
1318
1319        let parent_index = index_state.and_then(|index_state| {
1320            snippet.declaration.parent().and_then(|parent| {
1321                add_signature(
1322                    parent,
1323                    &mut declaration_to_signature_index,
1324                    &mut signatures,
1325                    index_state,
1326                )
1327            })
1328        });
1329
1330        let (text, text_is_truncated) = snippet.declaration.item_text();
1331        referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
1332            path: path.as_std_path().into(),
1333            text: text.into(),
1334            range: snippet.declaration.item_line_range(),
1335            text_is_truncated,
1336            signature_range: snippet.declaration.signature_range_in_item_text(),
1337            parent_index,
1338            signature_score: snippet.score(DeclarationStyle::Signature),
1339            declaration_score: snippet.score(DeclarationStyle::Declaration),
1340            score_components: snippet.components,
1341        });
1342    }
1343
1344    let excerpt_parent = index_state.and_then(|index_state| {
1345        context
1346            .excerpt
1347            .parent_declarations
1348            .last()
1349            .and_then(|(parent, _)| {
1350                add_signature(
1351                    *parent,
1352                    &mut declaration_to_signature_index,
1353                    &mut signatures,
1354                    index_state,
1355                )
1356            })
1357    });
1358
1359    predict_edits_v3::PredictEditsRequest {
1360        excerpt_path,
1361        excerpt: context.excerpt_text.body,
1362        excerpt_line_range: context.excerpt.line_range,
1363        excerpt_range: context.excerpt.range,
1364        cursor_point: predict_edits_v3::Point {
1365            line: predict_edits_v3::Line(context.cursor_point.row),
1366            column: context.cursor_point.column,
1367        },
1368        referenced_declarations,
1369        included_files: vec![],
1370        signatures,
1371        excerpt_parent,
1372        events,
1373        can_collect_data,
1374        diagnostic_groups,
1375        diagnostic_groups_truncated,
1376        git_info,
1377        debug_info,
1378        prompt_max_bytes,
1379        prompt_format,
1380    }
1381}
1382
1383fn add_signature(
1384    declaration_id: DeclarationId,
1385    declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
1386    signatures: &mut Vec<Signature>,
1387    index: &SyntaxIndexState,
1388) -> Option<usize> {
1389    if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
1390        return Some(*signature_index);
1391    }
1392    let Some(parent_declaration) = index.declaration(declaration_id) else {
1393        log::error!("bug: missing parent declaration");
1394        return None;
1395    };
1396    let parent_index = parent_declaration.parent().and_then(|parent| {
1397        add_signature(parent, declaration_to_signature_index, signatures, index)
1398    });
1399    let (text, text_is_truncated) = parent_declaration.signature_text();
1400    let signature_index = signatures.len();
1401    signatures.push(Signature {
1402        text: text.into(),
1403        text_is_truncated,
1404        parent_index,
1405        range: parent_declaration.signature_line_range(),
1406    });
1407    declaration_to_signature_index.insert(declaration_id, signature_index);
1408    Some(signature_index)
1409}
1410
1411#[cfg(test)]
1412mod tests {
1413    use std::{
1414        path::{Path, PathBuf},
1415        sync::Arc,
1416    };
1417
1418    use client::UserStore;
1419    use clock::FakeSystemClock;
1420    use cloud_llm_client::predict_edits_v3::{self, Point};
1421    use edit_prediction_context::Line;
1422    use futures::{
1423        AsyncReadExt, StreamExt,
1424        channel::{mpsc, oneshot},
1425    };
1426    use gpui::{
1427        Entity, TestAppContext,
1428        http_client::{FakeHttpClient, Response},
1429        prelude::*,
1430    };
1431    use indoc::indoc;
1432    use language::{LanguageServerId, OffsetRangeExt as _};
1433    use pretty_assertions::{assert_eq, assert_matches};
1434    use project::{FakeFs, Project};
1435    use serde_json::json;
1436    use settings::SettingsStore;
1437    use util::path;
1438    use uuid::Uuid;
1439
1440    use crate::{BufferEditPrediction, Zeta};
1441
1442    #[gpui::test]
1443    async fn test_current_state(cx: &mut TestAppContext) {
1444        let (zeta, mut req_rx) = init_test(cx);
1445        let fs = FakeFs::new(cx.executor());
1446        fs.insert_tree(
1447            "/root",
1448            json!({
1449                "1.txt": "Hello!\nHow\nBye",
1450                "2.txt": "Hola!\nComo\nAdios"
1451            }),
1452        )
1453        .await;
1454        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1455
1456        zeta.update(cx, |zeta, cx| {
1457            zeta.register_project(&project, cx);
1458        });
1459
1460        let buffer1 = project
1461            .update(cx, |project, cx| {
1462                let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
1463                project.open_buffer(path, cx)
1464            })
1465            .await
1466            .unwrap();
1467        let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
1468        let position = snapshot1.anchor_before(language::Point::new(1, 3));
1469
1470        // Prediction for current file
1471
1472        let prediction_task = zeta.update(cx, |zeta, cx| {
1473            zeta.refresh_prediction(&project, &buffer1, position, cx)
1474        });
1475        let (_request, respond_tx) = req_rx.next().await.unwrap();
1476        respond_tx
1477            .send(predict_edits_v3::PredictEditsResponse {
1478                request_id: Uuid::new_v4(),
1479                edits: vec![predict_edits_v3::Edit {
1480                    path: Path::new(path!("root/1.txt")).into(),
1481                    range: Line(0)..Line(snapshot1.max_point().row + 1),
1482                    content: "Hello!\nHow are you?\nBye".into(),
1483                }],
1484                debug_info: None,
1485            })
1486            .unwrap();
1487        prediction_task.await.unwrap();
1488
1489        zeta.read_with(cx, |zeta, cx| {
1490            let prediction = zeta
1491                .current_prediction_for_buffer(&buffer1, &project, cx)
1492                .unwrap();
1493            assert_matches!(prediction, BufferEditPrediction::Local { .. });
1494        });
1495
1496        // Prediction for another file
1497        let prediction_task = zeta.update(cx, |zeta, cx| {
1498            zeta.refresh_prediction(&project, &buffer1, position, cx)
1499        });
1500        let (_request, respond_tx) = req_rx.next().await.unwrap();
1501        respond_tx
1502            .send(predict_edits_v3::PredictEditsResponse {
1503                request_id: Uuid::new_v4(),
1504                edits: vec![predict_edits_v3::Edit {
1505                    path: Path::new(path!("root/2.txt")).into(),
1506                    range: Line(0)..Line(snapshot1.max_point().row + 1),
1507                    content: "Hola!\nComo estas?\nAdios".into(),
1508                }],
1509                debug_info: None,
1510            })
1511            .unwrap();
1512        prediction_task.await.unwrap();
1513        zeta.read_with(cx, |zeta, cx| {
1514            let prediction = zeta
1515                .current_prediction_for_buffer(&buffer1, &project, cx)
1516                .unwrap();
1517            assert_matches!(
1518                prediction,
1519                BufferEditPrediction::Jump { prediction } if prediction.path.as_ref() == Path::new(path!("root/2.txt"))
1520            );
1521        });
1522
1523        let buffer2 = project
1524            .update(cx, |project, cx| {
1525                let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
1526                project.open_buffer(path, cx)
1527            })
1528            .await
1529            .unwrap();
1530
1531        zeta.read_with(cx, |zeta, cx| {
1532            let prediction = zeta
1533                .current_prediction_for_buffer(&buffer2, &project, cx)
1534                .unwrap();
1535            assert_matches!(prediction, BufferEditPrediction::Local { .. });
1536        });
1537    }
1538
1539    #[gpui::test]
1540    async fn test_simple_request(cx: &mut TestAppContext) {
1541        let (zeta, mut req_rx) = init_test(cx);
1542        let fs = FakeFs::new(cx.executor());
1543        fs.insert_tree(
1544            "/root",
1545            json!({
1546                "foo.md":  "Hello!\nHow\nBye"
1547            }),
1548        )
1549        .await;
1550        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1551
1552        let buffer = project
1553            .update(cx, |project, cx| {
1554                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1555                project.open_buffer(path, cx)
1556            })
1557            .await
1558            .unwrap();
1559        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1560        let position = snapshot.anchor_before(language::Point::new(1, 3));
1561
1562        let prediction_task = zeta.update(cx, |zeta, cx| {
1563            zeta.request_prediction(&project, &buffer, position, cx)
1564        });
1565
1566        let (request, respond_tx) = req_rx.next().await.unwrap();
1567        assert_eq!(
1568            request.excerpt_path.as_ref(),
1569            Path::new(path!("root/foo.md"))
1570        );
1571        assert_eq!(
1572            request.cursor_point,
1573            Point {
1574                line: Line(1),
1575                column: 3
1576            }
1577        );
1578
1579        respond_tx
1580            .send(predict_edits_v3::PredictEditsResponse {
1581                request_id: Uuid::new_v4(),
1582                edits: vec![predict_edits_v3::Edit {
1583                    path: Path::new(path!("root/foo.md")).into(),
1584                    range: Line(0)..Line(snapshot.max_point().row + 1),
1585                    content: "Hello!\nHow are you?\nBye".into(),
1586                }],
1587                debug_info: None,
1588            })
1589            .unwrap();
1590
1591        let prediction = prediction_task.await.unwrap().unwrap();
1592
1593        assert_eq!(prediction.edits.len(), 1);
1594        assert_eq!(
1595            prediction.edits[0].0.to_point(&snapshot).start,
1596            language::Point::new(1, 3)
1597        );
1598        assert_eq!(prediction.edits[0].1, " are you?");
1599    }
1600
1601    #[gpui::test]
1602    async fn test_request_events(cx: &mut TestAppContext) {
1603        let (zeta, mut req_rx) = init_test(cx);
1604        let fs = FakeFs::new(cx.executor());
1605        fs.insert_tree(
1606            "/root",
1607            json!({
1608                "foo.md": "Hello!\n\nBye"
1609            }),
1610        )
1611        .await;
1612        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1613
1614        let buffer = project
1615            .update(cx, |project, cx| {
1616                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1617                project.open_buffer(path, cx)
1618            })
1619            .await
1620            .unwrap();
1621
1622        zeta.update(cx, |zeta, cx| {
1623            zeta.register_buffer(&buffer, &project, cx);
1624        });
1625
1626        buffer.update(cx, |buffer, cx| {
1627            buffer.edit(vec![(7..7, "How")], None, cx);
1628        });
1629
1630        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1631        let position = snapshot.anchor_before(language::Point::new(1, 3));
1632
1633        let prediction_task = zeta.update(cx, |zeta, cx| {
1634            zeta.request_prediction(&project, &buffer, position, cx)
1635        });
1636
1637        let (request, respond_tx) = req_rx.next().await.unwrap();
1638
1639        assert_eq!(request.events.len(), 1);
1640        assert_eq!(
1641            request.events[0],
1642            predict_edits_v3::Event::BufferChange {
1643                path: Some(PathBuf::from(path!("root/foo.md"))),
1644                old_path: None,
1645                diff: indoc! {"
1646                        @@ -1,3 +1,3 @@
1647                         Hello!
1648                        -
1649                        +How
1650                         Bye
1651                    "}
1652                .to_string(),
1653                predicted: false
1654            }
1655        );
1656
1657        respond_tx
1658            .send(predict_edits_v3::PredictEditsResponse {
1659                request_id: Uuid::new_v4(),
1660                edits: vec![predict_edits_v3::Edit {
1661                    path: Path::new(path!("root/foo.md")).into(),
1662                    range: Line(0)..Line(snapshot.max_point().row + 1),
1663                    content: "Hello!\nHow are you?\nBye".into(),
1664                }],
1665                debug_info: None,
1666            })
1667            .unwrap();
1668
1669        let prediction = prediction_task.await.unwrap().unwrap();
1670
1671        assert_eq!(prediction.edits.len(), 1);
1672        assert_eq!(
1673            prediction.edits[0].0.to_point(&snapshot).start,
1674            language::Point::new(1, 3)
1675        );
1676        assert_eq!(prediction.edits[0].1, " are you?");
1677    }
1678
1679    #[gpui::test]
1680    async fn test_request_diagnostics(cx: &mut TestAppContext) {
1681        let (zeta, mut req_rx) = init_test(cx);
1682        let fs = FakeFs::new(cx.executor());
1683        fs.insert_tree(
1684            "/root",
1685            json!({
1686                "foo.md": "Hello!\nBye"
1687            }),
1688        )
1689        .await;
1690        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1691
1692        let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1693        let diagnostic = lsp::Diagnostic {
1694            range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1695            severity: Some(lsp::DiagnosticSeverity::ERROR),
1696            message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1697            ..Default::default()
1698        };
1699
1700        project.update(cx, |project, cx| {
1701            project.lsp_store().update(cx, |lsp_store, cx| {
1702                // Create some diagnostics
1703                lsp_store
1704                    .update_diagnostics(
1705                        LanguageServerId(0),
1706                        lsp::PublishDiagnosticsParams {
1707                            uri: path_to_buffer_uri.clone(),
1708                            diagnostics: vec![diagnostic],
1709                            version: None,
1710                        },
1711                        None,
1712                        language::DiagnosticSourceKind::Pushed,
1713                        &[],
1714                        cx,
1715                    )
1716                    .unwrap();
1717            });
1718        });
1719
1720        let buffer = project
1721            .update(cx, |project, cx| {
1722                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1723                project.open_buffer(path, cx)
1724            })
1725            .await
1726            .unwrap();
1727
1728        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1729        let position = snapshot.anchor_before(language::Point::new(0, 0));
1730
1731        let _prediction_task = zeta.update(cx, |zeta, cx| {
1732            zeta.request_prediction(&project, &buffer, position, cx)
1733        });
1734
1735        let (request, _respond_tx) = req_rx.next().await.unwrap();
1736
1737        assert_eq!(request.diagnostic_groups.len(), 1);
1738        let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1739            .unwrap();
1740        // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1741        assert_eq!(
1742            value,
1743            json!({
1744                "entries": [{
1745                    "range": {
1746                        "start": 8,
1747                        "end": 10
1748                    },
1749                    "diagnostic": {
1750                        "source": null,
1751                        "code": null,
1752                        "code_description": null,
1753                        "severity": 1,
1754                        "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1755                        "markdown": null,
1756                        "group_id": 0,
1757                        "is_primary": true,
1758                        "is_disk_based": false,
1759                        "is_unnecessary": false,
1760                        "source_kind": "Pushed",
1761                        "data": null,
1762                        "underline": true
1763                    }
1764                }],
1765                "primary_ix": 0
1766            })
1767        );
1768    }
1769
1770    fn init_test(
1771        cx: &mut TestAppContext,
1772    ) -> (
1773        Entity<Zeta>,
1774        mpsc::UnboundedReceiver<(
1775            predict_edits_v3::PredictEditsRequest,
1776            oneshot::Sender<predict_edits_v3::PredictEditsResponse>,
1777        )>,
1778    ) {
1779        cx.update(move |cx| {
1780            let settings_store = SettingsStore::test(cx);
1781            cx.set_global(settings_store);
1782            language::init(cx);
1783            Project::init_settings(cx);
1784
1785            let (req_tx, req_rx) = mpsc::unbounded();
1786
1787            let http_client = FakeHttpClient::create({
1788                move |req| {
1789                    let uri = req.uri().path().to_string();
1790                    let mut body = req.into_body();
1791                    let req_tx = req_tx.clone();
1792                    async move {
1793                        let resp = match uri.as_str() {
1794                            "/client/llm_tokens" => serde_json::to_string(&json!({
1795                                "token": "test"
1796                            }))
1797                            .unwrap(),
1798                            "/predict_edits/v3" => {
1799                                let mut buf = Vec::new();
1800                                body.read_to_end(&mut buf).await.ok();
1801                                let req = serde_json::from_slice(&buf).unwrap();
1802
1803                                let (res_tx, res_rx) = oneshot::channel();
1804                                req_tx.unbounded_send((req, res_tx)).unwrap();
1805                                serde_json::to_string(&res_rx.await?).unwrap()
1806                            }
1807                            _ => {
1808                                panic!("Unexpected path: {}", uri)
1809                            }
1810                        };
1811
1812                        Ok(Response::builder().body(resp.into()).unwrap())
1813                    }
1814                }
1815            });
1816
1817            let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1818            client.cloud_client().set_credentials(1, "test".into());
1819
1820            language_model::init(client.clone(), cx);
1821
1822            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1823            let zeta = Zeta::global(&client, &user_store, cx);
1824
1825            (zeta, req_rx)
1826        })
1827    }
1828}