zeta2.rs

   1use anyhow::{Context as _, Result, anyhow};
   2use chrono::TimeDelta;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
   5use cloud_llm_client::{
   6    AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
   7    ZED_VERSION_HEADER_NAME,
   8};
   9use cloud_zeta2_prompt::{DEFAULT_MAX_PROMPT_BYTES, build_prompt};
  10use collections::HashMap;
  11use edit_prediction_context::{
  12    DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
  13    EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionScoreOptions, Line,
  14    SyntaxIndex, SyntaxIndexState,
  15};
  16use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
  17use futures::AsyncReadExt as _;
  18use futures::channel::{mpsc, oneshot};
  19use gpui::http_client::{AsyncBody, Method};
  20use gpui::{
  21    App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
  22    http_client, prelude::*,
  23};
  24use language::{Anchor, Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
  25use language::{BufferSnapshot, OffsetRangeExt};
  26use language_model::{LlmApiToken, RefreshLlmTokenListener};
  27use project::Project;
  28use release_channel::AppVersion;
  29use serde::de::DeserializeOwned;
  30use std::collections::{VecDeque, hash_map};
  31use std::fmt::Write;
  32use std::ops::Range;
  33use std::path::Path;
  34use std::str::FromStr as _;
  35use std::sync::Arc;
  36use std::time::{Duration, Instant};
  37use thiserror::Error;
  38use util::rel_path::RelPathBuf;
  39use util::{LogErrorFuture, TryFutureExt};
  40use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  41
  42pub mod merge_excerpts;
  43mod prediction;
  44mod provider;
  45pub mod related_excerpts;
  46
  47use crate::merge_excerpts::merge_excerpts;
  48use crate::prediction::EditPrediction;
  49use crate::related_excerpts::find_related_excerpts;
  50pub use crate::related_excerpts::{LlmContextOptions, SearchToolQuery};
  51pub use provider::ZetaEditPredictionProvider;
  52
  53/// Maximum number of events to track.
  54const MAX_EVENT_COUNT: usize = 16;
  55
  56pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
  57    max_bytes: 512,
  58    min_bytes: 128,
  59    target_before_cursor_over_total_bytes: 0.5,
  60};
  61
  62pub const DEFAULT_CONTEXT_OPTIONS: ContextMode = ContextMode::Llm(DEFAULT_LLM_CONTEXT_OPTIONS);
  63
  64pub const DEFAULT_LLM_CONTEXT_OPTIONS: LlmContextOptions = LlmContextOptions {
  65    excerpt: DEFAULT_EXCERPT_OPTIONS,
  66};
  67
  68pub const DEFAULT_SYNTAX_CONTEXT_OPTIONS: EditPredictionContextOptions =
  69    EditPredictionContextOptions {
  70        use_imports: true,
  71        max_retrieved_declarations: 0,
  72        excerpt: DEFAULT_EXCERPT_OPTIONS,
  73        score: EditPredictionScoreOptions {
  74            omit_excerpt_overlaps: true,
  75        },
  76    };
  77
  78pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
  79    context: DEFAULT_CONTEXT_OPTIONS,
  80    max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
  81    max_diagnostic_bytes: 2048,
  82    prompt_format: PromptFormat::DEFAULT,
  83    file_indexing_parallelism: 1,
  84    buffer_change_grouping_interval: Duration::from_secs(1),
  85};
  86
  87pub struct Zeta2FeatureFlag;
  88
  89impl FeatureFlag for Zeta2FeatureFlag {
  90    const NAME: &'static str = "zeta2";
  91
  92    fn enabled_for_staff() -> bool {
  93        false
  94    }
  95}
  96
  97#[derive(Clone)]
  98struct ZetaGlobal(Entity<Zeta>);
  99
 100impl Global for ZetaGlobal {}
 101
 102pub struct Zeta {
 103    client: Arc<Client>,
 104    user_store: Entity<UserStore>,
 105    llm_token: LlmApiToken,
 106    _llm_token_subscription: Subscription,
 107    projects: HashMap<EntityId, ZetaProject>,
 108    options: ZetaOptions,
 109    update_required: bool,
 110    debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
 111}
 112
 113#[derive(Debug, Clone, PartialEq)]
 114pub struct ZetaOptions {
 115    pub context: ContextMode,
 116    pub max_prompt_bytes: usize,
 117    pub max_diagnostic_bytes: usize,
 118    pub prompt_format: predict_edits_v3::PromptFormat,
 119    pub file_indexing_parallelism: usize,
 120    pub buffer_change_grouping_interval: Duration,
 121}
 122
 123#[derive(Debug, Clone, PartialEq)]
 124pub enum ContextMode {
 125    Llm(LlmContextOptions),
 126    Syntax(EditPredictionContextOptions),
 127}
 128
 129impl ContextMode {
 130    pub fn excerpt(&self) -> &EditPredictionExcerptOptions {
 131        match self {
 132            ContextMode::Llm(options) => &options.excerpt,
 133            ContextMode::Syntax(options) => &options.excerpt,
 134        }
 135    }
 136}
 137
 138#[derive(Debug)]
 139pub enum ZetaDebugInfo {
 140    ContextRetrievalStarted(ZetaContextRetrievalStartedDebugInfo),
 141    SearchQueriesGenerated(ZetaSearchQueryDebugInfo),
 142    SearchQueriesExecuted(ZetaContextRetrievalDebugInfo),
 143    SearchResultsFiltered(ZetaContextRetrievalDebugInfo),
 144    ContextRetrievalFinished(ZetaContextRetrievalDebugInfo),
 145    EditPredicted(ZetaEditPredictionDebugInfo),
 146}
 147
 148#[derive(Debug)]
 149pub struct ZetaContextRetrievalStartedDebugInfo {
 150    pub project: Entity<Project>,
 151    pub timestamp: Instant,
 152    pub search_prompt: String,
 153}
 154
 155#[derive(Debug)]
 156pub struct ZetaContextRetrievalDebugInfo {
 157    pub project: Entity<Project>,
 158    pub timestamp: Instant,
 159}
 160
 161#[derive(Debug)]
 162pub struct ZetaEditPredictionDebugInfo {
 163    pub request: predict_edits_v3::PredictEditsRequest,
 164    pub retrieval_time: TimeDelta,
 165    pub buffer: WeakEntity<Buffer>,
 166    pub position: language::Anchor,
 167    pub local_prompt: Result<String, String>,
 168    pub response_rx: oneshot::Receiver<Result<predict_edits_v3::PredictEditsResponse, String>>,
 169}
 170
 171#[derive(Debug)]
 172pub struct ZetaSearchQueryDebugInfo {
 173    pub project: Entity<Project>,
 174    pub timestamp: Instant,
 175    pub queries: Vec<SearchToolQuery>,
 176}
 177
 178pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
 179
 180struct ZetaProject {
 181    syntax_index: Entity<SyntaxIndex>,
 182    events: VecDeque<Event>,
 183    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 184    current_prediction: Option<CurrentEditPrediction>,
 185    context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
 186    refresh_context_task: Option<LogErrorFuture<Task<Result<()>>>>,
 187    refresh_context_debounce_task: Option<Task<Option<()>>>,
 188    refresh_context_timestamp: Option<Instant>,
 189}
 190
 191#[derive(Debug, Clone)]
 192struct CurrentEditPrediction {
 193    pub requested_by_buffer_id: EntityId,
 194    pub prediction: EditPrediction,
 195}
 196
 197impl CurrentEditPrediction {
 198    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 199        let Some(new_edits) = self
 200            .prediction
 201            .interpolate(&self.prediction.buffer.read(cx))
 202        else {
 203            return false;
 204        };
 205
 206        if self.prediction.buffer != old_prediction.prediction.buffer {
 207            return true;
 208        }
 209
 210        let Some(old_edits) = old_prediction
 211            .prediction
 212            .interpolate(&old_prediction.prediction.buffer.read(cx))
 213        else {
 214            return true;
 215        };
 216
 217        // This reduces the occurrence of UI thrash from replacing edits
 218        //
 219        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 220        if self.requested_by_buffer_id == self.prediction.buffer.entity_id()
 221            && self.requested_by_buffer_id == old_prediction.prediction.buffer.entity_id()
 222            && old_edits.len() == 1
 223            && new_edits.len() == 1
 224        {
 225            let (old_range, old_text) = &old_edits[0];
 226            let (new_range, new_text) = &new_edits[0];
 227            new_range == old_range && new_text.starts_with(old_text)
 228        } else {
 229            true
 230        }
 231    }
 232}
 233
 234/// A prediction from the perspective of a buffer.
 235#[derive(Debug)]
 236enum BufferEditPrediction<'a> {
 237    Local { prediction: &'a EditPrediction },
 238    Jump { prediction: &'a EditPrediction },
 239}
 240
 241struct RegisteredBuffer {
 242    snapshot: BufferSnapshot,
 243    _subscriptions: [gpui::Subscription; 2],
 244}
 245
 246#[derive(Clone)]
 247pub enum Event {
 248    BufferChange {
 249        old_snapshot: BufferSnapshot,
 250        new_snapshot: BufferSnapshot,
 251        timestamp: Instant,
 252    },
 253}
 254
 255impl Event {
 256    pub fn to_request_event(&self, cx: &App) -> Option<predict_edits_v3::Event> {
 257        match self {
 258            Event::BufferChange {
 259                old_snapshot,
 260                new_snapshot,
 261                ..
 262            } => {
 263                let path = new_snapshot.file().map(|f| f.full_path(cx));
 264
 265                let old_path = old_snapshot.file().and_then(|f| {
 266                    let old_path = f.full_path(cx);
 267                    if Some(&old_path) != path.as_ref() {
 268                        Some(old_path)
 269                    } else {
 270                        None
 271                    }
 272                });
 273
 274                // TODO [zeta2] move to bg?
 275                let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
 276
 277                if path == old_path && diff.is_empty() {
 278                    None
 279                } else {
 280                    Some(predict_edits_v3::Event::BufferChange {
 281                        old_path,
 282                        path,
 283                        diff,
 284                        //todo: Actually detect if this edit was predicted or not
 285                        predicted: false,
 286                    })
 287                }
 288            }
 289        }
 290    }
 291}
 292
 293impl Zeta {
 294    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 295        cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
 296    }
 297
 298    pub fn global(
 299        client: &Arc<Client>,
 300        user_store: &Entity<UserStore>,
 301        cx: &mut App,
 302    ) -> Entity<Self> {
 303        cx.try_global::<ZetaGlobal>()
 304            .map(|global| global.0.clone())
 305            .unwrap_or_else(|| {
 306                let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 307                cx.set_global(ZetaGlobal(zeta.clone()));
 308                zeta
 309            })
 310    }
 311
 312    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 313        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 314
 315        Self {
 316            projects: HashMap::default(),
 317            client,
 318            user_store,
 319            options: DEFAULT_OPTIONS,
 320            llm_token: LlmApiToken::default(),
 321            _llm_token_subscription: cx.subscribe(
 322                &refresh_llm_token_listener,
 323                |this, _listener, _event, cx| {
 324                    let client = this.client.clone();
 325                    let llm_token = this.llm_token.clone();
 326                    cx.spawn(async move |_this, _cx| {
 327                        llm_token.refresh(&client).await?;
 328                        anyhow::Ok(())
 329                    })
 330                    .detach_and_log_err(cx);
 331                },
 332            ),
 333            update_required: false,
 334            debug_tx: None,
 335        }
 336    }
 337
 338    pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<ZetaDebugInfo> {
 339        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
 340        self.debug_tx = Some(debug_watch_tx);
 341        debug_watch_rx
 342    }
 343
 344    pub fn options(&self) -> &ZetaOptions {
 345        &self.options
 346    }
 347
 348    pub fn set_options(&mut self, options: ZetaOptions) {
 349        self.options = options;
 350    }
 351
 352    pub fn clear_history(&mut self) {
 353        for zeta_project in self.projects.values_mut() {
 354            zeta_project.events.clear();
 355        }
 356    }
 357
 358    pub fn history_for_project(&self, project: &Entity<Project>) -> impl Iterator<Item = &Event> {
 359        self.projects
 360            .get(&project.entity_id())
 361            .map(|project| project.events.iter())
 362            .into_iter()
 363            .flatten()
 364    }
 365
 366    pub fn context_for_project(
 367        &self,
 368        project: &Entity<Project>,
 369    ) -> impl Iterator<Item = (Entity<Buffer>, &[Range<Anchor>])> {
 370        self.projects
 371            .get(&project.entity_id())
 372            .and_then(|project| {
 373                Some(
 374                    project
 375                        .context
 376                        .as_ref()?
 377                        .iter()
 378                        .map(|(buffer, ranges)| (buffer.clone(), ranges.as_slice())),
 379                )
 380            })
 381            .into_iter()
 382            .flatten()
 383    }
 384
 385    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 386        self.user_store.read(cx).edit_prediction_usage()
 387    }
 388
 389    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
 390        self.get_or_init_zeta_project(project, cx);
 391    }
 392
 393    pub fn register_buffer(
 394        &mut self,
 395        buffer: &Entity<Buffer>,
 396        project: &Entity<Project>,
 397        cx: &mut Context<Self>,
 398    ) {
 399        let zeta_project = self.get_or_init_zeta_project(project, cx);
 400        Self::register_buffer_impl(zeta_project, buffer, project, cx);
 401    }
 402
 403    fn get_or_init_zeta_project(
 404        &mut self,
 405        project: &Entity<Project>,
 406        cx: &mut App,
 407    ) -> &mut ZetaProject {
 408        self.projects
 409            .entry(project.entity_id())
 410            .or_insert_with(|| ZetaProject {
 411                syntax_index: cx.new(|cx| {
 412                    SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
 413                }),
 414                events: VecDeque::new(),
 415                registered_buffers: HashMap::default(),
 416                current_prediction: None,
 417                context: None,
 418                refresh_context_task: None,
 419                refresh_context_debounce_task: None,
 420                refresh_context_timestamp: None,
 421            })
 422    }
 423
 424    fn register_buffer_impl<'a>(
 425        zeta_project: &'a mut ZetaProject,
 426        buffer: &Entity<Buffer>,
 427        project: &Entity<Project>,
 428        cx: &mut Context<Self>,
 429    ) -> &'a mut RegisteredBuffer {
 430        let buffer_id = buffer.entity_id();
 431        match zeta_project.registered_buffers.entry(buffer_id) {
 432            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 433            hash_map::Entry::Vacant(entry) => {
 434                let snapshot = buffer.read(cx).snapshot();
 435                let project_entity_id = project.entity_id();
 436                entry.insert(RegisteredBuffer {
 437                    snapshot,
 438                    _subscriptions: [
 439                        cx.subscribe(buffer, {
 440                            let project = project.downgrade();
 441                            move |this, buffer, event, cx| {
 442                                if let language::BufferEvent::Edited = event
 443                                    && let Some(project) = project.upgrade()
 444                                {
 445                                    this.report_changes_for_buffer(&buffer, &project, cx);
 446                                }
 447                            }
 448                        }),
 449                        cx.observe_release(buffer, move |this, _buffer, _cx| {
 450                            let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
 451                            else {
 452                                return;
 453                            };
 454                            zeta_project.registered_buffers.remove(&buffer_id);
 455                        }),
 456                    ],
 457                })
 458            }
 459        }
 460    }
 461
 462    fn report_changes_for_buffer(
 463        &mut self,
 464        buffer: &Entity<Buffer>,
 465        project: &Entity<Project>,
 466        cx: &mut Context<Self>,
 467    ) -> BufferSnapshot {
 468        let buffer_change_grouping_interval = self.options.buffer_change_grouping_interval;
 469        let zeta_project = self.get_or_init_zeta_project(project, cx);
 470        let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
 471
 472        let new_snapshot = buffer.read(cx).snapshot();
 473        if new_snapshot.version != registered_buffer.snapshot.version {
 474            let old_snapshot =
 475                std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
 476            Self::push_event(
 477                zeta_project,
 478                buffer_change_grouping_interval,
 479                Event::BufferChange {
 480                    old_snapshot,
 481                    new_snapshot: new_snapshot.clone(),
 482                    timestamp: Instant::now(),
 483                },
 484            );
 485        }
 486
 487        new_snapshot
 488    }
 489
 490    fn push_event(
 491        zeta_project: &mut ZetaProject,
 492        buffer_change_grouping_interval: Duration,
 493        event: Event,
 494    ) {
 495        let events = &mut zeta_project.events;
 496
 497        if buffer_change_grouping_interval > Duration::ZERO
 498            && let Some(Event::BufferChange {
 499                new_snapshot: last_new_snapshot,
 500                timestamp: last_timestamp,
 501                ..
 502            }) = events.back_mut()
 503        {
 504            // Coalesce edits for the same buffer when they happen one after the other.
 505            let Event::BufferChange {
 506                old_snapshot,
 507                new_snapshot,
 508                timestamp,
 509            } = &event;
 510
 511            if timestamp.duration_since(*last_timestamp) <= buffer_change_grouping_interval
 512                && old_snapshot.remote_id() == last_new_snapshot.remote_id()
 513                && old_snapshot.version == last_new_snapshot.version
 514            {
 515                *last_new_snapshot = new_snapshot.clone();
 516                *last_timestamp = *timestamp;
 517                return;
 518            }
 519        }
 520
 521        if events.len() >= MAX_EVENT_COUNT {
 522            // These are halved instead of popping to improve prompt caching.
 523            events.drain(..MAX_EVENT_COUNT / 2);
 524        }
 525
 526        events.push_back(event);
 527    }
 528
 529    fn current_prediction_for_buffer(
 530        &self,
 531        buffer: &Entity<Buffer>,
 532        project: &Entity<Project>,
 533        cx: &App,
 534    ) -> Option<BufferEditPrediction<'_>> {
 535        let project_state = self.projects.get(&project.entity_id())?;
 536
 537        let CurrentEditPrediction {
 538            requested_by_buffer_id,
 539            prediction,
 540        } = project_state.current_prediction.as_ref()?;
 541
 542        if prediction.targets_buffer(buffer.read(cx), cx) {
 543            Some(BufferEditPrediction::Local { prediction })
 544        } else if *requested_by_buffer_id == buffer.entity_id() {
 545            Some(BufferEditPrediction::Jump { prediction })
 546        } else {
 547            None
 548        }
 549    }
 550
 551    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 552        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
 553            return;
 554        };
 555
 556        let Some(prediction) = project_state.current_prediction.take() else {
 557            return;
 558        };
 559        let request_id = prediction.prediction.id.into();
 560
 561        let client = self.client.clone();
 562        let llm_token = self.llm_token.clone();
 563        let app_version = AppVersion::global(cx);
 564        cx.spawn(async move |this, cx| {
 565            let url = if let Ok(predict_edits_url) = std::env::var("ZED_ACCEPT_PREDICTION_URL") {
 566                http_client::Url::parse(&predict_edits_url)?
 567            } else {
 568                client
 569                    .http_client()
 570                    .build_zed_llm_url("/predict_edits/accept", &[])?
 571            };
 572
 573            let response = cx
 574                .background_spawn(Self::send_api_request::<()>(
 575                    move |builder| {
 576                        let req = builder.uri(url.as_ref()).body(
 577                            serde_json::to_string(&AcceptEditPredictionBody { request_id })?.into(),
 578                        );
 579                        Ok(req?)
 580                    },
 581                    client,
 582                    llm_token,
 583                    app_version,
 584                ))
 585                .await;
 586
 587            Self::handle_api_response(&this, response, cx)?;
 588            anyhow::Ok(())
 589        })
 590        .detach_and_log_err(cx);
 591    }
 592
 593    fn discard_current_prediction(&mut self, project: &Entity<Project>) {
 594        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 595            project_state.current_prediction.take();
 596        };
 597    }
 598
 599    pub fn refresh_prediction(
 600        &mut self,
 601        project: &Entity<Project>,
 602        buffer: &Entity<Buffer>,
 603        position: language::Anchor,
 604        cx: &mut Context<Self>,
 605    ) -> Task<Result<()>> {
 606        let request_task = self.request_prediction(project, buffer, position, cx);
 607        let buffer = buffer.clone();
 608        let project = project.clone();
 609
 610        cx.spawn(async move |this, cx| {
 611            if let Some(prediction) = request_task.await? {
 612                this.update(cx, |this, cx| {
 613                    let project_state = this
 614                        .projects
 615                        .get_mut(&project.entity_id())
 616                        .context("Project not found")?;
 617
 618                    let new_prediction = CurrentEditPrediction {
 619                        requested_by_buffer_id: buffer.entity_id(),
 620                        prediction: prediction,
 621                    };
 622
 623                    if project_state
 624                        .current_prediction
 625                        .as_ref()
 626                        .is_none_or(|old_prediction| {
 627                            new_prediction.should_replace_prediction(&old_prediction, cx)
 628                        })
 629                    {
 630                        project_state.current_prediction = Some(new_prediction);
 631                    }
 632                    anyhow::Ok(())
 633                })??;
 634            }
 635            Ok(())
 636        })
 637    }
 638
 639    pub fn request_prediction(
 640        &mut self,
 641        project: &Entity<Project>,
 642        buffer: &Entity<Buffer>,
 643        position: language::Anchor,
 644        cx: &mut Context<Self>,
 645    ) -> Task<Result<Option<EditPrediction>>> {
 646        let project_state = self.projects.get(&project.entity_id());
 647
 648        let index_state = project_state.map(|state| {
 649            state
 650                .syntax_index
 651                .read_with(cx, |index, _cx| index.state().clone())
 652        });
 653        let options = self.options.clone();
 654        let snapshot = buffer.read(cx).snapshot();
 655        let Some(excerpt_path) = snapshot
 656            .file()
 657            .map(|path| -> Arc<Path> { path.full_path(cx).into() })
 658        else {
 659            return Task::ready(Err(anyhow!("No file path for excerpt")));
 660        };
 661        let client = self.client.clone();
 662        let llm_token = self.llm_token.clone();
 663        let app_version = AppVersion::global(cx);
 664        let worktree_snapshots = project
 665            .read(cx)
 666            .worktrees(cx)
 667            .map(|worktree| worktree.read(cx).snapshot())
 668            .collect::<Vec<_>>();
 669        let debug_tx = self.debug_tx.clone();
 670
 671        let events = project_state
 672            .map(|state| {
 673                state
 674                    .events
 675                    .iter()
 676                    .filter_map(|event| event.to_request_event(cx))
 677                    .collect::<Vec<_>>()
 678            })
 679            .unwrap_or_default();
 680
 681        let diagnostics = snapshot.diagnostic_sets().clone();
 682
 683        let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
 684            let mut path = f.worktree.read(cx).absolutize(&f.path);
 685            if path.pop() { Some(path) } else { None }
 686        });
 687
 688        // TODO data collection
 689        let can_collect_data = cx.is_staff();
 690
 691        let mut included_files = project_state
 692            .and_then(|project_state| project_state.context.as_ref())
 693            .unwrap_or(&HashMap::default())
 694            .iter()
 695            .filter_map(|(buffer, ranges)| {
 696                let buffer = buffer.read(cx);
 697                Some((
 698                    buffer.snapshot(),
 699                    buffer.file()?.full_path(cx).into(),
 700                    ranges.clone(),
 701                ))
 702            })
 703            .collect::<Vec<_>>();
 704
 705        let request_task = cx.background_spawn({
 706            let snapshot = snapshot.clone();
 707            let buffer = buffer.clone();
 708            async move {
 709                let index_state = if let Some(index_state) = index_state {
 710                    Some(index_state.lock_owned().await)
 711                } else {
 712                    None
 713                };
 714
 715                let cursor_offset = position.to_offset(&snapshot);
 716                let cursor_point = cursor_offset.to_point(&snapshot);
 717
 718                let before_retrieval = chrono::Utc::now();
 719
 720                let (diagnostic_groups, diagnostic_groups_truncated) =
 721                    Self::gather_nearby_diagnostics(
 722                        cursor_offset,
 723                        &diagnostics,
 724                        &snapshot,
 725                        options.max_diagnostic_bytes,
 726                    );
 727
 728                let request = match options.context {
 729                    ContextMode::Llm(context_options) => {
 730                        let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
 731                            cursor_point,
 732                            &snapshot,
 733                            &context_options.excerpt,
 734                            index_state.as_deref(),
 735                        ) else {
 736                            return Ok((None, None));
 737                        };
 738
 739                        let excerpt_anchor_range = snapshot.anchor_after(excerpt.range.start)
 740                            ..snapshot.anchor_before(excerpt.range.end);
 741
 742                        if let Some(buffer_ix) = included_files
 743                            .iter()
 744                            .position(|(buffer, _, _)| buffer.remote_id() == snapshot.remote_id())
 745                        {
 746                            let (buffer, _, ranges) = &mut included_files[buffer_ix];
 747                            let range_ix = ranges
 748                                .binary_search_by(|probe| {
 749                                    probe
 750                                        .start
 751                                        .cmp(&excerpt_anchor_range.start, buffer)
 752                                        .then(excerpt_anchor_range.end.cmp(&probe.end, buffer))
 753                                })
 754                                .unwrap_or_else(|ix| ix);
 755
 756                            ranges.insert(range_ix, excerpt_anchor_range);
 757                            let last_ix = included_files.len() - 1;
 758                            included_files.swap(buffer_ix, last_ix);
 759                        } else {
 760                            included_files.push((
 761                                snapshot,
 762                                excerpt_path.clone(),
 763                                vec![excerpt_anchor_range],
 764                            ));
 765                        }
 766
 767                        let included_files = included_files
 768                            .into_iter()
 769                            .map(|(buffer, path, ranges)| {
 770                                let excerpts = merge_excerpts(
 771                                    &buffer,
 772                                    ranges.iter().map(|range| {
 773                                        let point_range = range.to_point(&buffer);
 774                                        Line(point_range.start.row)..Line(point_range.end.row)
 775                                    }),
 776                                );
 777                                predict_edits_v3::IncludedFile {
 778                                    path,
 779                                    max_row: Line(buffer.max_point().row),
 780                                    excerpts,
 781                                }
 782                            })
 783                            .collect::<Vec<_>>();
 784
 785                        predict_edits_v3::PredictEditsRequest {
 786                            excerpt_path,
 787                            excerpt: String::new(),
 788                            excerpt_line_range: Line(0)..Line(0),
 789                            excerpt_range: 0..0,
 790                            cursor_point: predict_edits_v3::Point {
 791                                line: predict_edits_v3::Line(cursor_point.row),
 792                                column: cursor_point.column,
 793                            },
 794                            included_files,
 795                            referenced_declarations: vec![],
 796                            events,
 797                            can_collect_data,
 798                            diagnostic_groups,
 799                            diagnostic_groups_truncated,
 800                            debug_info: debug_tx.is_some(),
 801                            prompt_max_bytes: Some(options.max_prompt_bytes),
 802                            prompt_format: options.prompt_format,
 803                            // TODO [zeta2]
 804                            signatures: vec![],
 805                            excerpt_parent: None,
 806                            git_info: None,
 807                        }
 808                    }
 809                    ContextMode::Syntax(context_options) => {
 810                        let Some(context) = EditPredictionContext::gather_context(
 811                            cursor_point,
 812                            &snapshot,
 813                            parent_abs_path.as_deref(),
 814                            &context_options,
 815                            index_state.as_deref(),
 816                        ) else {
 817                            return Ok((None, None));
 818                        };
 819
 820                        make_syntax_context_cloud_request(
 821                            excerpt_path,
 822                            context,
 823                            events,
 824                            can_collect_data,
 825                            diagnostic_groups,
 826                            diagnostic_groups_truncated,
 827                            None,
 828                            debug_tx.is_some(),
 829                            &worktree_snapshots,
 830                            index_state.as_deref(),
 831                            Some(options.max_prompt_bytes),
 832                            options.prompt_format,
 833                        )
 834                    }
 835                };
 836
 837                let retrieval_time = chrono::Utc::now() - before_retrieval;
 838
 839                let debug_response_tx = if let Some(debug_tx) = &debug_tx {
 840                    let (response_tx, response_rx) = oneshot::channel();
 841
 842                    let local_prompt = build_prompt(&request)
 843                        .map(|(prompt, _)| prompt)
 844                        .map_err(|err| err.to_string());
 845
 846                    debug_tx
 847                        .unbounded_send(ZetaDebugInfo::EditPredicted(ZetaEditPredictionDebugInfo {
 848                            request: request.clone(),
 849                            retrieval_time,
 850                            buffer: buffer.downgrade(),
 851                            local_prompt,
 852                            position,
 853                            response_rx,
 854                        }))
 855                        .ok();
 856                    Some(response_tx)
 857                } else {
 858                    None
 859                };
 860
 861                if cfg!(debug_assertions) && std::env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
 862                    if let Some(debug_response_tx) = debug_response_tx {
 863                        debug_response_tx
 864                            .send(Err("Request skipped".to_string()))
 865                            .ok();
 866                    }
 867                    anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
 868                }
 869
 870                let response =
 871                    Self::send_prediction_request(client, llm_token, app_version, request).await;
 872
 873                if let Some(debug_response_tx) = debug_response_tx {
 874                    debug_response_tx
 875                        .send(
 876                            response
 877                                .as_ref()
 878                                .map_err(|err| err.to_string())
 879                                .map(|response| response.0.clone()),
 880                        )
 881                        .ok();
 882                }
 883
 884                response.map(|(res, usage)| (Some(res), usage))
 885            }
 886        });
 887
 888        let buffer = buffer.clone();
 889
 890        cx.spawn({
 891            let project = project.clone();
 892            async move |this, cx| {
 893                let Some(response) = Self::handle_api_response(&this, request_task.await, cx)?
 894                else {
 895                    return Ok(None);
 896                };
 897
 898                // TODO telemetry: duration, etc
 899                Ok(EditPrediction::from_response(response, &snapshot, &buffer, &project, cx).await)
 900            }
 901        })
 902    }
 903
 904    async fn send_prediction_request(
 905        client: Arc<Client>,
 906        llm_token: LlmApiToken,
 907        app_version: SemanticVersion,
 908        request: predict_edits_v3::PredictEditsRequest,
 909    ) -> Result<(
 910        predict_edits_v3::PredictEditsResponse,
 911        Option<EditPredictionUsage>,
 912    )> {
 913        let url = if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
 914            http_client::Url::parse(&predict_edits_url)?
 915        } else {
 916            client
 917                .http_client()
 918                .build_zed_llm_url("/predict_edits/v3", &[])?
 919        };
 920
 921        Self::send_api_request(
 922            |builder| {
 923                let req = builder
 924                    .uri(url.as_ref())
 925                    .body(serde_json::to_string(&request)?.into());
 926                Ok(req?)
 927            },
 928            client,
 929            llm_token,
 930            app_version,
 931        )
 932        .await
 933    }
 934
 935    fn handle_api_response<T>(
 936        this: &WeakEntity<Self>,
 937        response: Result<(T, Option<EditPredictionUsage>)>,
 938        cx: &mut gpui::AsyncApp,
 939    ) -> Result<T> {
 940        match response {
 941            Ok((data, usage)) => {
 942                if let Some(usage) = usage {
 943                    this.update(cx, |this, cx| {
 944                        this.user_store.update(cx, |user_store, cx| {
 945                            user_store.update_edit_prediction_usage(usage, cx);
 946                        });
 947                    })
 948                    .ok();
 949                }
 950                Ok(data)
 951            }
 952            Err(err) => {
 953                if err.is::<ZedUpdateRequiredError>() {
 954                    cx.update(|cx| {
 955                        this.update(cx, |this, _cx| {
 956                            this.update_required = true;
 957                        })
 958                        .ok();
 959
 960                        let error_message: SharedString = err.to_string().into();
 961                        show_app_notification(
 962                            NotificationId::unique::<ZedUpdateRequiredError>(),
 963                            cx,
 964                            move |cx| {
 965                                cx.new(|cx| {
 966                                    ErrorMessagePrompt::new(error_message.clone(), cx)
 967                                        .with_link_button("Update Zed", "https://zed.dev/releases")
 968                                })
 969                            },
 970                        );
 971                    })
 972                    .ok();
 973                }
 974                Err(err)
 975            }
 976        }
 977    }
 978
 979    async fn send_api_request<Res>(
 980        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
 981        client: Arc<Client>,
 982        llm_token: LlmApiToken,
 983        app_version: SemanticVersion,
 984    ) -> Result<(Res, Option<EditPredictionUsage>)>
 985    where
 986        Res: DeserializeOwned,
 987    {
 988        let http_client = client.http_client();
 989        let mut token = llm_token.acquire(&client).await?;
 990        let mut did_retry = false;
 991
 992        loop {
 993            let request_builder = http_client::Request::builder().method(Method::POST);
 994
 995            let request = build(
 996                request_builder
 997                    .header("Content-Type", "application/json")
 998                    .header("Authorization", format!("Bearer {}", token))
 999                    .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
1000            )?;
1001
1002            let mut response = http_client.send(request).await?;
1003
1004            if let Some(minimum_required_version) = response
1005                .headers()
1006                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
1007                .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
1008            {
1009                anyhow::ensure!(
1010                    app_version >= minimum_required_version,
1011                    ZedUpdateRequiredError {
1012                        minimum_version: minimum_required_version
1013                    }
1014                );
1015            }
1016
1017            if response.status().is_success() {
1018                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
1019
1020                let mut body = Vec::new();
1021                response.body_mut().read_to_end(&mut body).await?;
1022                return Ok((serde_json::from_slice(&body)?, usage));
1023            } else if !did_retry
1024                && response
1025                    .headers()
1026                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1027                    .is_some()
1028            {
1029                did_retry = true;
1030                token = llm_token.refresh(&client).await?;
1031            } else {
1032                let mut body = String::new();
1033                response.body_mut().read_to_string(&mut body).await?;
1034                anyhow::bail!(
1035                    "Request failed with status: {:?}\nBody: {}",
1036                    response.status(),
1037                    body
1038                );
1039            }
1040        }
1041    }
1042
1043    pub const CONTEXT_RETRIEVAL_IDLE_DURATION: Duration = Duration::from_secs(10);
1044    pub const CONTEXT_RETRIEVAL_DEBOUNCE_DURATION: Duration = Duration::from_secs(3);
1045
1046    // Refresh the related excerpts when the user just beguns editing after
1047    // an idle period, and after they pause editing.
1048    fn refresh_context_if_needed(
1049        &mut self,
1050        project: &Entity<Project>,
1051        buffer: &Entity<language::Buffer>,
1052        cursor_position: language::Anchor,
1053        cx: &mut Context<Self>,
1054    ) {
1055        if !matches!(&self.options().context, ContextMode::Llm { .. }) {
1056            return;
1057        }
1058
1059        let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1060            return;
1061        };
1062
1063        let now = Instant::now();
1064        let was_idle = zeta_project
1065            .refresh_context_timestamp
1066            .map_or(true, |timestamp| {
1067                now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION
1068            });
1069        zeta_project.refresh_context_timestamp = Some(now);
1070        zeta_project.refresh_context_debounce_task = Some(cx.spawn({
1071            let buffer = buffer.clone();
1072            let project = project.clone();
1073            async move |this, cx| {
1074                if was_idle {
1075                    log::debug!("refetching edit prediction context after idle");
1076                } else {
1077                    cx.background_executor()
1078                        .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION)
1079                        .await;
1080                    log::debug!("refetching edit prediction context after pause");
1081                }
1082                this.update(cx, |this, cx| {
1083                    let task = this.refresh_context(project.clone(), buffer, cursor_position, cx);
1084
1085                    if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
1086                        zeta_project.refresh_context_task = Some(task.log_err());
1087                    };
1088                })
1089                .ok()
1090            }
1091        }));
1092    }
1093
1094    // Refresh the related excerpts asynchronously. Ensure the task runs to completion,
1095    // and avoid spawning more than one concurrent task.
1096    pub fn refresh_context(
1097        &mut self,
1098        project: Entity<Project>,
1099        buffer: Entity<language::Buffer>,
1100        cursor_position: language::Anchor,
1101        cx: &mut Context<Self>,
1102    ) -> Task<Result<()>> {
1103        cx.spawn(async move |this, cx| {
1104            let related_excerpts_result = this
1105                .update(cx, |this, cx| {
1106                    let Some(zeta_project) = this.projects.get(&project.entity_id()) else {
1107                        return Task::ready(anyhow::Ok(HashMap::default()));
1108                    };
1109
1110                    let ContextMode::Llm(options) = &this.options().context else {
1111                        return Task::ready(anyhow::Ok(HashMap::default()));
1112                    };
1113
1114                    let mut edit_history_unified_diff = String::new();
1115
1116                    for event in zeta_project.events.iter() {
1117                        if let Some(event) = event.to_request_event(cx) {
1118                            writeln!(&mut edit_history_unified_diff, "{event}").ok();
1119                        }
1120                    }
1121
1122                    find_related_excerpts(
1123                        buffer.clone(),
1124                        cursor_position,
1125                        &project,
1126                        edit_history_unified_diff,
1127                        options,
1128                        this.debug_tx.clone(),
1129                        cx,
1130                    )
1131                })?
1132                .await;
1133
1134            this.update(cx, |this, _cx| {
1135                let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
1136                    return Ok(());
1137                };
1138                zeta_project.refresh_context_task.take();
1139                if let Some(debug_tx) = &this.debug_tx {
1140                    debug_tx
1141                        .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished(
1142                            ZetaContextRetrievalDebugInfo {
1143                                project,
1144                                timestamp: Instant::now(),
1145                            },
1146                        ))
1147                        .ok();
1148                }
1149                match related_excerpts_result {
1150                    Ok(excerpts) => {
1151                        zeta_project.context = Some(excerpts);
1152                        Ok(())
1153                    }
1154                    Err(error) => Err(error),
1155                }
1156            })?
1157        })
1158    }
1159
1160    fn gather_nearby_diagnostics(
1161        cursor_offset: usize,
1162        diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
1163        snapshot: &BufferSnapshot,
1164        max_diagnostics_bytes: usize,
1165    ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
1166        // TODO: Could make this more efficient
1167        let mut diagnostic_groups = Vec::new();
1168        for (language_server_id, diagnostics) in diagnostic_sets {
1169            let mut groups = Vec::new();
1170            diagnostics.groups(*language_server_id, &mut groups, &snapshot);
1171            diagnostic_groups.extend(
1172                groups
1173                    .into_iter()
1174                    .map(|(_, group)| group.resolve::<usize>(&snapshot)),
1175            );
1176        }
1177
1178        // sort by proximity to cursor
1179        diagnostic_groups.sort_by_key(|group| {
1180            let range = &group.entries[group.primary_ix].range;
1181            if range.start >= cursor_offset {
1182                range.start - cursor_offset
1183            } else if cursor_offset >= range.end {
1184                cursor_offset - range.end
1185            } else {
1186                (cursor_offset - range.start).min(range.end - cursor_offset)
1187            }
1188        });
1189
1190        let mut results = Vec::new();
1191        let mut diagnostic_groups_truncated = false;
1192        let mut diagnostics_byte_count = 0;
1193        for group in diagnostic_groups {
1194            let raw_value = serde_json::value::to_raw_value(&group).unwrap();
1195            diagnostics_byte_count += raw_value.get().len();
1196            if diagnostics_byte_count > max_diagnostics_bytes {
1197                diagnostic_groups_truncated = true;
1198                break;
1199            }
1200            results.push(predict_edits_v3::DiagnosticGroup(raw_value));
1201        }
1202
1203        (results, diagnostic_groups_truncated)
1204    }
1205
1206    // TODO: Dedupe with similar code in request_prediction?
1207    pub fn cloud_request_for_zeta_cli(
1208        &mut self,
1209        project: &Entity<Project>,
1210        buffer: &Entity<Buffer>,
1211        position: language::Anchor,
1212        cx: &mut Context<Self>,
1213    ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
1214        let project_state = self.projects.get(&project.entity_id());
1215
1216        let index_state = project_state.map(|state| {
1217            state
1218                .syntax_index
1219                .read_with(cx, |index, _cx| index.state().clone())
1220        });
1221        let options = self.options.clone();
1222        let snapshot = buffer.read(cx).snapshot();
1223        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
1224            return Task::ready(Err(anyhow!("No file path for excerpt")));
1225        };
1226        let worktree_snapshots = project
1227            .read(cx)
1228            .worktrees(cx)
1229            .map(|worktree| worktree.read(cx).snapshot())
1230            .collect::<Vec<_>>();
1231
1232        let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
1233            let mut path = f.worktree.read(cx).absolutize(&f.path);
1234            if path.pop() { Some(path) } else { None }
1235        });
1236
1237        cx.background_spawn(async move {
1238            let index_state = if let Some(index_state) = index_state {
1239                Some(index_state.lock_owned().await)
1240            } else {
1241                None
1242            };
1243
1244            let cursor_point = position.to_point(&snapshot);
1245
1246            let debug_info = true;
1247            EditPredictionContext::gather_context(
1248                cursor_point,
1249                &snapshot,
1250                parent_abs_path.as_deref(),
1251                match &options.context {
1252                    ContextMode::Llm(_) => {
1253                        // TODO
1254                        panic!("Llm mode not supported in zeta cli yet");
1255                    }
1256                    ContextMode::Syntax(edit_prediction_context_options) => {
1257                        edit_prediction_context_options
1258                    }
1259                },
1260                index_state.as_deref(),
1261            )
1262            .context("Failed to select excerpt")
1263            .map(|context| {
1264                make_syntax_context_cloud_request(
1265                    excerpt_path.into(),
1266                    context,
1267                    // TODO pass everything
1268                    Vec::new(),
1269                    false,
1270                    Vec::new(),
1271                    false,
1272                    None,
1273                    debug_info,
1274                    &worktree_snapshots,
1275                    index_state.as_deref(),
1276                    Some(options.max_prompt_bytes),
1277                    options.prompt_format,
1278                )
1279            })
1280        })
1281    }
1282
1283    pub fn wait_for_initial_indexing(
1284        &mut self,
1285        project: &Entity<Project>,
1286        cx: &mut App,
1287    ) -> Task<Result<()>> {
1288        let zeta_project = self.get_or_init_zeta_project(project, cx);
1289        zeta_project
1290            .syntax_index
1291            .read(cx)
1292            .wait_for_initial_file_indexing(cx)
1293    }
1294}
1295
1296#[derive(Error, Debug)]
1297#[error(
1298    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1299)]
1300pub struct ZedUpdateRequiredError {
1301    minimum_version: SemanticVersion,
1302}
1303
1304fn make_syntax_context_cloud_request(
1305    excerpt_path: Arc<Path>,
1306    context: EditPredictionContext,
1307    events: Vec<predict_edits_v3::Event>,
1308    can_collect_data: bool,
1309    diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
1310    diagnostic_groups_truncated: bool,
1311    git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
1312    debug_info: bool,
1313    worktrees: &Vec<worktree::Snapshot>,
1314    index_state: Option<&SyntaxIndexState>,
1315    prompt_max_bytes: Option<usize>,
1316    prompt_format: PromptFormat,
1317) -> predict_edits_v3::PredictEditsRequest {
1318    let mut signatures = Vec::new();
1319    let mut declaration_to_signature_index = HashMap::default();
1320    let mut referenced_declarations = Vec::new();
1321
1322    for snippet in context.declarations {
1323        let project_entry_id = snippet.declaration.project_entry_id();
1324        let Some(path) = worktrees.iter().find_map(|worktree| {
1325            worktree.entry_for_id(project_entry_id).map(|entry| {
1326                let mut full_path = RelPathBuf::new();
1327                full_path.push(worktree.root_name());
1328                full_path.push(&entry.path);
1329                full_path
1330            })
1331        }) else {
1332            continue;
1333        };
1334
1335        let parent_index = index_state.and_then(|index_state| {
1336            snippet.declaration.parent().and_then(|parent| {
1337                add_signature(
1338                    parent,
1339                    &mut declaration_to_signature_index,
1340                    &mut signatures,
1341                    index_state,
1342                )
1343            })
1344        });
1345
1346        let (text, text_is_truncated) = snippet.declaration.item_text();
1347        referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
1348            path: path.as_std_path().into(),
1349            text: text.into(),
1350            range: snippet.declaration.item_line_range(),
1351            text_is_truncated,
1352            signature_range: snippet.declaration.signature_range_in_item_text(),
1353            parent_index,
1354            signature_score: snippet.score(DeclarationStyle::Signature),
1355            declaration_score: snippet.score(DeclarationStyle::Declaration),
1356            score_components: snippet.components,
1357        });
1358    }
1359
1360    let excerpt_parent = index_state.and_then(|index_state| {
1361        context
1362            .excerpt
1363            .parent_declarations
1364            .last()
1365            .and_then(|(parent, _)| {
1366                add_signature(
1367                    *parent,
1368                    &mut declaration_to_signature_index,
1369                    &mut signatures,
1370                    index_state,
1371                )
1372            })
1373    });
1374
1375    predict_edits_v3::PredictEditsRequest {
1376        excerpt_path,
1377        excerpt: context.excerpt_text.body,
1378        excerpt_line_range: context.excerpt.line_range,
1379        excerpt_range: context.excerpt.range,
1380        cursor_point: predict_edits_v3::Point {
1381            line: predict_edits_v3::Line(context.cursor_point.row),
1382            column: context.cursor_point.column,
1383        },
1384        referenced_declarations,
1385        included_files: vec![],
1386        signatures,
1387        excerpt_parent,
1388        events,
1389        can_collect_data,
1390        diagnostic_groups,
1391        diagnostic_groups_truncated,
1392        git_info,
1393        debug_info,
1394        prompt_max_bytes,
1395        prompt_format,
1396    }
1397}
1398
1399fn add_signature(
1400    declaration_id: DeclarationId,
1401    declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
1402    signatures: &mut Vec<Signature>,
1403    index: &SyntaxIndexState,
1404) -> Option<usize> {
1405    if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
1406        return Some(*signature_index);
1407    }
1408    let Some(parent_declaration) = index.declaration(declaration_id) else {
1409        log::error!("bug: missing parent declaration");
1410        return None;
1411    };
1412    let parent_index = parent_declaration.parent().and_then(|parent| {
1413        add_signature(parent, declaration_to_signature_index, signatures, index)
1414    });
1415    let (text, text_is_truncated) = parent_declaration.signature_text();
1416    let signature_index = signatures.len();
1417    signatures.push(Signature {
1418        text: text.into(),
1419        text_is_truncated,
1420        parent_index,
1421        range: parent_declaration.signature_line_range(),
1422    });
1423    declaration_to_signature_index.insert(declaration_id, signature_index);
1424    Some(signature_index)
1425}
1426
1427#[cfg(test)]
1428mod tests {
1429    use std::{
1430        path::{Path, PathBuf},
1431        sync::Arc,
1432    };
1433
1434    use client::UserStore;
1435    use clock::FakeSystemClock;
1436    use cloud_llm_client::predict_edits_v3::{self, Point};
1437    use edit_prediction_context::Line;
1438    use futures::{
1439        AsyncReadExt, StreamExt,
1440        channel::{mpsc, oneshot},
1441    };
1442    use gpui::{
1443        Entity, TestAppContext,
1444        http_client::{FakeHttpClient, Response},
1445        prelude::*,
1446    };
1447    use indoc::indoc;
1448    use language::{LanguageServerId, OffsetRangeExt as _};
1449    use pretty_assertions::{assert_eq, assert_matches};
1450    use project::{FakeFs, Project};
1451    use serde_json::json;
1452    use settings::SettingsStore;
1453    use util::path;
1454    use uuid::Uuid;
1455
1456    use crate::{BufferEditPrediction, Zeta};
1457
1458    #[gpui::test]
1459    async fn test_current_state(cx: &mut TestAppContext) {
1460        let (zeta, mut req_rx) = init_test(cx);
1461        let fs = FakeFs::new(cx.executor());
1462        fs.insert_tree(
1463            "/root",
1464            json!({
1465                "1.txt": "Hello!\nHow\nBye",
1466                "2.txt": "Hola!\nComo\nAdios"
1467            }),
1468        )
1469        .await;
1470        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1471
1472        zeta.update(cx, |zeta, cx| {
1473            zeta.register_project(&project, cx);
1474        });
1475
1476        let buffer1 = project
1477            .update(cx, |project, cx| {
1478                let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
1479                project.open_buffer(path, cx)
1480            })
1481            .await
1482            .unwrap();
1483        let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
1484        let position = snapshot1.anchor_before(language::Point::new(1, 3));
1485
1486        // Prediction for current file
1487
1488        let prediction_task = zeta.update(cx, |zeta, cx| {
1489            zeta.refresh_prediction(&project, &buffer1, position, cx)
1490        });
1491        let (_request, respond_tx) = req_rx.next().await.unwrap();
1492        respond_tx
1493            .send(predict_edits_v3::PredictEditsResponse {
1494                request_id: Uuid::new_v4(),
1495                edits: vec![predict_edits_v3::Edit {
1496                    path: Path::new(path!("root/1.txt")).into(),
1497                    range: Line(0)..Line(snapshot1.max_point().row + 1),
1498                    content: "Hello!\nHow are you?\nBye".into(),
1499                }],
1500                debug_info: None,
1501            })
1502            .unwrap();
1503        prediction_task.await.unwrap();
1504
1505        zeta.read_with(cx, |zeta, cx| {
1506            let prediction = zeta
1507                .current_prediction_for_buffer(&buffer1, &project, cx)
1508                .unwrap();
1509            assert_matches!(prediction, BufferEditPrediction::Local { .. });
1510        });
1511
1512        // Prediction for another file
1513        let prediction_task = zeta.update(cx, |zeta, cx| {
1514            zeta.refresh_prediction(&project, &buffer1, position, cx)
1515        });
1516        let (_request, respond_tx) = req_rx.next().await.unwrap();
1517        respond_tx
1518            .send(predict_edits_v3::PredictEditsResponse {
1519                request_id: Uuid::new_v4(),
1520                edits: vec![predict_edits_v3::Edit {
1521                    path: Path::new(path!("root/2.txt")).into(),
1522                    range: Line(0)..Line(snapshot1.max_point().row + 1),
1523                    content: "Hola!\nComo estas?\nAdios".into(),
1524                }],
1525                debug_info: None,
1526            })
1527            .unwrap();
1528        prediction_task.await.unwrap();
1529        zeta.read_with(cx, |zeta, cx| {
1530            let prediction = zeta
1531                .current_prediction_for_buffer(&buffer1, &project, cx)
1532                .unwrap();
1533            assert_matches!(
1534                prediction,
1535                BufferEditPrediction::Jump { prediction } if prediction.path.as_ref() == Path::new(path!("root/2.txt"))
1536            );
1537        });
1538
1539        let buffer2 = project
1540            .update(cx, |project, cx| {
1541                let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
1542                project.open_buffer(path, cx)
1543            })
1544            .await
1545            .unwrap();
1546
1547        zeta.read_with(cx, |zeta, cx| {
1548            let prediction = zeta
1549                .current_prediction_for_buffer(&buffer2, &project, cx)
1550                .unwrap();
1551            assert_matches!(prediction, BufferEditPrediction::Local { .. });
1552        });
1553    }
1554
1555    #[gpui::test]
1556    async fn test_simple_request(cx: &mut TestAppContext) {
1557        let (zeta, mut req_rx) = init_test(cx);
1558        let fs = FakeFs::new(cx.executor());
1559        fs.insert_tree(
1560            "/root",
1561            json!({
1562                "foo.md":  "Hello!\nHow\nBye"
1563            }),
1564        )
1565        .await;
1566        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1567
1568        let buffer = project
1569            .update(cx, |project, cx| {
1570                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1571                project.open_buffer(path, cx)
1572            })
1573            .await
1574            .unwrap();
1575        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1576        let position = snapshot.anchor_before(language::Point::new(1, 3));
1577
1578        let prediction_task = zeta.update(cx, |zeta, cx| {
1579            zeta.request_prediction(&project, &buffer, position, cx)
1580        });
1581
1582        let (request, respond_tx) = req_rx.next().await.unwrap();
1583        assert_eq!(
1584            request.excerpt_path.as_ref(),
1585            Path::new(path!("root/foo.md"))
1586        );
1587        assert_eq!(
1588            request.cursor_point,
1589            Point {
1590                line: Line(1),
1591                column: 3
1592            }
1593        );
1594
1595        respond_tx
1596            .send(predict_edits_v3::PredictEditsResponse {
1597                request_id: Uuid::new_v4(),
1598                edits: vec![predict_edits_v3::Edit {
1599                    path: Path::new(path!("root/foo.md")).into(),
1600                    range: Line(0)..Line(snapshot.max_point().row + 1),
1601                    content: "Hello!\nHow are you?\nBye".into(),
1602                }],
1603                debug_info: None,
1604            })
1605            .unwrap();
1606
1607        let prediction = prediction_task.await.unwrap().unwrap();
1608
1609        assert_eq!(prediction.edits.len(), 1);
1610        assert_eq!(
1611            prediction.edits[0].0.to_point(&snapshot).start,
1612            language::Point::new(1, 3)
1613        );
1614        assert_eq!(prediction.edits[0].1, " are you?");
1615    }
1616
1617    #[gpui::test]
1618    async fn test_request_events(cx: &mut TestAppContext) {
1619        let (zeta, mut req_rx) = init_test(cx);
1620        let fs = FakeFs::new(cx.executor());
1621        fs.insert_tree(
1622            "/root",
1623            json!({
1624                "foo.md": "Hello!\n\nBye"
1625            }),
1626        )
1627        .await;
1628        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1629
1630        let buffer = project
1631            .update(cx, |project, cx| {
1632                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1633                project.open_buffer(path, cx)
1634            })
1635            .await
1636            .unwrap();
1637
1638        zeta.update(cx, |zeta, cx| {
1639            zeta.register_buffer(&buffer, &project, cx);
1640        });
1641
1642        buffer.update(cx, |buffer, cx| {
1643            buffer.edit(vec![(7..7, "How")], None, cx);
1644        });
1645
1646        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1647        let position = snapshot.anchor_before(language::Point::new(1, 3));
1648
1649        let prediction_task = zeta.update(cx, |zeta, cx| {
1650            zeta.request_prediction(&project, &buffer, position, cx)
1651        });
1652
1653        let (request, respond_tx) = req_rx.next().await.unwrap();
1654
1655        assert_eq!(request.events.len(), 1);
1656        assert_eq!(
1657            request.events[0],
1658            predict_edits_v3::Event::BufferChange {
1659                path: Some(PathBuf::from(path!("root/foo.md"))),
1660                old_path: None,
1661                diff: indoc! {"
1662                        @@ -1,3 +1,3 @@
1663                         Hello!
1664                        -
1665                        +How
1666                         Bye
1667                    "}
1668                .to_string(),
1669                predicted: false
1670            }
1671        );
1672
1673        respond_tx
1674            .send(predict_edits_v3::PredictEditsResponse {
1675                request_id: Uuid::new_v4(),
1676                edits: vec![predict_edits_v3::Edit {
1677                    path: Path::new(path!("root/foo.md")).into(),
1678                    range: Line(0)..Line(snapshot.max_point().row + 1),
1679                    content: "Hello!\nHow are you?\nBye".into(),
1680                }],
1681                debug_info: None,
1682            })
1683            .unwrap();
1684
1685        let prediction = prediction_task.await.unwrap().unwrap();
1686
1687        assert_eq!(prediction.edits.len(), 1);
1688        assert_eq!(
1689            prediction.edits[0].0.to_point(&snapshot).start,
1690            language::Point::new(1, 3)
1691        );
1692        assert_eq!(prediction.edits[0].1, " are you?");
1693    }
1694
1695    #[gpui::test]
1696    async fn test_request_diagnostics(cx: &mut TestAppContext) {
1697        let (zeta, mut req_rx) = init_test(cx);
1698        let fs = FakeFs::new(cx.executor());
1699        fs.insert_tree(
1700            "/root",
1701            json!({
1702                "foo.md": "Hello!\nBye"
1703            }),
1704        )
1705        .await;
1706        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1707
1708        let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1709        let diagnostic = lsp::Diagnostic {
1710            range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1711            severity: Some(lsp::DiagnosticSeverity::ERROR),
1712            message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1713            ..Default::default()
1714        };
1715
1716        project.update(cx, |project, cx| {
1717            project.lsp_store().update(cx, |lsp_store, cx| {
1718                // Create some diagnostics
1719                lsp_store
1720                    .update_diagnostics(
1721                        LanguageServerId(0),
1722                        lsp::PublishDiagnosticsParams {
1723                            uri: path_to_buffer_uri.clone(),
1724                            diagnostics: vec![diagnostic],
1725                            version: None,
1726                        },
1727                        None,
1728                        language::DiagnosticSourceKind::Pushed,
1729                        &[],
1730                        cx,
1731                    )
1732                    .unwrap();
1733            });
1734        });
1735
1736        let buffer = project
1737            .update(cx, |project, cx| {
1738                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1739                project.open_buffer(path, cx)
1740            })
1741            .await
1742            .unwrap();
1743
1744        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1745        let position = snapshot.anchor_before(language::Point::new(0, 0));
1746
1747        let _prediction_task = zeta.update(cx, |zeta, cx| {
1748            zeta.request_prediction(&project, &buffer, position, cx)
1749        });
1750
1751        let (request, _respond_tx) = req_rx.next().await.unwrap();
1752
1753        assert_eq!(request.diagnostic_groups.len(), 1);
1754        let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1755            .unwrap();
1756        // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1757        assert_eq!(
1758            value,
1759            json!({
1760                "entries": [{
1761                    "range": {
1762                        "start": 8,
1763                        "end": 10
1764                    },
1765                    "diagnostic": {
1766                        "source": null,
1767                        "code": null,
1768                        "code_description": null,
1769                        "severity": 1,
1770                        "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1771                        "markdown": null,
1772                        "group_id": 0,
1773                        "is_primary": true,
1774                        "is_disk_based": false,
1775                        "is_unnecessary": false,
1776                        "source_kind": "Pushed",
1777                        "data": null,
1778                        "underline": true
1779                    }
1780                }],
1781                "primary_ix": 0
1782            })
1783        );
1784    }
1785
1786    fn init_test(
1787        cx: &mut TestAppContext,
1788    ) -> (
1789        Entity<Zeta>,
1790        mpsc::UnboundedReceiver<(
1791            predict_edits_v3::PredictEditsRequest,
1792            oneshot::Sender<predict_edits_v3::PredictEditsResponse>,
1793        )>,
1794    ) {
1795        cx.update(move |cx| {
1796            let settings_store = SettingsStore::test(cx);
1797            cx.set_global(settings_store);
1798            language::init(cx);
1799            Project::init_settings(cx);
1800
1801            let (req_tx, req_rx) = mpsc::unbounded();
1802
1803            let http_client = FakeHttpClient::create({
1804                move |req| {
1805                    let uri = req.uri().path().to_string();
1806                    let mut body = req.into_body();
1807                    let req_tx = req_tx.clone();
1808                    async move {
1809                        let resp = match uri.as_str() {
1810                            "/client/llm_tokens" => serde_json::to_string(&json!({
1811                                "token": "test"
1812                            }))
1813                            .unwrap(),
1814                            "/predict_edits/v3" => {
1815                                let mut buf = Vec::new();
1816                                body.read_to_end(&mut buf).await.ok();
1817                                let req = serde_json::from_slice(&buf).unwrap();
1818
1819                                let (res_tx, res_rx) = oneshot::channel();
1820                                req_tx.unbounded_send((req, res_tx)).unwrap();
1821                                serde_json::to_string(&res_rx.await?).unwrap()
1822                            }
1823                            _ => {
1824                                panic!("Unexpected path: {}", uri)
1825                            }
1826                        };
1827
1828                        Ok(Response::builder().body(resp.into()).unwrap())
1829                    }
1830                }
1831            });
1832
1833            let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1834            client.cloud_client().set_credentials(1, "test".into());
1835
1836            language_model::init(client.clone(), cx);
1837
1838            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1839            let zeta = Zeta::global(&client, &user_store, cx);
1840
1841            (zeta, req_rx)
1842        })
1843    }
1844}