zeta2.rs

   1use anyhow::{Context as _, Result, anyhow, bail};
   2use chrono::TimeDelta;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
   5use cloud_llm_client::{
   6    AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
   7    ZED_VERSION_HEADER_NAME,
   8};
   9use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
  10use cloud_zeta2_prompt::{CURSOR_MARKER, DEFAULT_MAX_PROMPT_BYTES};
  11use collections::HashMap;
  12use edit_prediction_context::{
  13    DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
  14    EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionScoreOptions, Line,
  15    SyntaxIndex, SyntaxIndexState,
  16};
  17use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
  18use futures::AsyncReadExt as _;
  19use futures::channel::{mpsc, oneshot};
  20use gpui::http_client::{AsyncBody, Method};
  21use gpui::{
  22    App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
  23    http_client, prelude::*,
  24};
  25use language::{Anchor, Buffer, DiagnosticSet, LanguageServerId, Point, ToOffset as _, ToPoint};
  26use language::{BufferSnapshot, OffsetRangeExt};
  27use language_model::{LlmApiToken, RefreshLlmTokenListener};
  28use open_ai::FunctionDefinition;
  29use project::{Project, ProjectPath};
  30use release_channel::AppVersion;
  31use serde::de::DeserializeOwned;
  32use std::collections::{VecDeque, hash_map};
  33
  34use std::ops::Range;
  35use std::path::Path;
  36use std::str::FromStr as _;
  37use std::sync::{Arc, LazyLock};
  38use std::time::{Duration, Instant};
  39use std::{env, mem};
  40use thiserror::Error;
  41use util::rel_path::RelPathBuf;
  42use util::{LogErrorFuture, ResultExt as _, TryFutureExt};
  43use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  44
  45pub mod assemble_excerpts;
  46mod prediction;
  47mod provider;
  48pub mod retrieval_search;
  49mod sweep_ai;
  50pub mod udiff;
  51mod xml_edits;
  52
  53use crate::assemble_excerpts::assemble_excerpts;
  54pub use crate::prediction::EditPrediction;
  55pub use crate::prediction::EditPredictionId;
  56pub use provider::ZetaEditPredictionProvider;
  57
  58/// Maximum number of events to track.
  59const EVENT_COUNT_MAX_SWEEP: usize = 6;
  60const EVENT_COUNT_MAX_ZETA: usize = 16;
  61const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
  62
  63pub struct SweepFeatureFlag;
  64
  65impl FeatureFlag for SweepFeatureFlag {
  66    const NAME: &str = "sweep-ai";
  67}
  68pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
  69    max_bytes: 512,
  70    min_bytes: 128,
  71    target_before_cursor_over_total_bytes: 0.5,
  72};
  73
  74pub const DEFAULT_CONTEXT_OPTIONS: ContextMode =
  75    ContextMode::Agentic(DEFAULT_AGENTIC_CONTEXT_OPTIONS);
  76
  77pub const DEFAULT_AGENTIC_CONTEXT_OPTIONS: AgenticContextOptions = AgenticContextOptions {
  78    excerpt: DEFAULT_EXCERPT_OPTIONS,
  79};
  80
  81pub const DEFAULT_SYNTAX_CONTEXT_OPTIONS: EditPredictionContextOptions =
  82    EditPredictionContextOptions {
  83        use_imports: true,
  84        max_retrieved_declarations: 0,
  85        excerpt: DEFAULT_EXCERPT_OPTIONS,
  86        score: EditPredictionScoreOptions {
  87            omit_excerpt_overlaps: true,
  88        },
  89    };
  90
  91pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
  92    context: DEFAULT_CONTEXT_OPTIONS,
  93    max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
  94    max_diagnostic_bytes: 2048,
  95    prompt_format: PromptFormat::DEFAULT,
  96    file_indexing_parallelism: 1,
  97    buffer_change_grouping_interval: Duration::from_secs(1),
  98};
  99
 100static USE_OLLAMA: LazyLock<bool> =
 101    LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
 102static CONTEXT_RETRIEVAL_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
 103    env::var("ZED_ZETA2_CONTEXT_MODEL").unwrap_or(if *USE_OLLAMA {
 104        "qwen3-coder:30b".to_string()
 105    } else {
 106        "yqvev8r3".to_string()
 107    })
 108});
 109static EDIT_PREDICTIONS_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
 110    match env::var("ZED_ZETA2_MODEL").as_deref() {
 111        Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten
 112        Ok(model) => model,
 113        Err(_) if *USE_OLLAMA => "qwen3-coder:30b",
 114        Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten
 115    }
 116    .to_string()
 117});
 118static PREDICT_EDITS_URL: LazyLock<Option<String>> = LazyLock::new(|| {
 119    env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| {
 120        if *USE_OLLAMA {
 121            Some("http://localhost:11434/v1/chat/completions".into())
 122        } else {
 123            None
 124        }
 125    })
 126});
 127
 128pub struct Zeta2FeatureFlag;
 129
 130impl FeatureFlag for Zeta2FeatureFlag {
 131    const NAME: &'static str = "zeta2";
 132
 133    fn enabled_for_staff() -> bool {
 134        false
 135    }
 136}
 137
 138#[derive(Clone)]
 139struct ZetaGlobal(Entity<Zeta>);
 140
 141impl Global for ZetaGlobal {}
 142
 143pub struct Zeta {
 144    client: Arc<Client>,
 145    user_store: Entity<UserStore>,
 146    llm_token: LlmApiToken,
 147    _llm_token_subscription: Subscription,
 148    projects: HashMap<EntityId, ZetaProject>,
 149    options: ZetaOptions,
 150    update_required: bool,
 151    debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
 152    #[cfg(feature = "eval-support")]
 153    eval_cache: Option<Arc<dyn EvalCache>>,
 154    edit_prediction_model: ZetaEditPredictionModel,
 155    sweep_api_token: Option<String>,
 156    sweep_ai_debug_info: Arc<str>,
 157}
 158
 159#[derive(PartialEq, Eq)]
 160pub enum ZetaEditPredictionModel {
 161    ZedCloud,
 162    Sweep,
 163}
 164
 165#[derive(Debug, Clone, PartialEq)]
 166pub struct ZetaOptions {
 167    pub context: ContextMode,
 168    pub max_prompt_bytes: usize,
 169    pub max_diagnostic_bytes: usize,
 170    pub prompt_format: predict_edits_v3::PromptFormat,
 171    pub file_indexing_parallelism: usize,
 172    pub buffer_change_grouping_interval: Duration,
 173}
 174
 175#[derive(Debug, Clone, PartialEq)]
 176pub enum ContextMode {
 177    Agentic(AgenticContextOptions),
 178    Syntax(EditPredictionContextOptions),
 179}
 180
 181#[derive(Debug, Clone, PartialEq)]
 182pub struct AgenticContextOptions {
 183    pub excerpt: EditPredictionExcerptOptions,
 184}
 185
 186impl ContextMode {
 187    pub fn excerpt(&self) -> &EditPredictionExcerptOptions {
 188        match self {
 189            ContextMode::Agentic(options) => &options.excerpt,
 190            ContextMode::Syntax(options) => &options.excerpt,
 191        }
 192    }
 193}
 194
 195#[derive(Debug)]
 196pub enum ZetaDebugInfo {
 197    ContextRetrievalStarted(ZetaContextRetrievalStartedDebugInfo),
 198    SearchQueriesGenerated(ZetaSearchQueryDebugInfo),
 199    SearchQueriesExecuted(ZetaContextRetrievalDebugInfo),
 200    ContextRetrievalFinished(ZetaContextRetrievalDebugInfo),
 201    EditPredictionRequested(ZetaEditPredictionDebugInfo),
 202}
 203
 204#[derive(Debug)]
 205pub struct ZetaContextRetrievalStartedDebugInfo {
 206    pub project: Entity<Project>,
 207    pub timestamp: Instant,
 208    pub search_prompt: String,
 209}
 210
 211#[derive(Debug)]
 212pub struct ZetaContextRetrievalDebugInfo {
 213    pub project: Entity<Project>,
 214    pub timestamp: Instant,
 215}
 216
 217#[derive(Debug)]
 218pub struct ZetaEditPredictionDebugInfo {
 219    pub request: predict_edits_v3::PredictEditsRequest,
 220    pub retrieval_time: TimeDelta,
 221    pub buffer: WeakEntity<Buffer>,
 222    pub position: language::Anchor,
 223    pub local_prompt: Result<String, String>,
 224    pub response_rx: oneshot::Receiver<(Result<open_ai::Response, String>, TimeDelta)>,
 225}
 226
 227#[derive(Debug)]
 228pub struct ZetaSearchQueryDebugInfo {
 229    pub project: Entity<Project>,
 230    pub timestamp: Instant,
 231    pub search_queries: Vec<SearchToolQuery>,
 232}
 233
 234pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
 235
 236struct ZetaProject {
 237    syntax_index: Option<Entity<SyntaxIndex>>,
 238    events: VecDeque<Event>,
 239    recent_paths: VecDeque<ProjectPath>,
 240    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 241    current_prediction: Option<CurrentEditPrediction>,
 242    context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
 243    refresh_context_task: Option<LogErrorFuture<Task<Result<()>>>>,
 244    refresh_context_debounce_task: Option<Task<Option<()>>>,
 245    refresh_context_timestamp: Option<Instant>,
 246    _subscription: gpui::Subscription,
 247}
 248
 249#[derive(Debug, Clone)]
 250struct CurrentEditPrediction {
 251    pub requested_by_buffer_id: EntityId,
 252    pub prediction: EditPrediction,
 253}
 254
 255impl CurrentEditPrediction {
 256    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 257        let Some(new_edits) = self
 258            .prediction
 259            .interpolate(&self.prediction.buffer.read(cx))
 260        else {
 261            return false;
 262        };
 263
 264        if self.prediction.buffer != old_prediction.prediction.buffer {
 265            return true;
 266        }
 267
 268        let Some(old_edits) = old_prediction
 269            .prediction
 270            .interpolate(&old_prediction.prediction.buffer.read(cx))
 271        else {
 272            return true;
 273        };
 274
 275        // This reduces the occurrence of UI thrash from replacing edits
 276        //
 277        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 278        if self.requested_by_buffer_id == self.prediction.buffer.entity_id()
 279            && self.requested_by_buffer_id == old_prediction.prediction.buffer.entity_id()
 280            && old_edits.len() == 1
 281            && new_edits.len() == 1
 282        {
 283            let (old_range, old_text) = &old_edits[0];
 284            let (new_range, new_text) = &new_edits[0];
 285            new_range == old_range && new_text.starts_with(old_text.as_ref())
 286        } else {
 287            true
 288        }
 289    }
 290}
 291
 292/// A prediction from the perspective of a buffer.
 293#[derive(Debug)]
 294enum BufferEditPrediction<'a> {
 295    Local { prediction: &'a EditPrediction },
 296    Jump { prediction: &'a EditPrediction },
 297}
 298
 299struct RegisteredBuffer {
 300    snapshot: BufferSnapshot,
 301    _subscriptions: [gpui::Subscription; 2],
 302}
 303
 304#[derive(Clone)]
 305pub enum Event {
 306    BufferChange {
 307        old_snapshot: BufferSnapshot,
 308        new_snapshot: BufferSnapshot,
 309        end_edit_anchor: Option<Anchor>,
 310        timestamp: Instant,
 311    },
 312}
 313
 314impl Event {
 315    pub fn to_request_event(&self, cx: &App) -> Option<predict_edits_v3::Event> {
 316        match self {
 317            Event::BufferChange {
 318                old_snapshot,
 319                new_snapshot,
 320                ..
 321            } => {
 322                let path = new_snapshot.file().map(|f| f.full_path(cx));
 323
 324                let old_path = old_snapshot.file().and_then(|f| {
 325                    let old_path = f.full_path(cx);
 326                    if Some(&old_path) != path.as_ref() {
 327                        Some(old_path)
 328                    } else {
 329                        None
 330                    }
 331                });
 332
 333                // TODO [zeta2] move to bg?
 334                let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
 335
 336                if path == old_path && diff.is_empty() {
 337                    None
 338                } else {
 339                    Some(predict_edits_v3::Event::BufferChange {
 340                        old_path,
 341                        path,
 342                        diff,
 343                        //todo: Actually detect if this edit was predicted or not
 344                        predicted: false,
 345                    })
 346                }
 347            }
 348        }
 349    }
 350
 351    pub fn project_path(&self, cx: &App) -> Option<project::ProjectPath> {
 352        match self {
 353            Event::BufferChange { new_snapshot, .. } => new_snapshot
 354                .file()
 355                .map(|f| project::ProjectPath::from_file(f.as_ref(), cx)),
 356        }
 357    }
 358}
 359
 360impl Zeta {
 361    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 362        cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
 363    }
 364
 365    pub fn global(
 366        client: &Arc<Client>,
 367        user_store: &Entity<UserStore>,
 368        cx: &mut App,
 369    ) -> Entity<Self> {
 370        cx.try_global::<ZetaGlobal>()
 371            .map(|global| global.0.clone())
 372            .unwrap_or_else(|| {
 373                let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 374                cx.set_global(ZetaGlobal(zeta.clone()));
 375                zeta
 376            })
 377    }
 378
 379    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 380        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 381
 382        Self {
 383            projects: HashMap::default(),
 384            client,
 385            user_store,
 386            options: DEFAULT_OPTIONS,
 387            llm_token: LlmApiToken::default(),
 388            _llm_token_subscription: cx.subscribe(
 389                &refresh_llm_token_listener,
 390                |this, _listener, _event, cx| {
 391                    let client = this.client.clone();
 392                    let llm_token = this.llm_token.clone();
 393                    cx.spawn(async move |_this, _cx| {
 394                        llm_token.refresh(&client).await?;
 395                        anyhow::Ok(())
 396                    })
 397                    .detach_and_log_err(cx);
 398                },
 399            ),
 400            update_required: false,
 401            debug_tx: None,
 402            #[cfg(feature = "eval-support")]
 403            eval_cache: None,
 404            edit_prediction_model: ZetaEditPredictionModel::ZedCloud,
 405            sweep_api_token: None,
 406            sweep_ai_debug_info: sweep_ai::debug_info(cx),
 407        }
 408    }
 409
 410    pub fn set_edit_prediction_model(&mut self, model: ZetaEditPredictionModel) {
 411        if model == ZetaEditPredictionModel::Sweep {
 412            self.sweep_api_token = std::env::var("SWEEP_AI_TOKEN")
 413                .context("No SWEEP_AI_TOKEN environment variable set")
 414                .log_err();
 415        }
 416        self.edit_prediction_model = model;
 417    }
 418
 419    #[cfg(feature = "eval-support")]
 420    pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
 421        self.eval_cache = Some(cache);
 422    }
 423
 424    pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<ZetaDebugInfo> {
 425        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
 426        self.debug_tx = Some(debug_watch_tx);
 427        debug_watch_rx
 428    }
 429
 430    pub fn options(&self) -> &ZetaOptions {
 431        &self.options
 432    }
 433
 434    pub fn set_options(&mut self, options: ZetaOptions) {
 435        self.options = options;
 436    }
 437
 438    pub fn clear_history(&mut self) {
 439        for zeta_project in self.projects.values_mut() {
 440            zeta_project.events.clear();
 441        }
 442    }
 443
 444    pub fn history_for_project(
 445        &self,
 446        project: &Entity<Project>,
 447    ) -> impl DoubleEndedIterator<Item = &Event> {
 448        self.projects
 449            .get(&project.entity_id())
 450            .map(|project| project.events.iter())
 451            .into_iter()
 452            .flatten()
 453    }
 454
 455    pub fn context_for_project(
 456        &self,
 457        project: &Entity<Project>,
 458    ) -> impl Iterator<Item = (Entity<Buffer>, &[Range<Anchor>])> {
 459        self.projects
 460            .get(&project.entity_id())
 461            .and_then(|project| {
 462                Some(
 463                    project
 464                        .context
 465                        .as_ref()?
 466                        .iter()
 467                        .map(|(buffer, ranges)| (buffer.clone(), ranges.as_slice())),
 468                )
 469            })
 470            .into_iter()
 471            .flatten()
 472    }
 473
 474    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 475        self.user_store.read(cx).edit_prediction_usage()
 476    }
 477
 478    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 479        self.get_or_init_zeta_project(project, cx);
 480    }
 481
 482    pub fn register_buffer(
 483        &mut self,
 484        buffer: &Entity<Buffer>,
 485        project: &Entity<Project>,
 486        cx: &mut Context<Self>,
 487    ) {
 488        let zeta_project = self.get_or_init_zeta_project(project, cx);
 489        Self::register_buffer_impl(zeta_project, buffer, project, cx);
 490    }
 491
 492    fn get_or_init_zeta_project(
 493        &mut self,
 494        project: &Entity<Project>,
 495        cx: &mut Context<Self>,
 496    ) -> &mut ZetaProject {
 497        self.projects
 498            .entry(project.entity_id())
 499            .or_insert_with(|| ZetaProject {
 500                syntax_index: if let ContextMode::Syntax(_) = &self.options.context {
 501                    Some(cx.new(|cx| {
 502                        SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
 503                    }))
 504                } else {
 505                    None
 506                },
 507                events: VecDeque::new(),
 508                recent_paths: VecDeque::new(),
 509                registered_buffers: HashMap::default(),
 510                current_prediction: None,
 511                context: None,
 512                refresh_context_task: None,
 513                refresh_context_debounce_task: None,
 514                refresh_context_timestamp: None,
 515                _subscription: cx.subscribe(&project, |this, project, event, cx| {
 516                    // TODO [zeta2] init with recent paths
 517                    if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
 518                        if let project::Event::ActiveEntryChanged(Some(active_entry_id)) = event {
 519                            let path = project.read(cx).path_for_entry(*active_entry_id, cx);
 520                            if let Some(path) = path {
 521                                if let Some(ix) = zeta_project
 522                                    .recent_paths
 523                                    .iter()
 524                                    .position(|probe| probe == &path)
 525                                {
 526                                    zeta_project.recent_paths.remove(ix);
 527                                }
 528                                zeta_project.recent_paths.push_front(path);
 529                            }
 530                        }
 531                    }
 532                }),
 533            })
 534    }
 535
 536    fn register_buffer_impl<'a>(
 537        zeta_project: &'a mut ZetaProject,
 538        buffer: &Entity<Buffer>,
 539        project: &Entity<Project>,
 540        cx: &mut Context<Self>,
 541    ) -> &'a mut RegisteredBuffer {
 542        let buffer_id = buffer.entity_id();
 543        match zeta_project.registered_buffers.entry(buffer_id) {
 544            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 545            hash_map::Entry::Vacant(entry) => {
 546                let snapshot = buffer.read(cx).snapshot();
 547                let project_entity_id = project.entity_id();
 548                entry.insert(RegisteredBuffer {
 549                    snapshot,
 550                    _subscriptions: [
 551                        cx.subscribe(buffer, {
 552                            let project = project.downgrade();
 553                            move |this, buffer, event, cx| {
 554                                if let language::BufferEvent::Edited = event
 555                                    && let Some(project) = project.upgrade()
 556                                {
 557                                    this.report_changes_for_buffer(&buffer, &project, cx);
 558                                }
 559                            }
 560                        }),
 561                        cx.observe_release(buffer, move |this, _buffer, _cx| {
 562                            let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
 563                            else {
 564                                return;
 565                            };
 566                            zeta_project.registered_buffers.remove(&buffer_id);
 567                        }),
 568                    ],
 569                })
 570            }
 571        }
 572    }
 573
 574    fn report_changes_for_buffer(
 575        &mut self,
 576        buffer: &Entity<Buffer>,
 577        project: &Entity<Project>,
 578        cx: &mut Context<Self>,
 579    ) {
 580        let event_count_max = match self.edit_prediction_model {
 581            ZetaEditPredictionModel::ZedCloud => EVENT_COUNT_MAX_ZETA,
 582            ZetaEditPredictionModel::Sweep => EVENT_COUNT_MAX_SWEEP,
 583        };
 584
 585        let sweep_ai_project = self.get_or_init_zeta_project(project, cx);
 586        let registered_buffer = Self::register_buffer_impl(sweep_ai_project, buffer, project, cx);
 587
 588        let new_snapshot = buffer.read(cx).snapshot();
 589        if new_snapshot.version == registered_buffer.snapshot.version {
 590            return;
 591        }
 592
 593        let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
 594        let end_edit_anchor = new_snapshot
 595            .anchored_edits_since::<Point>(&old_snapshot.version)
 596            .last()
 597            .map(|(_, range)| range.end);
 598        let events = &mut sweep_ai_project.events;
 599
 600        if let Some(Event::BufferChange {
 601            new_snapshot: last_new_snapshot,
 602            end_edit_anchor: last_end_edit_anchor,
 603            ..
 604        }) = events.back_mut()
 605        {
 606            let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
 607                == last_new_snapshot.remote_id()
 608                && old_snapshot.version == last_new_snapshot.version;
 609
 610            let should_coalesce = is_next_snapshot_of_same_buffer
 611                && end_edit_anchor
 612                    .as_ref()
 613                    .zip(last_end_edit_anchor.as_ref())
 614                    .is_some_and(|(a, b)| {
 615                        let a = a.to_point(&new_snapshot);
 616                        let b = b.to_point(&new_snapshot);
 617                        a.row.abs_diff(b.row) <= CHANGE_GROUPING_LINE_SPAN
 618                    });
 619
 620            if should_coalesce {
 621                *last_end_edit_anchor = end_edit_anchor;
 622                *last_new_snapshot = new_snapshot;
 623                return;
 624            }
 625        }
 626
 627        if events.len() >= event_count_max {
 628            events.pop_front();
 629        }
 630
 631        events.push_back(Event::BufferChange {
 632            old_snapshot,
 633            new_snapshot,
 634            end_edit_anchor,
 635            timestamp: Instant::now(),
 636        });
 637    }
 638
 639    fn current_prediction_for_buffer(
 640        &self,
 641        buffer: &Entity<Buffer>,
 642        project: &Entity<Project>,
 643        cx: &App,
 644    ) -> Option<BufferEditPrediction<'_>> {
 645        let project_state = self.projects.get(&project.entity_id())?;
 646
 647        let CurrentEditPrediction {
 648            requested_by_buffer_id,
 649            prediction,
 650        } = project_state.current_prediction.as_ref()?;
 651
 652        if prediction.targets_buffer(buffer.read(cx)) {
 653            Some(BufferEditPrediction::Local { prediction })
 654        } else if *requested_by_buffer_id == buffer.entity_id() {
 655            Some(BufferEditPrediction::Jump { prediction })
 656        } else {
 657            None
 658        }
 659    }
 660
 661    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 662        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
 663            return;
 664        };
 665
 666        let Some(prediction) = project_state.current_prediction.take() else {
 667            return;
 668        };
 669        let request_id = prediction.prediction.id.to_string();
 670
 671        let client = self.client.clone();
 672        let llm_token = self.llm_token.clone();
 673        let app_version = AppVersion::global(cx);
 674        cx.spawn(async move |this, cx| {
 675            let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") {
 676                http_client::Url::parse(&predict_edits_url)?
 677            } else {
 678                client
 679                    .http_client()
 680                    .build_zed_llm_url("/predict_edits/accept", &[])?
 681            };
 682
 683            let response = cx
 684                .background_spawn(Self::send_api_request::<()>(
 685                    move |builder| {
 686                        let req = builder.uri(url.as_ref()).body(
 687                            serde_json::to_string(&AcceptEditPredictionBody {
 688                                request_id: request_id.clone(),
 689                            })?
 690                            .into(),
 691                        );
 692                        Ok(req?)
 693                    },
 694                    client,
 695                    llm_token,
 696                    app_version,
 697                ))
 698                .await;
 699
 700            Self::handle_api_response(&this, response, cx)?;
 701            anyhow::Ok(())
 702        })
 703        .detach_and_log_err(cx);
 704    }
 705
 706    fn discard_current_prediction(&mut self, project: &Entity<Project>) {
 707        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 708            project_state.current_prediction.take();
 709        };
 710    }
 711
 712    pub fn refresh_prediction(
 713        &mut self,
 714        project: &Entity<Project>,
 715        buffer: &Entity<Buffer>,
 716        position: language::Anchor,
 717        cx: &mut Context<Self>,
 718    ) -> Task<Result<()>> {
 719        let request_task = self.request_prediction(project, buffer, position, cx);
 720        let buffer = buffer.clone();
 721        let project = project.clone();
 722
 723        cx.spawn(async move |this, cx| {
 724            if let Some(prediction) = request_task.await? {
 725                this.update(cx, |this, cx| {
 726                    let project_state = this
 727                        .projects
 728                        .get_mut(&project.entity_id())
 729                        .context("Project not found")?;
 730
 731                    let new_prediction = CurrentEditPrediction {
 732                        requested_by_buffer_id: buffer.entity_id(),
 733                        prediction: prediction,
 734                    };
 735
 736                    if project_state
 737                        .current_prediction
 738                        .as_ref()
 739                        .is_none_or(|old_prediction| {
 740                            new_prediction.should_replace_prediction(&old_prediction, cx)
 741                        })
 742                    {
 743                        project_state.current_prediction = Some(new_prediction);
 744                    }
 745                    anyhow::Ok(())
 746                })??;
 747            }
 748            Ok(())
 749        })
 750    }
 751
 752    pub fn request_prediction(
 753        &mut self,
 754        project: &Entity<Project>,
 755        active_buffer: &Entity<Buffer>,
 756        position: language::Anchor,
 757        cx: &mut Context<Self>,
 758    ) -> Task<Result<Option<EditPrediction>>> {
 759        match self.edit_prediction_model {
 760            ZetaEditPredictionModel::ZedCloud => {
 761                self.request_prediction_with_zed_cloud(project, active_buffer, position, cx)
 762            }
 763            ZetaEditPredictionModel::Sweep => {
 764                self.request_prediction_with_sweep(project, active_buffer, position, cx)
 765            }
 766        }
 767    }
 768
 769    fn request_prediction_with_sweep(
 770        &mut self,
 771        project: &Entity<Project>,
 772        active_buffer: &Entity<Buffer>,
 773        position: language::Anchor,
 774        cx: &mut Context<Self>,
 775    ) -> Task<Result<Option<EditPrediction>>> {
 776        let snapshot = active_buffer.read(cx).snapshot();
 777        let debug_info = self.sweep_ai_debug_info.clone();
 778        let Some(api_token) = self.sweep_api_token.clone() else {
 779            return Task::ready(Ok(None));
 780        };
 781        let full_path: Arc<Path> = snapshot
 782            .file()
 783            .map(|file| file.full_path(cx))
 784            .unwrap_or_else(|| "untitled".into())
 785            .into();
 786
 787        let project_file = project::File::from_dyn(snapshot.file());
 788        let repo_name = project_file
 789            .map(|file| file.worktree.read(cx).root_name_str())
 790            .unwrap_or("untitled")
 791            .into();
 792        let offset = position.to_offset(&snapshot);
 793
 794        let project_state = self.get_or_init_zeta_project(project, cx);
 795        let events = project_state.events.clone();
 796        let recent_buffers = project_state.recent_paths.iter().cloned();
 797        let http_client = cx.http_client();
 798
 799        let recent_buffer_snapshots = recent_buffers
 800            .filter_map(|project_path| {
 801                let buffer = project.read(cx).get_open_buffer(&project_path, cx)?;
 802                if active_buffer == &buffer {
 803                    None
 804                } else {
 805                    Some(buffer.read(cx).snapshot())
 806                }
 807            })
 808            .take(3)
 809            .collect::<Vec<_>>();
 810
 811        let result = cx.background_spawn(async move {
 812            let text = snapshot.text();
 813
 814            let mut recent_changes = String::new();
 815            for event in events {
 816                sweep_ai::write_event(event, &mut recent_changes).unwrap();
 817            }
 818
 819            let file_chunks = recent_buffer_snapshots
 820                .into_iter()
 821                .map(|snapshot| {
 822                    let end_point = language::Point::new(30, 0).min(snapshot.max_point());
 823                    sweep_ai::FileChunk {
 824                        content: snapshot
 825                            .text_for_range(language::Point::zero()..end_point)
 826                            .collect(),
 827                        file_path: snapshot
 828                            .file()
 829                            .map(|f| f.path().as_unix_str())
 830                            .unwrap_or("untitled")
 831                            .to_string(),
 832                        start_line: 0,
 833                        end_line: end_point.row as usize,
 834                        timestamp: snapshot.file().and_then(|file| {
 835                            Some(
 836                                file.disk_state()
 837                                    .mtime()?
 838                                    .to_seconds_and_nanos_for_persistence()?
 839                                    .0,
 840                            )
 841                        }),
 842                    }
 843                })
 844                .collect();
 845
 846            let request_body = sweep_ai::AutocompleteRequest {
 847                debug_info,
 848                repo_name,
 849                file_path: full_path.clone(),
 850                file_contents: text.clone(),
 851                original_file_contents: text,
 852                cursor_position: offset,
 853                recent_changes: recent_changes.clone(),
 854                changes_above_cursor: true,
 855                multiple_suggestions: false,
 856                branch: None,
 857                file_chunks,
 858                retrieval_chunks: vec![],
 859                recent_user_actions: vec![],
 860                // TODO
 861                privacy_mode_enabled: false,
 862            };
 863
 864            let mut buf: Vec<u8> = Vec::new();
 865            let writer = brotli::CompressorWriter::new(&mut buf, 4096, 11, 22);
 866            serde_json::to_writer(writer, &request_body)?;
 867            let body: AsyncBody = buf.into();
 868
 869            const SWEEP_API_URL: &str =
 870                "https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
 871
 872            let request = http_client::Request::builder()
 873                .uri(SWEEP_API_URL)
 874                .header("Content-Type", "application/json")
 875                .header("Authorization", format!("Bearer {}", api_token))
 876                .header("Connection", "keep-alive")
 877                .header("Content-Encoding", "br")
 878                .method(Method::POST)
 879                .body(body)?;
 880
 881            let mut response = http_client.send(request).await?;
 882
 883            let mut body: Vec<u8> = Vec::new();
 884            response.body_mut().read_to_end(&mut body).await?;
 885
 886            if !response.status().is_success() {
 887                anyhow::bail!(
 888                    "Request failed with status: {:?}\nBody: {}",
 889                    response.status(),
 890                    String::from_utf8_lossy(&body),
 891                );
 892            };
 893
 894            let response: sweep_ai::AutocompleteResponse = serde_json::from_slice(&body)?;
 895
 896            let old_text = snapshot
 897                .text_for_range(response.start_index..response.end_index)
 898                .collect::<String>();
 899            let edits = language::text_diff(&old_text, &response.completion)
 900                .into_iter()
 901                .map(|(range, text)| {
 902                    (
 903                        snapshot.anchor_after(response.start_index + range.start)
 904                            ..snapshot.anchor_before(response.start_index + range.end),
 905                        text,
 906                    )
 907                })
 908                .collect::<Vec<_>>();
 909
 910            anyhow::Ok((response.autocomplete_id, edits, snapshot))
 911        });
 912
 913        let buffer = active_buffer.clone();
 914
 915        cx.spawn(async move |_, cx| {
 916            let (id, edits, old_snapshot) = result.await?;
 917
 918            if edits.is_empty() {
 919                return anyhow::Ok(None);
 920            }
 921
 922            let Some((edits, new_snapshot, preview_task)) =
 923                buffer.read_with(cx, |buffer, cx| {
 924                    let new_snapshot = buffer.snapshot();
 925
 926                    let edits: Arc<[(Range<Anchor>, Arc<str>)]> =
 927                        edit_prediction::interpolate_edits(&old_snapshot, &new_snapshot, &edits)?
 928                            .into();
 929                    let preview_task = buffer.preview_edits(edits.clone(), cx);
 930
 931                    Some((edits, new_snapshot, preview_task))
 932                })?
 933            else {
 934                return anyhow::Ok(None);
 935            };
 936
 937            let prediction = EditPrediction {
 938                id: EditPredictionId(id.into()),
 939                edits,
 940                snapshot: new_snapshot,
 941                edit_preview: preview_task.await,
 942                buffer,
 943            };
 944
 945            anyhow::Ok(Some(prediction))
 946        })
 947    }
 948
 949    fn request_prediction_with_zed_cloud(
 950        &mut self,
 951        project: &Entity<Project>,
 952        active_buffer: &Entity<Buffer>,
 953        position: language::Anchor,
 954        cx: &mut Context<Self>,
 955    ) -> Task<Result<Option<EditPrediction>>> {
 956        let project_state = self.projects.get(&project.entity_id());
 957
 958        let index_state = project_state.and_then(|state| {
 959            state
 960                .syntax_index
 961                .as_ref()
 962                .map(|syntax_index| syntax_index.read_with(cx, |index, _cx| index.state().clone()))
 963        });
 964        let options = self.options.clone();
 965        let active_snapshot = active_buffer.read(cx).snapshot();
 966        let Some(excerpt_path) = active_snapshot
 967            .file()
 968            .map(|path| -> Arc<Path> { path.full_path(cx).into() })
 969        else {
 970            return Task::ready(Err(anyhow!("No file path for excerpt")));
 971        };
 972        let client = self.client.clone();
 973        let llm_token = self.llm_token.clone();
 974        let app_version = AppVersion::global(cx);
 975        let worktree_snapshots = project
 976            .read(cx)
 977            .worktrees(cx)
 978            .map(|worktree| worktree.read(cx).snapshot())
 979            .collect::<Vec<_>>();
 980        let debug_tx = self.debug_tx.clone();
 981
 982        let events = project_state
 983            .map(|state| {
 984                state
 985                    .events
 986                    .iter()
 987                    .filter_map(|event| event.to_request_event(cx))
 988                    .collect::<Vec<_>>()
 989            })
 990            .unwrap_or_default();
 991
 992        let diagnostics = active_snapshot.diagnostic_sets().clone();
 993
 994        let parent_abs_path =
 995            project::File::from_dyn(active_buffer.read(cx).file()).and_then(|f| {
 996                let mut path = f.worktree.read(cx).absolutize(&f.path);
 997                if path.pop() { Some(path) } else { None }
 998            });
 999
1000        // TODO data collection
1001        let can_collect_data = cx.is_staff();
1002
1003        let empty_context_files = HashMap::default();
1004        let context_files = project_state
1005            .and_then(|project_state| project_state.context.as_ref())
1006            .unwrap_or(&empty_context_files);
1007
1008        #[cfg(feature = "eval-support")]
1009        let parsed_fut = futures::future::join_all(
1010            context_files
1011                .keys()
1012                .map(|buffer| buffer.read(cx).parsing_idle()),
1013        );
1014
1015        let mut included_files = context_files
1016            .iter()
1017            .filter_map(|(buffer_entity, ranges)| {
1018                let buffer = buffer_entity.read(cx);
1019                Some((
1020                    buffer_entity.clone(),
1021                    buffer.snapshot(),
1022                    buffer.file()?.full_path(cx).into(),
1023                    ranges.clone(),
1024                ))
1025            })
1026            .collect::<Vec<_>>();
1027
1028        included_files.sort_by(|(_, _, path_a, ranges_a), (_, _, path_b, ranges_b)| {
1029            (path_a, ranges_a.len()).cmp(&(path_b, ranges_b.len()))
1030        });
1031
1032        #[cfg(feature = "eval-support")]
1033        let eval_cache = self.eval_cache.clone();
1034
1035        let request_task = cx.background_spawn({
1036            let active_buffer = active_buffer.clone();
1037            async move {
1038                #[cfg(feature = "eval-support")]
1039                parsed_fut.await;
1040
1041                let index_state = if let Some(index_state) = index_state {
1042                    Some(index_state.lock_owned().await)
1043                } else {
1044                    None
1045                };
1046
1047                let cursor_offset = position.to_offset(&active_snapshot);
1048                let cursor_point = cursor_offset.to_point(&active_snapshot);
1049
1050                let before_retrieval = chrono::Utc::now();
1051
1052                let (diagnostic_groups, diagnostic_groups_truncated) =
1053                    Self::gather_nearby_diagnostics(
1054                        cursor_offset,
1055                        &diagnostics,
1056                        &active_snapshot,
1057                        options.max_diagnostic_bytes,
1058                    );
1059
1060                let cloud_request = match options.context {
1061                    ContextMode::Agentic(context_options) => {
1062                        let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
1063                            cursor_point,
1064                            &active_snapshot,
1065                            &context_options.excerpt,
1066                            index_state.as_deref(),
1067                        ) else {
1068                            return Ok((None, None));
1069                        };
1070
1071                        let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start)
1072                            ..active_snapshot.anchor_before(excerpt.range.end);
1073
1074                        if let Some(buffer_ix) =
1075                            included_files.iter().position(|(_, snapshot, _, _)| {
1076                                snapshot.remote_id() == active_snapshot.remote_id()
1077                            })
1078                        {
1079                            let (_, buffer, _, ranges) = &mut included_files[buffer_ix];
1080                            ranges.push(excerpt_anchor_range);
1081                            retrieval_search::merge_anchor_ranges(ranges, buffer);
1082                            let last_ix = included_files.len() - 1;
1083                            included_files.swap(buffer_ix, last_ix);
1084                        } else {
1085                            included_files.push((
1086                                active_buffer.clone(),
1087                                active_snapshot.clone(),
1088                                excerpt_path.clone(),
1089                                vec![excerpt_anchor_range],
1090                            ));
1091                        }
1092
1093                        let included_files = included_files
1094                            .iter()
1095                            .map(|(_, snapshot, path, ranges)| {
1096                                let ranges = ranges
1097                                    .iter()
1098                                    .map(|range| {
1099                                        let point_range = range.to_point(&snapshot);
1100                                        Line(point_range.start.row)..Line(point_range.end.row)
1101                                    })
1102                                    .collect::<Vec<_>>();
1103                                let excerpts = assemble_excerpts(&snapshot, ranges);
1104                                predict_edits_v3::IncludedFile {
1105                                    path: path.clone(),
1106                                    max_row: Line(snapshot.max_point().row),
1107                                    excerpts,
1108                                }
1109                            })
1110                            .collect::<Vec<_>>();
1111
1112                        predict_edits_v3::PredictEditsRequest {
1113                            excerpt_path,
1114                            excerpt: String::new(),
1115                            excerpt_line_range: Line(0)..Line(0),
1116                            excerpt_range: 0..0,
1117                            cursor_point: predict_edits_v3::Point {
1118                                line: predict_edits_v3::Line(cursor_point.row),
1119                                column: cursor_point.column,
1120                            },
1121                            included_files,
1122                            referenced_declarations: vec![],
1123                            events,
1124                            can_collect_data,
1125                            diagnostic_groups,
1126                            diagnostic_groups_truncated,
1127                            debug_info: debug_tx.is_some(),
1128                            prompt_max_bytes: Some(options.max_prompt_bytes),
1129                            prompt_format: options.prompt_format,
1130                            // TODO [zeta2]
1131                            signatures: vec![],
1132                            excerpt_parent: None,
1133                            git_info: None,
1134                        }
1135                    }
1136                    ContextMode::Syntax(context_options) => {
1137                        let Some(context) = EditPredictionContext::gather_context(
1138                            cursor_point,
1139                            &active_snapshot,
1140                            parent_abs_path.as_deref(),
1141                            &context_options,
1142                            index_state.as_deref(),
1143                        ) else {
1144                            return Ok((None, None));
1145                        };
1146
1147                        make_syntax_context_cloud_request(
1148                            excerpt_path,
1149                            context,
1150                            events,
1151                            can_collect_data,
1152                            diagnostic_groups,
1153                            diagnostic_groups_truncated,
1154                            None,
1155                            debug_tx.is_some(),
1156                            &worktree_snapshots,
1157                            index_state.as_deref(),
1158                            Some(options.max_prompt_bytes),
1159                            options.prompt_format,
1160                        )
1161                    }
1162                };
1163
1164                let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request);
1165
1166                let retrieval_time = chrono::Utc::now() - before_retrieval;
1167
1168                let debug_response_tx = if let Some(debug_tx) = &debug_tx {
1169                    let (response_tx, response_rx) = oneshot::channel();
1170
1171                    debug_tx
1172                        .unbounded_send(ZetaDebugInfo::EditPredictionRequested(
1173                            ZetaEditPredictionDebugInfo {
1174                                request: cloud_request.clone(),
1175                                retrieval_time,
1176                                buffer: active_buffer.downgrade(),
1177                                local_prompt: match prompt_result.as_ref() {
1178                                    Ok((prompt, _)) => Ok(prompt.clone()),
1179                                    Err(err) => Err(err.to_string()),
1180                                },
1181                                position,
1182                                response_rx,
1183                            },
1184                        ))
1185                        .ok();
1186                    Some(response_tx)
1187                } else {
1188                    None
1189                };
1190
1191                if cfg!(debug_assertions) && env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
1192                    if let Some(debug_response_tx) = debug_response_tx {
1193                        debug_response_tx
1194                            .send((Err("Request skipped".to_string()), TimeDelta::zero()))
1195                            .ok();
1196                    }
1197                    anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
1198                }
1199
1200                let (prompt, _) = prompt_result?;
1201                let request = open_ai::Request {
1202                    model: EDIT_PREDICTIONS_MODEL_ID.clone(),
1203                    messages: vec![open_ai::RequestMessage::User {
1204                        content: open_ai::MessageContent::Plain(prompt),
1205                    }],
1206                    stream: false,
1207                    max_completion_tokens: None,
1208                    stop: Default::default(),
1209                    temperature: 0.7,
1210                    tool_choice: None,
1211                    parallel_tool_calls: None,
1212                    tools: vec![],
1213                    prompt_cache_key: None,
1214                    reasoning_effort: None,
1215                };
1216
1217                log::trace!("Sending edit prediction request");
1218
1219                let before_request = chrono::Utc::now();
1220                let response = Self::send_raw_llm_request(
1221                    request,
1222                    client,
1223                    llm_token,
1224                    app_version,
1225                    #[cfg(feature = "eval-support")]
1226                    eval_cache,
1227                    #[cfg(feature = "eval-support")]
1228                    EvalCacheEntryKind::Prediction,
1229                )
1230                .await;
1231                let request_time = chrono::Utc::now() - before_request;
1232
1233                log::trace!("Got edit prediction response");
1234
1235                if let Some(debug_response_tx) = debug_response_tx {
1236                    debug_response_tx
1237                        .send((
1238                            response
1239                                .as_ref()
1240                                .map_err(|err| err.to_string())
1241                                .map(|response| response.0.clone()),
1242                            request_time,
1243                        ))
1244                        .ok();
1245                }
1246
1247                let (res, usage) = response?;
1248                let request_id = EditPredictionId(res.id.clone().into());
1249                let Some(mut output_text) = text_from_response(res) else {
1250                    return Ok((None, usage));
1251                };
1252
1253                if output_text.contains(CURSOR_MARKER) {
1254                    log::trace!("Stripping out {CURSOR_MARKER} from response");
1255                    output_text = output_text.replace(CURSOR_MARKER, "");
1256                }
1257
1258                let get_buffer_from_context = |path: &Path| {
1259                    included_files
1260                        .iter()
1261                        .find_map(|(_, buffer, probe_path, ranges)| {
1262                            if probe_path.as_ref() == path {
1263                                Some((buffer, ranges.as_slice()))
1264                            } else {
1265                                None
1266                            }
1267                        })
1268                };
1269
1270                let (edited_buffer_snapshot, edits) = match options.prompt_format {
1271                    PromptFormat::NumLinesUniDiff => {
1272                        // TODO: Implement parsing of multi-file diffs
1273                        crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
1274                    }
1275                    PromptFormat::Minimal | PromptFormat::MinimalQwen => {
1276                        if output_text.contains("--- a/\n+++ b/\nNo edits") {
1277                            let edits = vec![];
1278                            (&active_snapshot, edits)
1279                        } else {
1280                            crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
1281                        }
1282                    }
1283                    PromptFormat::OldTextNewText => {
1284                        crate::xml_edits::parse_xml_edits(&output_text, get_buffer_from_context)
1285                            .await?
1286                    }
1287                    _ => {
1288                        bail!("unsupported prompt format {}", options.prompt_format)
1289                    }
1290                };
1291
1292                let edited_buffer = included_files
1293                    .iter()
1294                    .find_map(|(buffer, snapshot, _, _)| {
1295                        if snapshot.remote_id() == edited_buffer_snapshot.remote_id() {
1296                            Some(buffer.clone())
1297                        } else {
1298                            None
1299                        }
1300                    })
1301                    .context("Failed to find buffer in included_buffers")?;
1302
1303                anyhow::Ok((
1304                    Some((
1305                        request_id,
1306                        edited_buffer,
1307                        edited_buffer_snapshot.clone(),
1308                        edits,
1309                    )),
1310                    usage,
1311                ))
1312            }
1313        });
1314
1315        cx.spawn({
1316            async move |this, cx| {
1317                let Some((id, edited_buffer, edited_buffer_snapshot, edits)) =
1318                    Self::handle_api_response(&this, request_task.await, cx)?
1319                else {
1320                    return Ok(None);
1321                };
1322
1323                // TODO telemetry: duration, etc
1324                Ok(
1325                    EditPrediction::new(id, &edited_buffer, &edited_buffer_snapshot, edits, cx)
1326                        .await,
1327                )
1328            }
1329        })
1330    }
1331
1332    async fn send_raw_llm_request(
1333        request: open_ai::Request,
1334        client: Arc<Client>,
1335        llm_token: LlmApiToken,
1336        app_version: SemanticVersion,
1337        #[cfg(feature = "eval-support")] eval_cache: Option<Arc<dyn EvalCache>>,
1338        #[cfg(feature = "eval-support")] eval_cache_kind: EvalCacheEntryKind,
1339    ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
1340        let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
1341            http_client::Url::parse(&predict_edits_url)?
1342        } else {
1343            client
1344                .http_client()
1345                .build_zed_llm_url("/predict_edits/raw", &[])?
1346        };
1347
1348        #[cfg(feature = "eval-support")]
1349        let cache_key = if let Some(cache) = eval_cache {
1350            use collections::FxHasher;
1351            use std::hash::{Hash, Hasher};
1352
1353            let mut hasher = FxHasher::default();
1354            url.hash(&mut hasher);
1355            let request_str = serde_json::to_string_pretty(&request)?;
1356            request_str.hash(&mut hasher);
1357            let hash = hasher.finish();
1358
1359            let key = (eval_cache_kind, hash);
1360            if let Some(response_str) = cache.read(key) {
1361                return Ok((serde_json::from_str(&response_str)?, None));
1362            }
1363
1364            Some((cache, request_str, key))
1365        } else {
1366            None
1367        };
1368
1369        let (response, usage) = Self::send_api_request(
1370            |builder| {
1371                let req = builder
1372                    .uri(url.as_ref())
1373                    .body(serde_json::to_string(&request)?.into());
1374                Ok(req?)
1375            },
1376            client,
1377            llm_token,
1378            app_version,
1379        )
1380        .await?;
1381
1382        #[cfg(feature = "eval-support")]
1383        if let Some((cache, request, key)) = cache_key {
1384            cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
1385        }
1386
1387        Ok((response, usage))
1388    }
1389
1390    fn handle_api_response<T>(
1391        this: &WeakEntity<Self>,
1392        response: Result<(T, Option<EditPredictionUsage>)>,
1393        cx: &mut gpui::AsyncApp,
1394    ) -> Result<T> {
1395        match response {
1396            Ok((data, usage)) => {
1397                if let Some(usage) = usage {
1398                    this.update(cx, |this, cx| {
1399                        this.user_store.update(cx, |user_store, cx| {
1400                            user_store.update_edit_prediction_usage(usage, cx);
1401                        });
1402                    })
1403                    .ok();
1404                }
1405                Ok(data)
1406            }
1407            Err(err) => {
1408                if err.is::<ZedUpdateRequiredError>() {
1409                    cx.update(|cx| {
1410                        this.update(cx, |this, _cx| {
1411                            this.update_required = true;
1412                        })
1413                        .ok();
1414
1415                        let error_message: SharedString = err.to_string().into();
1416                        show_app_notification(
1417                            NotificationId::unique::<ZedUpdateRequiredError>(),
1418                            cx,
1419                            move |cx| {
1420                                cx.new(|cx| {
1421                                    ErrorMessagePrompt::new(error_message.clone(), cx)
1422                                        .with_link_button("Update Zed", "https://zed.dev/releases")
1423                                })
1424                            },
1425                        );
1426                    })
1427                    .ok();
1428                }
1429                Err(err)
1430            }
1431        }
1432    }
1433
1434    async fn send_api_request<Res>(
1435        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1436        client: Arc<Client>,
1437        llm_token: LlmApiToken,
1438        app_version: SemanticVersion,
1439    ) -> Result<(Res, Option<EditPredictionUsage>)>
1440    where
1441        Res: DeserializeOwned,
1442    {
1443        let http_client = client.http_client();
1444        let mut token = llm_token.acquire(&client).await?;
1445        let mut did_retry = false;
1446
1447        loop {
1448            let request_builder = http_client::Request::builder().method(Method::POST);
1449
1450            let request = build(
1451                request_builder
1452                    .header("Content-Type", "application/json")
1453                    .header("Authorization", format!("Bearer {}", token))
1454                    .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
1455            )?;
1456
1457            let mut response = http_client.send(request).await?;
1458
1459            if let Some(minimum_required_version) = response
1460                .headers()
1461                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
1462                .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
1463            {
1464                anyhow::ensure!(
1465                    app_version >= minimum_required_version,
1466                    ZedUpdateRequiredError {
1467                        minimum_version: minimum_required_version
1468                    }
1469                );
1470            }
1471
1472            if response.status().is_success() {
1473                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
1474
1475                let mut body = Vec::new();
1476                response.body_mut().read_to_end(&mut body).await?;
1477                return Ok((serde_json::from_slice(&body)?, usage));
1478            } else if !did_retry
1479                && response
1480                    .headers()
1481                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1482                    .is_some()
1483            {
1484                did_retry = true;
1485                token = llm_token.refresh(&client).await?;
1486            } else {
1487                let mut body = String::new();
1488                response.body_mut().read_to_string(&mut body).await?;
1489                anyhow::bail!(
1490                    "Request failed with status: {:?}\nBody: {}",
1491                    response.status(),
1492                    body
1493                );
1494            }
1495        }
1496    }
1497
1498    pub const CONTEXT_RETRIEVAL_IDLE_DURATION: Duration = Duration::from_secs(10);
1499    pub const CONTEXT_RETRIEVAL_DEBOUNCE_DURATION: Duration = Duration::from_secs(3);
1500
1501    // Refresh the related excerpts when the user just beguns editing after
1502    // an idle period, and after they pause editing.
1503    fn refresh_context_if_needed(
1504        &mut self,
1505        project: &Entity<Project>,
1506        buffer: &Entity<language::Buffer>,
1507        cursor_position: language::Anchor,
1508        cx: &mut Context<Self>,
1509    ) {
1510        if !matches!(&self.options().context, ContextMode::Agentic { .. }) {
1511            return;
1512        }
1513
1514        let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1515            return;
1516        };
1517
1518        let now = Instant::now();
1519        let was_idle = zeta_project
1520            .refresh_context_timestamp
1521            .map_or(true, |timestamp| {
1522                now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION
1523            });
1524        zeta_project.refresh_context_timestamp = Some(now);
1525        zeta_project.refresh_context_debounce_task = Some(cx.spawn({
1526            let buffer = buffer.clone();
1527            let project = project.clone();
1528            async move |this, cx| {
1529                if was_idle {
1530                    log::debug!("refetching edit prediction context after idle");
1531                } else {
1532                    cx.background_executor()
1533                        .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION)
1534                        .await;
1535                    log::debug!("refetching edit prediction context after pause");
1536                }
1537                this.update(cx, |this, cx| {
1538                    let task = this.refresh_context(project.clone(), buffer, cursor_position, cx);
1539
1540                    if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
1541                        zeta_project.refresh_context_task = Some(task.log_err());
1542                    };
1543                })
1544                .ok()
1545            }
1546        }));
1547    }
1548
1549    // Refresh the related excerpts asynchronously. Ensure the task runs to completion,
1550    // and avoid spawning more than one concurrent task.
1551    pub fn refresh_context(
1552        &mut self,
1553        project: Entity<Project>,
1554        buffer: Entity<language::Buffer>,
1555        cursor_position: language::Anchor,
1556        cx: &mut Context<Self>,
1557    ) -> Task<Result<()>> {
1558        let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
1559            return Task::ready(anyhow::Ok(()));
1560        };
1561
1562        let ContextMode::Agentic(options) = &self.options().context else {
1563            return Task::ready(anyhow::Ok(()));
1564        };
1565
1566        let snapshot = buffer.read(cx).snapshot();
1567        let cursor_point = cursor_position.to_point(&snapshot);
1568        let Some(cursor_excerpt) = EditPredictionExcerpt::select_from_buffer(
1569            cursor_point,
1570            &snapshot,
1571            &options.excerpt,
1572            None,
1573        ) else {
1574            return Task::ready(Ok(()));
1575        };
1576
1577        let app_version = AppVersion::global(cx);
1578        let client = self.client.clone();
1579        let llm_token = self.llm_token.clone();
1580        let debug_tx = self.debug_tx.clone();
1581        let current_file_path: Arc<Path> = snapshot
1582            .file()
1583            .map(|f| f.full_path(cx).into())
1584            .unwrap_or_else(|| Path::new("untitled").into());
1585
1586        let prompt = match cloud_zeta2_prompt::retrieval_prompt::build_prompt(
1587            predict_edits_v3::PlanContextRetrievalRequest {
1588                excerpt: cursor_excerpt.text(&snapshot).body,
1589                excerpt_path: current_file_path,
1590                excerpt_line_range: cursor_excerpt.line_range,
1591                cursor_file_max_row: Line(snapshot.max_point().row),
1592                events: zeta_project
1593                    .events
1594                    .iter()
1595                    .filter_map(|ev| ev.to_request_event(cx))
1596                    .collect(),
1597            },
1598        ) {
1599            Ok(prompt) => prompt,
1600            Err(err) => {
1601                return Task::ready(Err(err));
1602            }
1603        };
1604
1605        if let Some(debug_tx) = &debug_tx {
1606            debug_tx
1607                .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted(
1608                    ZetaContextRetrievalStartedDebugInfo {
1609                        project: project.clone(),
1610                        timestamp: Instant::now(),
1611                        search_prompt: prompt.clone(),
1612                    },
1613                ))
1614                .ok();
1615        }
1616
1617        pub static TOOL_SCHEMA: LazyLock<(serde_json::Value, String)> = LazyLock::new(|| {
1618            let schema = language_model::tool_schema::root_schema_for::<SearchToolInput>(
1619                language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset,
1620            );
1621
1622            let description = schema
1623                .get("description")
1624                .and_then(|description| description.as_str())
1625                .unwrap()
1626                .to_string();
1627
1628            (schema.into(), description)
1629        });
1630
1631        let (tool_schema, tool_description) = TOOL_SCHEMA.clone();
1632
1633        let request = open_ai::Request {
1634            model: CONTEXT_RETRIEVAL_MODEL_ID.clone(),
1635            messages: vec![open_ai::RequestMessage::User {
1636                content: open_ai::MessageContent::Plain(prompt),
1637            }],
1638            stream: false,
1639            max_completion_tokens: None,
1640            stop: Default::default(),
1641            temperature: 0.7,
1642            tool_choice: None,
1643            parallel_tool_calls: None,
1644            tools: vec![open_ai::ToolDefinition::Function {
1645                function: FunctionDefinition {
1646                    name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME.to_string(),
1647                    description: Some(tool_description),
1648                    parameters: Some(tool_schema),
1649                },
1650            }],
1651            prompt_cache_key: None,
1652            reasoning_effort: None,
1653        };
1654
1655        #[cfg(feature = "eval-support")]
1656        let eval_cache = self.eval_cache.clone();
1657
1658        cx.spawn(async move |this, cx| {
1659            log::trace!("Sending search planning request");
1660            let response = Self::send_raw_llm_request(
1661                request,
1662                client,
1663                llm_token,
1664                app_version,
1665                #[cfg(feature = "eval-support")]
1666                eval_cache.clone(),
1667                #[cfg(feature = "eval-support")]
1668                EvalCacheEntryKind::Context,
1669            )
1670            .await;
1671            let mut response = Self::handle_api_response(&this, response, cx)?;
1672            log::trace!("Got search planning response");
1673
1674            let choice = response
1675                .choices
1676                .pop()
1677                .context("No choices in retrieval response")?;
1678            let open_ai::RequestMessage::Assistant {
1679                content: _,
1680                tool_calls,
1681            } = choice.message
1682            else {
1683                anyhow::bail!("Retrieval response didn't include an assistant message");
1684            };
1685
1686            let mut queries: Vec<SearchToolQuery> = Vec::new();
1687            for tool_call in tool_calls {
1688                let open_ai::ToolCallContent::Function { function } = tool_call.content;
1689                if function.name != cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME {
1690                    log::warn!(
1691                        "Context retrieval response tried to call an unknown tool: {}",
1692                        function.name
1693                    );
1694
1695                    continue;
1696                }
1697
1698                let input: SearchToolInput = serde_json::from_str(&function.arguments)
1699                    .with_context(|| format!("invalid search json {}", &function.arguments))?;
1700                queries.extend(input.queries);
1701            }
1702
1703            if let Some(debug_tx) = &debug_tx {
1704                debug_tx
1705                    .unbounded_send(ZetaDebugInfo::SearchQueriesGenerated(
1706                        ZetaSearchQueryDebugInfo {
1707                            project: project.clone(),
1708                            timestamp: Instant::now(),
1709                            search_queries: queries.clone(),
1710                        },
1711                    ))
1712                    .ok();
1713            }
1714
1715            log::trace!("Running retrieval search: {queries:#?}");
1716
1717            let related_excerpts_result = retrieval_search::run_retrieval_searches(
1718                queries,
1719                project.clone(),
1720                #[cfg(feature = "eval-support")]
1721                eval_cache,
1722                cx,
1723            )
1724            .await;
1725
1726            log::trace!("Search queries executed");
1727
1728            if let Some(debug_tx) = &debug_tx {
1729                debug_tx
1730                    .unbounded_send(ZetaDebugInfo::SearchQueriesExecuted(
1731                        ZetaContextRetrievalDebugInfo {
1732                            project: project.clone(),
1733                            timestamp: Instant::now(),
1734                        },
1735                    ))
1736                    .ok();
1737            }
1738
1739            this.update(cx, |this, _cx| {
1740                let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
1741                    return Ok(());
1742                };
1743                zeta_project.refresh_context_task.take();
1744                if let Some(debug_tx) = &this.debug_tx {
1745                    debug_tx
1746                        .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished(
1747                            ZetaContextRetrievalDebugInfo {
1748                                project,
1749                                timestamp: Instant::now(),
1750                            },
1751                        ))
1752                        .ok();
1753                }
1754                match related_excerpts_result {
1755                    Ok(excerpts) => {
1756                        zeta_project.context = Some(excerpts);
1757                        Ok(())
1758                    }
1759                    Err(error) => Err(error),
1760                }
1761            })?
1762        })
1763    }
1764
1765    pub fn set_context(
1766        &mut self,
1767        project: Entity<Project>,
1768        context: HashMap<Entity<Buffer>, Vec<Range<Anchor>>>,
1769    ) {
1770        if let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) {
1771            zeta_project.context = Some(context);
1772        }
1773    }
1774
1775    fn gather_nearby_diagnostics(
1776        cursor_offset: usize,
1777        diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
1778        snapshot: &BufferSnapshot,
1779        max_diagnostics_bytes: usize,
1780    ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
1781        // TODO: Could make this more efficient
1782        let mut diagnostic_groups = Vec::new();
1783        for (language_server_id, diagnostics) in diagnostic_sets {
1784            let mut groups = Vec::new();
1785            diagnostics.groups(*language_server_id, &mut groups, &snapshot);
1786            diagnostic_groups.extend(
1787                groups
1788                    .into_iter()
1789                    .map(|(_, group)| group.resolve::<usize>(&snapshot)),
1790            );
1791        }
1792
1793        // sort by proximity to cursor
1794        diagnostic_groups.sort_by_key(|group| {
1795            let range = &group.entries[group.primary_ix].range;
1796            if range.start >= cursor_offset {
1797                range.start - cursor_offset
1798            } else if cursor_offset >= range.end {
1799                cursor_offset - range.end
1800            } else {
1801                (cursor_offset - range.start).min(range.end - cursor_offset)
1802            }
1803        });
1804
1805        let mut results = Vec::new();
1806        let mut diagnostic_groups_truncated = false;
1807        let mut diagnostics_byte_count = 0;
1808        for group in diagnostic_groups {
1809            let raw_value = serde_json::value::to_raw_value(&group).unwrap();
1810            diagnostics_byte_count += raw_value.get().len();
1811            if diagnostics_byte_count > max_diagnostics_bytes {
1812                diagnostic_groups_truncated = true;
1813                break;
1814            }
1815            results.push(predict_edits_v3::DiagnosticGroup(raw_value));
1816        }
1817
1818        (results, diagnostic_groups_truncated)
1819    }
1820
1821    // TODO: Dedupe with similar code in request_prediction?
1822    pub fn cloud_request_for_zeta_cli(
1823        &mut self,
1824        project: &Entity<Project>,
1825        buffer: &Entity<Buffer>,
1826        position: language::Anchor,
1827        cx: &mut Context<Self>,
1828    ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
1829        let project_state = self.projects.get(&project.entity_id());
1830
1831        let index_state = project_state.and_then(|state| {
1832            state
1833                .syntax_index
1834                .as_ref()
1835                .map(|index| index.read_with(cx, |index, _cx| index.state().clone()))
1836        });
1837        let options = self.options.clone();
1838        let snapshot = buffer.read(cx).snapshot();
1839        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
1840            return Task::ready(Err(anyhow!("No file path for excerpt")));
1841        };
1842        let worktree_snapshots = project
1843            .read(cx)
1844            .worktrees(cx)
1845            .map(|worktree| worktree.read(cx).snapshot())
1846            .collect::<Vec<_>>();
1847
1848        let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
1849            let mut path = f.worktree.read(cx).absolutize(&f.path);
1850            if path.pop() { Some(path) } else { None }
1851        });
1852
1853        cx.background_spawn(async move {
1854            let index_state = if let Some(index_state) = index_state {
1855                Some(index_state.lock_owned().await)
1856            } else {
1857                None
1858            };
1859
1860            let cursor_point = position.to_point(&snapshot);
1861
1862            let debug_info = true;
1863            EditPredictionContext::gather_context(
1864                cursor_point,
1865                &snapshot,
1866                parent_abs_path.as_deref(),
1867                match &options.context {
1868                    ContextMode::Agentic(_) => {
1869                        // TODO
1870                        panic!("Llm mode not supported in zeta cli yet");
1871                    }
1872                    ContextMode::Syntax(edit_prediction_context_options) => {
1873                        edit_prediction_context_options
1874                    }
1875                },
1876                index_state.as_deref(),
1877            )
1878            .context("Failed to select excerpt")
1879            .map(|context| {
1880                make_syntax_context_cloud_request(
1881                    excerpt_path.into(),
1882                    context,
1883                    // TODO pass everything
1884                    Vec::new(),
1885                    false,
1886                    Vec::new(),
1887                    false,
1888                    None,
1889                    debug_info,
1890                    &worktree_snapshots,
1891                    index_state.as_deref(),
1892                    Some(options.max_prompt_bytes),
1893                    options.prompt_format,
1894                )
1895            })
1896        })
1897    }
1898
1899    pub fn wait_for_initial_indexing(
1900        &mut self,
1901        project: &Entity<Project>,
1902        cx: &mut Context<Self>,
1903    ) -> Task<Result<()>> {
1904        let zeta_project = self.get_or_init_zeta_project(project, cx);
1905        if let Some(syntax_index) = &zeta_project.syntax_index {
1906            syntax_index.read(cx).wait_for_initial_file_indexing(cx)
1907        } else {
1908            Task::ready(Ok(()))
1909        }
1910    }
1911}
1912
1913pub fn text_from_response(mut res: open_ai::Response) -> Option<String> {
1914    let choice = res.choices.pop()?;
1915    let output_text = match choice.message {
1916        open_ai::RequestMessage::Assistant {
1917            content: Some(open_ai::MessageContent::Plain(content)),
1918            ..
1919        } => content,
1920        open_ai::RequestMessage::Assistant {
1921            content: Some(open_ai::MessageContent::Multipart(mut content)),
1922            ..
1923        } => {
1924            if content.is_empty() {
1925                log::error!("No output from Baseten completion response");
1926                return None;
1927            }
1928
1929            match content.remove(0) {
1930                open_ai::MessagePart::Text { text } => text,
1931                open_ai::MessagePart::Image { .. } => {
1932                    log::error!("Expected text, got an image");
1933                    return None;
1934                }
1935            }
1936        }
1937        _ => {
1938            log::error!("Invalid response message: {:?}", choice.message);
1939            return None;
1940        }
1941    };
1942    Some(output_text)
1943}
1944
1945#[derive(Error, Debug)]
1946#[error(
1947    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1948)]
1949pub struct ZedUpdateRequiredError {
1950    minimum_version: SemanticVersion,
1951}
1952
1953fn make_syntax_context_cloud_request(
1954    excerpt_path: Arc<Path>,
1955    context: EditPredictionContext,
1956    events: Vec<predict_edits_v3::Event>,
1957    can_collect_data: bool,
1958    diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
1959    diagnostic_groups_truncated: bool,
1960    git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
1961    debug_info: bool,
1962    worktrees: &Vec<worktree::Snapshot>,
1963    index_state: Option<&SyntaxIndexState>,
1964    prompt_max_bytes: Option<usize>,
1965    prompt_format: PromptFormat,
1966) -> predict_edits_v3::PredictEditsRequest {
1967    let mut signatures = Vec::new();
1968    let mut declaration_to_signature_index = HashMap::default();
1969    let mut referenced_declarations = Vec::new();
1970
1971    for snippet in context.declarations {
1972        let project_entry_id = snippet.declaration.project_entry_id();
1973        let Some(path) = worktrees.iter().find_map(|worktree| {
1974            worktree.entry_for_id(project_entry_id).map(|entry| {
1975                let mut full_path = RelPathBuf::new();
1976                full_path.push(worktree.root_name());
1977                full_path.push(&entry.path);
1978                full_path
1979            })
1980        }) else {
1981            continue;
1982        };
1983
1984        let parent_index = index_state.and_then(|index_state| {
1985            snippet.declaration.parent().and_then(|parent| {
1986                add_signature(
1987                    parent,
1988                    &mut declaration_to_signature_index,
1989                    &mut signatures,
1990                    index_state,
1991                )
1992            })
1993        });
1994
1995        let (text, text_is_truncated) = snippet.declaration.item_text();
1996        referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
1997            path: path.as_std_path().into(),
1998            text: text.into(),
1999            range: snippet.declaration.item_line_range(),
2000            text_is_truncated,
2001            signature_range: snippet.declaration.signature_range_in_item_text(),
2002            parent_index,
2003            signature_score: snippet.score(DeclarationStyle::Signature),
2004            declaration_score: snippet.score(DeclarationStyle::Declaration),
2005            score_components: snippet.components,
2006        });
2007    }
2008
2009    let excerpt_parent = index_state.and_then(|index_state| {
2010        context
2011            .excerpt
2012            .parent_declarations
2013            .last()
2014            .and_then(|(parent, _)| {
2015                add_signature(
2016                    *parent,
2017                    &mut declaration_to_signature_index,
2018                    &mut signatures,
2019                    index_state,
2020                )
2021            })
2022    });
2023
2024    predict_edits_v3::PredictEditsRequest {
2025        excerpt_path,
2026        excerpt: context.excerpt_text.body,
2027        excerpt_line_range: context.excerpt.line_range,
2028        excerpt_range: context.excerpt.range,
2029        cursor_point: predict_edits_v3::Point {
2030            line: predict_edits_v3::Line(context.cursor_point.row),
2031            column: context.cursor_point.column,
2032        },
2033        referenced_declarations,
2034        included_files: vec![],
2035        signatures,
2036        excerpt_parent,
2037        events,
2038        can_collect_data,
2039        diagnostic_groups,
2040        diagnostic_groups_truncated,
2041        git_info,
2042        debug_info,
2043        prompt_max_bytes,
2044        prompt_format,
2045    }
2046}
2047
2048fn add_signature(
2049    declaration_id: DeclarationId,
2050    declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
2051    signatures: &mut Vec<Signature>,
2052    index: &SyntaxIndexState,
2053) -> Option<usize> {
2054    if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
2055        return Some(*signature_index);
2056    }
2057    let Some(parent_declaration) = index.declaration(declaration_id) else {
2058        log::error!("bug: missing parent declaration");
2059        return None;
2060    };
2061    let parent_index = parent_declaration.parent().and_then(|parent| {
2062        add_signature(parent, declaration_to_signature_index, signatures, index)
2063    });
2064    let (text, text_is_truncated) = parent_declaration.signature_text();
2065    let signature_index = signatures.len();
2066    signatures.push(Signature {
2067        text: text.into(),
2068        text_is_truncated,
2069        parent_index,
2070        range: parent_declaration.signature_line_range(),
2071    });
2072    declaration_to_signature_index.insert(declaration_id, signature_index);
2073    Some(signature_index)
2074}
2075
2076#[cfg(feature = "eval-support")]
2077pub type EvalCacheKey = (EvalCacheEntryKind, u64);
2078
2079#[cfg(feature = "eval-support")]
2080#[derive(Debug, Clone, Copy, PartialEq)]
2081pub enum EvalCacheEntryKind {
2082    Context,
2083    Search,
2084    Prediction,
2085}
2086
2087#[cfg(feature = "eval-support")]
2088impl std::fmt::Display for EvalCacheEntryKind {
2089    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2090        match self {
2091            EvalCacheEntryKind::Search => write!(f, "search"),
2092            EvalCacheEntryKind::Context => write!(f, "context"),
2093            EvalCacheEntryKind::Prediction => write!(f, "prediction"),
2094        }
2095    }
2096}
2097
2098#[cfg(feature = "eval-support")]
2099pub trait EvalCache: Send + Sync {
2100    fn read(&self, key: EvalCacheKey) -> Option<String>;
2101    fn write(&self, key: EvalCacheKey, input: &str, value: &str);
2102}
2103
2104#[cfg(test)]
2105mod tests {
2106    use std::{path::Path, sync::Arc};
2107
2108    use client::UserStore;
2109    use clock::FakeSystemClock;
2110    use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
2111    use futures::{
2112        AsyncReadExt, StreamExt,
2113        channel::{mpsc, oneshot},
2114    };
2115    use gpui::{
2116        Entity, TestAppContext,
2117        http_client::{FakeHttpClient, Response},
2118        prelude::*,
2119    };
2120    use indoc::indoc;
2121    use language::OffsetRangeExt as _;
2122    use open_ai::Usage;
2123    use pretty_assertions::{assert_eq, assert_matches};
2124    use project::{FakeFs, Project};
2125    use serde_json::json;
2126    use settings::SettingsStore;
2127    use util::path;
2128    use uuid::Uuid;
2129
2130    use crate::{BufferEditPrediction, Zeta};
2131
2132    #[gpui::test]
2133    async fn test_current_state(cx: &mut TestAppContext) {
2134        let (zeta, mut req_rx) = init_test(cx);
2135        let fs = FakeFs::new(cx.executor());
2136        fs.insert_tree(
2137            "/root",
2138            json!({
2139                "1.txt": "Hello!\nHow\nBye\n",
2140                "2.txt": "Hola!\nComo\nAdios\n"
2141            }),
2142        )
2143        .await;
2144        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2145
2146        zeta.update(cx, |zeta, cx| {
2147            zeta.register_project(&project, cx);
2148        });
2149
2150        let buffer1 = project
2151            .update(cx, |project, cx| {
2152                let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
2153                project.open_buffer(path, cx)
2154            })
2155            .await
2156            .unwrap();
2157        let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
2158        let position = snapshot1.anchor_before(language::Point::new(1, 3));
2159
2160        // Prediction for current file
2161
2162        let prediction_task = zeta.update(cx, |zeta, cx| {
2163            zeta.refresh_prediction(&project, &buffer1, position, cx)
2164        });
2165        let (_request, respond_tx) = req_rx.next().await.unwrap();
2166
2167        respond_tx
2168            .send(model_response(indoc! {r"
2169                --- a/root/1.txt
2170                +++ b/root/1.txt
2171                @@ ... @@
2172                 Hello!
2173                -How
2174                +How are you?
2175                 Bye
2176            "}))
2177            .unwrap();
2178        prediction_task.await.unwrap();
2179
2180        zeta.read_with(cx, |zeta, cx| {
2181            let prediction = zeta
2182                .current_prediction_for_buffer(&buffer1, &project, cx)
2183                .unwrap();
2184            assert_matches!(prediction, BufferEditPrediction::Local { .. });
2185        });
2186
2187        // Context refresh
2188        let refresh_task = zeta.update(cx, |zeta, cx| {
2189            zeta.refresh_context(project.clone(), buffer1.clone(), position, cx)
2190        });
2191        let (_request, respond_tx) = req_rx.next().await.unwrap();
2192        respond_tx
2193            .send(open_ai::Response {
2194                id: Uuid::new_v4().to_string(),
2195                object: "response".into(),
2196                created: 0,
2197                model: "model".into(),
2198                choices: vec![open_ai::Choice {
2199                    index: 0,
2200                    message: open_ai::RequestMessage::Assistant {
2201                        content: None,
2202                        tool_calls: vec![open_ai::ToolCall {
2203                            id: "search".into(),
2204                            content: open_ai::ToolCallContent::Function {
2205                                function: open_ai::FunctionContent {
2206                                    name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME
2207                                        .to_string(),
2208                                    arguments: serde_json::to_string(&SearchToolInput {
2209                                        queries: Box::new([SearchToolQuery {
2210                                            glob: "root/2.txt".to_string(),
2211                                            syntax_node: vec![],
2212                                            content: Some(".".into()),
2213                                        }]),
2214                                    })
2215                                    .unwrap(),
2216                                },
2217                            },
2218                        }],
2219                    },
2220                    finish_reason: None,
2221                }],
2222                usage: Usage {
2223                    prompt_tokens: 0,
2224                    completion_tokens: 0,
2225                    total_tokens: 0,
2226                },
2227            })
2228            .unwrap();
2229        refresh_task.await.unwrap();
2230
2231        zeta.update(cx, |zeta, _cx| {
2232            zeta.discard_current_prediction(&project);
2233        });
2234
2235        // Prediction for another file
2236        let prediction_task = zeta.update(cx, |zeta, cx| {
2237            zeta.refresh_prediction(&project, &buffer1, position, cx)
2238        });
2239        let (_request, respond_tx) = req_rx.next().await.unwrap();
2240        respond_tx
2241            .send(model_response(indoc! {r#"
2242                --- a/root/2.txt
2243                +++ b/root/2.txt
2244                 Hola!
2245                -Como
2246                +Como estas?
2247                 Adios
2248            "#}))
2249            .unwrap();
2250        prediction_task.await.unwrap();
2251        zeta.read_with(cx, |zeta, cx| {
2252            let prediction = zeta
2253                .current_prediction_for_buffer(&buffer1, &project, cx)
2254                .unwrap();
2255            assert_matches!(
2256                prediction,
2257                BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
2258            );
2259        });
2260
2261        let buffer2 = project
2262            .update(cx, |project, cx| {
2263                let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
2264                project.open_buffer(path, cx)
2265            })
2266            .await
2267            .unwrap();
2268
2269        zeta.read_with(cx, |zeta, cx| {
2270            let prediction = zeta
2271                .current_prediction_for_buffer(&buffer2, &project, cx)
2272                .unwrap();
2273            assert_matches!(prediction, BufferEditPrediction::Local { .. });
2274        });
2275    }
2276
2277    #[gpui::test]
2278    async fn test_simple_request(cx: &mut TestAppContext) {
2279        let (zeta, mut req_rx) = init_test(cx);
2280        let fs = FakeFs::new(cx.executor());
2281        fs.insert_tree(
2282            "/root",
2283            json!({
2284                "foo.md":  "Hello!\nHow\nBye\n"
2285            }),
2286        )
2287        .await;
2288        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2289
2290        let buffer = project
2291            .update(cx, |project, cx| {
2292                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2293                project.open_buffer(path, cx)
2294            })
2295            .await
2296            .unwrap();
2297        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2298        let position = snapshot.anchor_before(language::Point::new(1, 3));
2299
2300        let prediction_task = zeta.update(cx, |zeta, cx| {
2301            zeta.request_prediction(&project, &buffer, position, cx)
2302        });
2303
2304        let (_, respond_tx) = req_rx.next().await.unwrap();
2305
2306        // TODO Put back when we have a structured request again
2307        // assert_eq!(
2308        //     request.excerpt_path.as_ref(),
2309        //     Path::new(path!("root/foo.md"))
2310        // );
2311        // assert_eq!(
2312        //     request.cursor_point,
2313        //     Point {
2314        //         line: Line(1),
2315        //         column: 3
2316        //     }
2317        // );
2318
2319        respond_tx
2320            .send(model_response(indoc! { r"
2321                --- a/root/foo.md
2322                +++ b/root/foo.md
2323                @@ ... @@
2324                 Hello!
2325                -How
2326                +How are you?
2327                 Bye
2328            "}))
2329            .unwrap();
2330
2331        let prediction = prediction_task.await.unwrap().unwrap();
2332
2333        assert_eq!(prediction.edits.len(), 1);
2334        assert_eq!(
2335            prediction.edits[0].0.to_point(&snapshot).start,
2336            language::Point::new(1, 3)
2337        );
2338        assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
2339    }
2340
2341    #[gpui::test]
2342    async fn test_request_events(cx: &mut TestAppContext) {
2343        let (zeta, mut req_rx) = init_test(cx);
2344        let fs = FakeFs::new(cx.executor());
2345        fs.insert_tree(
2346            "/root",
2347            json!({
2348                "foo.md": "Hello!\n\nBye\n"
2349            }),
2350        )
2351        .await;
2352        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2353
2354        let buffer = project
2355            .update(cx, |project, cx| {
2356                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2357                project.open_buffer(path, cx)
2358            })
2359            .await
2360            .unwrap();
2361
2362        zeta.update(cx, |zeta, cx| {
2363            zeta.register_buffer(&buffer, &project, cx);
2364        });
2365
2366        buffer.update(cx, |buffer, cx| {
2367            buffer.edit(vec![(7..7, "How")], None, cx);
2368        });
2369
2370        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2371        let position = snapshot.anchor_before(language::Point::new(1, 3));
2372
2373        let prediction_task = zeta.update(cx, |zeta, cx| {
2374            zeta.request_prediction(&project, &buffer, position, cx)
2375        });
2376
2377        let (request, respond_tx) = req_rx.next().await.unwrap();
2378
2379        let prompt = prompt_from_request(&request);
2380        assert!(
2381            prompt.contains(indoc! {"
2382            --- a/root/foo.md
2383            +++ b/root/foo.md
2384            @@ -1,3 +1,3 @@
2385             Hello!
2386            -
2387            +How
2388             Bye
2389        "}),
2390            "{prompt}"
2391        );
2392
2393        respond_tx
2394            .send(model_response(indoc! {r#"
2395                --- a/root/foo.md
2396                +++ b/root/foo.md
2397                @@ ... @@
2398                 Hello!
2399                -How
2400                +How are you?
2401                 Bye
2402            "#}))
2403            .unwrap();
2404
2405        let prediction = prediction_task.await.unwrap().unwrap();
2406
2407        assert_eq!(prediction.edits.len(), 1);
2408        assert_eq!(
2409            prediction.edits[0].0.to_point(&snapshot).start,
2410            language::Point::new(1, 3)
2411        );
2412        assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
2413    }
2414
2415    // Skipped until we start including diagnostics in prompt
2416    // #[gpui::test]
2417    // async fn test_request_diagnostics(cx: &mut TestAppContext) {
2418    //     let (zeta, mut req_rx) = init_test(cx);
2419    //     let fs = FakeFs::new(cx.executor());
2420    //     fs.insert_tree(
2421    //         "/root",
2422    //         json!({
2423    //             "foo.md": "Hello!\nBye"
2424    //         }),
2425    //     )
2426    //     .await;
2427    //     let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2428
2429    //     let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
2430    //     let diagnostic = lsp::Diagnostic {
2431    //         range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
2432    //         severity: Some(lsp::DiagnosticSeverity::ERROR),
2433    //         message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
2434    //         ..Default::default()
2435    //     };
2436
2437    //     project.update(cx, |project, cx| {
2438    //         project.lsp_store().update(cx, |lsp_store, cx| {
2439    //             // Create some diagnostics
2440    //             lsp_store
2441    //                 .update_diagnostics(
2442    //                     LanguageServerId(0),
2443    //                     lsp::PublishDiagnosticsParams {
2444    //                         uri: path_to_buffer_uri.clone(),
2445    //                         diagnostics: vec![diagnostic],
2446    //                         version: None,
2447    //                     },
2448    //                     None,
2449    //                     language::DiagnosticSourceKind::Pushed,
2450    //                     &[],
2451    //                     cx,
2452    //                 )
2453    //                 .unwrap();
2454    //         });
2455    //     });
2456
2457    //     let buffer = project
2458    //         .update(cx, |project, cx| {
2459    //             let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2460    //             project.open_buffer(path, cx)
2461    //         })
2462    //         .await
2463    //         .unwrap();
2464
2465    //     let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2466    //     let position = snapshot.anchor_before(language::Point::new(0, 0));
2467
2468    //     let _prediction_task = zeta.update(cx, |zeta, cx| {
2469    //         zeta.request_prediction(&project, &buffer, position, cx)
2470    //     });
2471
2472    //     let (request, _respond_tx) = req_rx.next().await.unwrap();
2473
2474    //     assert_eq!(request.diagnostic_groups.len(), 1);
2475    //     let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
2476    //         .unwrap();
2477    //     // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
2478    //     assert_eq!(
2479    //         value,
2480    //         json!({
2481    //             "entries": [{
2482    //                 "range": {
2483    //                     "start": 8,
2484    //                     "end": 10
2485    //                 },
2486    //                 "diagnostic": {
2487    //                     "source": null,
2488    //                     "code": null,
2489    //                     "code_description": null,
2490    //                     "severity": 1,
2491    //                     "message": "\"Hello\" deprecated. Use \"Hi\" instead",
2492    //                     "markdown": null,
2493    //                     "group_id": 0,
2494    //                     "is_primary": true,
2495    //                     "is_disk_based": false,
2496    //                     "is_unnecessary": false,
2497    //                     "source_kind": "Pushed",
2498    //                     "data": null,
2499    //                     "underline": true
2500    //                 }
2501    //             }],
2502    //             "primary_ix": 0
2503    //         })
2504    //     );
2505    // }
2506
2507    fn model_response(text: &str) -> open_ai::Response {
2508        open_ai::Response {
2509            id: Uuid::new_v4().to_string(),
2510            object: "response".into(),
2511            created: 0,
2512            model: "model".into(),
2513            choices: vec![open_ai::Choice {
2514                index: 0,
2515                message: open_ai::RequestMessage::Assistant {
2516                    content: Some(open_ai::MessageContent::Plain(text.to_string())),
2517                    tool_calls: vec![],
2518                },
2519                finish_reason: None,
2520            }],
2521            usage: Usage {
2522                prompt_tokens: 0,
2523                completion_tokens: 0,
2524                total_tokens: 0,
2525            },
2526        }
2527    }
2528
2529    fn prompt_from_request(request: &open_ai::Request) -> &str {
2530        assert_eq!(request.messages.len(), 1);
2531        let open_ai::RequestMessage::User {
2532            content: open_ai::MessageContent::Plain(content),
2533            ..
2534        } = &request.messages[0]
2535        else {
2536            panic!(
2537                "Request does not have single user message of type Plain. {:#?}",
2538                request
2539            );
2540        };
2541        content
2542    }
2543
2544    fn init_test(
2545        cx: &mut TestAppContext,
2546    ) -> (
2547        Entity<Zeta>,
2548        mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender<open_ai::Response>)>,
2549    ) {
2550        cx.update(move |cx| {
2551            let settings_store = SettingsStore::test(cx);
2552            cx.set_global(settings_store);
2553            zlog::init_test();
2554
2555            let (req_tx, req_rx) = mpsc::unbounded();
2556
2557            let http_client = FakeHttpClient::create({
2558                move |req| {
2559                    let uri = req.uri().path().to_string();
2560                    let mut body = req.into_body();
2561                    let req_tx = req_tx.clone();
2562                    async move {
2563                        let resp = match uri.as_str() {
2564                            "/client/llm_tokens" => serde_json::to_string(&json!({
2565                                "token": "test"
2566                            }))
2567                            .unwrap(),
2568                            "/predict_edits/raw" => {
2569                                let mut buf = Vec::new();
2570                                body.read_to_end(&mut buf).await.ok();
2571                                let req = serde_json::from_slice(&buf).unwrap();
2572
2573                                let (res_tx, res_rx) = oneshot::channel();
2574                                req_tx.unbounded_send((req, res_tx)).unwrap();
2575                                serde_json::to_string(&res_rx.await?).unwrap()
2576                            }
2577                            _ => {
2578                                panic!("Unexpected path: {}", uri)
2579                            }
2580                        };
2581
2582                        Ok(Response::builder().body(resp.into()).unwrap())
2583                    }
2584                }
2585            });
2586
2587            let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
2588            client.cloud_client().set_credentials(1, "test".into());
2589
2590            language_model::init(client.clone(), cx);
2591
2592            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2593            let zeta = Zeta::global(&client, &user_store, cx);
2594
2595            (zeta, req_rx)
2596        })
2597    }
2598}