zeta2.rs

   1use anyhow::{Context as _, Result, anyhow, bail};
   2use arrayvec::ArrayVec;
   3use chrono::TimeDelta;
   4use client::{Client, EditPredictionUsage, UserStore};
   5use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
   6use cloud_llm_client::{
   7    AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
   8    ZED_VERSION_HEADER_NAME,
   9};
  10use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
  11use cloud_zeta2_prompt::{CURSOR_MARKER, DEFAULT_MAX_PROMPT_BYTES};
  12use collections::HashMap;
  13use edit_prediction_context::{
  14    DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
  15    EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionScoreOptions, Line,
  16    SyntaxIndex, SyntaxIndexState,
  17};
  18use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
  19use futures::AsyncReadExt as _;
  20use futures::channel::{mpsc, oneshot};
  21use gpui::http_client::{AsyncBody, Method};
  22use gpui::{
  23    App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity,
  24    http_client, prelude::*,
  25};
  26use language::{Anchor, Buffer, DiagnosticSet, LanguageServerId, Point, ToOffset as _, ToPoint};
  27use language::{BufferSnapshot, OffsetRangeExt};
  28use language_model::{LlmApiToken, RefreshLlmTokenListener};
  29use lsp::DiagnosticSeverity;
  30use open_ai::FunctionDefinition;
  31use project::{Project, ProjectPath};
  32use release_channel::AppVersion;
  33use semver::Version;
  34use serde::de::DeserializeOwned;
  35use std::collections::{VecDeque, hash_map};
  36
  37use std::fmt::Write;
  38use std::ops::Range;
  39use std::path::Path;
  40use std::str::FromStr;
  41use std::sync::{Arc, LazyLock};
  42use std::time::{Duration, Instant};
  43use std::{env, mem};
  44use thiserror::Error;
  45use util::rel_path::RelPathBuf;
  46use util::{LogErrorFuture, RangeExt as _, ResultExt as _, TryFutureExt};
  47use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  48
  49pub mod assemble_excerpts;
  50mod prediction;
  51mod provider;
  52pub mod retrieval_search;
  53mod sweep_ai;
  54pub mod udiff;
  55mod xml_edits;
  56
  57use crate::assemble_excerpts::assemble_excerpts;
  58pub use crate::prediction::EditPrediction;
  59pub use crate::prediction::EditPredictionId;
  60pub use provider::ZetaEditPredictionProvider;
  61
  62/// Maximum number of events to track.
  63const EVENT_COUNT_MAX_SWEEP: usize = 6;
  64const EVENT_COUNT_MAX_ZETA: usize = 16;
  65const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
  66
  67pub struct SweepFeatureFlag;
  68
  69impl FeatureFlag for SweepFeatureFlag {
  70    const NAME: &str = "sweep-ai";
  71}
  72pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
  73    max_bytes: 512,
  74    min_bytes: 128,
  75    target_before_cursor_over_total_bytes: 0.5,
  76};
  77
  78pub const DEFAULT_CONTEXT_OPTIONS: ContextMode =
  79    ContextMode::Agentic(DEFAULT_AGENTIC_CONTEXT_OPTIONS);
  80
  81pub const DEFAULT_AGENTIC_CONTEXT_OPTIONS: AgenticContextOptions = AgenticContextOptions {
  82    excerpt: DEFAULT_EXCERPT_OPTIONS,
  83};
  84
  85pub const DEFAULT_SYNTAX_CONTEXT_OPTIONS: EditPredictionContextOptions =
  86    EditPredictionContextOptions {
  87        use_imports: true,
  88        max_retrieved_declarations: 0,
  89        excerpt: DEFAULT_EXCERPT_OPTIONS,
  90        score: EditPredictionScoreOptions {
  91            omit_excerpt_overlaps: true,
  92        },
  93    };
  94
  95pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
  96    context: DEFAULT_CONTEXT_OPTIONS,
  97    max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
  98    max_diagnostic_bytes: 2048,
  99    prompt_format: PromptFormat::DEFAULT,
 100    file_indexing_parallelism: 1,
 101    buffer_change_grouping_interval: Duration::from_secs(1),
 102};
 103
 104static USE_OLLAMA: LazyLock<bool> =
 105    LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
 106static CONTEXT_RETRIEVAL_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
 107    env::var("ZED_ZETA2_CONTEXT_MODEL").unwrap_or(if *USE_OLLAMA {
 108        "qwen3-coder:30b".to_string()
 109    } else {
 110        "yqvev8r3".to_string()
 111    })
 112});
 113static EDIT_PREDICTIONS_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
 114    match env::var("ZED_ZETA2_MODEL").as_deref() {
 115        Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten
 116        Ok(model) => model,
 117        Err(_) if *USE_OLLAMA => "qwen3-coder:30b",
 118        Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten
 119    }
 120    .to_string()
 121});
 122static PREDICT_EDITS_URL: LazyLock<Option<String>> = LazyLock::new(|| {
 123    env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| {
 124        if *USE_OLLAMA {
 125            Some("http://localhost:11434/v1/chat/completions".into())
 126        } else {
 127            None
 128        }
 129    })
 130});
 131
 132pub struct Zeta2FeatureFlag;
 133
 134impl FeatureFlag for Zeta2FeatureFlag {
 135    const NAME: &'static str = "zeta2";
 136
 137    fn enabled_for_staff() -> bool {
 138        false
 139    }
 140}
 141
 142#[derive(Clone)]
 143struct ZetaGlobal(Entity<Zeta>);
 144
 145impl Global for ZetaGlobal {}
 146
 147pub struct Zeta {
 148    client: Arc<Client>,
 149    user_store: Entity<UserStore>,
 150    llm_token: LlmApiToken,
 151    _llm_token_subscription: Subscription,
 152    projects: HashMap<EntityId, ZetaProject>,
 153    options: ZetaOptions,
 154    update_required: bool,
 155    debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
 156    #[cfg(feature = "eval-support")]
 157    eval_cache: Option<Arc<dyn EvalCache>>,
 158    edit_prediction_model: ZetaEditPredictionModel,
 159    sweep_api_token: Option<String>,
 160    sweep_ai_debug_info: Arc<str>,
 161}
 162
 163#[derive(PartialEq, Eq)]
 164pub enum ZetaEditPredictionModel {
 165    ZedCloud,
 166    Sweep,
 167}
 168
 169#[derive(Debug, Clone, PartialEq)]
 170pub struct ZetaOptions {
 171    pub context: ContextMode,
 172    pub max_prompt_bytes: usize,
 173    pub max_diagnostic_bytes: usize,
 174    pub prompt_format: predict_edits_v3::PromptFormat,
 175    pub file_indexing_parallelism: usize,
 176    pub buffer_change_grouping_interval: Duration,
 177}
 178
 179#[derive(Debug, Clone, PartialEq)]
 180pub enum ContextMode {
 181    Agentic(AgenticContextOptions),
 182    Syntax(EditPredictionContextOptions),
 183}
 184
 185#[derive(Debug, Clone, PartialEq)]
 186pub struct AgenticContextOptions {
 187    pub excerpt: EditPredictionExcerptOptions,
 188}
 189
 190impl ContextMode {
 191    pub fn excerpt(&self) -> &EditPredictionExcerptOptions {
 192        match self {
 193            ContextMode::Agentic(options) => &options.excerpt,
 194            ContextMode::Syntax(options) => &options.excerpt,
 195        }
 196    }
 197}
 198
 199#[derive(Debug)]
 200pub enum ZetaDebugInfo {
 201    ContextRetrievalStarted(ZetaContextRetrievalStartedDebugInfo),
 202    SearchQueriesGenerated(ZetaSearchQueryDebugInfo),
 203    SearchQueriesExecuted(ZetaContextRetrievalDebugInfo),
 204    ContextRetrievalFinished(ZetaContextRetrievalDebugInfo),
 205    EditPredictionRequested(ZetaEditPredictionDebugInfo),
 206}
 207
 208#[derive(Debug)]
 209pub struct ZetaContextRetrievalStartedDebugInfo {
 210    pub project: Entity<Project>,
 211    pub timestamp: Instant,
 212    pub search_prompt: String,
 213}
 214
 215#[derive(Debug)]
 216pub struct ZetaContextRetrievalDebugInfo {
 217    pub project: Entity<Project>,
 218    pub timestamp: Instant,
 219}
 220
 221#[derive(Debug)]
 222pub struct ZetaEditPredictionDebugInfo {
 223    pub request: predict_edits_v3::PredictEditsRequest,
 224    pub retrieval_time: TimeDelta,
 225    pub buffer: WeakEntity<Buffer>,
 226    pub position: language::Anchor,
 227    pub local_prompt: Result<String, String>,
 228    pub response_rx: oneshot::Receiver<(Result<open_ai::Response, String>, TimeDelta)>,
 229}
 230
 231#[derive(Debug)]
 232pub struct ZetaSearchQueryDebugInfo {
 233    pub project: Entity<Project>,
 234    pub timestamp: Instant,
 235    pub search_queries: Vec<SearchToolQuery>,
 236}
 237
 238pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
 239
 240struct ZetaProject {
 241    syntax_index: Option<Entity<SyntaxIndex>>,
 242    events: VecDeque<Event>,
 243    recent_paths: VecDeque<ProjectPath>,
 244    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 245    current_prediction: Option<CurrentEditPrediction>,
 246    next_pending_prediction_id: usize,
 247    pending_predictions: ArrayVec<PendingPrediction, 2>,
 248    last_prediction_refresh: Option<(EntityId, Instant)>,
 249    context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
 250    refresh_context_task: Option<LogErrorFuture<Task<Result<()>>>>,
 251    refresh_context_debounce_task: Option<Task<Option<()>>>,
 252    refresh_context_timestamp: Option<Instant>,
 253    _subscription: gpui::Subscription,
 254}
 255
 256#[derive(Debug, Clone)]
 257struct CurrentEditPrediction {
 258    pub requested_by: PredictionRequestedBy,
 259    pub prediction: EditPrediction,
 260}
 261
 262impl CurrentEditPrediction {
 263    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 264        let Some(new_edits) = self
 265            .prediction
 266            .interpolate(&self.prediction.buffer.read(cx))
 267        else {
 268            return false;
 269        };
 270
 271        if self.prediction.buffer != old_prediction.prediction.buffer {
 272            return true;
 273        }
 274
 275        let Some(old_edits) = old_prediction
 276            .prediction
 277            .interpolate(&old_prediction.prediction.buffer.read(cx))
 278        else {
 279            return true;
 280        };
 281
 282        let requested_by_buffer_id = self.requested_by.buffer_id();
 283
 284        // This reduces the occurrence of UI thrash from replacing edits
 285        //
 286        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 287        if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
 288            && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
 289            && old_edits.len() == 1
 290            && new_edits.len() == 1
 291        {
 292            let (old_range, old_text) = &old_edits[0];
 293            let (new_range, new_text) = &new_edits[0];
 294            new_range == old_range && new_text.starts_with(old_text.as_ref())
 295        } else {
 296            true
 297        }
 298    }
 299}
 300
 301#[derive(Debug, Clone)]
 302enum PredictionRequestedBy {
 303    DiagnosticsUpdate,
 304    Buffer(EntityId),
 305}
 306
 307impl PredictionRequestedBy {
 308    pub fn buffer_id(&self) -> Option<EntityId> {
 309        match self {
 310            PredictionRequestedBy::DiagnosticsUpdate => None,
 311            PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
 312        }
 313    }
 314}
 315
 316struct PendingPrediction {
 317    id: usize,
 318    _task: Task<()>,
 319}
 320
 321/// A prediction from the perspective of a buffer.
 322#[derive(Debug)]
 323enum BufferEditPrediction<'a> {
 324    Local { prediction: &'a EditPrediction },
 325    Jump { prediction: &'a EditPrediction },
 326}
 327
 328struct RegisteredBuffer {
 329    snapshot: BufferSnapshot,
 330    _subscriptions: [gpui::Subscription; 2],
 331}
 332
 333#[derive(Clone)]
 334pub enum Event {
 335    BufferChange {
 336        old_snapshot: BufferSnapshot,
 337        new_snapshot: BufferSnapshot,
 338        end_edit_anchor: Option<Anchor>,
 339        timestamp: Instant,
 340    },
 341}
 342
 343impl Event {
 344    pub fn to_request_event(&self, cx: &App) -> Option<predict_edits_v3::Event> {
 345        match self {
 346            Event::BufferChange {
 347                old_snapshot,
 348                new_snapshot,
 349                ..
 350            } => {
 351                let path = new_snapshot.file().map(|f| f.full_path(cx));
 352
 353                let old_path = old_snapshot.file().and_then(|f| {
 354                    let old_path = f.full_path(cx);
 355                    if Some(&old_path) != path.as_ref() {
 356                        Some(old_path)
 357                    } else {
 358                        None
 359                    }
 360                });
 361
 362                // TODO [zeta2] move to bg?
 363                let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
 364
 365                if path == old_path && diff.is_empty() {
 366                    None
 367                } else {
 368                    Some(predict_edits_v3::Event::BufferChange {
 369                        old_path,
 370                        path,
 371                        diff,
 372                        //todo: Actually detect if this edit was predicted or not
 373                        predicted: false,
 374                    })
 375                }
 376            }
 377        }
 378    }
 379
 380    pub fn project_path(&self, cx: &App) -> Option<project::ProjectPath> {
 381        match self {
 382            Event::BufferChange { new_snapshot, .. } => new_snapshot
 383                .file()
 384                .map(|f| project::ProjectPath::from_file(f.as_ref(), cx)),
 385        }
 386    }
 387}
 388
 389impl Zeta {
 390    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 391        cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
 392    }
 393
 394    pub fn global(
 395        client: &Arc<Client>,
 396        user_store: &Entity<UserStore>,
 397        cx: &mut App,
 398    ) -> Entity<Self> {
 399        cx.try_global::<ZetaGlobal>()
 400            .map(|global| global.0.clone())
 401            .unwrap_or_else(|| {
 402                let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 403                cx.set_global(ZetaGlobal(zeta.clone()));
 404                zeta
 405            })
 406    }
 407
 408    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 409        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 410
 411        Self {
 412            projects: HashMap::default(),
 413            client,
 414            user_store,
 415            options: DEFAULT_OPTIONS,
 416            llm_token: LlmApiToken::default(),
 417            _llm_token_subscription: cx.subscribe(
 418                &refresh_llm_token_listener,
 419                |this, _listener, _event, cx| {
 420                    let client = this.client.clone();
 421                    let llm_token = this.llm_token.clone();
 422                    cx.spawn(async move |_this, _cx| {
 423                        llm_token.refresh(&client).await?;
 424                        anyhow::Ok(())
 425                    })
 426                    .detach_and_log_err(cx);
 427                },
 428            ),
 429            update_required: false,
 430            debug_tx: None,
 431            #[cfg(feature = "eval-support")]
 432            eval_cache: None,
 433            edit_prediction_model: ZetaEditPredictionModel::ZedCloud,
 434            sweep_api_token: std::env::var("SWEEP_AI_TOKEN")
 435                .context("No SWEEP_AI_TOKEN environment variable set")
 436                .log_err(),
 437            sweep_ai_debug_info: sweep_ai::debug_info(cx),
 438        }
 439    }
 440
 441    pub fn set_edit_prediction_model(&mut self, model: ZetaEditPredictionModel) {
 442        self.edit_prediction_model = model;
 443    }
 444
 445    pub fn has_sweep_api_token(&self) -> bool {
 446        self.sweep_api_token.is_some()
 447    }
 448
 449    #[cfg(feature = "eval-support")]
 450    pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
 451        self.eval_cache = Some(cache);
 452    }
 453
 454    pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<ZetaDebugInfo> {
 455        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
 456        self.debug_tx = Some(debug_watch_tx);
 457        debug_watch_rx
 458    }
 459
 460    pub fn options(&self) -> &ZetaOptions {
 461        &self.options
 462    }
 463
 464    pub fn set_options(&mut self, options: ZetaOptions) {
 465        self.options = options;
 466    }
 467
 468    pub fn clear_history(&mut self) {
 469        for zeta_project in self.projects.values_mut() {
 470            zeta_project.events.clear();
 471        }
 472    }
 473
 474    pub fn history_for_project(
 475        &self,
 476        project: &Entity<Project>,
 477    ) -> impl DoubleEndedIterator<Item = &Event> {
 478        self.projects
 479            .get(&project.entity_id())
 480            .map(|project| project.events.iter())
 481            .into_iter()
 482            .flatten()
 483    }
 484
 485    pub fn context_for_project(
 486        &self,
 487        project: &Entity<Project>,
 488    ) -> impl Iterator<Item = (Entity<Buffer>, &[Range<Anchor>])> {
 489        self.projects
 490            .get(&project.entity_id())
 491            .and_then(|project| {
 492                Some(
 493                    project
 494                        .context
 495                        .as_ref()?
 496                        .iter()
 497                        .map(|(buffer, ranges)| (buffer.clone(), ranges.as_slice())),
 498                )
 499            })
 500            .into_iter()
 501            .flatten()
 502    }
 503
 504    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 505        if self.edit_prediction_model == ZetaEditPredictionModel::ZedCloud {
 506            self.user_store.read(cx).edit_prediction_usage()
 507        } else {
 508            None
 509        }
 510    }
 511
 512    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 513        self.get_or_init_zeta_project(project, cx);
 514    }
 515
 516    pub fn register_buffer(
 517        &mut self,
 518        buffer: &Entity<Buffer>,
 519        project: &Entity<Project>,
 520        cx: &mut Context<Self>,
 521    ) {
 522        let zeta_project = self.get_or_init_zeta_project(project, cx);
 523        Self::register_buffer_impl(zeta_project, buffer, project, cx);
 524    }
 525
 526    fn get_or_init_zeta_project(
 527        &mut self,
 528        project: &Entity<Project>,
 529        cx: &mut Context<Self>,
 530    ) -> &mut ZetaProject {
 531        self.projects
 532            .entry(project.entity_id())
 533            .or_insert_with(|| ZetaProject {
 534                syntax_index: if let ContextMode::Syntax(_) = &self.options.context {
 535                    Some(cx.new(|cx| {
 536                        SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
 537                    }))
 538                } else {
 539                    None
 540                },
 541                events: VecDeque::new(),
 542                recent_paths: VecDeque::new(),
 543                registered_buffers: HashMap::default(),
 544                current_prediction: None,
 545                pending_predictions: ArrayVec::new(),
 546                next_pending_prediction_id: 0,
 547                last_prediction_refresh: None,
 548                context: None,
 549                refresh_context_task: None,
 550                refresh_context_debounce_task: None,
 551                refresh_context_timestamp: None,
 552                _subscription: cx.subscribe(&project, Self::handle_project_event),
 553            })
 554    }
 555
 556    fn handle_project_event(
 557        &mut self,
 558        project: Entity<Project>,
 559        event: &project::Event,
 560        cx: &mut Context<Self>,
 561    ) {
 562        // TODO [zeta2] init with recent paths
 563        match event {
 564            project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
 565                let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
 566                    return;
 567                };
 568                let path = project.read(cx).path_for_entry(*active_entry_id, cx);
 569                if let Some(path) = path {
 570                    if let Some(ix) = zeta_project
 571                        .recent_paths
 572                        .iter()
 573                        .position(|probe| probe == &path)
 574                    {
 575                        zeta_project.recent_paths.remove(ix);
 576                    }
 577                    zeta_project.recent_paths.push_front(path);
 578                }
 579            }
 580            project::Event::DiagnosticsUpdated { .. } => {
 581                self.refresh_prediction_from_diagnostics(project, cx);
 582            }
 583            _ => (),
 584        }
 585    }
 586
 587    fn register_buffer_impl<'a>(
 588        zeta_project: &'a mut ZetaProject,
 589        buffer: &Entity<Buffer>,
 590        project: &Entity<Project>,
 591        cx: &mut Context<Self>,
 592    ) -> &'a mut RegisteredBuffer {
 593        let buffer_id = buffer.entity_id();
 594        match zeta_project.registered_buffers.entry(buffer_id) {
 595            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 596            hash_map::Entry::Vacant(entry) => {
 597                let snapshot = buffer.read(cx).snapshot();
 598                let project_entity_id = project.entity_id();
 599                entry.insert(RegisteredBuffer {
 600                    snapshot,
 601                    _subscriptions: [
 602                        cx.subscribe(buffer, {
 603                            let project = project.downgrade();
 604                            move |this, buffer, event, cx| {
 605                                if let language::BufferEvent::Edited = event
 606                                    && let Some(project) = project.upgrade()
 607                                {
 608                                    this.report_changes_for_buffer(&buffer, &project, cx);
 609                                }
 610                            }
 611                        }),
 612                        cx.observe_release(buffer, move |this, _buffer, _cx| {
 613                            let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
 614                            else {
 615                                return;
 616                            };
 617                            zeta_project.registered_buffers.remove(&buffer_id);
 618                        }),
 619                    ],
 620                })
 621            }
 622        }
 623    }
 624
 625    fn report_changes_for_buffer(
 626        &mut self,
 627        buffer: &Entity<Buffer>,
 628        project: &Entity<Project>,
 629        cx: &mut Context<Self>,
 630    ) {
 631        let event_count_max = match self.edit_prediction_model {
 632            ZetaEditPredictionModel::ZedCloud => EVENT_COUNT_MAX_ZETA,
 633            ZetaEditPredictionModel::Sweep => EVENT_COUNT_MAX_SWEEP,
 634        };
 635
 636        let sweep_ai_project = self.get_or_init_zeta_project(project, cx);
 637        let registered_buffer = Self::register_buffer_impl(sweep_ai_project, buffer, project, cx);
 638
 639        let new_snapshot = buffer.read(cx).snapshot();
 640        if new_snapshot.version == registered_buffer.snapshot.version {
 641            return;
 642        }
 643
 644        let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
 645        let end_edit_anchor = new_snapshot
 646            .anchored_edits_since::<Point>(&old_snapshot.version)
 647            .last()
 648            .map(|(_, range)| range.end);
 649        let events = &mut sweep_ai_project.events;
 650
 651        if let Some(Event::BufferChange {
 652            new_snapshot: last_new_snapshot,
 653            end_edit_anchor: last_end_edit_anchor,
 654            ..
 655        }) = events.back_mut()
 656        {
 657            let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
 658                == last_new_snapshot.remote_id()
 659                && old_snapshot.version == last_new_snapshot.version;
 660
 661            let should_coalesce = is_next_snapshot_of_same_buffer
 662                && end_edit_anchor
 663                    .as_ref()
 664                    .zip(last_end_edit_anchor.as_ref())
 665                    .is_some_and(|(a, b)| {
 666                        let a = a.to_point(&new_snapshot);
 667                        let b = b.to_point(&new_snapshot);
 668                        a.row.abs_diff(b.row) <= CHANGE_GROUPING_LINE_SPAN
 669                    });
 670
 671            if should_coalesce {
 672                *last_end_edit_anchor = end_edit_anchor;
 673                *last_new_snapshot = new_snapshot;
 674                return;
 675            }
 676        }
 677
 678        if events.len() >= event_count_max {
 679            events.pop_front();
 680        }
 681
 682        events.push_back(Event::BufferChange {
 683            old_snapshot,
 684            new_snapshot,
 685            end_edit_anchor,
 686            timestamp: Instant::now(),
 687        });
 688    }
 689
 690    fn current_prediction_for_buffer(
 691        &self,
 692        buffer: &Entity<Buffer>,
 693        project: &Entity<Project>,
 694        cx: &App,
 695    ) -> Option<BufferEditPrediction<'_>> {
 696        let project_state = self.projects.get(&project.entity_id())?;
 697
 698        let CurrentEditPrediction {
 699            requested_by,
 700            prediction,
 701        } = project_state.current_prediction.as_ref()?;
 702
 703        if prediction.targets_buffer(buffer.read(cx)) {
 704            Some(BufferEditPrediction::Local { prediction })
 705        } else {
 706            let show_jump = match requested_by {
 707                PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
 708                    requested_by_buffer_id == &buffer.entity_id()
 709                }
 710                PredictionRequestedBy::DiagnosticsUpdate => true,
 711            };
 712
 713            if show_jump {
 714                Some(BufferEditPrediction::Jump { prediction })
 715            } else {
 716                None
 717            }
 718        }
 719    }
 720
 721    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 722        if self.edit_prediction_model != ZetaEditPredictionModel::ZedCloud {
 723            return;
 724        }
 725
 726        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
 727            return;
 728        };
 729
 730        let Some(prediction) = project_state.current_prediction.take() else {
 731            return;
 732        };
 733        let request_id = prediction.prediction.id.to_string();
 734        project_state.pending_predictions.clear();
 735
 736        let client = self.client.clone();
 737        let llm_token = self.llm_token.clone();
 738        let app_version = AppVersion::global(cx);
 739        cx.spawn(async move |this, cx| {
 740            let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") {
 741                http_client::Url::parse(&predict_edits_url)?
 742            } else {
 743                client
 744                    .http_client()
 745                    .build_zed_llm_url("/predict_edits/accept", &[])?
 746            };
 747
 748            let response = cx
 749                .background_spawn(Self::send_api_request::<()>(
 750                    move |builder| {
 751                        let req = builder.uri(url.as_ref()).body(
 752                            serde_json::to_string(&AcceptEditPredictionBody {
 753                                request_id: request_id.clone(),
 754                            })?
 755                            .into(),
 756                        );
 757                        Ok(req?)
 758                    },
 759                    client,
 760                    llm_token,
 761                    app_version,
 762                ))
 763                .await;
 764
 765            Self::handle_api_response(&this, response, cx)?;
 766            anyhow::Ok(())
 767        })
 768        .detach_and_log_err(cx);
 769    }
 770
 771    fn discard_current_prediction(&mut self, project: &Entity<Project>) {
 772        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 773            project_state.current_prediction.take();
 774            project_state.pending_predictions.clear();
 775        };
 776    }
 777
 778    fn is_refreshing(&self, project: &Entity<Project>) -> bool {
 779        self.projects
 780            .get(&project.entity_id())
 781            .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
 782    }
 783
 784    pub fn refresh_prediction_from_buffer(
 785        &mut self,
 786        project: Entity<Project>,
 787        buffer: Entity<Buffer>,
 788        position: language::Anchor,
 789        cx: &mut Context<Self>,
 790    ) {
 791        self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
 792            let Some(request_task) = this
 793                .update(cx, |this, cx| {
 794                    this.request_prediction(&project, &buffer, position, cx)
 795                })
 796                .log_err()
 797            else {
 798                return Task::ready(anyhow::Ok(()));
 799            };
 800
 801            let project = project.clone();
 802            cx.spawn(async move |cx| {
 803                if let Some(prediction) = request_task.await? {
 804                    this.update(cx, |this, cx| {
 805                        let project_state = this
 806                            .projects
 807                            .get_mut(&project.entity_id())
 808                            .context("Project not found")?;
 809
 810                        let new_prediction = CurrentEditPrediction {
 811                            requested_by: PredictionRequestedBy::Buffer(buffer.entity_id()),
 812                            prediction: prediction,
 813                        };
 814
 815                        if project_state
 816                            .current_prediction
 817                            .as_ref()
 818                            .is_none_or(|old_prediction| {
 819                                new_prediction.should_replace_prediction(&old_prediction, cx)
 820                            })
 821                        {
 822                            project_state.current_prediction = Some(new_prediction);
 823                            cx.notify();
 824                        }
 825                        anyhow::Ok(())
 826                    })??;
 827                }
 828                Ok(())
 829            })
 830        })
 831    }
 832
 833    pub fn refresh_prediction_from_diagnostics(
 834        &mut self,
 835        project: Entity<Project>,
 836        cx: &mut Context<Self>,
 837    ) {
 838        let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
 839            return;
 840        };
 841
 842        // Prefer predictions from buffer
 843        if zeta_project.current_prediction.is_some() {
 844            return;
 845        };
 846
 847        self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
 848            let Some(open_buffer_task) = project
 849                .update(cx, |project, cx| {
 850                    project
 851                        .active_entry()
 852                        .and_then(|entry| project.path_for_entry(entry, cx))
 853                        .map(|path| project.open_buffer(path, cx))
 854                })
 855                .log_err()
 856                .flatten()
 857            else {
 858                return Task::ready(anyhow::Ok(()));
 859            };
 860
 861            cx.spawn(async move |cx| {
 862                let active_buffer = open_buffer_task.await?;
 863                let snapshot = active_buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
 864
 865                let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
 866                    active_buffer,
 867                    &snapshot,
 868                    Default::default(),
 869                    Default::default(),
 870                    &project,
 871                    cx,
 872                )
 873                .await?
 874                else {
 875                    return anyhow::Ok(());
 876                };
 877
 878                let Some(prediction) = this
 879                    .update(cx, |this, cx| {
 880                        this.request_prediction(&project, &jump_buffer, jump_position, cx)
 881                    })?
 882                    .await?
 883                else {
 884                    return anyhow::Ok(());
 885                };
 886
 887                this.update(cx, |this, cx| {
 888                    if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
 889                        zeta_project.current_prediction.get_or_insert_with(|| {
 890                            cx.notify();
 891                            CurrentEditPrediction {
 892                                requested_by: PredictionRequestedBy::DiagnosticsUpdate,
 893                                prediction,
 894                            }
 895                        });
 896                    }
 897                })?;
 898
 899                anyhow::Ok(())
 900            })
 901        });
 902    }
 903
 904    #[cfg(not(test))]
 905    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
 906    #[cfg(test)]
 907    pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
 908
 909    fn queue_prediction_refresh(
 910        &mut self,
 911        project: Entity<Project>,
 912        throttle_entity: EntityId,
 913        cx: &mut Context<Self>,
 914        do_refresh: impl FnOnce(WeakEntity<Self>, &mut AsyncApp) -> Task<Result<()>> + 'static,
 915    ) {
 916        let zeta_project = self.get_or_init_zeta_project(&project, cx);
 917        let pending_prediction_id = zeta_project.next_pending_prediction_id;
 918        zeta_project.next_pending_prediction_id += 1;
 919        let last_request = zeta_project.last_prediction_refresh;
 920
 921        // TODO report cancelled requests like in zeta1
 922        let task = cx.spawn(async move |this, cx| {
 923            if let Some((last_entity, last_timestamp)) = last_request
 924                && throttle_entity == last_entity
 925                && let Some(timeout) =
 926                    (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
 927            {
 928                cx.background_executor().timer(timeout).await;
 929            }
 930
 931            do_refresh(this.clone(), cx).await.log_err();
 932
 933            this.update(cx, |this, cx| {
 934                let zeta_project = this.get_or_init_zeta_project(&project, cx);
 935
 936                if zeta_project.pending_predictions[0].id == pending_prediction_id {
 937                    zeta_project.pending_predictions.remove(0);
 938                } else {
 939                    zeta_project.pending_predictions.clear();
 940                }
 941
 942                cx.notify();
 943            })
 944            .ok();
 945        });
 946
 947        if zeta_project.pending_predictions.len() <= 1 {
 948            zeta_project.pending_predictions.push(PendingPrediction {
 949                id: pending_prediction_id,
 950                _task: task,
 951            });
 952        } else if zeta_project.pending_predictions.len() == 2 {
 953            zeta_project.pending_predictions.pop();
 954            zeta_project.pending_predictions.push(PendingPrediction {
 955                id: pending_prediction_id,
 956                _task: task,
 957            });
 958        }
 959    }
 960
 961    pub fn request_prediction(
 962        &mut self,
 963        project: &Entity<Project>,
 964        active_buffer: &Entity<Buffer>,
 965        position: language::Anchor,
 966        cx: &mut Context<Self>,
 967    ) -> Task<Result<Option<EditPrediction>>> {
 968        match self.edit_prediction_model {
 969            ZetaEditPredictionModel::ZedCloud => {
 970                self.request_prediction_with_zed_cloud(project, active_buffer, position, cx)
 971            }
 972            ZetaEditPredictionModel::Sweep => {
 973                self.request_prediction_with_sweep(project, active_buffer, position, true, cx)
 974            }
 975        }
 976    }
 977
 978    fn request_prediction_with_sweep(
 979        &mut self,
 980        project: &Entity<Project>,
 981        active_buffer: &Entity<Buffer>,
 982        position: language::Anchor,
 983        allow_jump: bool,
 984        cx: &mut Context<Self>,
 985    ) -> Task<Result<Option<EditPrediction>>> {
 986        let snapshot = active_buffer.read(cx).snapshot();
 987        let debug_info = self.sweep_ai_debug_info.clone();
 988        let Some(api_token) = self.sweep_api_token.clone() else {
 989            return Task::ready(Ok(None));
 990        };
 991        let full_path: Arc<Path> = snapshot
 992            .file()
 993            .map(|file| file.full_path(cx))
 994            .unwrap_or_else(|| "untitled".into())
 995            .into();
 996
 997        let project_file = project::File::from_dyn(snapshot.file());
 998        let repo_name = project_file
 999            .map(|file| file.worktree.read(cx).root_name_str())
1000            .unwrap_or("untitled")
1001            .into();
1002        let offset = position.to_offset(&snapshot);
1003
1004        let project_state = self.get_or_init_zeta_project(project, cx);
1005        let events = project_state.events.clone();
1006        let has_events = !events.is_empty();
1007        let recent_buffers = project_state.recent_paths.iter().cloned();
1008        let http_client = cx.http_client();
1009
1010        let recent_buffer_snapshots = recent_buffers
1011            .filter_map(|project_path| {
1012                let buffer = project.read(cx).get_open_buffer(&project_path, cx)?;
1013                if active_buffer == &buffer {
1014                    None
1015                } else {
1016                    Some(buffer.read(cx).snapshot())
1017                }
1018            })
1019            .take(3)
1020            .collect::<Vec<_>>();
1021
1022        const DIAGNOSTIC_LINES_RANGE: u32 = 20;
1023
1024        let cursor_point = position.to_point(&snapshot);
1025        let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
1026        let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
1027        let diagnostic_search_range =
1028            Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
1029
1030        let result = cx.background_spawn({
1031            let snapshot = snapshot.clone();
1032            let diagnostic_search_range = diagnostic_search_range.clone();
1033            async move {
1034                let text = snapshot.text();
1035
1036                let mut recent_changes = String::new();
1037                for event in events {
1038                    sweep_ai::write_event(event, &mut recent_changes).unwrap();
1039                }
1040
1041                let mut file_chunks = recent_buffer_snapshots
1042                    .into_iter()
1043                    .map(|snapshot| {
1044                        let end_point = Point::new(30, 0).min(snapshot.max_point());
1045                        sweep_ai::FileChunk {
1046                            content: snapshot.text_for_range(Point::zero()..end_point).collect(),
1047                            file_path: snapshot
1048                                .file()
1049                                .map(|f| f.path().as_unix_str())
1050                                .unwrap_or("untitled")
1051                                .to_string(),
1052                            start_line: 0,
1053                            end_line: end_point.row as usize,
1054                            timestamp: snapshot.file().and_then(|file| {
1055                                Some(
1056                                    file.disk_state()
1057                                        .mtime()?
1058                                        .to_seconds_and_nanos_for_persistence()?
1059                                        .0,
1060                                )
1061                            }),
1062                        }
1063                    })
1064                    .collect::<Vec<_>>();
1065
1066                let diagnostic_entries =
1067                    snapshot.diagnostics_in_range(diagnostic_search_range, false);
1068                let mut diagnostic_content = String::new();
1069                let mut diagnostic_count = 0;
1070
1071                for entry in diagnostic_entries {
1072                    let start_point: Point = entry.range.start;
1073
1074                    let severity = match entry.diagnostic.severity {
1075                        DiagnosticSeverity::ERROR => "error",
1076                        DiagnosticSeverity::WARNING => "warning",
1077                        DiagnosticSeverity::INFORMATION => "info",
1078                        DiagnosticSeverity::HINT => "hint",
1079                        _ => continue,
1080                    };
1081
1082                    diagnostic_count += 1;
1083
1084                    writeln!(
1085                        &mut diagnostic_content,
1086                        "{} at line {}: {}",
1087                        severity,
1088                        start_point.row + 1,
1089                        entry.diagnostic.message
1090                    )?;
1091                }
1092
1093                if !diagnostic_content.is_empty() {
1094                    file_chunks.push(sweep_ai::FileChunk {
1095                        file_path: format!("Diagnostics for {}", full_path.display()),
1096                        start_line: 0,
1097                        end_line: diagnostic_count,
1098                        content: diagnostic_content,
1099                        timestamp: None,
1100                    });
1101                }
1102
1103                let request_body = sweep_ai::AutocompleteRequest {
1104                    debug_info,
1105                    repo_name,
1106                    file_path: full_path.clone(),
1107                    file_contents: text.clone(),
1108                    original_file_contents: text,
1109                    cursor_position: offset,
1110                    recent_changes: recent_changes.clone(),
1111                    changes_above_cursor: true,
1112                    multiple_suggestions: false,
1113                    branch: None,
1114                    file_chunks,
1115                    retrieval_chunks: vec![],
1116                    recent_user_actions: vec![],
1117                    // TODO
1118                    privacy_mode_enabled: false,
1119                };
1120
1121                let mut buf: Vec<u8> = Vec::new();
1122                let writer = brotli::CompressorWriter::new(&mut buf, 4096, 11, 22);
1123                serde_json::to_writer(writer, &request_body)?;
1124                let body: AsyncBody = buf.into();
1125
1126                const SWEEP_API_URL: &str =
1127                    "https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
1128
1129                let request = http_client::Request::builder()
1130                    .uri(SWEEP_API_URL)
1131                    .header("Content-Type", "application/json")
1132                    .header("Authorization", format!("Bearer {}", api_token))
1133                    .header("Connection", "keep-alive")
1134                    .header("Content-Encoding", "br")
1135                    .method(Method::POST)
1136                    .body(body)?;
1137
1138                let mut response = http_client.send(request).await?;
1139
1140                let mut body: Vec<u8> = Vec::new();
1141                response.body_mut().read_to_end(&mut body).await?;
1142
1143                if !response.status().is_success() {
1144                    anyhow::bail!(
1145                        "Request failed with status: {:?}\nBody: {}",
1146                        response.status(),
1147                        String::from_utf8_lossy(&body),
1148                    );
1149                };
1150
1151                let response: sweep_ai::AutocompleteResponse = serde_json::from_slice(&body)?;
1152
1153                let old_text = snapshot
1154                    .text_for_range(response.start_index..response.end_index)
1155                    .collect::<String>();
1156                let edits = language::text_diff(&old_text, &response.completion)
1157                    .into_iter()
1158                    .map(|(range, text)| {
1159                        (
1160                            snapshot.anchor_after(response.start_index + range.start)
1161                                ..snapshot.anchor_before(response.start_index + range.end),
1162                            text,
1163                        )
1164                    })
1165                    .collect::<Vec<_>>();
1166
1167                anyhow::Ok((response.autocomplete_id, edits, snapshot))
1168            }
1169        });
1170
1171        let buffer = active_buffer.clone();
1172        let project = project.clone();
1173        let active_buffer = active_buffer.clone();
1174
1175        cx.spawn(async move |this, cx| {
1176            let (id, edits, old_snapshot) = result.await?;
1177
1178            if edits.is_empty() {
1179                if has_events
1180                    && allow_jump
1181                    && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
1182                        active_buffer,
1183                        &snapshot,
1184                        diagnostic_search_range,
1185                        cursor_point,
1186                        &project,
1187                        cx,
1188                    )
1189                    .await?
1190                {
1191                    return this
1192                        .update(cx, |this, cx| {
1193                            this.request_prediction_with_sweep(
1194                                &project,
1195                                &jump_buffer,
1196                                jump_position,
1197                                false,
1198                                cx,
1199                            )
1200                        })?
1201                        .await;
1202                }
1203
1204                return anyhow::Ok(None);
1205            }
1206
1207            let Some((edits, new_snapshot, preview_task)) =
1208                buffer.read_with(cx, |buffer, cx| {
1209                    let new_snapshot = buffer.snapshot();
1210
1211                    let edits: Arc<[(Range<Anchor>, Arc<str>)]> =
1212                        edit_prediction::interpolate_edits(&old_snapshot, &new_snapshot, &edits)?
1213                            .into();
1214                    let preview_task = buffer.preview_edits(edits.clone(), cx);
1215
1216                    Some((edits, new_snapshot, preview_task))
1217                })?
1218            else {
1219                return anyhow::Ok(None);
1220            };
1221
1222            let prediction = EditPrediction {
1223                id: EditPredictionId(id.into()),
1224                edits,
1225                snapshot: new_snapshot,
1226                edit_preview: preview_task.await,
1227                buffer,
1228            };
1229
1230            anyhow::Ok(Some(prediction))
1231        })
1232    }
1233
1234    async fn next_diagnostic_location(
1235        active_buffer: Entity<Buffer>,
1236        active_buffer_snapshot: &BufferSnapshot,
1237        active_buffer_diagnostic_search_range: Range<Point>,
1238        active_buffer_cursor_point: Point,
1239        project: &Entity<Project>,
1240        cx: &mut AsyncApp,
1241    ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
1242        // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
1243        let mut jump_location = active_buffer_snapshot
1244            .diagnostic_groups(None)
1245            .into_iter()
1246            .filter_map(|(_, group)| {
1247                let range = &group.entries[group.primary_ix]
1248                    .range
1249                    .to_point(&active_buffer_snapshot);
1250                if range.overlaps(&active_buffer_diagnostic_search_range) {
1251                    None
1252                } else {
1253                    Some(range.start)
1254                }
1255            })
1256            .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
1257            .map(|position| {
1258                (
1259                    active_buffer.clone(),
1260                    active_buffer_snapshot.anchor_before(position),
1261                )
1262            });
1263
1264        if jump_location.is_none() {
1265            let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
1266                let file = buffer.file()?;
1267
1268                Some(ProjectPath {
1269                    worktree_id: file.worktree_id(cx),
1270                    path: file.path().clone(),
1271                })
1272            })?;
1273
1274            let buffer_task = project.update(cx, |project, cx| {
1275                let (path, _, _) = project
1276                    .diagnostic_summaries(false, cx)
1277                    .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
1278                    .max_by_key(|(path, _, _)| {
1279                        // find the buffer with errors that shares most parent directories
1280                        path.path
1281                            .components()
1282                            .zip(
1283                                active_buffer_path
1284                                    .as_ref()
1285                                    .map(|p| p.path.components())
1286                                    .unwrap_or_default(),
1287                            )
1288                            .take_while(|(a, b)| a == b)
1289                            .count()
1290                    })?;
1291
1292                Some(project.open_buffer(path, cx))
1293            })?;
1294
1295            if let Some(buffer_task) = buffer_task {
1296                let closest_buffer = buffer_task.await?;
1297
1298                jump_location = closest_buffer
1299                    .read_with(cx, |buffer, _cx| {
1300                        buffer
1301                            .buffer_diagnostics(None)
1302                            .into_iter()
1303                            .min_by_key(|entry| entry.diagnostic.severity)
1304                            .map(|entry| entry.range.start)
1305                    })?
1306                    .map(|position| (closest_buffer, position));
1307            }
1308        }
1309
1310        anyhow::Ok(jump_location)
1311    }
1312
1313    fn request_prediction_with_zed_cloud(
1314        &mut self,
1315        project: &Entity<Project>,
1316        active_buffer: &Entity<Buffer>,
1317        position: language::Anchor,
1318        cx: &mut Context<Self>,
1319    ) -> Task<Result<Option<EditPrediction>>> {
1320        let project_state = self.projects.get(&project.entity_id());
1321
1322        let index_state = project_state.and_then(|state| {
1323            state
1324                .syntax_index
1325                .as_ref()
1326                .map(|syntax_index| syntax_index.read_with(cx, |index, _cx| index.state().clone()))
1327        });
1328        let options = self.options.clone();
1329        let active_snapshot = active_buffer.read(cx).snapshot();
1330        let Some(excerpt_path) = active_snapshot
1331            .file()
1332            .map(|path| -> Arc<Path> { path.full_path(cx).into() })
1333        else {
1334            return Task::ready(Err(anyhow!("No file path for excerpt")));
1335        };
1336        let client = self.client.clone();
1337        let llm_token = self.llm_token.clone();
1338        let app_version = AppVersion::global(cx);
1339        let worktree_snapshots = project
1340            .read(cx)
1341            .worktrees(cx)
1342            .map(|worktree| worktree.read(cx).snapshot())
1343            .collect::<Vec<_>>();
1344        let debug_tx = self.debug_tx.clone();
1345
1346        let events = project_state
1347            .map(|state| {
1348                state
1349                    .events
1350                    .iter()
1351                    .filter_map(|event| event.to_request_event(cx))
1352                    .collect::<Vec<_>>()
1353            })
1354            .unwrap_or_default();
1355
1356        let diagnostics = active_snapshot.diagnostic_sets().clone();
1357
1358        let parent_abs_path =
1359            project::File::from_dyn(active_buffer.read(cx).file()).and_then(|f| {
1360                let mut path = f.worktree.read(cx).absolutize(&f.path);
1361                if path.pop() { Some(path) } else { None }
1362            });
1363
1364        // TODO data collection
1365        let can_collect_data = cx.is_staff();
1366
1367        let empty_context_files = HashMap::default();
1368        let context_files = project_state
1369            .and_then(|project_state| project_state.context.as_ref())
1370            .unwrap_or(&empty_context_files);
1371
1372        #[cfg(feature = "eval-support")]
1373        let parsed_fut = futures::future::join_all(
1374            context_files
1375                .keys()
1376                .map(|buffer| buffer.read(cx).parsing_idle()),
1377        );
1378
1379        let mut included_files = context_files
1380            .iter()
1381            .filter_map(|(buffer_entity, ranges)| {
1382                let buffer = buffer_entity.read(cx);
1383                Some((
1384                    buffer_entity.clone(),
1385                    buffer.snapshot(),
1386                    buffer.file()?.full_path(cx).into(),
1387                    ranges.clone(),
1388                ))
1389            })
1390            .collect::<Vec<_>>();
1391
1392        included_files.sort_by(|(_, _, path_a, ranges_a), (_, _, path_b, ranges_b)| {
1393            (path_a, ranges_a.len()).cmp(&(path_b, ranges_b.len()))
1394        });
1395
1396        #[cfg(feature = "eval-support")]
1397        let eval_cache = self.eval_cache.clone();
1398
1399        let request_task = cx.background_spawn({
1400            let active_buffer = active_buffer.clone();
1401            async move {
1402                #[cfg(feature = "eval-support")]
1403                parsed_fut.await;
1404
1405                let index_state = if let Some(index_state) = index_state {
1406                    Some(index_state.lock_owned().await)
1407                } else {
1408                    None
1409                };
1410
1411                let cursor_offset = position.to_offset(&active_snapshot);
1412                let cursor_point = cursor_offset.to_point(&active_snapshot);
1413
1414                let before_retrieval = chrono::Utc::now();
1415
1416                let (diagnostic_groups, diagnostic_groups_truncated) =
1417                    Self::gather_nearby_diagnostics(
1418                        cursor_offset,
1419                        &diagnostics,
1420                        &active_snapshot,
1421                        options.max_diagnostic_bytes,
1422                    );
1423
1424                let cloud_request = match options.context {
1425                    ContextMode::Agentic(context_options) => {
1426                        let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
1427                            cursor_point,
1428                            &active_snapshot,
1429                            &context_options.excerpt,
1430                            index_state.as_deref(),
1431                        ) else {
1432                            return Ok((None, None));
1433                        };
1434
1435                        let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start)
1436                            ..active_snapshot.anchor_before(excerpt.range.end);
1437
1438                        if let Some(buffer_ix) =
1439                            included_files.iter().position(|(_, snapshot, _, _)| {
1440                                snapshot.remote_id() == active_snapshot.remote_id()
1441                            })
1442                        {
1443                            let (_, buffer, _, ranges) = &mut included_files[buffer_ix];
1444                            ranges.push(excerpt_anchor_range);
1445                            retrieval_search::merge_anchor_ranges(ranges, buffer);
1446                            let last_ix = included_files.len() - 1;
1447                            included_files.swap(buffer_ix, last_ix);
1448                        } else {
1449                            included_files.push((
1450                                active_buffer.clone(),
1451                                active_snapshot.clone(),
1452                                excerpt_path.clone(),
1453                                vec![excerpt_anchor_range],
1454                            ));
1455                        }
1456
1457                        let included_files = included_files
1458                            .iter()
1459                            .map(|(_, snapshot, path, ranges)| {
1460                                let ranges = ranges
1461                                    .iter()
1462                                    .map(|range| {
1463                                        let point_range = range.to_point(&snapshot);
1464                                        Line(point_range.start.row)..Line(point_range.end.row)
1465                                    })
1466                                    .collect::<Vec<_>>();
1467                                let excerpts = assemble_excerpts(&snapshot, ranges);
1468                                predict_edits_v3::IncludedFile {
1469                                    path: path.clone(),
1470                                    max_row: Line(snapshot.max_point().row),
1471                                    excerpts,
1472                                }
1473                            })
1474                            .collect::<Vec<_>>();
1475
1476                        predict_edits_v3::PredictEditsRequest {
1477                            excerpt_path,
1478                            excerpt: String::new(),
1479                            excerpt_line_range: Line(0)..Line(0),
1480                            excerpt_range: 0..0,
1481                            cursor_point: predict_edits_v3::Point {
1482                                line: predict_edits_v3::Line(cursor_point.row),
1483                                column: cursor_point.column,
1484                            },
1485                            included_files,
1486                            referenced_declarations: vec![],
1487                            events,
1488                            can_collect_data,
1489                            diagnostic_groups,
1490                            diagnostic_groups_truncated,
1491                            debug_info: debug_tx.is_some(),
1492                            prompt_max_bytes: Some(options.max_prompt_bytes),
1493                            prompt_format: options.prompt_format,
1494                            // TODO [zeta2]
1495                            signatures: vec![],
1496                            excerpt_parent: None,
1497                            git_info: None,
1498                        }
1499                    }
1500                    ContextMode::Syntax(context_options) => {
1501                        let Some(context) = EditPredictionContext::gather_context(
1502                            cursor_point,
1503                            &active_snapshot,
1504                            parent_abs_path.as_deref(),
1505                            &context_options,
1506                            index_state.as_deref(),
1507                        ) else {
1508                            return Ok((None, None));
1509                        };
1510
1511                        make_syntax_context_cloud_request(
1512                            excerpt_path,
1513                            context,
1514                            events,
1515                            can_collect_data,
1516                            diagnostic_groups,
1517                            diagnostic_groups_truncated,
1518                            None,
1519                            debug_tx.is_some(),
1520                            &worktree_snapshots,
1521                            index_state.as_deref(),
1522                            Some(options.max_prompt_bytes),
1523                            options.prompt_format,
1524                        )
1525                    }
1526                };
1527
1528                let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request);
1529
1530                let retrieval_time = chrono::Utc::now() - before_retrieval;
1531
1532                let debug_response_tx = if let Some(debug_tx) = &debug_tx {
1533                    let (response_tx, response_rx) = oneshot::channel();
1534
1535                    debug_tx
1536                        .unbounded_send(ZetaDebugInfo::EditPredictionRequested(
1537                            ZetaEditPredictionDebugInfo {
1538                                request: cloud_request.clone(),
1539                                retrieval_time,
1540                                buffer: active_buffer.downgrade(),
1541                                local_prompt: match prompt_result.as_ref() {
1542                                    Ok((prompt, _)) => Ok(prompt.clone()),
1543                                    Err(err) => Err(err.to_string()),
1544                                },
1545                                position,
1546                                response_rx,
1547                            },
1548                        ))
1549                        .ok();
1550                    Some(response_tx)
1551                } else {
1552                    None
1553                };
1554
1555                if cfg!(debug_assertions) && env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
1556                    if let Some(debug_response_tx) = debug_response_tx {
1557                        debug_response_tx
1558                            .send((Err("Request skipped".to_string()), TimeDelta::zero()))
1559                            .ok();
1560                    }
1561                    anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
1562                }
1563
1564                let (prompt, _) = prompt_result?;
1565                let generation_params =
1566                    cloud_zeta2_prompt::generation_params(cloud_request.prompt_format);
1567                let request = open_ai::Request {
1568                    model: EDIT_PREDICTIONS_MODEL_ID.clone(),
1569                    messages: vec![open_ai::RequestMessage::User {
1570                        content: open_ai::MessageContent::Plain(prompt),
1571                    }],
1572                    stream: false,
1573                    max_completion_tokens: None,
1574                    stop: generation_params.stop.unwrap_or_default(),
1575                    temperature: generation_params.temperature.unwrap_or(0.7),
1576                    tool_choice: None,
1577                    parallel_tool_calls: None,
1578                    tools: vec![],
1579                    prompt_cache_key: None,
1580                    reasoning_effort: None,
1581                };
1582
1583                log::trace!("Sending edit prediction request");
1584
1585                let before_request = chrono::Utc::now();
1586                let response = Self::send_raw_llm_request(
1587                    request,
1588                    client,
1589                    llm_token,
1590                    app_version,
1591                    #[cfg(feature = "eval-support")]
1592                    eval_cache,
1593                    #[cfg(feature = "eval-support")]
1594                    EvalCacheEntryKind::Prediction,
1595                )
1596                .await;
1597                let request_time = chrono::Utc::now() - before_request;
1598
1599                log::trace!("Got edit prediction response");
1600
1601                if let Some(debug_response_tx) = debug_response_tx {
1602                    debug_response_tx
1603                        .send((
1604                            response
1605                                .as_ref()
1606                                .map_err(|err| err.to_string())
1607                                .map(|response| response.0.clone()),
1608                            request_time,
1609                        ))
1610                        .ok();
1611                }
1612
1613                let (res, usage) = response?;
1614                let request_id = EditPredictionId(res.id.clone().into());
1615                let Some(mut output_text) = text_from_response(res) else {
1616                    return Ok((None, usage));
1617                };
1618
1619                if output_text.contains(CURSOR_MARKER) {
1620                    log::trace!("Stripping out {CURSOR_MARKER} from response");
1621                    output_text = output_text.replace(CURSOR_MARKER, "");
1622                }
1623
1624                let get_buffer_from_context = |path: &Path| {
1625                    included_files
1626                        .iter()
1627                        .find_map(|(_, buffer, probe_path, ranges)| {
1628                            if probe_path.as_ref() == path {
1629                                Some((buffer, ranges.as_slice()))
1630                            } else {
1631                                None
1632                            }
1633                        })
1634                };
1635
1636                let (edited_buffer_snapshot, edits) = match options.prompt_format {
1637                    PromptFormat::NumLinesUniDiff => {
1638                        // TODO: Implement parsing of multi-file diffs
1639                        crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
1640                    }
1641                    PromptFormat::Minimal
1642                    | PromptFormat::MinimalQwen
1643                    | PromptFormat::SeedCoder1120 => {
1644                        if output_text.contains("--- a/\n+++ b/\nNo edits") {
1645                            let edits = vec![];
1646                            (&active_snapshot, edits)
1647                        } else {
1648                            crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
1649                        }
1650                    }
1651                    PromptFormat::OldTextNewText => {
1652                        crate::xml_edits::parse_xml_edits(&output_text, get_buffer_from_context)
1653                            .await?
1654                    }
1655                    _ => {
1656                        bail!("unsupported prompt format {}", options.prompt_format)
1657                    }
1658                };
1659
1660                let edited_buffer = included_files
1661                    .iter()
1662                    .find_map(|(buffer, snapshot, _, _)| {
1663                        if snapshot.remote_id() == edited_buffer_snapshot.remote_id() {
1664                            Some(buffer.clone())
1665                        } else {
1666                            None
1667                        }
1668                    })
1669                    .context("Failed to find buffer in included_buffers")?;
1670
1671                anyhow::Ok((
1672                    Some((
1673                        request_id,
1674                        edited_buffer,
1675                        edited_buffer_snapshot.clone(),
1676                        edits,
1677                    )),
1678                    usage,
1679                ))
1680            }
1681        });
1682
1683        cx.spawn({
1684            async move |this, cx| {
1685                let Some((id, edited_buffer, edited_buffer_snapshot, edits)) =
1686                    Self::handle_api_response(&this, request_task.await, cx)?
1687                else {
1688                    return Ok(None);
1689                };
1690
1691                // TODO telemetry: duration, etc
1692                Ok(
1693                    EditPrediction::new(id, &edited_buffer, &edited_buffer_snapshot, edits, cx)
1694                        .await,
1695                )
1696            }
1697        })
1698    }
1699
1700    async fn send_raw_llm_request(
1701        request: open_ai::Request,
1702        client: Arc<Client>,
1703        llm_token: LlmApiToken,
1704        app_version: Version,
1705        #[cfg(feature = "eval-support")] eval_cache: Option<Arc<dyn EvalCache>>,
1706        #[cfg(feature = "eval-support")] eval_cache_kind: EvalCacheEntryKind,
1707    ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
1708        let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
1709            http_client::Url::parse(&predict_edits_url)?
1710        } else {
1711            client
1712                .http_client()
1713                .build_zed_llm_url("/predict_edits/raw", &[])?
1714        };
1715
1716        #[cfg(feature = "eval-support")]
1717        let cache_key = if let Some(cache) = eval_cache {
1718            use collections::FxHasher;
1719            use std::hash::{Hash, Hasher};
1720
1721            let mut hasher = FxHasher::default();
1722            url.hash(&mut hasher);
1723            let request_str = serde_json::to_string_pretty(&request)?;
1724            request_str.hash(&mut hasher);
1725            let hash = hasher.finish();
1726
1727            let key = (eval_cache_kind, hash);
1728            if let Some(response_str) = cache.read(key) {
1729                return Ok((serde_json::from_str(&response_str)?, None));
1730            }
1731
1732            Some((cache, request_str, key))
1733        } else {
1734            None
1735        };
1736
1737        let (response, usage) = Self::send_api_request(
1738            |builder| {
1739                let req = builder
1740                    .uri(url.as_ref())
1741                    .body(serde_json::to_string(&request)?.into());
1742                Ok(req?)
1743            },
1744            client,
1745            llm_token,
1746            app_version,
1747        )
1748        .await?;
1749
1750        #[cfg(feature = "eval-support")]
1751        if let Some((cache, request, key)) = cache_key {
1752            cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
1753        }
1754
1755        Ok((response, usage))
1756    }
1757
1758    fn handle_api_response<T>(
1759        this: &WeakEntity<Self>,
1760        response: Result<(T, Option<EditPredictionUsage>)>,
1761        cx: &mut gpui::AsyncApp,
1762    ) -> Result<T> {
1763        match response {
1764            Ok((data, usage)) => {
1765                if let Some(usage) = usage {
1766                    this.update(cx, |this, cx| {
1767                        this.user_store.update(cx, |user_store, cx| {
1768                            user_store.update_edit_prediction_usage(usage, cx);
1769                        });
1770                    })
1771                    .ok();
1772                }
1773                Ok(data)
1774            }
1775            Err(err) => {
1776                if err.is::<ZedUpdateRequiredError>() {
1777                    cx.update(|cx| {
1778                        this.update(cx, |this, _cx| {
1779                            this.update_required = true;
1780                        })
1781                        .ok();
1782
1783                        let error_message: SharedString = err.to_string().into();
1784                        show_app_notification(
1785                            NotificationId::unique::<ZedUpdateRequiredError>(),
1786                            cx,
1787                            move |cx| {
1788                                cx.new(|cx| {
1789                                    ErrorMessagePrompt::new(error_message.clone(), cx)
1790                                        .with_link_button("Update Zed", "https://zed.dev/releases")
1791                                })
1792                            },
1793                        );
1794                    })
1795                    .ok();
1796                }
1797                Err(err)
1798            }
1799        }
1800    }
1801
1802    async fn send_api_request<Res>(
1803        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1804        client: Arc<Client>,
1805        llm_token: LlmApiToken,
1806        app_version: Version,
1807    ) -> Result<(Res, Option<EditPredictionUsage>)>
1808    where
1809        Res: DeserializeOwned,
1810    {
1811        let http_client = client.http_client();
1812        let mut token = llm_token.acquire(&client).await?;
1813        let mut did_retry = false;
1814
1815        loop {
1816            let request_builder = http_client::Request::builder().method(Method::POST);
1817
1818            let request = build(
1819                request_builder
1820                    .header("Content-Type", "application/json")
1821                    .header("Authorization", format!("Bearer {}", token))
1822                    .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
1823            )?;
1824
1825            let mut response = http_client.send(request).await?;
1826
1827            if let Some(minimum_required_version) = response
1828                .headers()
1829                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
1830                .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
1831            {
1832                anyhow::ensure!(
1833                    app_version >= minimum_required_version,
1834                    ZedUpdateRequiredError {
1835                        minimum_version: minimum_required_version
1836                    }
1837                );
1838            }
1839
1840            if response.status().is_success() {
1841                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
1842
1843                let mut body = Vec::new();
1844                response.body_mut().read_to_end(&mut body).await?;
1845                return Ok((serde_json::from_slice(&body)?, usage));
1846            } else if !did_retry
1847                && response
1848                    .headers()
1849                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1850                    .is_some()
1851            {
1852                did_retry = true;
1853                token = llm_token.refresh(&client).await?;
1854            } else {
1855                let mut body = String::new();
1856                response.body_mut().read_to_string(&mut body).await?;
1857                anyhow::bail!(
1858                    "Request failed with status: {:?}\nBody: {}",
1859                    response.status(),
1860                    body
1861                );
1862            }
1863        }
1864    }
1865
1866    pub const CONTEXT_RETRIEVAL_IDLE_DURATION: Duration = Duration::from_secs(10);
1867    pub const CONTEXT_RETRIEVAL_DEBOUNCE_DURATION: Duration = Duration::from_secs(3);
1868
1869    // Refresh the related excerpts when the user just beguns editing after
1870    // an idle period, and after they pause editing.
1871    fn refresh_context_if_needed(
1872        &mut self,
1873        project: &Entity<Project>,
1874        buffer: &Entity<language::Buffer>,
1875        cursor_position: language::Anchor,
1876        cx: &mut Context<Self>,
1877    ) {
1878        if !matches!(&self.options().context, ContextMode::Agentic { .. }) {
1879            return;
1880        }
1881
1882        let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1883            return;
1884        };
1885
1886        let now = Instant::now();
1887        let was_idle = zeta_project
1888            .refresh_context_timestamp
1889            .map_or(true, |timestamp| {
1890                now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION
1891            });
1892        zeta_project.refresh_context_timestamp = Some(now);
1893        zeta_project.refresh_context_debounce_task = Some(cx.spawn({
1894            let buffer = buffer.clone();
1895            let project = project.clone();
1896            async move |this, cx| {
1897                if was_idle {
1898                    log::debug!("refetching edit prediction context after idle");
1899                } else {
1900                    cx.background_executor()
1901                        .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION)
1902                        .await;
1903                    log::debug!("refetching edit prediction context after pause");
1904                }
1905                this.update(cx, |this, cx| {
1906                    let task = this.refresh_context(project.clone(), buffer, cursor_position, cx);
1907
1908                    if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
1909                        zeta_project.refresh_context_task = Some(task.log_err());
1910                    };
1911                })
1912                .ok()
1913            }
1914        }));
1915    }
1916
1917    // Refresh the related excerpts asynchronously. Ensure the task runs to completion,
1918    // and avoid spawning more than one concurrent task.
1919    pub fn refresh_context(
1920        &mut self,
1921        project: Entity<Project>,
1922        buffer: Entity<language::Buffer>,
1923        cursor_position: language::Anchor,
1924        cx: &mut Context<Self>,
1925    ) -> Task<Result<()>> {
1926        let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
1927            return Task::ready(anyhow::Ok(()));
1928        };
1929
1930        let ContextMode::Agentic(options) = &self.options().context else {
1931            return Task::ready(anyhow::Ok(()));
1932        };
1933
1934        let snapshot = buffer.read(cx).snapshot();
1935        let cursor_point = cursor_position.to_point(&snapshot);
1936        let Some(cursor_excerpt) = EditPredictionExcerpt::select_from_buffer(
1937            cursor_point,
1938            &snapshot,
1939            &options.excerpt,
1940            None,
1941        ) else {
1942            return Task::ready(Ok(()));
1943        };
1944
1945        let app_version = AppVersion::global(cx);
1946        let client = self.client.clone();
1947        let llm_token = self.llm_token.clone();
1948        let debug_tx = self.debug_tx.clone();
1949        let current_file_path: Arc<Path> = snapshot
1950            .file()
1951            .map(|f| f.full_path(cx).into())
1952            .unwrap_or_else(|| Path::new("untitled").into());
1953
1954        let prompt = match cloud_zeta2_prompt::retrieval_prompt::build_prompt(
1955            predict_edits_v3::PlanContextRetrievalRequest {
1956                excerpt: cursor_excerpt.text(&snapshot).body,
1957                excerpt_path: current_file_path,
1958                excerpt_line_range: cursor_excerpt.line_range,
1959                cursor_file_max_row: Line(snapshot.max_point().row),
1960                events: zeta_project
1961                    .events
1962                    .iter()
1963                    .filter_map(|ev| ev.to_request_event(cx))
1964                    .collect(),
1965            },
1966        ) {
1967            Ok(prompt) => prompt,
1968            Err(err) => {
1969                return Task::ready(Err(err));
1970            }
1971        };
1972
1973        if let Some(debug_tx) = &debug_tx {
1974            debug_tx
1975                .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted(
1976                    ZetaContextRetrievalStartedDebugInfo {
1977                        project: project.clone(),
1978                        timestamp: Instant::now(),
1979                        search_prompt: prompt.clone(),
1980                    },
1981                ))
1982                .ok();
1983        }
1984
1985        pub static TOOL_SCHEMA: LazyLock<(serde_json::Value, String)> = LazyLock::new(|| {
1986            let schema = language_model::tool_schema::root_schema_for::<SearchToolInput>(
1987                language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset,
1988            );
1989
1990            let description = schema
1991                .get("description")
1992                .and_then(|description| description.as_str())
1993                .unwrap()
1994                .to_string();
1995
1996            (schema.into(), description)
1997        });
1998
1999        let (tool_schema, tool_description) = TOOL_SCHEMA.clone();
2000
2001        let request = open_ai::Request {
2002            model: CONTEXT_RETRIEVAL_MODEL_ID.clone(),
2003            messages: vec![open_ai::RequestMessage::User {
2004                content: open_ai::MessageContent::Plain(prompt),
2005            }],
2006            stream: false,
2007            max_completion_tokens: None,
2008            stop: Default::default(),
2009            temperature: 0.7,
2010            tool_choice: None,
2011            parallel_tool_calls: None,
2012            tools: vec![open_ai::ToolDefinition::Function {
2013                function: FunctionDefinition {
2014                    name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME.to_string(),
2015                    description: Some(tool_description),
2016                    parameters: Some(tool_schema),
2017                },
2018            }],
2019            prompt_cache_key: None,
2020            reasoning_effort: None,
2021        };
2022
2023        #[cfg(feature = "eval-support")]
2024        let eval_cache = self.eval_cache.clone();
2025
2026        cx.spawn(async move |this, cx| {
2027            log::trace!("Sending search planning request");
2028            let response = Self::send_raw_llm_request(
2029                request,
2030                client,
2031                llm_token,
2032                app_version,
2033                #[cfg(feature = "eval-support")]
2034                eval_cache.clone(),
2035                #[cfg(feature = "eval-support")]
2036                EvalCacheEntryKind::Context,
2037            )
2038            .await;
2039            let mut response = Self::handle_api_response(&this, response, cx)?;
2040            log::trace!("Got search planning response");
2041
2042            let choice = response
2043                .choices
2044                .pop()
2045                .context("No choices in retrieval response")?;
2046            let open_ai::RequestMessage::Assistant {
2047                content: _,
2048                tool_calls,
2049            } = choice.message
2050            else {
2051                anyhow::bail!("Retrieval response didn't include an assistant message");
2052            };
2053
2054            let mut queries: Vec<SearchToolQuery> = Vec::new();
2055            for tool_call in tool_calls {
2056                let open_ai::ToolCallContent::Function { function } = tool_call.content;
2057                if function.name != cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME {
2058                    log::warn!(
2059                        "Context retrieval response tried to call an unknown tool: {}",
2060                        function.name
2061                    );
2062
2063                    continue;
2064                }
2065
2066                let input: SearchToolInput = serde_json::from_str(&function.arguments)
2067                    .with_context(|| format!("invalid search json {}", &function.arguments))?;
2068                queries.extend(input.queries);
2069            }
2070
2071            if let Some(debug_tx) = &debug_tx {
2072                debug_tx
2073                    .unbounded_send(ZetaDebugInfo::SearchQueriesGenerated(
2074                        ZetaSearchQueryDebugInfo {
2075                            project: project.clone(),
2076                            timestamp: Instant::now(),
2077                            search_queries: queries.clone(),
2078                        },
2079                    ))
2080                    .ok();
2081            }
2082
2083            log::trace!("Running retrieval search: {queries:#?}");
2084
2085            let related_excerpts_result = retrieval_search::run_retrieval_searches(
2086                queries,
2087                project.clone(),
2088                #[cfg(feature = "eval-support")]
2089                eval_cache,
2090                cx,
2091            )
2092            .await;
2093
2094            log::trace!("Search queries executed");
2095
2096            if let Some(debug_tx) = &debug_tx {
2097                debug_tx
2098                    .unbounded_send(ZetaDebugInfo::SearchQueriesExecuted(
2099                        ZetaContextRetrievalDebugInfo {
2100                            project: project.clone(),
2101                            timestamp: Instant::now(),
2102                        },
2103                    ))
2104                    .ok();
2105            }
2106
2107            this.update(cx, |this, _cx| {
2108                let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
2109                    return Ok(());
2110                };
2111                zeta_project.refresh_context_task.take();
2112                if let Some(debug_tx) = &this.debug_tx {
2113                    debug_tx
2114                        .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished(
2115                            ZetaContextRetrievalDebugInfo {
2116                                project,
2117                                timestamp: Instant::now(),
2118                            },
2119                        ))
2120                        .ok();
2121                }
2122                match related_excerpts_result {
2123                    Ok(excerpts) => {
2124                        zeta_project.context = Some(excerpts);
2125                        Ok(())
2126                    }
2127                    Err(error) => Err(error),
2128                }
2129            })?
2130        })
2131    }
2132
2133    pub fn set_context(
2134        &mut self,
2135        project: Entity<Project>,
2136        context: HashMap<Entity<Buffer>, Vec<Range<Anchor>>>,
2137    ) {
2138        if let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) {
2139            zeta_project.context = Some(context);
2140        }
2141    }
2142
2143    fn gather_nearby_diagnostics(
2144        cursor_offset: usize,
2145        diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
2146        snapshot: &BufferSnapshot,
2147        max_diagnostics_bytes: usize,
2148    ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
2149        // TODO: Could make this more efficient
2150        let mut diagnostic_groups = Vec::new();
2151        for (language_server_id, diagnostics) in diagnostic_sets {
2152            let mut groups = Vec::new();
2153            diagnostics.groups(*language_server_id, &mut groups, &snapshot);
2154            diagnostic_groups.extend(
2155                groups
2156                    .into_iter()
2157                    .map(|(_, group)| group.resolve::<usize>(&snapshot)),
2158            );
2159        }
2160
2161        // sort by proximity to cursor
2162        diagnostic_groups.sort_by_key(|group| {
2163            let range = &group.entries[group.primary_ix].range;
2164            if range.start >= cursor_offset {
2165                range.start - cursor_offset
2166            } else if cursor_offset >= range.end {
2167                cursor_offset - range.end
2168            } else {
2169                (cursor_offset - range.start).min(range.end - cursor_offset)
2170            }
2171        });
2172
2173        let mut results = Vec::new();
2174        let mut diagnostic_groups_truncated = false;
2175        let mut diagnostics_byte_count = 0;
2176        for group in diagnostic_groups {
2177            let raw_value = serde_json::value::to_raw_value(&group).unwrap();
2178            diagnostics_byte_count += raw_value.get().len();
2179            if diagnostics_byte_count > max_diagnostics_bytes {
2180                diagnostic_groups_truncated = true;
2181                break;
2182            }
2183            results.push(predict_edits_v3::DiagnosticGroup(raw_value));
2184        }
2185
2186        (results, diagnostic_groups_truncated)
2187    }
2188
2189    // TODO: Dedupe with similar code in request_prediction?
2190    pub fn cloud_request_for_zeta_cli(
2191        &mut self,
2192        project: &Entity<Project>,
2193        buffer: &Entity<Buffer>,
2194        position: language::Anchor,
2195        cx: &mut Context<Self>,
2196    ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
2197        let project_state = self.projects.get(&project.entity_id());
2198
2199        let index_state = project_state.and_then(|state| {
2200            state
2201                .syntax_index
2202                .as_ref()
2203                .map(|index| index.read_with(cx, |index, _cx| index.state().clone()))
2204        });
2205        let options = self.options.clone();
2206        let snapshot = buffer.read(cx).snapshot();
2207        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
2208            return Task::ready(Err(anyhow!("No file path for excerpt")));
2209        };
2210        let worktree_snapshots = project
2211            .read(cx)
2212            .worktrees(cx)
2213            .map(|worktree| worktree.read(cx).snapshot())
2214            .collect::<Vec<_>>();
2215
2216        let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
2217            let mut path = f.worktree.read(cx).absolutize(&f.path);
2218            if path.pop() { Some(path) } else { None }
2219        });
2220
2221        cx.background_spawn(async move {
2222            let index_state = if let Some(index_state) = index_state {
2223                Some(index_state.lock_owned().await)
2224            } else {
2225                None
2226            };
2227
2228            let cursor_point = position.to_point(&snapshot);
2229
2230            let debug_info = true;
2231            EditPredictionContext::gather_context(
2232                cursor_point,
2233                &snapshot,
2234                parent_abs_path.as_deref(),
2235                match &options.context {
2236                    ContextMode::Agentic(_) => {
2237                        // TODO
2238                        panic!("Llm mode not supported in zeta cli yet");
2239                    }
2240                    ContextMode::Syntax(edit_prediction_context_options) => {
2241                        edit_prediction_context_options
2242                    }
2243                },
2244                index_state.as_deref(),
2245            )
2246            .context("Failed to select excerpt")
2247            .map(|context| {
2248                make_syntax_context_cloud_request(
2249                    excerpt_path.into(),
2250                    context,
2251                    // TODO pass everything
2252                    Vec::new(),
2253                    false,
2254                    Vec::new(),
2255                    false,
2256                    None,
2257                    debug_info,
2258                    &worktree_snapshots,
2259                    index_state.as_deref(),
2260                    Some(options.max_prompt_bytes),
2261                    options.prompt_format,
2262                )
2263            })
2264        })
2265    }
2266
2267    pub fn wait_for_initial_indexing(
2268        &mut self,
2269        project: &Entity<Project>,
2270        cx: &mut Context<Self>,
2271    ) -> Task<Result<()>> {
2272        let zeta_project = self.get_or_init_zeta_project(project, cx);
2273        if let Some(syntax_index) = &zeta_project.syntax_index {
2274            syntax_index.read(cx).wait_for_initial_file_indexing(cx)
2275        } else {
2276            Task::ready(Ok(()))
2277        }
2278    }
2279}
2280
2281pub fn text_from_response(mut res: open_ai::Response) -> Option<String> {
2282    let choice = res.choices.pop()?;
2283    let output_text = match choice.message {
2284        open_ai::RequestMessage::Assistant {
2285            content: Some(open_ai::MessageContent::Plain(content)),
2286            ..
2287        } => content,
2288        open_ai::RequestMessage::Assistant {
2289            content: Some(open_ai::MessageContent::Multipart(mut content)),
2290            ..
2291        } => {
2292            if content.is_empty() {
2293                log::error!("No output from Baseten completion response");
2294                return None;
2295            }
2296
2297            match content.remove(0) {
2298                open_ai::MessagePart::Text { text } => text,
2299                open_ai::MessagePart::Image { .. } => {
2300                    log::error!("Expected text, got an image");
2301                    return None;
2302                }
2303            }
2304        }
2305        _ => {
2306            log::error!("Invalid response message: {:?}", choice.message);
2307            return None;
2308        }
2309    };
2310    Some(output_text)
2311}
2312
2313#[derive(Error, Debug)]
2314#[error(
2315    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
2316)]
2317pub struct ZedUpdateRequiredError {
2318    minimum_version: Version,
2319}
2320
2321fn make_syntax_context_cloud_request(
2322    excerpt_path: Arc<Path>,
2323    context: EditPredictionContext,
2324    events: Vec<predict_edits_v3::Event>,
2325    can_collect_data: bool,
2326    diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
2327    diagnostic_groups_truncated: bool,
2328    git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
2329    debug_info: bool,
2330    worktrees: &Vec<worktree::Snapshot>,
2331    index_state: Option<&SyntaxIndexState>,
2332    prompt_max_bytes: Option<usize>,
2333    prompt_format: PromptFormat,
2334) -> predict_edits_v3::PredictEditsRequest {
2335    let mut signatures = Vec::new();
2336    let mut declaration_to_signature_index = HashMap::default();
2337    let mut referenced_declarations = Vec::new();
2338
2339    for snippet in context.declarations {
2340        let project_entry_id = snippet.declaration.project_entry_id();
2341        let Some(path) = worktrees.iter().find_map(|worktree| {
2342            worktree.entry_for_id(project_entry_id).map(|entry| {
2343                let mut full_path = RelPathBuf::new();
2344                full_path.push(worktree.root_name());
2345                full_path.push(&entry.path);
2346                full_path
2347            })
2348        }) else {
2349            continue;
2350        };
2351
2352        let parent_index = index_state.and_then(|index_state| {
2353            snippet.declaration.parent().and_then(|parent| {
2354                add_signature(
2355                    parent,
2356                    &mut declaration_to_signature_index,
2357                    &mut signatures,
2358                    index_state,
2359                )
2360            })
2361        });
2362
2363        let (text, text_is_truncated) = snippet.declaration.item_text();
2364        referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
2365            path: path.as_std_path().into(),
2366            text: text.into(),
2367            range: snippet.declaration.item_line_range(),
2368            text_is_truncated,
2369            signature_range: snippet.declaration.signature_range_in_item_text(),
2370            parent_index,
2371            signature_score: snippet.score(DeclarationStyle::Signature),
2372            declaration_score: snippet.score(DeclarationStyle::Declaration),
2373            score_components: snippet.components,
2374        });
2375    }
2376
2377    let excerpt_parent = index_state.and_then(|index_state| {
2378        context
2379            .excerpt
2380            .parent_declarations
2381            .last()
2382            .and_then(|(parent, _)| {
2383                add_signature(
2384                    *parent,
2385                    &mut declaration_to_signature_index,
2386                    &mut signatures,
2387                    index_state,
2388                )
2389            })
2390    });
2391
2392    predict_edits_v3::PredictEditsRequest {
2393        excerpt_path,
2394        excerpt: context.excerpt_text.body,
2395        excerpt_line_range: context.excerpt.line_range,
2396        excerpt_range: context.excerpt.range,
2397        cursor_point: predict_edits_v3::Point {
2398            line: predict_edits_v3::Line(context.cursor_point.row),
2399            column: context.cursor_point.column,
2400        },
2401        referenced_declarations,
2402        included_files: vec![],
2403        signatures,
2404        excerpt_parent,
2405        events,
2406        can_collect_data,
2407        diagnostic_groups,
2408        diagnostic_groups_truncated,
2409        git_info,
2410        debug_info,
2411        prompt_max_bytes,
2412        prompt_format,
2413    }
2414}
2415
2416fn add_signature(
2417    declaration_id: DeclarationId,
2418    declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
2419    signatures: &mut Vec<Signature>,
2420    index: &SyntaxIndexState,
2421) -> Option<usize> {
2422    if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
2423        return Some(*signature_index);
2424    }
2425    let Some(parent_declaration) = index.declaration(declaration_id) else {
2426        log::error!("bug: missing parent declaration");
2427        return None;
2428    };
2429    let parent_index = parent_declaration.parent().and_then(|parent| {
2430        add_signature(parent, declaration_to_signature_index, signatures, index)
2431    });
2432    let (text, text_is_truncated) = parent_declaration.signature_text();
2433    let signature_index = signatures.len();
2434    signatures.push(Signature {
2435        text: text.into(),
2436        text_is_truncated,
2437        parent_index,
2438        range: parent_declaration.signature_line_range(),
2439    });
2440    declaration_to_signature_index.insert(declaration_id, signature_index);
2441    Some(signature_index)
2442}
2443
2444#[cfg(feature = "eval-support")]
2445pub type EvalCacheKey = (EvalCacheEntryKind, u64);
2446
2447#[cfg(feature = "eval-support")]
2448#[derive(Debug, Clone, Copy, PartialEq)]
2449pub enum EvalCacheEntryKind {
2450    Context,
2451    Search,
2452    Prediction,
2453}
2454
2455#[cfg(feature = "eval-support")]
2456impl std::fmt::Display for EvalCacheEntryKind {
2457    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2458        match self {
2459            EvalCacheEntryKind::Search => write!(f, "search"),
2460            EvalCacheEntryKind::Context => write!(f, "context"),
2461            EvalCacheEntryKind::Prediction => write!(f, "prediction"),
2462        }
2463    }
2464}
2465
2466#[cfg(feature = "eval-support")]
2467pub trait EvalCache: Send + Sync {
2468    fn read(&self, key: EvalCacheKey) -> Option<String>;
2469    fn write(&self, key: EvalCacheKey, input: &str, value: &str);
2470}
2471
2472#[cfg(test)]
2473mod tests {
2474    use std::{path::Path, sync::Arc};
2475
2476    use client::UserStore;
2477    use clock::FakeSystemClock;
2478    use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
2479    use futures::{
2480        AsyncReadExt, StreamExt,
2481        channel::{mpsc, oneshot},
2482    };
2483    use gpui::{
2484        Entity, TestAppContext,
2485        http_client::{FakeHttpClient, Response},
2486        prelude::*,
2487    };
2488    use indoc::indoc;
2489    use language::OffsetRangeExt as _;
2490    use open_ai::Usage;
2491    use pretty_assertions::{assert_eq, assert_matches};
2492    use project::{FakeFs, Project};
2493    use serde_json::json;
2494    use settings::SettingsStore;
2495    use util::path;
2496    use uuid::Uuid;
2497
2498    use crate::{BufferEditPrediction, Zeta};
2499
2500    #[gpui::test]
2501    async fn test_current_state(cx: &mut TestAppContext) {
2502        let (zeta, mut req_rx) = init_test(cx);
2503        let fs = FakeFs::new(cx.executor());
2504        fs.insert_tree(
2505            "/root",
2506            json!({
2507                "1.txt": "Hello!\nHow\nBye\n",
2508                "2.txt": "Hola!\nComo\nAdios\n"
2509            }),
2510        )
2511        .await;
2512        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2513
2514        zeta.update(cx, |zeta, cx| {
2515            zeta.register_project(&project, cx);
2516        });
2517
2518        let buffer1 = project
2519            .update(cx, |project, cx| {
2520                let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
2521                project.open_buffer(path, cx)
2522            })
2523            .await
2524            .unwrap();
2525        let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
2526        let position = snapshot1.anchor_before(language::Point::new(1, 3));
2527
2528        // Prediction for current file
2529
2530        zeta.update(cx, |zeta, cx| {
2531            zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
2532        });
2533        let (_request, respond_tx) = req_rx.next().await.unwrap();
2534
2535        respond_tx
2536            .send(model_response(indoc! {r"
2537                --- a/root/1.txt
2538                +++ b/root/1.txt
2539                @@ ... @@
2540                 Hello!
2541                -How
2542                +How are you?
2543                 Bye
2544            "}))
2545            .unwrap();
2546
2547        cx.run_until_parked();
2548
2549        zeta.read_with(cx, |zeta, cx| {
2550            let prediction = zeta
2551                .current_prediction_for_buffer(&buffer1, &project, cx)
2552                .unwrap();
2553            assert_matches!(prediction, BufferEditPrediction::Local { .. });
2554        });
2555
2556        // Context refresh
2557        let refresh_task = zeta.update(cx, |zeta, cx| {
2558            zeta.refresh_context(project.clone(), buffer1.clone(), position, cx)
2559        });
2560        let (_request, respond_tx) = req_rx.next().await.unwrap();
2561        respond_tx
2562            .send(open_ai::Response {
2563                id: Uuid::new_v4().to_string(),
2564                object: "response".into(),
2565                created: 0,
2566                model: "model".into(),
2567                choices: vec![open_ai::Choice {
2568                    index: 0,
2569                    message: open_ai::RequestMessage::Assistant {
2570                        content: None,
2571                        tool_calls: vec![open_ai::ToolCall {
2572                            id: "search".into(),
2573                            content: open_ai::ToolCallContent::Function {
2574                                function: open_ai::FunctionContent {
2575                                    name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME
2576                                        .to_string(),
2577                                    arguments: serde_json::to_string(&SearchToolInput {
2578                                        queries: Box::new([SearchToolQuery {
2579                                            glob: "root/2.txt".to_string(),
2580                                            syntax_node: vec![],
2581                                            content: Some(".".into()),
2582                                        }]),
2583                                    })
2584                                    .unwrap(),
2585                                },
2586                            },
2587                        }],
2588                    },
2589                    finish_reason: None,
2590                }],
2591                usage: Usage {
2592                    prompt_tokens: 0,
2593                    completion_tokens: 0,
2594                    total_tokens: 0,
2595                },
2596            })
2597            .unwrap();
2598        refresh_task.await.unwrap();
2599
2600        zeta.update(cx, |zeta, _cx| {
2601            zeta.discard_current_prediction(&project);
2602        });
2603
2604        // Prediction for another file
2605        zeta.update(cx, |zeta, cx| {
2606            zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
2607        });
2608        let (_request, respond_tx) = req_rx.next().await.unwrap();
2609        respond_tx
2610            .send(model_response(indoc! {r#"
2611                --- a/root/2.txt
2612                +++ b/root/2.txt
2613                 Hola!
2614                -Como
2615                +Como estas?
2616                 Adios
2617            "#}))
2618            .unwrap();
2619        cx.run_until_parked();
2620
2621        zeta.read_with(cx, |zeta, cx| {
2622            let prediction = zeta
2623                .current_prediction_for_buffer(&buffer1, &project, cx)
2624                .unwrap();
2625            assert_matches!(
2626                prediction,
2627                BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
2628            );
2629        });
2630
2631        let buffer2 = project
2632            .update(cx, |project, cx| {
2633                let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
2634                project.open_buffer(path, cx)
2635            })
2636            .await
2637            .unwrap();
2638
2639        zeta.read_with(cx, |zeta, cx| {
2640            let prediction = zeta
2641                .current_prediction_for_buffer(&buffer2, &project, cx)
2642                .unwrap();
2643            assert_matches!(prediction, BufferEditPrediction::Local { .. });
2644        });
2645    }
2646
2647    #[gpui::test]
2648    async fn test_simple_request(cx: &mut TestAppContext) {
2649        let (zeta, mut req_rx) = init_test(cx);
2650        let fs = FakeFs::new(cx.executor());
2651        fs.insert_tree(
2652            "/root",
2653            json!({
2654                "foo.md":  "Hello!\nHow\nBye\n"
2655            }),
2656        )
2657        .await;
2658        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2659
2660        let buffer = project
2661            .update(cx, |project, cx| {
2662                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2663                project.open_buffer(path, cx)
2664            })
2665            .await
2666            .unwrap();
2667        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2668        let position = snapshot.anchor_before(language::Point::new(1, 3));
2669
2670        let prediction_task = zeta.update(cx, |zeta, cx| {
2671            zeta.request_prediction(&project, &buffer, position, cx)
2672        });
2673
2674        let (_, respond_tx) = req_rx.next().await.unwrap();
2675
2676        // TODO Put back when we have a structured request again
2677        // assert_eq!(
2678        //     request.excerpt_path.as_ref(),
2679        //     Path::new(path!("root/foo.md"))
2680        // );
2681        // assert_eq!(
2682        //     request.cursor_point,
2683        //     Point {
2684        //         line: Line(1),
2685        //         column: 3
2686        //     }
2687        // );
2688
2689        respond_tx
2690            .send(model_response(indoc! { r"
2691                --- a/root/foo.md
2692                +++ b/root/foo.md
2693                @@ ... @@
2694                 Hello!
2695                -How
2696                +How are you?
2697                 Bye
2698            "}))
2699            .unwrap();
2700
2701        let prediction = prediction_task.await.unwrap().unwrap();
2702
2703        assert_eq!(prediction.edits.len(), 1);
2704        assert_eq!(
2705            prediction.edits[0].0.to_point(&snapshot).start,
2706            language::Point::new(1, 3)
2707        );
2708        assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
2709    }
2710
2711    #[gpui::test]
2712    async fn test_request_events(cx: &mut TestAppContext) {
2713        let (zeta, mut req_rx) = init_test(cx);
2714        let fs = FakeFs::new(cx.executor());
2715        fs.insert_tree(
2716            "/root",
2717            json!({
2718                "foo.md": "Hello!\n\nBye\n"
2719            }),
2720        )
2721        .await;
2722        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2723
2724        let buffer = project
2725            .update(cx, |project, cx| {
2726                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2727                project.open_buffer(path, cx)
2728            })
2729            .await
2730            .unwrap();
2731
2732        zeta.update(cx, |zeta, cx| {
2733            zeta.register_buffer(&buffer, &project, cx);
2734        });
2735
2736        buffer.update(cx, |buffer, cx| {
2737            buffer.edit(vec![(7..7, "How")], None, cx);
2738        });
2739
2740        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2741        let position = snapshot.anchor_before(language::Point::new(1, 3));
2742
2743        let prediction_task = zeta.update(cx, |zeta, cx| {
2744            zeta.request_prediction(&project, &buffer, position, cx)
2745        });
2746
2747        let (request, respond_tx) = req_rx.next().await.unwrap();
2748
2749        let prompt = prompt_from_request(&request);
2750        assert!(
2751            prompt.contains(indoc! {"
2752            --- a/root/foo.md
2753            +++ b/root/foo.md
2754            @@ -1,3 +1,3 @@
2755             Hello!
2756            -
2757            +How
2758             Bye
2759        "}),
2760            "{prompt}"
2761        );
2762
2763        respond_tx
2764            .send(model_response(indoc! {r#"
2765                --- a/root/foo.md
2766                +++ b/root/foo.md
2767                @@ ... @@
2768                 Hello!
2769                -How
2770                +How are you?
2771                 Bye
2772            "#}))
2773            .unwrap();
2774
2775        let prediction = prediction_task.await.unwrap().unwrap();
2776
2777        assert_eq!(prediction.edits.len(), 1);
2778        assert_eq!(
2779            prediction.edits[0].0.to_point(&snapshot).start,
2780            language::Point::new(1, 3)
2781        );
2782        assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
2783    }
2784
2785    // Skipped until we start including diagnostics in prompt
2786    // #[gpui::test]
2787    // async fn test_request_diagnostics(cx: &mut TestAppContext) {
2788    //     let (zeta, mut req_rx) = init_test(cx);
2789    //     let fs = FakeFs::new(cx.executor());
2790    //     fs.insert_tree(
2791    //         "/root",
2792    //         json!({
2793    //             "foo.md": "Hello!\nBye"
2794    //         }),
2795    //     )
2796    //     .await;
2797    //     let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2798
2799    //     let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
2800    //     let diagnostic = lsp::Diagnostic {
2801    //         range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
2802    //         severity: Some(lsp::DiagnosticSeverity::ERROR),
2803    //         message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
2804    //         ..Default::default()
2805    //     };
2806
2807    //     project.update(cx, |project, cx| {
2808    //         project.lsp_store().update(cx, |lsp_store, cx| {
2809    //             // Create some diagnostics
2810    //             lsp_store
2811    //                 .update_diagnostics(
2812    //                     LanguageServerId(0),
2813    //                     lsp::PublishDiagnosticsParams {
2814    //                         uri: path_to_buffer_uri.clone(),
2815    //                         diagnostics: vec![diagnostic],
2816    //                         version: None,
2817    //                     },
2818    //                     None,
2819    //                     language::DiagnosticSourceKind::Pushed,
2820    //                     &[],
2821    //                     cx,
2822    //                 )
2823    //                 .unwrap();
2824    //         });
2825    //     });
2826
2827    //     let buffer = project
2828    //         .update(cx, |project, cx| {
2829    //             let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2830    //             project.open_buffer(path, cx)
2831    //         })
2832    //         .await
2833    //         .unwrap();
2834
2835    //     let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2836    //     let position = snapshot.anchor_before(language::Point::new(0, 0));
2837
2838    //     let _prediction_task = zeta.update(cx, |zeta, cx| {
2839    //         zeta.request_prediction(&project, &buffer, position, cx)
2840    //     });
2841
2842    //     let (request, _respond_tx) = req_rx.next().await.unwrap();
2843
2844    //     assert_eq!(request.diagnostic_groups.len(), 1);
2845    //     let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
2846    //         .unwrap();
2847    //     // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
2848    //     assert_eq!(
2849    //         value,
2850    //         json!({
2851    //             "entries": [{
2852    //                 "range": {
2853    //                     "start": 8,
2854    //                     "end": 10
2855    //                 },
2856    //                 "diagnostic": {
2857    //                     "source": null,
2858    //                     "code": null,
2859    //                     "code_description": null,
2860    //                     "severity": 1,
2861    //                     "message": "\"Hello\" deprecated. Use \"Hi\" instead",
2862    //                     "markdown": null,
2863    //                     "group_id": 0,
2864    //                     "is_primary": true,
2865    //                     "is_disk_based": false,
2866    //                     "is_unnecessary": false,
2867    //                     "source_kind": "Pushed",
2868    //                     "data": null,
2869    //                     "underline": true
2870    //                 }
2871    //             }],
2872    //             "primary_ix": 0
2873    //         })
2874    //     );
2875    // }
2876
2877    fn model_response(text: &str) -> open_ai::Response {
2878        open_ai::Response {
2879            id: Uuid::new_v4().to_string(),
2880            object: "response".into(),
2881            created: 0,
2882            model: "model".into(),
2883            choices: vec![open_ai::Choice {
2884                index: 0,
2885                message: open_ai::RequestMessage::Assistant {
2886                    content: Some(open_ai::MessageContent::Plain(text.to_string())),
2887                    tool_calls: vec![],
2888                },
2889                finish_reason: None,
2890            }],
2891            usage: Usage {
2892                prompt_tokens: 0,
2893                completion_tokens: 0,
2894                total_tokens: 0,
2895            },
2896        }
2897    }
2898
2899    fn prompt_from_request(request: &open_ai::Request) -> &str {
2900        assert_eq!(request.messages.len(), 1);
2901        let open_ai::RequestMessage::User {
2902            content: open_ai::MessageContent::Plain(content),
2903            ..
2904        } = &request.messages[0]
2905        else {
2906            panic!(
2907                "Request does not have single user message of type Plain. {:#?}",
2908                request
2909            );
2910        };
2911        content
2912    }
2913
2914    fn init_test(
2915        cx: &mut TestAppContext,
2916    ) -> (
2917        Entity<Zeta>,
2918        mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender<open_ai::Response>)>,
2919    ) {
2920        cx.update(move |cx| {
2921            let settings_store = SettingsStore::test(cx);
2922            cx.set_global(settings_store);
2923            zlog::init_test();
2924
2925            let (req_tx, req_rx) = mpsc::unbounded();
2926
2927            let http_client = FakeHttpClient::create({
2928                move |req| {
2929                    let uri = req.uri().path().to_string();
2930                    let mut body = req.into_body();
2931                    let req_tx = req_tx.clone();
2932                    async move {
2933                        let resp = match uri.as_str() {
2934                            "/client/llm_tokens" => serde_json::to_string(&json!({
2935                                "token": "test"
2936                            }))
2937                            .unwrap(),
2938                            "/predict_edits/raw" => {
2939                                let mut buf = Vec::new();
2940                                body.read_to_end(&mut buf).await.ok();
2941                                let req = serde_json::from_slice(&buf).unwrap();
2942
2943                                let (res_tx, res_rx) = oneshot::channel();
2944                                req_tx.unbounded_send((req, res_tx)).unwrap();
2945                                serde_json::to_string(&res_rx.await?).unwrap()
2946                            }
2947                            _ => {
2948                                panic!("Unexpected path: {}", uri)
2949                            }
2950                        };
2951
2952                        Ok(Response::builder().body(resp.into()).unwrap())
2953                    }
2954                }
2955            });
2956
2957            let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
2958            client.cloud_client().set_credentials(1, "test".into());
2959
2960            language_model::init(client.clone(), cx);
2961
2962            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2963            let zeta = Zeta::global(&client, &user_store, cx);
2964
2965            (zeta, req_rx)
2966        })
2967    }
2968}