zeta2.rs

   1use anyhow::{Context as _, Result, anyhow, bail};
   2use chrono::TimeDelta;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
   5use cloud_llm_client::{
   6    AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
   7    ZED_VERSION_HEADER_NAME,
   8};
   9use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
  10use cloud_zeta2_prompt::{CURSOR_MARKER, DEFAULT_MAX_PROMPT_BYTES};
  11use collections::HashMap;
  12use edit_prediction_context::{
  13    DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
  14    EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionScoreOptions, Line,
  15    SyntaxIndex, SyntaxIndexState,
  16};
  17use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
  18use futures::AsyncReadExt as _;
  19use futures::channel::{mpsc, oneshot};
  20use gpui::http_client::{AsyncBody, Method};
  21use gpui::{
  22    App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
  23    http_client, prelude::*,
  24};
  25use language::{Anchor, Buffer, DiagnosticSet, LanguageServerId, Point, ToOffset as _, ToPoint};
  26use language::{BufferSnapshot, OffsetRangeExt};
  27use language_model::{LlmApiToken, RefreshLlmTokenListener};
  28use open_ai::FunctionDefinition;
  29use project::{Project, ProjectPath};
  30use release_channel::AppVersion;
  31use serde::de::DeserializeOwned;
  32use std::collections::{VecDeque, hash_map};
  33
  34use std::ops::Range;
  35use std::path::Path;
  36use std::str::FromStr as _;
  37use std::sync::{Arc, LazyLock};
  38use std::time::{Duration, Instant};
  39use std::{env, mem};
  40use thiserror::Error;
  41use util::rel_path::RelPathBuf;
  42use util::{LogErrorFuture, ResultExt as _, TryFutureExt};
  43use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  44
  45pub mod assemble_excerpts;
  46mod prediction;
  47mod provider;
  48pub mod retrieval_search;
  49mod sweep_ai;
  50pub mod udiff;
  51mod xml_edits;
  52
  53use crate::assemble_excerpts::assemble_excerpts;
  54pub use crate::prediction::EditPrediction;
  55pub use crate::prediction::EditPredictionId;
  56pub use provider::ZetaEditPredictionProvider;
  57
  58/// Maximum number of events to track.
  59const EVENT_COUNT_MAX_SWEEP: usize = 6;
  60const EVENT_COUNT_MAX_ZETA: usize = 16;
  61const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
  62
  63pub struct SweepFeatureFlag;
  64
  65impl FeatureFlag for SweepFeatureFlag {
  66    const NAME: &str = "sweep-ai";
  67}
  68pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
  69    max_bytes: 512,
  70    min_bytes: 128,
  71    target_before_cursor_over_total_bytes: 0.5,
  72};
  73
  74pub const DEFAULT_CONTEXT_OPTIONS: ContextMode =
  75    ContextMode::Agentic(DEFAULT_AGENTIC_CONTEXT_OPTIONS);
  76
  77pub const DEFAULT_AGENTIC_CONTEXT_OPTIONS: AgenticContextOptions = AgenticContextOptions {
  78    excerpt: DEFAULT_EXCERPT_OPTIONS,
  79};
  80
  81pub const DEFAULT_SYNTAX_CONTEXT_OPTIONS: EditPredictionContextOptions =
  82    EditPredictionContextOptions {
  83        use_imports: true,
  84        max_retrieved_declarations: 0,
  85        excerpt: DEFAULT_EXCERPT_OPTIONS,
  86        score: EditPredictionScoreOptions {
  87            omit_excerpt_overlaps: true,
  88        },
  89    };
  90
  91pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
  92    context: DEFAULT_CONTEXT_OPTIONS,
  93    max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
  94    max_diagnostic_bytes: 2048,
  95    prompt_format: PromptFormat::DEFAULT,
  96    file_indexing_parallelism: 1,
  97    buffer_change_grouping_interval: Duration::from_secs(1),
  98};
  99
 100static USE_OLLAMA: LazyLock<bool> =
 101    LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
 102static CONTEXT_RETRIEVAL_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
 103    env::var("ZED_ZETA2_CONTEXT_MODEL").unwrap_or(if *USE_OLLAMA {
 104        "qwen3-coder:30b".to_string()
 105    } else {
 106        "yqvev8r3".to_string()
 107    })
 108});
 109static EDIT_PREDICTIONS_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
 110    match env::var("ZED_ZETA2_MODEL").as_deref() {
 111        Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten
 112        Ok(model) => model,
 113        Err(_) if *USE_OLLAMA => "qwen3-coder:30b",
 114        Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten
 115    }
 116    .to_string()
 117});
 118static PREDICT_EDITS_URL: LazyLock<Option<String>> = LazyLock::new(|| {
 119    env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| {
 120        if *USE_OLLAMA {
 121            Some("http://localhost:11434/v1/chat/completions".into())
 122        } else {
 123            None
 124        }
 125    })
 126});
 127
 128pub struct Zeta2FeatureFlag;
 129
 130impl FeatureFlag for Zeta2FeatureFlag {
 131    const NAME: &'static str = "zeta2";
 132
 133    fn enabled_for_staff() -> bool {
 134        false
 135    }
 136}
 137
 138#[derive(Clone)]
 139struct ZetaGlobal(Entity<Zeta>);
 140
 141impl Global for ZetaGlobal {}
 142
 143pub struct Zeta {
 144    client: Arc<Client>,
 145    user_store: Entity<UserStore>,
 146    llm_token: LlmApiToken,
 147    _llm_token_subscription: Subscription,
 148    projects: HashMap<EntityId, ZetaProject>,
 149    options: ZetaOptions,
 150    update_required: bool,
 151    debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
 152    #[cfg(feature = "eval-support")]
 153    eval_cache: Option<Arc<dyn EvalCache>>,
 154    edit_prediction_model: ZetaEditPredictionModel,
 155    sweep_api_token: Option<String>,
 156    sweep_ai_debug_info: Arc<str>,
 157}
 158
 159#[derive(PartialEq, Eq)]
 160pub enum ZetaEditPredictionModel {
 161    ZedCloud,
 162    Sweep,
 163}
 164
 165#[derive(Debug, Clone, PartialEq)]
 166pub struct ZetaOptions {
 167    pub context: ContextMode,
 168    pub max_prompt_bytes: usize,
 169    pub max_diagnostic_bytes: usize,
 170    pub prompt_format: predict_edits_v3::PromptFormat,
 171    pub file_indexing_parallelism: usize,
 172    pub buffer_change_grouping_interval: Duration,
 173}
 174
 175#[derive(Debug, Clone, PartialEq)]
 176pub enum ContextMode {
 177    Agentic(AgenticContextOptions),
 178    Syntax(EditPredictionContextOptions),
 179}
 180
 181#[derive(Debug, Clone, PartialEq)]
 182pub struct AgenticContextOptions {
 183    pub excerpt: EditPredictionExcerptOptions,
 184}
 185
 186impl ContextMode {
 187    pub fn excerpt(&self) -> &EditPredictionExcerptOptions {
 188        match self {
 189            ContextMode::Agentic(options) => &options.excerpt,
 190            ContextMode::Syntax(options) => &options.excerpt,
 191        }
 192    }
 193}
 194
 195#[derive(Debug)]
 196pub enum ZetaDebugInfo {
 197    ContextRetrievalStarted(ZetaContextRetrievalStartedDebugInfo),
 198    SearchQueriesGenerated(ZetaSearchQueryDebugInfo),
 199    SearchQueriesExecuted(ZetaContextRetrievalDebugInfo),
 200    ContextRetrievalFinished(ZetaContextRetrievalDebugInfo),
 201    EditPredictionRequested(ZetaEditPredictionDebugInfo),
 202}
 203
 204#[derive(Debug)]
 205pub struct ZetaContextRetrievalStartedDebugInfo {
 206    pub project: Entity<Project>,
 207    pub timestamp: Instant,
 208    pub search_prompt: String,
 209}
 210
 211#[derive(Debug)]
 212pub struct ZetaContextRetrievalDebugInfo {
 213    pub project: Entity<Project>,
 214    pub timestamp: Instant,
 215}
 216
 217#[derive(Debug)]
 218pub struct ZetaEditPredictionDebugInfo {
 219    pub request: predict_edits_v3::PredictEditsRequest,
 220    pub retrieval_time: TimeDelta,
 221    pub buffer: WeakEntity<Buffer>,
 222    pub position: language::Anchor,
 223    pub local_prompt: Result<String, String>,
 224    pub response_rx: oneshot::Receiver<(Result<open_ai::Response, String>, TimeDelta)>,
 225}
 226
 227#[derive(Debug)]
 228pub struct ZetaSearchQueryDebugInfo {
 229    pub project: Entity<Project>,
 230    pub timestamp: Instant,
 231    pub search_queries: Vec<SearchToolQuery>,
 232}
 233
 234pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
 235
 236struct ZetaProject {
 237    syntax_index: Option<Entity<SyntaxIndex>>,
 238    events: VecDeque<Event>,
 239    recent_paths: VecDeque<ProjectPath>,
 240    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 241    current_prediction: Option<CurrentEditPrediction>,
 242    context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
 243    refresh_context_task: Option<LogErrorFuture<Task<Result<()>>>>,
 244    refresh_context_debounce_task: Option<Task<Option<()>>>,
 245    refresh_context_timestamp: Option<Instant>,
 246    _subscription: gpui::Subscription,
 247}
 248
 249#[derive(Debug, Clone)]
 250struct CurrentEditPrediction {
 251    pub requested_by_buffer_id: EntityId,
 252    pub prediction: EditPrediction,
 253}
 254
 255impl CurrentEditPrediction {
 256    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 257        let Some(new_edits) = self
 258            .prediction
 259            .interpolate(&self.prediction.buffer.read(cx))
 260        else {
 261            return false;
 262        };
 263
 264        if self.prediction.buffer != old_prediction.prediction.buffer {
 265            return true;
 266        }
 267
 268        let Some(old_edits) = old_prediction
 269            .prediction
 270            .interpolate(&old_prediction.prediction.buffer.read(cx))
 271        else {
 272            return true;
 273        };
 274
 275        // This reduces the occurrence of UI thrash from replacing edits
 276        //
 277        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 278        if self.requested_by_buffer_id == self.prediction.buffer.entity_id()
 279            && self.requested_by_buffer_id == old_prediction.prediction.buffer.entity_id()
 280            && old_edits.len() == 1
 281            && new_edits.len() == 1
 282        {
 283            let (old_range, old_text) = &old_edits[0];
 284            let (new_range, new_text) = &new_edits[0];
 285            new_range == old_range && new_text.starts_with(old_text.as_ref())
 286        } else {
 287            true
 288        }
 289    }
 290}
 291
 292/// A prediction from the perspective of a buffer.
 293#[derive(Debug)]
 294enum BufferEditPrediction<'a> {
 295    Local { prediction: &'a EditPrediction },
 296    Jump { prediction: &'a EditPrediction },
 297}
 298
 299struct RegisteredBuffer {
 300    snapshot: BufferSnapshot,
 301    _subscriptions: [gpui::Subscription; 2],
 302}
 303
 304#[derive(Clone)]
 305pub enum Event {
 306    BufferChange {
 307        old_snapshot: BufferSnapshot,
 308        new_snapshot: BufferSnapshot,
 309        end_edit_anchor: Option<Anchor>,
 310        timestamp: Instant,
 311    },
 312}
 313
 314impl Event {
 315    pub fn to_request_event(&self, cx: &App) -> Option<predict_edits_v3::Event> {
 316        match self {
 317            Event::BufferChange {
 318                old_snapshot,
 319                new_snapshot,
 320                ..
 321            } => {
 322                let path = new_snapshot.file().map(|f| f.full_path(cx));
 323
 324                let old_path = old_snapshot.file().and_then(|f| {
 325                    let old_path = f.full_path(cx);
 326                    if Some(&old_path) != path.as_ref() {
 327                        Some(old_path)
 328                    } else {
 329                        None
 330                    }
 331                });
 332
 333                // TODO [zeta2] move to bg?
 334                let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
 335
 336                if path == old_path && diff.is_empty() {
 337                    None
 338                } else {
 339                    Some(predict_edits_v3::Event::BufferChange {
 340                        old_path,
 341                        path,
 342                        diff,
 343                        //todo: Actually detect if this edit was predicted or not
 344                        predicted: false,
 345                    })
 346                }
 347            }
 348        }
 349    }
 350
 351    pub fn project_path(&self, cx: &App) -> Option<project::ProjectPath> {
 352        match self {
 353            Event::BufferChange { new_snapshot, .. } => new_snapshot
 354                .file()
 355                .map(|f| project::ProjectPath::from_file(f.as_ref(), cx)),
 356        }
 357    }
 358}
 359
 360impl Zeta {
 361    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 362        cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
 363    }
 364
 365    pub fn global(
 366        client: &Arc<Client>,
 367        user_store: &Entity<UserStore>,
 368        cx: &mut App,
 369    ) -> Entity<Self> {
 370        cx.try_global::<ZetaGlobal>()
 371            .map(|global| global.0.clone())
 372            .unwrap_or_else(|| {
 373                let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 374                cx.set_global(ZetaGlobal(zeta.clone()));
 375                zeta
 376            })
 377    }
 378
 379    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 380        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 381
 382        Self {
 383            projects: HashMap::default(),
 384            client,
 385            user_store,
 386            options: DEFAULT_OPTIONS,
 387            llm_token: LlmApiToken::default(),
 388            _llm_token_subscription: cx.subscribe(
 389                &refresh_llm_token_listener,
 390                |this, _listener, _event, cx| {
 391                    let client = this.client.clone();
 392                    let llm_token = this.llm_token.clone();
 393                    cx.spawn(async move |_this, _cx| {
 394                        llm_token.refresh(&client).await?;
 395                        anyhow::Ok(())
 396                    })
 397                    .detach_and_log_err(cx);
 398                },
 399            ),
 400            update_required: false,
 401            debug_tx: None,
 402            #[cfg(feature = "eval-support")]
 403            eval_cache: None,
 404            edit_prediction_model: ZetaEditPredictionModel::ZedCloud,
 405            sweep_api_token: std::env::var("SWEEP_AI_TOKEN")
 406                .context("No SWEEP_AI_TOKEN environment variable set")
 407                .log_err(),
 408            sweep_ai_debug_info: sweep_ai::debug_info(cx),
 409        }
 410    }
 411
 412    pub fn set_edit_prediction_model(&mut self, model: ZetaEditPredictionModel) {
 413        self.edit_prediction_model = model;
 414    }
 415
 416    pub fn has_sweep_api_token(&self) -> bool {
 417        self.sweep_api_token.is_some()
 418    }
 419
 420    #[cfg(feature = "eval-support")]
 421    pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
 422        self.eval_cache = Some(cache);
 423    }
 424
 425    pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<ZetaDebugInfo> {
 426        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
 427        self.debug_tx = Some(debug_watch_tx);
 428        debug_watch_rx
 429    }
 430
 431    pub fn options(&self) -> &ZetaOptions {
 432        &self.options
 433    }
 434
 435    pub fn set_options(&mut self, options: ZetaOptions) {
 436        self.options = options;
 437    }
 438
 439    pub fn clear_history(&mut self) {
 440        for zeta_project in self.projects.values_mut() {
 441            zeta_project.events.clear();
 442        }
 443    }
 444
 445    pub fn history_for_project(
 446        &self,
 447        project: &Entity<Project>,
 448    ) -> impl DoubleEndedIterator<Item = &Event> {
 449        self.projects
 450            .get(&project.entity_id())
 451            .map(|project| project.events.iter())
 452            .into_iter()
 453            .flatten()
 454    }
 455
 456    pub fn context_for_project(
 457        &self,
 458        project: &Entity<Project>,
 459    ) -> impl Iterator<Item = (Entity<Buffer>, &[Range<Anchor>])> {
 460        self.projects
 461            .get(&project.entity_id())
 462            .and_then(|project| {
 463                Some(
 464                    project
 465                        .context
 466                        .as_ref()?
 467                        .iter()
 468                        .map(|(buffer, ranges)| (buffer.clone(), ranges.as_slice())),
 469                )
 470            })
 471            .into_iter()
 472            .flatten()
 473    }
 474
 475    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 476        if self.edit_prediction_model == ZetaEditPredictionModel::ZedCloud {
 477            self.user_store.read(cx).edit_prediction_usage()
 478        } else {
 479            None
 480        }
 481    }
 482
 483    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 484        self.get_or_init_zeta_project(project, cx);
 485    }
 486
 487    pub fn register_buffer(
 488        &mut self,
 489        buffer: &Entity<Buffer>,
 490        project: &Entity<Project>,
 491        cx: &mut Context<Self>,
 492    ) {
 493        let zeta_project = self.get_or_init_zeta_project(project, cx);
 494        Self::register_buffer_impl(zeta_project, buffer, project, cx);
 495    }
 496
 497    fn get_or_init_zeta_project(
 498        &mut self,
 499        project: &Entity<Project>,
 500        cx: &mut Context<Self>,
 501    ) -> &mut ZetaProject {
 502        self.projects
 503            .entry(project.entity_id())
 504            .or_insert_with(|| ZetaProject {
 505                syntax_index: if let ContextMode::Syntax(_) = &self.options.context {
 506                    Some(cx.new(|cx| {
 507                        SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
 508                    }))
 509                } else {
 510                    None
 511                },
 512                events: VecDeque::new(),
 513                recent_paths: VecDeque::new(),
 514                registered_buffers: HashMap::default(),
 515                current_prediction: None,
 516                context: None,
 517                refresh_context_task: None,
 518                refresh_context_debounce_task: None,
 519                refresh_context_timestamp: None,
 520                _subscription: cx.subscribe(&project, |this, project, event, cx| {
 521                    // TODO [zeta2] init with recent paths
 522                    if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
 523                        if let project::Event::ActiveEntryChanged(Some(active_entry_id)) = event {
 524                            let path = project.read(cx).path_for_entry(*active_entry_id, cx);
 525                            if let Some(path) = path {
 526                                if let Some(ix) = zeta_project
 527                                    .recent_paths
 528                                    .iter()
 529                                    .position(|probe| probe == &path)
 530                                {
 531                                    zeta_project.recent_paths.remove(ix);
 532                                }
 533                                zeta_project.recent_paths.push_front(path);
 534                            }
 535                        }
 536                    }
 537                }),
 538            })
 539    }
 540
 541    fn register_buffer_impl<'a>(
 542        zeta_project: &'a mut ZetaProject,
 543        buffer: &Entity<Buffer>,
 544        project: &Entity<Project>,
 545        cx: &mut Context<Self>,
 546    ) -> &'a mut RegisteredBuffer {
 547        let buffer_id = buffer.entity_id();
 548        match zeta_project.registered_buffers.entry(buffer_id) {
 549            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 550            hash_map::Entry::Vacant(entry) => {
 551                let snapshot = buffer.read(cx).snapshot();
 552                let project_entity_id = project.entity_id();
 553                entry.insert(RegisteredBuffer {
 554                    snapshot,
 555                    _subscriptions: [
 556                        cx.subscribe(buffer, {
 557                            let project = project.downgrade();
 558                            move |this, buffer, event, cx| {
 559                                if let language::BufferEvent::Edited = event
 560                                    && let Some(project) = project.upgrade()
 561                                {
 562                                    this.report_changes_for_buffer(&buffer, &project, cx);
 563                                }
 564                            }
 565                        }),
 566                        cx.observe_release(buffer, move |this, _buffer, _cx| {
 567                            let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
 568                            else {
 569                                return;
 570                            };
 571                            zeta_project.registered_buffers.remove(&buffer_id);
 572                        }),
 573                    ],
 574                })
 575            }
 576        }
 577    }
 578
 579    fn report_changes_for_buffer(
 580        &mut self,
 581        buffer: &Entity<Buffer>,
 582        project: &Entity<Project>,
 583        cx: &mut Context<Self>,
 584    ) {
 585        let event_count_max = match self.edit_prediction_model {
 586            ZetaEditPredictionModel::ZedCloud => EVENT_COUNT_MAX_ZETA,
 587            ZetaEditPredictionModel::Sweep => EVENT_COUNT_MAX_SWEEP,
 588        };
 589
 590        let sweep_ai_project = self.get_or_init_zeta_project(project, cx);
 591        let registered_buffer = Self::register_buffer_impl(sweep_ai_project, buffer, project, cx);
 592
 593        let new_snapshot = buffer.read(cx).snapshot();
 594        if new_snapshot.version == registered_buffer.snapshot.version {
 595            return;
 596        }
 597
 598        let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
 599        let end_edit_anchor = new_snapshot
 600            .anchored_edits_since::<Point>(&old_snapshot.version)
 601            .last()
 602            .map(|(_, range)| range.end);
 603        let events = &mut sweep_ai_project.events;
 604
 605        if let Some(Event::BufferChange {
 606            new_snapshot: last_new_snapshot,
 607            end_edit_anchor: last_end_edit_anchor,
 608            ..
 609        }) = events.back_mut()
 610        {
 611            let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
 612                == last_new_snapshot.remote_id()
 613                && old_snapshot.version == last_new_snapshot.version;
 614
 615            let should_coalesce = is_next_snapshot_of_same_buffer
 616                && end_edit_anchor
 617                    .as_ref()
 618                    .zip(last_end_edit_anchor.as_ref())
 619                    .is_some_and(|(a, b)| {
 620                        let a = a.to_point(&new_snapshot);
 621                        let b = b.to_point(&new_snapshot);
 622                        a.row.abs_diff(b.row) <= CHANGE_GROUPING_LINE_SPAN
 623                    });
 624
 625            if should_coalesce {
 626                *last_end_edit_anchor = end_edit_anchor;
 627                *last_new_snapshot = new_snapshot;
 628                return;
 629            }
 630        }
 631
 632        if events.len() >= event_count_max {
 633            events.pop_front();
 634        }
 635
 636        events.push_back(Event::BufferChange {
 637            old_snapshot,
 638            new_snapshot,
 639            end_edit_anchor,
 640            timestamp: Instant::now(),
 641        });
 642    }
 643
 644    fn current_prediction_for_buffer(
 645        &self,
 646        buffer: &Entity<Buffer>,
 647        project: &Entity<Project>,
 648        cx: &App,
 649    ) -> Option<BufferEditPrediction<'_>> {
 650        let project_state = self.projects.get(&project.entity_id())?;
 651
 652        let CurrentEditPrediction {
 653            requested_by_buffer_id,
 654            prediction,
 655        } = project_state.current_prediction.as_ref()?;
 656
 657        if prediction.targets_buffer(buffer.read(cx)) {
 658            Some(BufferEditPrediction::Local { prediction })
 659        } else if *requested_by_buffer_id == buffer.entity_id() {
 660            Some(BufferEditPrediction::Jump { prediction })
 661        } else {
 662            None
 663        }
 664    }
 665
 666    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 667        if self.edit_prediction_model != ZetaEditPredictionModel::ZedCloud {
 668            return;
 669        }
 670
 671        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
 672            return;
 673        };
 674
 675        let Some(prediction) = project_state.current_prediction.take() else {
 676            return;
 677        };
 678        let request_id = prediction.prediction.id.to_string();
 679
 680        let client = self.client.clone();
 681        let llm_token = self.llm_token.clone();
 682        let app_version = AppVersion::global(cx);
 683        cx.spawn(async move |this, cx| {
 684            let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") {
 685                http_client::Url::parse(&predict_edits_url)?
 686            } else {
 687                client
 688                    .http_client()
 689                    .build_zed_llm_url("/predict_edits/accept", &[])?
 690            };
 691
 692            let response = cx
 693                .background_spawn(Self::send_api_request::<()>(
 694                    move |builder| {
 695                        let req = builder.uri(url.as_ref()).body(
 696                            serde_json::to_string(&AcceptEditPredictionBody {
 697                                request_id: request_id.clone(),
 698                            })?
 699                            .into(),
 700                        );
 701                        Ok(req?)
 702                    },
 703                    client,
 704                    llm_token,
 705                    app_version,
 706                ))
 707                .await;
 708
 709            Self::handle_api_response(&this, response, cx)?;
 710            anyhow::Ok(())
 711        })
 712        .detach_and_log_err(cx);
 713    }
 714
 715    fn discard_current_prediction(&mut self, project: &Entity<Project>) {
 716        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 717            project_state.current_prediction.take();
 718        };
 719    }
 720
 721    pub fn refresh_prediction(
 722        &mut self,
 723        project: &Entity<Project>,
 724        buffer: &Entity<Buffer>,
 725        position: language::Anchor,
 726        cx: &mut Context<Self>,
 727    ) -> Task<Result<()>> {
 728        let request_task = self.request_prediction(project, buffer, position, cx);
 729        let buffer = buffer.clone();
 730        let project = project.clone();
 731
 732        cx.spawn(async move |this, cx| {
 733            if let Some(prediction) = request_task.await? {
 734                this.update(cx, |this, cx| {
 735                    let project_state = this
 736                        .projects
 737                        .get_mut(&project.entity_id())
 738                        .context("Project not found")?;
 739
 740                    let new_prediction = CurrentEditPrediction {
 741                        requested_by_buffer_id: buffer.entity_id(),
 742                        prediction: prediction,
 743                    };
 744
 745                    if project_state
 746                        .current_prediction
 747                        .as_ref()
 748                        .is_none_or(|old_prediction| {
 749                            new_prediction.should_replace_prediction(&old_prediction, cx)
 750                        })
 751                    {
 752                        project_state.current_prediction = Some(new_prediction);
 753                    }
 754                    anyhow::Ok(())
 755                })??;
 756            }
 757            Ok(())
 758        })
 759    }
 760
 761    pub fn request_prediction(
 762        &mut self,
 763        project: &Entity<Project>,
 764        active_buffer: &Entity<Buffer>,
 765        position: language::Anchor,
 766        cx: &mut Context<Self>,
 767    ) -> Task<Result<Option<EditPrediction>>> {
 768        match self.edit_prediction_model {
 769            ZetaEditPredictionModel::ZedCloud => {
 770                self.request_prediction_with_zed_cloud(project, active_buffer, position, cx)
 771            }
 772            ZetaEditPredictionModel::Sweep => {
 773                self.request_prediction_with_sweep(project, active_buffer, position, cx)
 774            }
 775        }
 776    }
 777
 778    fn request_prediction_with_sweep(
 779        &mut self,
 780        project: &Entity<Project>,
 781        active_buffer: &Entity<Buffer>,
 782        position: language::Anchor,
 783        cx: &mut Context<Self>,
 784    ) -> Task<Result<Option<EditPrediction>>> {
 785        let snapshot = active_buffer.read(cx).snapshot();
 786        let debug_info = self.sweep_ai_debug_info.clone();
 787        let Some(api_token) = self.sweep_api_token.clone() else {
 788            return Task::ready(Ok(None));
 789        };
 790        let full_path: Arc<Path> = snapshot
 791            .file()
 792            .map(|file| file.full_path(cx))
 793            .unwrap_or_else(|| "untitled".into())
 794            .into();
 795
 796        let project_file = project::File::from_dyn(snapshot.file());
 797        let repo_name = project_file
 798            .map(|file| file.worktree.read(cx).root_name_str())
 799            .unwrap_or("untitled")
 800            .into();
 801        let offset = position.to_offset(&snapshot);
 802
 803        let project_state = self.get_or_init_zeta_project(project, cx);
 804        let events = project_state.events.clone();
 805        let recent_buffers = project_state.recent_paths.iter().cloned();
 806        let http_client = cx.http_client();
 807
 808        let recent_buffer_snapshots = recent_buffers
 809            .filter_map(|project_path| {
 810                let buffer = project.read(cx).get_open_buffer(&project_path, cx)?;
 811                if active_buffer == &buffer {
 812                    None
 813                } else {
 814                    Some(buffer.read(cx).snapshot())
 815                }
 816            })
 817            .take(3)
 818            .collect::<Vec<_>>();
 819
 820        let result = cx.background_spawn(async move {
 821            let text = snapshot.text();
 822
 823            let mut recent_changes = String::new();
 824            for event in events {
 825                sweep_ai::write_event(event, &mut recent_changes).unwrap();
 826            }
 827
 828            let file_chunks = recent_buffer_snapshots
 829                .into_iter()
 830                .map(|snapshot| {
 831                    let end_point = language::Point::new(30, 0).min(snapshot.max_point());
 832                    sweep_ai::FileChunk {
 833                        content: snapshot
 834                            .text_for_range(language::Point::zero()..end_point)
 835                            .collect(),
 836                        file_path: snapshot
 837                            .file()
 838                            .map(|f| f.path().as_unix_str())
 839                            .unwrap_or("untitled")
 840                            .to_string(),
 841                        start_line: 0,
 842                        end_line: end_point.row as usize,
 843                        timestamp: snapshot.file().and_then(|file| {
 844                            Some(
 845                                file.disk_state()
 846                                    .mtime()?
 847                                    .to_seconds_and_nanos_for_persistence()?
 848                                    .0,
 849                            )
 850                        }),
 851                    }
 852                })
 853                .collect();
 854
 855            let request_body = sweep_ai::AutocompleteRequest {
 856                debug_info,
 857                repo_name,
 858                file_path: full_path.clone(),
 859                file_contents: text.clone(),
 860                original_file_contents: text,
 861                cursor_position: offset,
 862                recent_changes: recent_changes.clone(),
 863                changes_above_cursor: true,
 864                multiple_suggestions: false,
 865                branch: None,
 866                file_chunks,
 867                retrieval_chunks: vec![],
 868                recent_user_actions: vec![],
 869                // TODO
 870                privacy_mode_enabled: false,
 871            };
 872
 873            let mut buf: Vec<u8> = Vec::new();
 874            let writer = brotli::CompressorWriter::new(&mut buf, 4096, 11, 22);
 875            serde_json::to_writer(writer, &request_body)?;
 876            let body: AsyncBody = buf.into();
 877
 878            const SWEEP_API_URL: &str =
 879                "https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
 880
 881            let request = http_client::Request::builder()
 882                .uri(SWEEP_API_URL)
 883                .header("Content-Type", "application/json")
 884                .header("Authorization", format!("Bearer {}", api_token))
 885                .header("Connection", "keep-alive")
 886                .header("Content-Encoding", "br")
 887                .method(Method::POST)
 888                .body(body)?;
 889
 890            let mut response = http_client.send(request).await?;
 891
 892            let mut body: Vec<u8> = Vec::new();
 893            response.body_mut().read_to_end(&mut body).await?;
 894
 895            if !response.status().is_success() {
 896                anyhow::bail!(
 897                    "Request failed with status: {:?}\nBody: {}",
 898                    response.status(),
 899                    String::from_utf8_lossy(&body),
 900                );
 901            };
 902
 903            let response: sweep_ai::AutocompleteResponse = serde_json::from_slice(&body)?;
 904
 905            let old_text = snapshot
 906                .text_for_range(response.start_index..response.end_index)
 907                .collect::<String>();
 908            let edits = language::text_diff(&old_text, &response.completion)
 909                .into_iter()
 910                .map(|(range, text)| {
 911                    (
 912                        snapshot.anchor_after(response.start_index + range.start)
 913                            ..snapshot.anchor_before(response.start_index + range.end),
 914                        text,
 915                    )
 916                })
 917                .collect::<Vec<_>>();
 918
 919            anyhow::Ok((response.autocomplete_id, edits, snapshot))
 920        });
 921
 922        let buffer = active_buffer.clone();
 923
 924        cx.spawn(async move |_, cx| {
 925            let (id, edits, old_snapshot) = result.await?;
 926
 927            if edits.is_empty() {
 928                return anyhow::Ok(None);
 929            }
 930
 931            let Some((edits, new_snapshot, preview_task)) =
 932                buffer.read_with(cx, |buffer, cx| {
 933                    let new_snapshot = buffer.snapshot();
 934
 935                    let edits: Arc<[(Range<Anchor>, Arc<str>)]> =
 936                        edit_prediction::interpolate_edits(&old_snapshot, &new_snapshot, &edits)?
 937                            .into();
 938                    let preview_task = buffer.preview_edits(edits.clone(), cx);
 939
 940                    Some((edits, new_snapshot, preview_task))
 941                })?
 942            else {
 943                return anyhow::Ok(None);
 944            };
 945
 946            let prediction = EditPrediction {
 947                id: EditPredictionId(id.into()),
 948                edits,
 949                snapshot: new_snapshot,
 950                edit_preview: preview_task.await,
 951                buffer,
 952            };
 953
 954            anyhow::Ok(Some(prediction))
 955        })
 956    }
 957
 958    fn request_prediction_with_zed_cloud(
 959        &mut self,
 960        project: &Entity<Project>,
 961        active_buffer: &Entity<Buffer>,
 962        position: language::Anchor,
 963        cx: &mut Context<Self>,
 964    ) -> Task<Result<Option<EditPrediction>>> {
 965        let project_state = self.projects.get(&project.entity_id());
 966
 967        let index_state = project_state.and_then(|state| {
 968            state
 969                .syntax_index
 970                .as_ref()
 971                .map(|syntax_index| syntax_index.read_with(cx, |index, _cx| index.state().clone()))
 972        });
 973        let options = self.options.clone();
 974        let active_snapshot = active_buffer.read(cx).snapshot();
 975        let Some(excerpt_path) = active_snapshot
 976            .file()
 977            .map(|path| -> Arc<Path> { path.full_path(cx).into() })
 978        else {
 979            return Task::ready(Err(anyhow!("No file path for excerpt")));
 980        };
 981        let client = self.client.clone();
 982        let llm_token = self.llm_token.clone();
 983        let app_version = AppVersion::global(cx);
 984        let worktree_snapshots = project
 985            .read(cx)
 986            .worktrees(cx)
 987            .map(|worktree| worktree.read(cx).snapshot())
 988            .collect::<Vec<_>>();
 989        let debug_tx = self.debug_tx.clone();
 990
 991        let events = project_state
 992            .map(|state| {
 993                state
 994                    .events
 995                    .iter()
 996                    .filter_map(|event| event.to_request_event(cx))
 997                    .collect::<Vec<_>>()
 998            })
 999            .unwrap_or_default();
1000
1001        let diagnostics = active_snapshot.diagnostic_sets().clone();
1002
1003        let parent_abs_path =
1004            project::File::from_dyn(active_buffer.read(cx).file()).and_then(|f| {
1005                let mut path = f.worktree.read(cx).absolutize(&f.path);
1006                if path.pop() { Some(path) } else { None }
1007            });
1008
1009        // TODO data collection
1010        let can_collect_data = cx.is_staff();
1011
1012        let empty_context_files = HashMap::default();
1013        let context_files = project_state
1014            .and_then(|project_state| project_state.context.as_ref())
1015            .unwrap_or(&empty_context_files);
1016
1017        #[cfg(feature = "eval-support")]
1018        let parsed_fut = futures::future::join_all(
1019            context_files
1020                .keys()
1021                .map(|buffer| buffer.read(cx).parsing_idle()),
1022        );
1023
1024        let mut included_files = context_files
1025            .iter()
1026            .filter_map(|(buffer_entity, ranges)| {
1027                let buffer = buffer_entity.read(cx);
1028                Some((
1029                    buffer_entity.clone(),
1030                    buffer.snapshot(),
1031                    buffer.file()?.full_path(cx).into(),
1032                    ranges.clone(),
1033                ))
1034            })
1035            .collect::<Vec<_>>();
1036
1037        included_files.sort_by(|(_, _, path_a, ranges_a), (_, _, path_b, ranges_b)| {
1038            (path_a, ranges_a.len()).cmp(&(path_b, ranges_b.len()))
1039        });
1040
1041        #[cfg(feature = "eval-support")]
1042        let eval_cache = self.eval_cache.clone();
1043
1044        let request_task = cx.background_spawn({
1045            let active_buffer = active_buffer.clone();
1046            async move {
1047                #[cfg(feature = "eval-support")]
1048                parsed_fut.await;
1049
1050                let index_state = if let Some(index_state) = index_state {
1051                    Some(index_state.lock_owned().await)
1052                } else {
1053                    None
1054                };
1055
1056                let cursor_offset = position.to_offset(&active_snapshot);
1057                let cursor_point = cursor_offset.to_point(&active_snapshot);
1058
1059                let before_retrieval = chrono::Utc::now();
1060
1061                let (diagnostic_groups, diagnostic_groups_truncated) =
1062                    Self::gather_nearby_diagnostics(
1063                        cursor_offset,
1064                        &diagnostics,
1065                        &active_snapshot,
1066                        options.max_diagnostic_bytes,
1067                    );
1068
1069                let cloud_request = match options.context {
1070                    ContextMode::Agentic(context_options) => {
1071                        let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
1072                            cursor_point,
1073                            &active_snapshot,
1074                            &context_options.excerpt,
1075                            index_state.as_deref(),
1076                        ) else {
1077                            return Ok((None, None));
1078                        };
1079
1080                        let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start)
1081                            ..active_snapshot.anchor_before(excerpt.range.end);
1082
1083                        if let Some(buffer_ix) =
1084                            included_files.iter().position(|(_, snapshot, _, _)| {
1085                                snapshot.remote_id() == active_snapshot.remote_id()
1086                            })
1087                        {
1088                            let (_, buffer, _, ranges) = &mut included_files[buffer_ix];
1089                            ranges.push(excerpt_anchor_range);
1090                            retrieval_search::merge_anchor_ranges(ranges, buffer);
1091                            let last_ix = included_files.len() - 1;
1092                            included_files.swap(buffer_ix, last_ix);
1093                        } else {
1094                            included_files.push((
1095                                active_buffer.clone(),
1096                                active_snapshot.clone(),
1097                                excerpt_path.clone(),
1098                                vec![excerpt_anchor_range],
1099                            ));
1100                        }
1101
1102                        let included_files = included_files
1103                            .iter()
1104                            .map(|(_, snapshot, path, ranges)| {
1105                                let ranges = ranges
1106                                    .iter()
1107                                    .map(|range| {
1108                                        let point_range = range.to_point(&snapshot);
1109                                        Line(point_range.start.row)..Line(point_range.end.row)
1110                                    })
1111                                    .collect::<Vec<_>>();
1112                                let excerpts = assemble_excerpts(&snapshot, ranges);
1113                                predict_edits_v3::IncludedFile {
1114                                    path: path.clone(),
1115                                    max_row: Line(snapshot.max_point().row),
1116                                    excerpts,
1117                                }
1118                            })
1119                            .collect::<Vec<_>>();
1120
1121                        predict_edits_v3::PredictEditsRequest {
1122                            excerpt_path,
1123                            excerpt: String::new(),
1124                            excerpt_line_range: Line(0)..Line(0),
1125                            excerpt_range: 0..0,
1126                            cursor_point: predict_edits_v3::Point {
1127                                line: predict_edits_v3::Line(cursor_point.row),
1128                                column: cursor_point.column,
1129                            },
1130                            included_files,
1131                            referenced_declarations: vec![],
1132                            events,
1133                            can_collect_data,
1134                            diagnostic_groups,
1135                            diagnostic_groups_truncated,
1136                            debug_info: debug_tx.is_some(),
1137                            prompt_max_bytes: Some(options.max_prompt_bytes),
1138                            prompt_format: options.prompt_format,
1139                            // TODO [zeta2]
1140                            signatures: vec![],
1141                            excerpt_parent: None,
1142                            git_info: None,
1143                        }
1144                    }
1145                    ContextMode::Syntax(context_options) => {
1146                        let Some(context) = EditPredictionContext::gather_context(
1147                            cursor_point,
1148                            &active_snapshot,
1149                            parent_abs_path.as_deref(),
1150                            &context_options,
1151                            index_state.as_deref(),
1152                        ) else {
1153                            return Ok((None, None));
1154                        };
1155
1156                        make_syntax_context_cloud_request(
1157                            excerpt_path,
1158                            context,
1159                            events,
1160                            can_collect_data,
1161                            diagnostic_groups,
1162                            diagnostic_groups_truncated,
1163                            None,
1164                            debug_tx.is_some(),
1165                            &worktree_snapshots,
1166                            index_state.as_deref(),
1167                            Some(options.max_prompt_bytes),
1168                            options.prompt_format,
1169                        )
1170                    }
1171                };
1172
1173                let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request);
1174
1175                let retrieval_time = chrono::Utc::now() - before_retrieval;
1176
1177                let debug_response_tx = if let Some(debug_tx) = &debug_tx {
1178                    let (response_tx, response_rx) = oneshot::channel();
1179
1180                    debug_tx
1181                        .unbounded_send(ZetaDebugInfo::EditPredictionRequested(
1182                            ZetaEditPredictionDebugInfo {
1183                                request: cloud_request.clone(),
1184                                retrieval_time,
1185                                buffer: active_buffer.downgrade(),
1186                                local_prompt: match prompt_result.as_ref() {
1187                                    Ok((prompt, _)) => Ok(prompt.clone()),
1188                                    Err(err) => Err(err.to_string()),
1189                                },
1190                                position,
1191                                response_rx,
1192                            },
1193                        ))
1194                        .ok();
1195                    Some(response_tx)
1196                } else {
1197                    None
1198                };
1199
1200                if cfg!(debug_assertions) && env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
1201                    if let Some(debug_response_tx) = debug_response_tx {
1202                        debug_response_tx
1203                            .send((Err("Request skipped".to_string()), TimeDelta::zero()))
1204                            .ok();
1205                    }
1206                    anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
1207                }
1208
1209                let (prompt, _) = prompt_result?;
1210                let request = open_ai::Request {
1211                    model: EDIT_PREDICTIONS_MODEL_ID.clone(),
1212                    messages: vec![open_ai::RequestMessage::User {
1213                        content: open_ai::MessageContent::Plain(prompt),
1214                    }],
1215                    stream: false,
1216                    max_completion_tokens: None,
1217                    stop: Default::default(),
1218                    temperature: 0.7,
1219                    tool_choice: None,
1220                    parallel_tool_calls: None,
1221                    tools: vec![],
1222                    prompt_cache_key: None,
1223                    reasoning_effort: None,
1224                };
1225
1226                log::trace!("Sending edit prediction request");
1227
1228                let before_request = chrono::Utc::now();
1229                let response = Self::send_raw_llm_request(
1230                    request,
1231                    client,
1232                    llm_token,
1233                    app_version,
1234                    #[cfg(feature = "eval-support")]
1235                    eval_cache,
1236                    #[cfg(feature = "eval-support")]
1237                    EvalCacheEntryKind::Prediction,
1238                )
1239                .await;
1240                let request_time = chrono::Utc::now() - before_request;
1241
1242                log::trace!("Got edit prediction response");
1243
1244                if let Some(debug_response_tx) = debug_response_tx {
1245                    debug_response_tx
1246                        .send((
1247                            response
1248                                .as_ref()
1249                                .map_err(|err| err.to_string())
1250                                .map(|response| response.0.clone()),
1251                            request_time,
1252                        ))
1253                        .ok();
1254                }
1255
1256                let (res, usage) = response?;
1257                let request_id = EditPredictionId(res.id.clone().into());
1258                let Some(mut output_text) = text_from_response(res) else {
1259                    return Ok((None, usage));
1260                };
1261
1262                if output_text.contains(CURSOR_MARKER) {
1263                    log::trace!("Stripping out {CURSOR_MARKER} from response");
1264                    output_text = output_text.replace(CURSOR_MARKER, "");
1265                }
1266
1267                let get_buffer_from_context = |path: &Path| {
1268                    included_files
1269                        .iter()
1270                        .find_map(|(_, buffer, probe_path, ranges)| {
1271                            if probe_path.as_ref() == path {
1272                                Some((buffer, ranges.as_slice()))
1273                            } else {
1274                                None
1275                            }
1276                        })
1277                };
1278
1279                let (edited_buffer_snapshot, edits) = match options.prompt_format {
1280                    PromptFormat::NumLinesUniDiff => {
1281                        // TODO: Implement parsing of multi-file diffs
1282                        crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
1283                    }
1284                    PromptFormat::Minimal | PromptFormat::MinimalQwen => {
1285                        if output_text.contains("--- a/\n+++ b/\nNo edits") {
1286                            let edits = vec![];
1287                            (&active_snapshot, edits)
1288                        } else {
1289                            crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
1290                        }
1291                    }
1292                    PromptFormat::OldTextNewText => {
1293                        crate::xml_edits::parse_xml_edits(&output_text, get_buffer_from_context)
1294                            .await?
1295                    }
1296                    _ => {
1297                        bail!("unsupported prompt format {}", options.prompt_format)
1298                    }
1299                };
1300
1301                let edited_buffer = included_files
1302                    .iter()
1303                    .find_map(|(buffer, snapshot, _, _)| {
1304                        if snapshot.remote_id() == edited_buffer_snapshot.remote_id() {
1305                            Some(buffer.clone())
1306                        } else {
1307                            None
1308                        }
1309                    })
1310                    .context("Failed to find buffer in included_buffers")?;
1311
1312                anyhow::Ok((
1313                    Some((
1314                        request_id,
1315                        edited_buffer,
1316                        edited_buffer_snapshot.clone(),
1317                        edits,
1318                    )),
1319                    usage,
1320                ))
1321            }
1322        });
1323
1324        cx.spawn({
1325            async move |this, cx| {
1326                let Some((id, edited_buffer, edited_buffer_snapshot, edits)) =
1327                    Self::handle_api_response(&this, request_task.await, cx)?
1328                else {
1329                    return Ok(None);
1330                };
1331
1332                // TODO telemetry: duration, etc
1333                Ok(
1334                    EditPrediction::new(id, &edited_buffer, &edited_buffer_snapshot, edits, cx)
1335                        .await,
1336                )
1337            }
1338        })
1339    }
1340
1341    async fn send_raw_llm_request(
1342        request: open_ai::Request,
1343        client: Arc<Client>,
1344        llm_token: LlmApiToken,
1345        app_version: SemanticVersion,
1346        #[cfg(feature = "eval-support")] eval_cache: Option<Arc<dyn EvalCache>>,
1347        #[cfg(feature = "eval-support")] eval_cache_kind: EvalCacheEntryKind,
1348    ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
1349        let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
1350            http_client::Url::parse(&predict_edits_url)?
1351        } else {
1352            client
1353                .http_client()
1354                .build_zed_llm_url("/predict_edits/raw", &[])?
1355        };
1356
1357        #[cfg(feature = "eval-support")]
1358        let cache_key = if let Some(cache) = eval_cache {
1359            use collections::FxHasher;
1360            use std::hash::{Hash, Hasher};
1361
1362            let mut hasher = FxHasher::default();
1363            url.hash(&mut hasher);
1364            let request_str = serde_json::to_string_pretty(&request)?;
1365            request_str.hash(&mut hasher);
1366            let hash = hasher.finish();
1367
1368            let key = (eval_cache_kind, hash);
1369            if let Some(response_str) = cache.read(key) {
1370                return Ok((serde_json::from_str(&response_str)?, None));
1371            }
1372
1373            Some((cache, request_str, key))
1374        } else {
1375            None
1376        };
1377
1378        let (response, usage) = Self::send_api_request(
1379            |builder| {
1380                let req = builder
1381                    .uri(url.as_ref())
1382                    .body(serde_json::to_string(&request)?.into());
1383                Ok(req?)
1384            },
1385            client,
1386            llm_token,
1387            app_version,
1388        )
1389        .await?;
1390
1391        #[cfg(feature = "eval-support")]
1392        if let Some((cache, request, key)) = cache_key {
1393            cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
1394        }
1395
1396        Ok((response, usage))
1397    }
1398
1399    fn handle_api_response<T>(
1400        this: &WeakEntity<Self>,
1401        response: Result<(T, Option<EditPredictionUsage>)>,
1402        cx: &mut gpui::AsyncApp,
1403    ) -> Result<T> {
1404        match response {
1405            Ok((data, usage)) => {
1406                if let Some(usage) = usage {
1407                    this.update(cx, |this, cx| {
1408                        this.user_store.update(cx, |user_store, cx| {
1409                            user_store.update_edit_prediction_usage(usage, cx);
1410                        });
1411                    })
1412                    .ok();
1413                }
1414                Ok(data)
1415            }
1416            Err(err) => {
1417                if err.is::<ZedUpdateRequiredError>() {
1418                    cx.update(|cx| {
1419                        this.update(cx, |this, _cx| {
1420                            this.update_required = true;
1421                        })
1422                        .ok();
1423
1424                        let error_message: SharedString = err.to_string().into();
1425                        show_app_notification(
1426                            NotificationId::unique::<ZedUpdateRequiredError>(),
1427                            cx,
1428                            move |cx| {
1429                                cx.new(|cx| {
1430                                    ErrorMessagePrompt::new(error_message.clone(), cx)
1431                                        .with_link_button("Update Zed", "https://zed.dev/releases")
1432                                })
1433                            },
1434                        );
1435                    })
1436                    .ok();
1437                }
1438                Err(err)
1439            }
1440        }
1441    }
1442
1443    async fn send_api_request<Res>(
1444        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1445        client: Arc<Client>,
1446        llm_token: LlmApiToken,
1447        app_version: SemanticVersion,
1448    ) -> Result<(Res, Option<EditPredictionUsage>)>
1449    where
1450        Res: DeserializeOwned,
1451    {
1452        let http_client = client.http_client();
1453        let mut token = llm_token.acquire(&client).await?;
1454        let mut did_retry = false;
1455
1456        loop {
1457            let request_builder = http_client::Request::builder().method(Method::POST);
1458
1459            let request = build(
1460                request_builder
1461                    .header("Content-Type", "application/json")
1462                    .header("Authorization", format!("Bearer {}", token))
1463                    .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
1464            )?;
1465
1466            let mut response = http_client.send(request).await?;
1467
1468            if let Some(minimum_required_version) = response
1469                .headers()
1470                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
1471                .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
1472            {
1473                anyhow::ensure!(
1474                    app_version >= minimum_required_version,
1475                    ZedUpdateRequiredError {
1476                        minimum_version: minimum_required_version
1477                    }
1478                );
1479            }
1480
1481            if response.status().is_success() {
1482                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
1483
1484                let mut body = Vec::new();
1485                response.body_mut().read_to_end(&mut body).await?;
1486                return Ok((serde_json::from_slice(&body)?, usage));
1487            } else if !did_retry
1488                && response
1489                    .headers()
1490                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1491                    .is_some()
1492            {
1493                did_retry = true;
1494                token = llm_token.refresh(&client).await?;
1495            } else {
1496                let mut body = String::new();
1497                response.body_mut().read_to_string(&mut body).await?;
1498                anyhow::bail!(
1499                    "Request failed with status: {:?}\nBody: {}",
1500                    response.status(),
1501                    body
1502                );
1503            }
1504        }
1505    }
1506
1507    pub const CONTEXT_RETRIEVAL_IDLE_DURATION: Duration = Duration::from_secs(10);
1508    pub const CONTEXT_RETRIEVAL_DEBOUNCE_DURATION: Duration = Duration::from_secs(3);
1509
1510    // Refresh the related excerpts when the user just beguns editing after
1511    // an idle period, and after they pause editing.
1512    fn refresh_context_if_needed(
1513        &mut self,
1514        project: &Entity<Project>,
1515        buffer: &Entity<language::Buffer>,
1516        cursor_position: language::Anchor,
1517        cx: &mut Context<Self>,
1518    ) {
1519        if !matches!(&self.options().context, ContextMode::Agentic { .. }) {
1520            return;
1521        }
1522
1523        let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1524            return;
1525        };
1526
1527        let now = Instant::now();
1528        let was_idle = zeta_project
1529            .refresh_context_timestamp
1530            .map_or(true, |timestamp| {
1531                now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION
1532            });
1533        zeta_project.refresh_context_timestamp = Some(now);
1534        zeta_project.refresh_context_debounce_task = Some(cx.spawn({
1535            let buffer = buffer.clone();
1536            let project = project.clone();
1537            async move |this, cx| {
1538                if was_idle {
1539                    log::debug!("refetching edit prediction context after idle");
1540                } else {
1541                    cx.background_executor()
1542                        .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION)
1543                        .await;
1544                    log::debug!("refetching edit prediction context after pause");
1545                }
1546                this.update(cx, |this, cx| {
1547                    let task = this.refresh_context(project.clone(), buffer, cursor_position, cx);
1548
1549                    if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
1550                        zeta_project.refresh_context_task = Some(task.log_err());
1551                    };
1552                })
1553                .ok()
1554            }
1555        }));
1556    }
1557
1558    // Refresh the related excerpts asynchronously. Ensure the task runs to completion,
1559    // and avoid spawning more than one concurrent task.
1560    pub fn refresh_context(
1561        &mut self,
1562        project: Entity<Project>,
1563        buffer: Entity<language::Buffer>,
1564        cursor_position: language::Anchor,
1565        cx: &mut Context<Self>,
1566    ) -> Task<Result<()>> {
1567        let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
1568            return Task::ready(anyhow::Ok(()));
1569        };
1570
1571        let ContextMode::Agentic(options) = &self.options().context else {
1572            return Task::ready(anyhow::Ok(()));
1573        };
1574
1575        let snapshot = buffer.read(cx).snapshot();
1576        let cursor_point = cursor_position.to_point(&snapshot);
1577        let Some(cursor_excerpt) = EditPredictionExcerpt::select_from_buffer(
1578            cursor_point,
1579            &snapshot,
1580            &options.excerpt,
1581            None,
1582        ) else {
1583            return Task::ready(Ok(()));
1584        };
1585
1586        let app_version = AppVersion::global(cx);
1587        let client = self.client.clone();
1588        let llm_token = self.llm_token.clone();
1589        let debug_tx = self.debug_tx.clone();
1590        let current_file_path: Arc<Path> = snapshot
1591            .file()
1592            .map(|f| f.full_path(cx).into())
1593            .unwrap_or_else(|| Path::new("untitled").into());
1594
1595        let prompt = match cloud_zeta2_prompt::retrieval_prompt::build_prompt(
1596            predict_edits_v3::PlanContextRetrievalRequest {
1597                excerpt: cursor_excerpt.text(&snapshot).body,
1598                excerpt_path: current_file_path,
1599                excerpt_line_range: cursor_excerpt.line_range,
1600                cursor_file_max_row: Line(snapshot.max_point().row),
1601                events: zeta_project
1602                    .events
1603                    .iter()
1604                    .filter_map(|ev| ev.to_request_event(cx))
1605                    .collect(),
1606            },
1607        ) {
1608            Ok(prompt) => prompt,
1609            Err(err) => {
1610                return Task::ready(Err(err));
1611            }
1612        };
1613
1614        if let Some(debug_tx) = &debug_tx {
1615            debug_tx
1616                .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted(
1617                    ZetaContextRetrievalStartedDebugInfo {
1618                        project: project.clone(),
1619                        timestamp: Instant::now(),
1620                        search_prompt: prompt.clone(),
1621                    },
1622                ))
1623                .ok();
1624        }
1625
1626        pub static TOOL_SCHEMA: LazyLock<(serde_json::Value, String)> = LazyLock::new(|| {
1627            let schema = language_model::tool_schema::root_schema_for::<SearchToolInput>(
1628                language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset,
1629            );
1630
1631            let description = schema
1632                .get("description")
1633                .and_then(|description| description.as_str())
1634                .unwrap()
1635                .to_string();
1636
1637            (schema.into(), description)
1638        });
1639
1640        let (tool_schema, tool_description) = TOOL_SCHEMA.clone();
1641
1642        let request = open_ai::Request {
1643            model: CONTEXT_RETRIEVAL_MODEL_ID.clone(),
1644            messages: vec![open_ai::RequestMessage::User {
1645                content: open_ai::MessageContent::Plain(prompt),
1646            }],
1647            stream: false,
1648            max_completion_tokens: None,
1649            stop: Default::default(),
1650            temperature: 0.7,
1651            tool_choice: None,
1652            parallel_tool_calls: None,
1653            tools: vec![open_ai::ToolDefinition::Function {
1654                function: FunctionDefinition {
1655                    name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME.to_string(),
1656                    description: Some(tool_description),
1657                    parameters: Some(tool_schema),
1658                },
1659            }],
1660            prompt_cache_key: None,
1661            reasoning_effort: None,
1662        };
1663
1664        #[cfg(feature = "eval-support")]
1665        let eval_cache = self.eval_cache.clone();
1666
1667        cx.spawn(async move |this, cx| {
1668            log::trace!("Sending search planning request");
1669            let response = Self::send_raw_llm_request(
1670                request,
1671                client,
1672                llm_token,
1673                app_version,
1674                #[cfg(feature = "eval-support")]
1675                eval_cache.clone(),
1676                #[cfg(feature = "eval-support")]
1677                EvalCacheEntryKind::Context,
1678            )
1679            .await;
1680            let mut response = Self::handle_api_response(&this, response, cx)?;
1681            log::trace!("Got search planning response");
1682
1683            let choice = response
1684                .choices
1685                .pop()
1686                .context("No choices in retrieval response")?;
1687            let open_ai::RequestMessage::Assistant {
1688                content: _,
1689                tool_calls,
1690            } = choice.message
1691            else {
1692                anyhow::bail!("Retrieval response didn't include an assistant message");
1693            };
1694
1695            let mut queries: Vec<SearchToolQuery> = Vec::new();
1696            for tool_call in tool_calls {
1697                let open_ai::ToolCallContent::Function { function } = tool_call.content;
1698                if function.name != cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME {
1699                    log::warn!(
1700                        "Context retrieval response tried to call an unknown tool: {}",
1701                        function.name
1702                    );
1703
1704                    continue;
1705                }
1706
1707                let input: SearchToolInput = serde_json::from_str(&function.arguments)
1708                    .with_context(|| format!("invalid search json {}", &function.arguments))?;
1709                queries.extend(input.queries);
1710            }
1711
1712            if let Some(debug_tx) = &debug_tx {
1713                debug_tx
1714                    .unbounded_send(ZetaDebugInfo::SearchQueriesGenerated(
1715                        ZetaSearchQueryDebugInfo {
1716                            project: project.clone(),
1717                            timestamp: Instant::now(),
1718                            search_queries: queries.clone(),
1719                        },
1720                    ))
1721                    .ok();
1722            }
1723
1724            log::trace!("Running retrieval search: {queries:#?}");
1725
1726            let related_excerpts_result = retrieval_search::run_retrieval_searches(
1727                queries,
1728                project.clone(),
1729                #[cfg(feature = "eval-support")]
1730                eval_cache,
1731                cx,
1732            )
1733            .await;
1734
1735            log::trace!("Search queries executed");
1736
1737            if let Some(debug_tx) = &debug_tx {
1738                debug_tx
1739                    .unbounded_send(ZetaDebugInfo::SearchQueriesExecuted(
1740                        ZetaContextRetrievalDebugInfo {
1741                            project: project.clone(),
1742                            timestamp: Instant::now(),
1743                        },
1744                    ))
1745                    .ok();
1746            }
1747
1748            this.update(cx, |this, _cx| {
1749                let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
1750                    return Ok(());
1751                };
1752                zeta_project.refresh_context_task.take();
1753                if let Some(debug_tx) = &this.debug_tx {
1754                    debug_tx
1755                        .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished(
1756                            ZetaContextRetrievalDebugInfo {
1757                                project,
1758                                timestamp: Instant::now(),
1759                            },
1760                        ))
1761                        .ok();
1762                }
1763                match related_excerpts_result {
1764                    Ok(excerpts) => {
1765                        zeta_project.context = Some(excerpts);
1766                        Ok(())
1767                    }
1768                    Err(error) => Err(error),
1769                }
1770            })?
1771        })
1772    }
1773
1774    pub fn set_context(
1775        &mut self,
1776        project: Entity<Project>,
1777        context: HashMap<Entity<Buffer>, Vec<Range<Anchor>>>,
1778    ) {
1779        if let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) {
1780            zeta_project.context = Some(context);
1781        }
1782    }
1783
1784    fn gather_nearby_diagnostics(
1785        cursor_offset: usize,
1786        diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
1787        snapshot: &BufferSnapshot,
1788        max_diagnostics_bytes: usize,
1789    ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
1790        // TODO: Could make this more efficient
1791        let mut diagnostic_groups = Vec::new();
1792        for (language_server_id, diagnostics) in diagnostic_sets {
1793            let mut groups = Vec::new();
1794            diagnostics.groups(*language_server_id, &mut groups, &snapshot);
1795            diagnostic_groups.extend(
1796                groups
1797                    .into_iter()
1798                    .map(|(_, group)| group.resolve::<usize>(&snapshot)),
1799            );
1800        }
1801
1802        // sort by proximity to cursor
1803        diagnostic_groups.sort_by_key(|group| {
1804            let range = &group.entries[group.primary_ix].range;
1805            if range.start >= cursor_offset {
1806                range.start - cursor_offset
1807            } else if cursor_offset >= range.end {
1808                cursor_offset - range.end
1809            } else {
1810                (cursor_offset - range.start).min(range.end - cursor_offset)
1811            }
1812        });
1813
1814        let mut results = Vec::new();
1815        let mut diagnostic_groups_truncated = false;
1816        let mut diagnostics_byte_count = 0;
1817        for group in diagnostic_groups {
1818            let raw_value = serde_json::value::to_raw_value(&group).unwrap();
1819            diagnostics_byte_count += raw_value.get().len();
1820            if diagnostics_byte_count > max_diagnostics_bytes {
1821                diagnostic_groups_truncated = true;
1822                break;
1823            }
1824            results.push(predict_edits_v3::DiagnosticGroup(raw_value));
1825        }
1826
1827        (results, diagnostic_groups_truncated)
1828    }
1829
1830    // TODO: Dedupe with similar code in request_prediction?
1831    pub fn cloud_request_for_zeta_cli(
1832        &mut self,
1833        project: &Entity<Project>,
1834        buffer: &Entity<Buffer>,
1835        position: language::Anchor,
1836        cx: &mut Context<Self>,
1837    ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
1838        let project_state = self.projects.get(&project.entity_id());
1839
1840        let index_state = project_state.and_then(|state| {
1841            state
1842                .syntax_index
1843                .as_ref()
1844                .map(|index| index.read_with(cx, |index, _cx| index.state().clone()))
1845        });
1846        let options = self.options.clone();
1847        let snapshot = buffer.read(cx).snapshot();
1848        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
1849            return Task::ready(Err(anyhow!("No file path for excerpt")));
1850        };
1851        let worktree_snapshots = project
1852            .read(cx)
1853            .worktrees(cx)
1854            .map(|worktree| worktree.read(cx).snapshot())
1855            .collect::<Vec<_>>();
1856
1857        let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
1858            let mut path = f.worktree.read(cx).absolutize(&f.path);
1859            if path.pop() { Some(path) } else { None }
1860        });
1861
1862        cx.background_spawn(async move {
1863            let index_state = if let Some(index_state) = index_state {
1864                Some(index_state.lock_owned().await)
1865            } else {
1866                None
1867            };
1868
1869            let cursor_point = position.to_point(&snapshot);
1870
1871            let debug_info = true;
1872            EditPredictionContext::gather_context(
1873                cursor_point,
1874                &snapshot,
1875                parent_abs_path.as_deref(),
1876                match &options.context {
1877                    ContextMode::Agentic(_) => {
1878                        // TODO
1879                        panic!("Llm mode not supported in zeta cli yet");
1880                    }
1881                    ContextMode::Syntax(edit_prediction_context_options) => {
1882                        edit_prediction_context_options
1883                    }
1884                },
1885                index_state.as_deref(),
1886            )
1887            .context("Failed to select excerpt")
1888            .map(|context| {
1889                make_syntax_context_cloud_request(
1890                    excerpt_path.into(),
1891                    context,
1892                    // TODO pass everything
1893                    Vec::new(),
1894                    false,
1895                    Vec::new(),
1896                    false,
1897                    None,
1898                    debug_info,
1899                    &worktree_snapshots,
1900                    index_state.as_deref(),
1901                    Some(options.max_prompt_bytes),
1902                    options.prompt_format,
1903                )
1904            })
1905        })
1906    }
1907
1908    pub fn wait_for_initial_indexing(
1909        &mut self,
1910        project: &Entity<Project>,
1911        cx: &mut Context<Self>,
1912    ) -> Task<Result<()>> {
1913        let zeta_project = self.get_or_init_zeta_project(project, cx);
1914        if let Some(syntax_index) = &zeta_project.syntax_index {
1915            syntax_index.read(cx).wait_for_initial_file_indexing(cx)
1916        } else {
1917            Task::ready(Ok(()))
1918        }
1919    }
1920}
1921
1922pub fn text_from_response(mut res: open_ai::Response) -> Option<String> {
1923    let choice = res.choices.pop()?;
1924    let output_text = match choice.message {
1925        open_ai::RequestMessage::Assistant {
1926            content: Some(open_ai::MessageContent::Plain(content)),
1927            ..
1928        } => content,
1929        open_ai::RequestMessage::Assistant {
1930            content: Some(open_ai::MessageContent::Multipart(mut content)),
1931            ..
1932        } => {
1933            if content.is_empty() {
1934                log::error!("No output from Baseten completion response");
1935                return None;
1936            }
1937
1938            match content.remove(0) {
1939                open_ai::MessagePart::Text { text } => text,
1940                open_ai::MessagePart::Image { .. } => {
1941                    log::error!("Expected text, got an image");
1942                    return None;
1943                }
1944            }
1945        }
1946        _ => {
1947            log::error!("Invalid response message: {:?}", choice.message);
1948            return None;
1949        }
1950    };
1951    Some(output_text)
1952}
1953
1954#[derive(Error, Debug)]
1955#[error(
1956    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1957)]
1958pub struct ZedUpdateRequiredError {
1959    minimum_version: SemanticVersion,
1960}
1961
1962fn make_syntax_context_cloud_request(
1963    excerpt_path: Arc<Path>,
1964    context: EditPredictionContext,
1965    events: Vec<predict_edits_v3::Event>,
1966    can_collect_data: bool,
1967    diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
1968    diagnostic_groups_truncated: bool,
1969    git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
1970    debug_info: bool,
1971    worktrees: &Vec<worktree::Snapshot>,
1972    index_state: Option<&SyntaxIndexState>,
1973    prompt_max_bytes: Option<usize>,
1974    prompt_format: PromptFormat,
1975) -> predict_edits_v3::PredictEditsRequest {
1976    let mut signatures = Vec::new();
1977    let mut declaration_to_signature_index = HashMap::default();
1978    let mut referenced_declarations = Vec::new();
1979
1980    for snippet in context.declarations {
1981        let project_entry_id = snippet.declaration.project_entry_id();
1982        let Some(path) = worktrees.iter().find_map(|worktree| {
1983            worktree.entry_for_id(project_entry_id).map(|entry| {
1984                let mut full_path = RelPathBuf::new();
1985                full_path.push(worktree.root_name());
1986                full_path.push(&entry.path);
1987                full_path
1988            })
1989        }) else {
1990            continue;
1991        };
1992
1993        let parent_index = index_state.and_then(|index_state| {
1994            snippet.declaration.parent().and_then(|parent| {
1995                add_signature(
1996                    parent,
1997                    &mut declaration_to_signature_index,
1998                    &mut signatures,
1999                    index_state,
2000                )
2001            })
2002        });
2003
2004        let (text, text_is_truncated) = snippet.declaration.item_text();
2005        referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
2006            path: path.as_std_path().into(),
2007            text: text.into(),
2008            range: snippet.declaration.item_line_range(),
2009            text_is_truncated,
2010            signature_range: snippet.declaration.signature_range_in_item_text(),
2011            parent_index,
2012            signature_score: snippet.score(DeclarationStyle::Signature),
2013            declaration_score: snippet.score(DeclarationStyle::Declaration),
2014            score_components: snippet.components,
2015        });
2016    }
2017
2018    let excerpt_parent = index_state.and_then(|index_state| {
2019        context
2020            .excerpt
2021            .parent_declarations
2022            .last()
2023            .and_then(|(parent, _)| {
2024                add_signature(
2025                    *parent,
2026                    &mut declaration_to_signature_index,
2027                    &mut signatures,
2028                    index_state,
2029                )
2030            })
2031    });
2032
2033    predict_edits_v3::PredictEditsRequest {
2034        excerpt_path,
2035        excerpt: context.excerpt_text.body,
2036        excerpt_line_range: context.excerpt.line_range,
2037        excerpt_range: context.excerpt.range,
2038        cursor_point: predict_edits_v3::Point {
2039            line: predict_edits_v3::Line(context.cursor_point.row),
2040            column: context.cursor_point.column,
2041        },
2042        referenced_declarations,
2043        included_files: vec![],
2044        signatures,
2045        excerpt_parent,
2046        events,
2047        can_collect_data,
2048        diagnostic_groups,
2049        diagnostic_groups_truncated,
2050        git_info,
2051        debug_info,
2052        prompt_max_bytes,
2053        prompt_format,
2054    }
2055}
2056
2057fn add_signature(
2058    declaration_id: DeclarationId,
2059    declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
2060    signatures: &mut Vec<Signature>,
2061    index: &SyntaxIndexState,
2062) -> Option<usize> {
2063    if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
2064        return Some(*signature_index);
2065    }
2066    let Some(parent_declaration) = index.declaration(declaration_id) else {
2067        log::error!("bug: missing parent declaration");
2068        return None;
2069    };
2070    let parent_index = parent_declaration.parent().and_then(|parent| {
2071        add_signature(parent, declaration_to_signature_index, signatures, index)
2072    });
2073    let (text, text_is_truncated) = parent_declaration.signature_text();
2074    let signature_index = signatures.len();
2075    signatures.push(Signature {
2076        text: text.into(),
2077        text_is_truncated,
2078        parent_index,
2079        range: parent_declaration.signature_line_range(),
2080    });
2081    declaration_to_signature_index.insert(declaration_id, signature_index);
2082    Some(signature_index)
2083}
2084
2085#[cfg(feature = "eval-support")]
2086pub type EvalCacheKey = (EvalCacheEntryKind, u64);
2087
2088#[cfg(feature = "eval-support")]
2089#[derive(Debug, Clone, Copy, PartialEq)]
2090pub enum EvalCacheEntryKind {
2091    Context,
2092    Search,
2093    Prediction,
2094}
2095
2096#[cfg(feature = "eval-support")]
2097impl std::fmt::Display for EvalCacheEntryKind {
2098    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2099        match self {
2100            EvalCacheEntryKind::Search => write!(f, "search"),
2101            EvalCacheEntryKind::Context => write!(f, "context"),
2102            EvalCacheEntryKind::Prediction => write!(f, "prediction"),
2103        }
2104    }
2105}
2106
2107#[cfg(feature = "eval-support")]
2108pub trait EvalCache: Send + Sync {
2109    fn read(&self, key: EvalCacheKey) -> Option<String>;
2110    fn write(&self, key: EvalCacheKey, input: &str, value: &str);
2111}
2112
2113#[cfg(test)]
2114mod tests {
2115    use std::{path::Path, sync::Arc};
2116
2117    use client::UserStore;
2118    use clock::FakeSystemClock;
2119    use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
2120    use futures::{
2121        AsyncReadExt, StreamExt,
2122        channel::{mpsc, oneshot},
2123    };
2124    use gpui::{
2125        Entity, TestAppContext,
2126        http_client::{FakeHttpClient, Response},
2127        prelude::*,
2128    };
2129    use indoc::indoc;
2130    use language::OffsetRangeExt as _;
2131    use open_ai::Usage;
2132    use pretty_assertions::{assert_eq, assert_matches};
2133    use project::{FakeFs, Project};
2134    use serde_json::json;
2135    use settings::SettingsStore;
2136    use util::path;
2137    use uuid::Uuid;
2138
2139    use crate::{BufferEditPrediction, Zeta};
2140
2141    #[gpui::test]
2142    async fn test_current_state(cx: &mut TestAppContext) {
2143        let (zeta, mut req_rx) = init_test(cx);
2144        let fs = FakeFs::new(cx.executor());
2145        fs.insert_tree(
2146            "/root",
2147            json!({
2148                "1.txt": "Hello!\nHow\nBye\n",
2149                "2.txt": "Hola!\nComo\nAdios\n"
2150            }),
2151        )
2152        .await;
2153        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2154
2155        zeta.update(cx, |zeta, cx| {
2156            zeta.register_project(&project, cx);
2157        });
2158
2159        let buffer1 = project
2160            .update(cx, |project, cx| {
2161                let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
2162                project.open_buffer(path, cx)
2163            })
2164            .await
2165            .unwrap();
2166        let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
2167        let position = snapshot1.anchor_before(language::Point::new(1, 3));
2168
2169        // Prediction for current file
2170
2171        let prediction_task = zeta.update(cx, |zeta, cx| {
2172            zeta.refresh_prediction(&project, &buffer1, position, cx)
2173        });
2174        let (_request, respond_tx) = req_rx.next().await.unwrap();
2175
2176        respond_tx
2177            .send(model_response(indoc! {r"
2178                --- a/root/1.txt
2179                +++ b/root/1.txt
2180                @@ ... @@
2181                 Hello!
2182                -How
2183                +How are you?
2184                 Bye
2185            "}))
2186            .unwrap();
2187        prediction_task.await.unwrap();
2188
2189        zeta.read_with(cx, |zeta, cx| {
2190            let prediction = zeta
2191                .current_prediction_for_buffer(&buffer1, &project, cx)
2192                .unwrap();
2193            assert_matches!(prediction, BufferEditPrediction::Local { .. });
2194        });
2195
2196        // Context refresh
2197        let refresh_task = zeta.update(cx, |zeta, cx| {
2198            zeta.refresh_context(project.clone(), buffer1.clone(), position, cx)
2199        });
2200        let (_request, respond_tx) = req_rx.next().await.unwrap();
2201        respond_tx
2202            .send(open_ai::Response {
2203                id: Uuid::new_v4().to_string(),
2204                object: "response".into(),
2205                created: 0,
2206                model: "model".into(),
2207                choices: vec![open_ai::Choice {
2208                    index: 0,
2209                    message: open_ai::RequestMessage::Assistant {
2210                        content: None,
2211                        tool_calls: vec![open_ai::ToolCall {
2212                            id: "search".into(),
2213                            content: open_ai::ToolCallContent::Function {
2214                                function: open_ai::FunctionContent {
2215                                    name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME
2216                                        .to_string(),
2217                                    arguments: serde_json::to_string(&SearchToolInput {
2218                                        queries: Box::new([SearchToolQuery {
2219                                            glob: "root/2.txt".to_string(),
2220                                            syntax_node: vec![],
2221                                            content: Some(".".into()),
2222                                        }]),
2223                                    })
2224                                    .unwrap(),
2225                                },
2226                            },
2227                        }],
2228                    },
2229                    finish_reason: None,
2230                }],
2231                usage: Usage {
2232                    prompt_tokens: 0,
2233                    completion_tokens: 0,
2234                    total_tokens: 0,
2235                },
2236            })
2237            .unwrap();
2238        refresh_task.await.unwrap();
2239
2240        zeta.update(cx, |zeta, _cx| {
2241            zeta.discard_current_prediction(&project);
2242        });
2243
2244        // Prediction for another file
2245        let prediction_task = zeta.update(cx, |zeta, cx| {
2246            zeta.refresh_prediction(&project, &buffer1, position, cx)
2247        });
2248        let (_request, respond_tx) = req_rx.next().await.unwrap();
2249        respond_tx
2250            .send(model_response(indoc! {r#"
2251                --- a/root/2.txt
2252                +++ b/root/2.txt
2253                 Hola!
2254                -Como
2255                +Como estas?
2256                 Adios
2257            "#}))
2258            .unwrap();
2259        prediction_task.await.unwrap();
2260        zeta.read_with(cx, |zeta, cx| {
2261            let prediction = zeta
2262                .current_prediction_for_buffer(&buffer1, &project, cx)
2263                .unwrap();
2264            assert_matches!(
2265                prediction,
2266                BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
2267            );
2268        });
2269
2270        let buffer2 = project
2271            .update(cx, |project, cx| {
2272                let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
2273                project.open_buffer(path, cx)
2274            })
2275            .await
2276            .unwrap();
2277
2278        zeta.read_with(cx, |zeta, cx| {
2279            let prediction = zeta
2280                .current_prediction_for_buffer(&buffer2, &project, cx)
2281                .unwrap();
2282            assert_matches!(prediction, BufferEditPrediction::Local { .. });
2283        });
2284    }
2285
2286    #[gpui::test]
2287    async fn test_simple_request(cx: &mut TestAppContext) {
2288        let (zeta, mut req_rx) = init_test(cx);
2289        let fs = FakeFs::new(cx.executor());
2290        fs.insert_tree(
2291            "/root",
2292            json!({
2293                "foo.md":  "Hello!\nHow\nBye\n"
2294            }),
2295        )
2296        .await;
2297        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2298
2299        let buffer = project
2300            .update(cx, |project, cx| {
2301                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2302                project.open_buffer(path, cx)
2303            })
2304            .await
2305            .unwrap();
2306        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2307        let position = snapshot.anchor_before(language::Point::new(1, 3));
2308
2309        let prediction_task = zeta.update(cx, |zeta, cx| {
2310            zeta.request_prediction(&project, &buffer, position, cx)
2311        });
2312
2313        let (_, respond_tx) = req_rx.next().await.unwrap();
2314
2315        // TODO Put back when we have a structured request again
2316        // assert_eq!(
2317        //     request.excerpt_path.as_ref(),
2318        //     Path::new(path!("root/foo.md"))
2319        // );
2320        // assert_eq!(
2321        //     request.cursor_point,
2322        //     Point {
2323        //         line: Line(1),
2324        //         column: 3
2325        //     }
2326        // );
2327
2328        respond_tx
2329            .send(model_response(indoc! { r"
2330                --- a/root/foo.md
2331                +++ b/root/foo.md
2332                @@ ... @@
2333                 Hello!
2334                -How
2335                +How are you?
2336                 Bye
2337            "}))
2338            .unwrap();
2339
2340        let prediction = prediction_task.await.unwrap().unwrap();
2341
2342        assert_eq!(prediction.edits.len(), 1);
2343        assert_eq!(
2344            prediction.edits[0].0.to_point(&snapshot).start,
2345            language::Point::new(1, 3)
2346        );
2347        assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
2348    }
2349
2350    #[gpui::test]
2351    async fn test_request_events(cx: &mut TestAppContext) {
2352        let (zeta, mut req_rx) = init_test(cx);
2353        let fs = FakeFs::new(cx.executor());
2354        fs.insert_tree(
2355            "/root",
2356            json!({
2357                "foo.md": "Hello!\n\nBye\n"
2358            }),
2359        )
2360        .await;
2361        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2362
2363        let buffer = project
2364            .update(cx, |project, cx| {
2365                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2366                project.open_buffer(path, cx)
2367            })
2368            .await
2369            .unwrap();
2370
2371        zeta.update(cx, |zeta, cx| {
2372            zeta.register_buffer(&buffer, &project, cx);
2373        });
2374
2375        buffer.update(cx, |buffer, cx| {
2376            buffer.edit(vec![(7..7, "How")], None, cx);
2377        });
2378
2379        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2380        let position = snapshot.anchor_before(language::Point::new(1, 3));
2381
2382        let prediction_task = zeta.update(cx, |zeta, cx| {
2383            zeta.request_prediction(&project, &buffer, position, cx)
2384        });
2385
2386        let (request, respond_tx) = req_rx.next().await.unwrap();
2387
2388        let prompt = prompt_from_request(&request);
2389        assert!(
2390            prompt.contains(indoc! {"
2391            --- a/root/foo.md
2392            +++ b/root/foo.md
2393            @@ -1,3 +1,3 @@
2394             Hello!
2395            -
2396            +How
2397             Bye
2398        "}),
2399            "{prompt}"
2400        );
2401
2402        respond_tx
2403            .send(model_response(indoc! {r#"
2404                --- a/root/foo.md
2405                +++ b/root/foo.md
2406                @@ ... @@
2407                 Hello!
2408                -How
2409                +How are you?
2410                 Bye
2411            "#}))
2412            .unwrap();
2413
2414        let prediction = prediction_task.await.unwrap().unwrap();
2415
2416        assert_eq!(prediction.edits.len(), 1);
2417        assert_eq!(
2418            prediction.edits[0].0.to_point(&snapshot).start,
2419            language::Point::new(1, 3)
2420        );
2421        assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
2422    }
2423
2424    // Skipped until we start including diagnostics in prompt
2425    // #[gpui::test]
2426    // async fn test_request_diagnostics(cx: &mut TestAppContext) {
2427    //     let (zeta, mut req_rx) = init_test(cx);
2428    //     let fs = FakeFs::new(cx.executor());
2429    //     fs.insert_tree(
2430    //         "/root",
2431    //         json!({
2432    //             "foo.md": "Hello!\nBye"
2433    //         }),
2434    //     )
2435    //     .await;
2436    //     let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2437
2438    //     let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
2439    //     let diagnostic = lsp::Diagnostic {
2440    //         range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
2441    //         severity: Some(lsp::DiagnosticSeverity::ERROR),
2442    //         message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
2443    //         ..Default::default()
2444    //     };
2445
2446    //     project.update(cx, |project, cx| {
2447    //         project.lsp_store().update(cx, |lsp_store, cx| {
2448    //             // Create some diagnostics
2449    //             lsp_store
2450    //                 .update_diagnostics(
2451    //                     LanguageServerId(0),
2452    //                     lsp::PublishDiagnosticsParams {
2453    //                         uri: path_to_buffer_uri.clone(),
2454    //                         diagnostics: vec![diagnostic],
2455    //                         version: None,
2456    //                     },
2457    //                     None,
2458    //                     language::DiagnosticSourceKind::Pushed,
2459    //                     &[],
2460    //                     cx,
2461    //                 )
2462    //                 .unwrap();
2463    //         });
2464    //     });
2465
2466    //     let buffer = project
2467    //         .update(cx, |project, cx| {
2468    //             let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2469    //             project.open_buffer(path, cx)
2470    //         })
2471    //         .await
2472    //         .unwrap();
2473
2474    //     let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2475    //     let position = snapshot.anchor_before(language::Point::new(0, 0));
2476
2477    //     let _prediction_task = zeta.update(cx, |zeta, cx| {
2478    //         zeta.request_prediction(&project, &buffer, position, cx)
2479    //     });
2480
2481    //     let (request, _respond_tx) = req_rx.next().await.unwrap();
2482
2483    //     assert_eq!(request.diagnostic_groups.len(), 1);
2484    //     let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
2485    //         .unwrap();
2486    //     // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
2487    //     assert_eq!(
2488    //         value,
2489    //         json!({
2490    //             "entries": [{
2491    //                 "range": {
2492    //                     "start": 8,
2493    //                     "end": 10
2494    //                 },
2495    //                 "diagnostic": {
2496    //                     "source": null,
2497    //                     "code": null,
2498    //                     "code_description": null,
2499    //                     "severity": 1,
2500    //                     "message": "\"Hello\" deprecated. Use \"Hi\" instead",
2501    //                     "markdown": null,
2502    //                     "group_id": 0,
2503    //                     "is_primary": true,
2504    //                     "is_disk_based": false,
2505    //                     "is_unnecessary": false,
2506    //                     "source_kind": "Pushed",
2507    //                     "data": null,
2508    //                     "underline": true
2509    //                 }
2510    //             }],
2511    //             "primary_ix": 0
2512    //         })
2513    //     );
2514    // }
2515
2516    fn model_response(text: &str) -> open_ai::Response {
2517        open_ai::Response {
2518            id: Uuid::new_v4().to_string(),
2519            object: "response".into(),
2520            created: 0,
2521            model: "model".into(),
2522            choices: vec![open_ai::Choice {
2523                index: 0,
2524                message: open_ai::RequestMessage::Assistant {
2525                    content: Some(open_ai::MessageContent::Plain(text.to_string())),
2526                    tool_calls: vec![],
2527                },
2528                finish_reason: None,
2529            }],
2530            usage: Usage {
2531                prompt_tokens: 0,
2532                completion_tokens: 0,
2533                total_tokens: 0,
2534            },
2535        }
2536    }
2537
2538    fn prompt_from_request(request: &open_ai::Request) -> &str {
2539        assert_eq!(request.messages.len(), 1);
2540        let open_ai::RequestMessage::User {
2541            content: open_ai::MessageContent::Plain(content),
2542            ..
2543        } = &request.messages[0]
2544        else {
2545            panic!(
2546                "Request does not have single user message of type Plain. {:#?}",
2547                request
2548            );
2549        };
2550        content
2551    }
2552
2553    fn init_test(
2554        cx: &mut TestAppContext,
2555    ) -> (
2556        Entity<Zeta>,
2557        mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender<open_ai::Response>)>,
2558    ) {
2559        cx.update(move |cx| {
2560            let settings_store = SettingsStore::test(cx);
2561            cx.set_global(settings_store);
2562            zlog::init_test();
2563
2564            let (req_tx, req_rx) = mpsc::unbounded();
2565
2566            let http_client = FakeHttpClient::create({
2567                move |req| {
2568                    let uri = req.uri().path().to_string();
2569                    let mut body = req.into_body();
2570                    let req_tx = req_tx.clone();
2571                    async move {
2572                        let resp = match uri.as_str() {
2573                            "/client/llm_tokens" => serde_json::to_string(&json!({
2574                                "token": "test"
2575                            }))
2576                            .unwrap(),
2577                            "/predict_edits/raw" => {
2578                                let mut buf = Vec::new();
2579                                body.read_to_end(&mut buf).await.ok();
2580                                let req = serde_json::from_slice(&buf).unwrap();
2581
2582                                let (res_tx, res_rx) = oneshot::channel();
2583                                req_tx.unbounded_send((req, res_tx)).unwrap();
2584                                serde_json::to_string(&res_rx.await?).unwrap()
2585                            }
2586                            _ => {
2587                                panic!("Unexpected path: {}", uri)
2588                            }
2589                        };
2590
2591                        Ok(Response::builder().body(resp.into()).unwrap())
2592                    }
2593                }
2594            });
2595
2596            let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
2597            client.cloud_client().set_credentials(1, "test".into());
2598
2599            language_model::init(client.clone(), cx);
2600
2601            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2602            let zeta = Zeta::global(&client, &user_store, cx);
2603
2604            (zeta, req_rx)
2605        })
2606    }
2607}