zeta2.rs

   1use anyhow::{Context as _, Result, anyhow, bail};
   2use chrono::TimeDelta;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
   5use cloud_llm_client::{
   6    AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
   7    ZED_VERSION_HEADER_NAME,
   8};
   9use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
  10use cloud_zeta2_prompt::{CURSOR_MARKER, DEFAULT_MAX_PROMPT_BYTES};
  11use collections::HashMap;
  12use edit_prediction_context::{
  13    DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
  14    EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionScoreOptions, Line,
  15    SyntaxIndex, SyntaxIndexState,
  16};
  17use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
  18use futures::AsyncReadExt as _;
  19use futures::channel::{mpsc, oneshot};
  20use gpui::http_client::{AsyncBody, Method};
  21use gpui::{
  22    App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
  23    http_client, prelude::*,
  24};
  25use language::{Anchor, Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
  26use language::{BufferSnapshot, OffsetRangeExt};
  27use language_model::{LlmApiToken, RefreshLlmTokenListener};
  28use open_ai::FunctionDefinition;
  29use project::Project;
  30use release_channel::AppVersion;
  31use serde::de::DeserializeOwned;
  32use std::collections::{VecDeque, hash_map};
  33
  34use std::env;
  35use std::ops::Range;
  36use std::path::Path;
  37use std::str::FromStr as _;
  38use std::sync::{Arc, LazyLock};
  39use std::time::{Duration, Instant};
  40use thiserror::Error;
  41use util::rel_path::RelPathBuf;
  42use util::{LogErrorFuture, TryFutureExt};
  43use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  44
  45pub mod merge_excerpts;
  46mod prediction;
  47mod provider;
  48pub mod retrieval_search;
  49pub mod udiff;
  50mod xml_edits;
  51
  52use crate::merge_excerpts::merge_excerpts;
  53use crate::prediction::EditPrediction;
  54pub use crate::prediction::EditPredictionId;
  55pub use provider::ZetaEditPredictionProvider;
  56
  57/// Maximum number of events to track.
  58const MAX_EVENT_COUNT: usize = 16;
  59
  60pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
  61    max_bytes: 512,
  62    min_bytes: 128,
  63    target_before_cursor_over_total_bytes: 0.5,
  64};
  65
  66pub const DEFAULT_CONTEXT_OPTIONS: ContextMode =
  67    ContextMode::Agentic(DEFAULT_AGENTIC_CONTEXT_OPTIONS);
  68
  69pub const DEFAULT_AGENTIC_CONTEXT_OPTIONS: AgenticContextOptions = AgenticContextOptions {
  70    excerpt: DEFAULT_EXCERPT_OPTIONS,
  71};
  72
  73pub const DEFAULT_SYNTAX_CONTEXT_OPTIONS: EditPredictionContextOptions =
  74    EditPredictionContextOptions {
  75        use_imports: true,
  76        max_retrieved_declarations: 0,
  77        excerpt: DEFAULT_EXCERPT_OPTIONS,
  78        score: EditPredictionScoreOptions {
  79            omit_excerpt_overlaps: true,
  80        },
  81    };
  82
  83pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
  84    context: DEFAULT_CONTEXT_OPTIONS,
  85    max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
  86    max_diagnostic_bytes: 2048,
  87    prompt_format: PromptFormat::DEFAULT,
  88    file_indexing_parallelism: 1,
  89    buffer_change_grouping_interval: Duration::from_secs(1),
  90};
  91
  92static USE_OLLAMA: LazyLock<bool> =
  93    LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
  94static MODEL_ID: LazyLock<String> = LazyLock::new(|| {
  95    env::var("ZED_ZETA2_MODEL").unwrap_or(if *USE_OLLAMA {
  96        "qwen3-coder:30b".to_string()
  97    } else {
  98        "yqvev8r3".to_string()
  99    })
 100});
 101static PREDICT_EDITS_URL: LazyLock<Option<String>> = LazyLock::new(|| {
 102    env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| {
 103        if *USE_OLLAMA {
 104            Some("http://localhost:11434/v1/chat/completions".into())
 105        } else {
 106            None
 107        }
 108    })
 109});
 110
 111pub struct Zeta2FeatureFlag;
 112
 113impl FeatureFlag for Zeta2FeatureFlag {
 114    const NAME: &'static str = "zeta2";
 115
 116    fn enabled_for_staff() -> bool {
 117        false
 118    }
 119}
 120
 121#[derive(Clone)]
 122struct ZetaGlobal(Entity<Zeta>);
 123
 124impl Global for ZetaGlobal {}
 125
 126pub struct Zeta {
 127    client: Arc<Client>,
 128    user_store: Entity<UserStore>,
 129    llm_token: LlmApiToken,
 130    _llm_token_subscription: Subscription,
 131    projects: HashMap<EntityId, ZetaProject>,
 132    options: ZetaOptions,
 133    update_required: bool,
 134    debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
 135    #[cfg(feature = "llm-response-cache")]
 136    llm_response_cache: Option<Arc<dyn LlmResponseCache>>,
 137}
 138
 139#[cfg(feature = "llm-response-cache")]
 140pub trait LlmResponseCache: Send + Sync {
 141    fn get_key(&self, url: &gpui::http_client::Url, body: &str) -> u64;
 142    fn read_response(&self, key: u64) -> Option<String>;
 143    fn write_response(&self, key: u64, value: &str);
 144}
 145
 146#[derive(Debug, Clone, PartialEq)]
 147pub struct ZetaOptions {
 148    pub context: ContextMode,
 149    pub max_prompt_bytes: usize,
 150    pub max_diagnostic_bytes: usize,
 151    pub prompt_format: predict_edits_v3::PromptFormat,
 152    pub file_indexing_parallelism: usize,
 153    pub buffer_change_grouping_interval: Duration,
 154}
 155
 156#[derive(Debug, Clone, PartialEq)]
 157pub enum ContextMode {
 158    Agentic(AgenticContextOptions),
 159    Syntax(EditPredictionContextOptions),
 160}
 161
 162#[derive(Debug, Clone, PartialEq)]
 163pub struct AgenticContextOptions {
 164    pub excerpt: EditPredictionExcerptOptions,
 165}
 166
 167impl ContextMode {
 168    pub fn excerpt(&self) -> &EditPredictionExcerptOptions {
 169        match self {
 170            ContextMode::Agentic(options) => &options.excerpt,
 171            ContextMode::Syntax(options) => &options.excerpt,
 172        }
 173    }
 174}
 175
 176#[derive(Debug)]
 177pub enum ZetaDebugInfo {
 178    ContextRetrievalStarted(ZetaContextRetrievalStartedDebugInfo),
 179    SearchQueriesGenerated(ZetaSearchQueryDebugInfo),
 180    SearchQueriesExecuted(ZetaContextRetrievalDebugInfo),
 181    ContextRetrievalFinished(ZetaContextRetrievalDebugInfo),
 182    EditPredictionRequested(ZetaEditPredictionDebugInfo),
 183}
 184
 185#[derive(Debug)]
 186pub struct ZetaContextRetrievalStartedDebugInfo {
 187    pub project: Entity<Project>,
 188    pub timestamp: Instant,
 189    pub search_prompt: String,
 190}
 191
 192#[derive(Debug)]
 193pub struct ZetaContextRetrievalDebugInfo {
 194    pub project: Entity<Project>,
 195    pub timestamp: Instant,
 196}
 197
 198#[derive(Debug)]
 199pub struct ZetaEditPredictionDebugInfo {
 200    pub request: predict_edits_v3::PredictEditsRequest,
 201    pub retrieval_time: TimeDelta,
 202    pub buffer: WeakEntity<Buffer>,
 203    pub position: language::Anchor,
 204    pub local_prompt: Result<String, String>,
 205    pub response_rx: oneshot::Receiver<(Result<open_ai::Response, String>, TimeDelta)>,
 206}
 207
 208#[derive(Debug)]
 209pub struct ZetaSearchQueryDebugInfo {
 210    pub project: Entity<Project>,
 211    pub timestamp: Instant,
 212    pub search_queries: Vec<SearchToolQuery>,
 213}
 214
 215pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
 216
 217struct ZetaProject {
 218    syntax_index: Entity<SyntaxIndex>,
 219    events: VecDeque<Event>,
 220    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 221    current_prediction: Option<CurrentEditPrediction>,
 222    context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
 223    refresh_context_task: Option<LogErrorFuture<Task<Result<()>>>>,
 224    refresh_context_debounce_task: Option<Task<Option<()>>>,
 225    refresh_context_timestamp: Option<Instant>,
 226}
 227
 228#[derive(Debug, Clone)]
 229struct CurrentEditPrediction {
 230    pub requested_by_buffer_id: EntityId,
 231    pub prediction: EditPrediction,
 232}
 233
 234impl CurrentEditPrediction {
 235    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 236        let Some(new_edits) = self
 237            .prediction
 238            .interpolate(&self.prediction.buffer.read(cx))
 239        else {
 240            return false;
 241        };
 242
 243        if self.prediction.buffer != old_prediction.prediction.buffer {
 244            return true;
 245        }
 246
 247        let Some(old_edits) = old_prediction
 248            .prediction
 249            .interpolate(&old_prediction.prediction.buffer.read(cx))
 250        else {
 251            return true;
 252        };
 253
 254        // This reduces the occurrence of UI thrash from replacing edits
 255        //
 256        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 257        if self.requested_by_buffer_id == self.prediction.buffer.entity_id()
 258            && self.requested_by_buffer_id == old_prediction.prediction.buffer.entity_id()
 259            && old_edits.len() == 1
 260            && new_edits.len() == 1
 261        {
 262            let (old_range, old_text) = &old_edits[0];
 263            let (new_range, new_text) = &new_edits[0];
 264            new_range == old_range && new_text.starts_with(old_text.as_ref())
 265        } else {
 266            true
 267        }
 268    }
 269}
 270
 271/// A prediction from the perspective of a buffer.
 272#[derive(Debug)]
 273enum BufferEditPrediction<'a> {
 274    Local { prediction: &'a EditPrediction },
 275    Jump { prediction: &'a EditPrediction },
 276}
 277
 278struct RegisteredBuffer {
 279    snapshot: BufferSnapshot,
 280    _subscriptions: [gpui::Subscription; 2],
 281}
 282
 283#[derive(Clone)]
 284pub enum Event {
 285    BufferChange {
 286        old_snapshot: BufferSnapshot,
 287        new_snapshot: BufferSnapshot,
 288        timestamp: Instant,
 289    },
 290}
 291
 292impl Event {
 293    pub fn to_request_event(&self, cx: &App) -> Option<predict_edits_v3::Event> {
 294        match self {
 295            Event::BufferChange {
 296                old_snapshot,
 297                new_snapshot,
 298                ..
 299            } => {
 300                let path = new_snapshot.file().map(|f| f.full_path(cx));
 301
 302                let old_path = old_snapshot.file().and_then(|f| {
 303                    let old_path = f.full_path(cx);
 304                    if Some(&old_path) != path.as_ref() {
 305                        Some(old_path)
 306                    } else {
 307                        None
 308                    }
 309                });
 310
 311                // TODO [zeta2] move to bg?
 312                let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
 313
 314                if path == old_path && diff.is_empty() {
 315                    None
 316                } else {
 317                    Some(predict_edits_v3::Event::BufferChange {
 318                        old_path,
 319                        path,
 320                        diff,
 321                        //todo: Actually detect if this edit was predicted or not
 322                        predicted: false,
 323                    })
 324                }
 325            }
 326        }
 327    }
 328}
 329
 330impl Zeta {
 331    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 332        cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
 333    }
 334
 335    pub fn global(
 336        client: &Arc<Client>,
 337        user_store: &Entity<UserStore>,
 338        cx: &mut App,
 339    ) -> Entity<Self> {
 340        cx.try_global::<ZetaGlobal>()
 341            .map(|global| global.0.clone())
 342            .unwrap_or_else(|| {
 343                let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 344                cx.set_global(ZetaGlobal(zeta.clone()));
 345                zeta
 346            })
 347    }
 348
 349    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 350        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 351
 352        Self {
 353            projects: HashMap::default(),
 354            client,
 355            user_store,
 356            options: DEFAULT_OPTIONS,
 357            llm_token: LlmApiToken::default(),
 358            _llm_token_subscription: cx.subscribe(
 359                &refresh_llm_token_listener,
 360                |this, _listener, _event, cx| {
 361                    let client = this.client.clone();
 362                    let llm_token = this.llm_token.clone();
 363                    cx.spawn(async move |_this, _cx| {
 364                        llm_token.refresh(&client).await?;
 365                        anyhow::Ok(())
 366                    })
 367                    .detach_and_log_err(cx);
 368                },
 369            ),
 370            update_required: false,
 371            debug_tx: None,
 372            #[cfg(feature = "llm-response-cache")]
 373            llm_response_cache: None,
 374        }
 375    }
 376
 377    #[cfg(feature = "llm-response-cache")]
 378    pub fn with_llm_response_cache(&mut self, cache: Arc<dyn LlmResponseCache>) {
 379        self.llm_response_cache = Some(cache);
 380    }
 381
 382    pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<ZetaDebugInfo> {
 383        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
 384        self.debug_tx = Some(debug_watch_tx);
 385        debug_watch_rx
 386    }
 387
 388    pub fn options(&self) -> &ZetaOptions {
 389        &self.options
 390    }
 391
 392    pub fn set_options(&mut self, options: ZetaOptions) {
 393        self.options = options;
 394    }
 395
 396    pub fn clear_history(&mut self) {
 397        for zeta_project in self.projects.values_mut() {
 398            zeta_project.events.clear();
 399        }
 400    }
 401
 402    pub fn history_for_project(&self, project: &Entity<Project>) -> impl Iterator<Item = &Event> {
 403        self.projects
 404            .get(&project.entity_id())
 405            .map(|project| project.events.iter())
 406            .into_iter()
 407            .flatten()
 408    }
 409
 410    pub fn context_for_project(
 411        &self,
 412        project: &Entity<Project>,
 413    ) -> impl Iterator<Item = (Entity<Buffer>, &[Range<Anchor>])> {
 414        self.projects
 415            .get(&project.entity_id())
 416            .and_then(|project| {
 417                Some(
 418                    project
 419                        .context
 420                        .as_ref()?
 421                        .iter()
 422                        .map(|(buffer, ranges)| (buffer.clone(), ranges.as_slice())),
 423                )
 424            })
 425            .into_iter()
 426            .flatten()
 427    }
 428
 429    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 430        self.user_store.read(cx).edit_prediction_usage()
 431    }
 432
 433    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
 434        self.get_or_init_zeta_project(project, cx);
 435    }
 436
 437    pub fn register_buffer(
 438        &mut self,
 439        buffer: &Entity<Buffer>,
 440        project: &Entity<Project>,
 441        cx: &mut Context<Self>,
 442    ) {
 443        let zeta_project = self.get_or_init_zeta_project(project, cx);
 444        Self::register_buffer_impl(zeta_project, buffer, project, cx);
 445    }
 446
 447    fn get_or_init_zeta_project(
 448        &mut self,
 449        project: &Entity<Project>,
 450        cx: &mut App,
 451    ) -> &mut ZetaProject {
 452        self.projects
 453            .entry(project.entity_id())
 454            .or_insert_with(|| ZetaProject {
 455                syntax_index: cx.new(|cx| {
 456                    SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
 457                }),
 458                events: VecDeque::new(),
 459                registered_buffers: HashMap::default(),
 460                current_prediction: None,
 461                context: None,
 462                refresh_context_task: None,
 463                refresh_context_debounce_task: None,
 464                refresh_context_timestamp: None,
 465            })
 466    }
 467
 468    fn register_buffer_impl<'a>(
 469        zeta_project: &'a mut ZetaProject,
 470        buffer: &Entity<Buffer>,
 471        project: &Entity<Project>,
 472        cx: &mut Context<Self>,
 473    ) -> &'a mut RegisteredBuffer {
 474        let buffer_id = buffer.entity_id();
 475        match zeta_project.registered_buffers.entry(buffer_id) {
 476            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 477            hash_map::Entry::Vacant(entry) => {
 478                let snapshot = buffer.read(cx).snapshot();
 479                let project_entity_id = project.entity_id();
 480                entry.insert(RegisteredBuffer {
 481                    snapshot,
 482                    _subscriptions: [
 483                        cx.subscribe(buffer, {
 484                            let project = project.downgrade();
 485                            move |this, buffer, event, cx| {
 486                                if let language::BufferEvent::Edited = event
 487                                    && let Some(project) = project.upgrade()
 488                                {
 489                                    this.report_changes_for_buffer(&buffer, &project, cx);
 490                                }
 491                            }
 492                        }),
 493                        cx.observe_release(buffer, move |this, _buffer, _cx| {
 494                            let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
 495                            else {
 496                                return;
 497                            };
 498                            zeta_project.registered_buffers.remove(&buffer_id);
 499                        }),
 500                    ],
 501                })
 502            }
 503        }
 504    }
 505
 506    fn report_changes_for_buffer(
 507        &mut self,
 508        buffer: &Entity<Buffer>,
 509        project: &Entity<Project>,
 510        cx: &mut Context<Self>,
 511    ) -> BufferSnapshot {
 512        let buffer_change_grouping_interval = self.options.buffer_change_grouping_interval;
 513        let zeta_project = self.get_or_init_zeta_project(project, cx);
 514        let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
 515
 516        let new_snapshot = buffer.read(cx).snapshot();
 517        if new_snapshot.version != registered_buffer.snapshot.version {
 518            let old_snapshot =
 519                std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
 520            Self::push_event(
 521                zeta_project,
 522                buffer_change_grouping_interval,
 523                Event::BufferChange {
 524                    old_snapshot,
 525                    new_snapshot: new_snapshot.clone(),
 526                    timestamp: Instant::now(),
 527                },
 528            );
 529        }
 530
 531        new_snapshot
 532    }
 533
 534    fn push_event(
 535        zeta_project: &mut ZetaProject,
 536        buffer_change_grouping_interval: Duration,
 537        event: Event,
 538    ) {
 539        let events = &mut zeta_project.events;
 540
 541        if buffer_change_grouping_interval > Duration::ZERO
 542            && let Some(Event::BufferChange {
 543                new_snapshot: last_new_snapshot,
 544                timestamp: last_timestamp,
 545                ..
 546            }) = events.back_mut()
 547        {
 548            // Coalesce edits for the same buffer when they happen one after the other.
 549            let Event::BufferChange {
 550                old_snapshot,
 551                new_snapshot,
 552                timestamp,
 553            } = &event;
 554
 555            if timestamp.duration_since(*last_timestamp) <= buffer_change_grouping_interval
 556                && old_snapshot.remote_id() == last_new_snapshot.remote_id()
 557                && old_snapshot.version == last_new_snapshot.version
 558            {
 559                *last_new_snapshot = new_snapshot.clone();
 560                *last_timestamp = *timestamp;
 561                return;
 562            }
 563        }
 564
 565        if events.len() >= MAX_EVENT_COUNT {
 566            // These are halved instead of popping to improve prompt caching.
 567            events.drain(..MAX_EVENT_COUNT / 2);
 568        }
 569
 570        events.push_back(event);
 571    }
 572
 573    fn current_prediction_for_buffer(
 574        &self,
 575        buffer: &Entity<Buffer>,
 576        project: &Entity<Project>,
 577        cx: &App,
 578    ) -> Option<BufferEditPrediction<'_>> {
 579        let project_state = self.projects.get(&project.entity_id())?;
 580
 581        let CurrentEditPrediction {
 582            requested_by_buffer_id,
 583            prediction,
 584        } = project_state.current_prediction.as_ref()?;
 585
 586        if prediction.targets_buffer(buffer.read(cx)) {
 587            Some(BufferEditPrediction::Local { prediction })
 588        } else if *requested_by_buffer_id == buffer.entity_id() {
 589            Some(BufferEditPrediction::Jump { prediction })
 590        } else {
 591            None
 592        }
 593    }
 594
 595    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 596        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
 597            return;
 598        };
 599
 600        let Some(prediction) = project_state.current_prediction.take() else {
 601            return;
 602        };
 603        let request_id = prediction.prediction.id.to_string();
 604
 605        let client = self.client.clone();
 606        let llm_token = self.llm_token.clone();
 607        let app_version = AppVersion::global(cx);
 608        cx.spawn(async move |this, cx| {
 609            let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") {
 610                http_client::Url::parse(&predict_edits_url)?
 611            } else {
 612                client
 613                    .http_client()
 614                    .build_zed_llm_url("/predict_edits/accept", &[])?
 615            };
 616
 617            let response = cx
 618                .background_spawn(Self::send_api_request::<()>(
 619                    move |builder| {
 620                        let req = builder.uri(url.as_ref()).body(
 621                            serde_json::to_string(&AcceptEditPredictionBody {
 622                                request_id: request_id.clone(),
 623                            })?
 624                            .into(),
 625                        );
 626                        Ok(req?)
 627                    },
 628                    client,
 629                    llm_token,
 630                    app_version,
 631                ))
 632                .await;
 633
 634            Self::handle_api_response(&this, response, cx)?;
 635            anyhow::Ok(())
 636        })
 637        .detach_and_log_err(cx);
 638    }
 639
 640    fn discard_current_prediction(&mut self, project: &Entity<Project>) {
 641        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 642            project_state.current_prediction.take();
 643        };
 644    }
 645
 646    pub fn refresh_prediction(
 647        &mut self,
 648        project: &Entity<Project>,
 649        buffer: &Entity<Buffer>,
 650        position: language::Anchor,
 651        cx: &mut Context<Self>,
 652    ) -> Task<Result<()>> {
 653        let request_task = self.request_prediction(project, buffer, position, cx);
 654        let buffer = buffer.clone();
 655        let project = project.clone();
 656
 657        cx.spawn(async move |this, cx| {
 658            if let Some(prediction) = request_task.await? {
 659                this.update(cx, |this, cx| {
 660                    let project_state = this
 661                        .projects
 662                        .get_mut(&project.entity_id())
 663                        .context("Project not found")?;
 664
 665                    let new_prediction = CurrentEditPrediction {
 666                        requested_by_buffer_id: buffer.entity_id(),
 667                        prediction: prediction,
 668                    };
 669
 670                    if project_state
 671                        .current_prediction
 672                        .as_ref()
 673                        .is_none_or(|old_prediction| {
 674                            new_prediction.should_replace_prediction(&old_prediction, cx)
 675                        })
 676                    {
 677                        project_state.current_prediction = Some(new_prediction);
 678                    }
 679                    anyhow::Ok(())
 680                })??;
 681            }
 682            Ok(())
 683        })
 684    }
 685
 686    pub fn request_prediction(
 687        &mut self,
 688        project: &Entity<Project>,
 689        active_buffer: &Entity<Buffer>,
 690        position: language::Anchor,
 691        cx: &mut Context<Self>,
 692    ) -> Task<Result<Option<EditPrediction>>> {
 693        let project_state = self.projects.get(&project.entity_id());
 694
 695        let index_state = project_state.map(|state| {
 696            state
 697                .syntax_index
 698                .read_with(cx, |index, _cx| index.state().clone())
 699        });
 700        let options = self.options.clone();
 701        let active_snapshot = active_buffer.read(cx).snapshot();
 702        let Some(excerpt_path) = active_snapshot
 703            .file()
 704            .map(|path| -> Arc<Path> { path.full_path(cx).into() })
 705        else {
 706            return Task::ready(Err(anyhow!("No file path for excerpt")));
 707        };
 708        let client = self.client.clone();
 709        let llm_token = self.llm_token.clone();
 710        let app_version = AppVersion::global(cx);
 711        let worktree_snapshots = project
 712            .read(cx)
 713            .worktrees(cx)
 714            .map(|worktree| worktree.read(cx).snapshot())
 715            .collect::<Vec<_>>();
 716        let debug_tx = self.debug_tx.clone();
 717
 718        let events = project_state
 719            .map(|state| {
 720                state
 721                    .events
 722                    .iter()
 723                    .filter_map(|event| event.to_request_event(cx))
 724                    .collect::<Vec<_>>()
 725            })
 726            .unwrap_or_default();
 727
 728        let diagnostics = active_snapshot.diagnostic_sets().clone();
 729
 730        let parent_abs_path =
 731            project::File::from_dyn(active_buffer.read(cx).file()).and_then(|f| {
 732                let mut path = f.worktree.read(cx).absolutize(&f.path);
 733                if path.pop() { Some(path) } else { None }
 734            });
 735
 736        // TODO data collection
 737        let can_collect_data = cx.is_staff();
 738
 739        let mut included_files = project_state
 740            .and_then(|project_state| project_state.context.as_ref())
 741            .unwrap_or(&HashMap::default())
 742            .iter()
 743            .filter_map(|(buffer_entity, ranges)| {
 744                let buffer = buffer_entity.read(cx);
 745                Some((
 746                    buffer_entity.clone(),
 747                    buffer.snapshot(),
 748                    buffer.file()?.full_path(cx).into(),
 749                    ranges.clone(),
 750                ))
 751            })
 752            .collect::<Vec<_>>();
 753
 754        #[cfg(feature = "llm-response-cache")]
 755        let llm_response_cache = self.llm_response_cache.clone();
 756
 757        let request_task = cx.background_spawn({
 758            let active_buffer = active_buffer.clone();
 759            async move {
 760                let index_state = if let Some(index_state) = index_state {
 761                    Some(index_state.lock_owned().await)
 762                } else {
 763                    None
 764                };
 765
 766                let cursor_offset = position.to_offset(&active_snapshot);
 767                let cursor_point = cursor_offset.to_point(&active_snapshot);
 768
 769                let before_retrieval = chrono::Utc::now();
 770
 771                let (diagnostic_groups, diagnostic_groups_truncated) =
 772                    Self::gather_nearby_diagnostics(
 773                        cursor_offset,
 774                        &diagnostics,
 775                        &active_snapshot,
 776                        options.max_diagnostic_bytes,
 777                    );
 778
 779                let cloud_request = match options.context {
 780                    ContextMode::Agentic(context_options) => {
 781                        let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
 782                            cursor_point,
 783                            &active_snapshot,
 784                            &context_options.excerpt,
 785                            index_state.as_deref(),
 786                        ) else {
 787                            return Ok((None, None));
 788                        };
 789
 790                        let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start)
 791                            ..active_snapshot.anchor_before(excerpt.range.end);
 792
 793                        if let Some(buffer_ix) =
 794                            included_files.iter().position(|(_, snapshot, _, _)| {
 795                                snapshot.remote_id() == active_snapshot.remote_id()
 796                            })
 797                        {
 798                            let (_, buffer, _, ranges) = &mut included_files[buffer_ix];
 799                            let range_ix = ranges
 800                                .binary_search_by(|probe| {
 801                                    probe
 802                                        .start
 803                                        .cmp(&excerpt_anchor_range.start, buffer)
 804                                        .then(excerpt_anchor_range.end.cmp(&probe.end, buffer))
 805                                })
 806                                .unwrap_or_else(|ix| ix);
 807
 808                            ranges.insert(range_ix, excerpt_anchor_range);
 809                            let last_ix = included_files.len() - 1;
 810                            included_files.swap(buffer_ix, last_ix);
 811                        } else {
 812                            included_files.push((
 813                                active_buffer.clone(),
 814                                active_snapshot,
 815                                excerpt_path.clone(),
 816                                vec![excerpt_anchor_range],
 817                            ));
 818                        }
 819
 820                        let included_files = included_files
 821                            .iter()
 822                            .map(|(_, buffer, path, ranges)| {
 823                                let excerpts = merge_excerpts(
 824                                    &buffer,
 825                                    ranges.iter().map(|range| {
 826                                        let point_range = range.to_point(&buffer);
 827                                        Line(point_range.start.row)..Line(point_range.end.row)
 828                                    }),
 829                                );
 830                                predict_edits_v3::IncludedFile {
 831                                    path: path.clone(),
 832                                    max_row: Line(buffer.max_point().row),
 833                                    excerpts,
 834                                }
 835                            })
 836                            .collect::<Vec<_>>();
 837
 838                        predict_edits_v3::PredictEditsRequest {
 839                            excerpt_path,
 840                            excerpt: String::new(),
 841                            excerpt_line_range: Line(0)..Line(0),
 842                            excerpt_range: 0..0,
 843                            cursor_point: predict_edits_v3::Point {
 844                                line: predict_edits_v3::Line(cursor_point.row),
 845                                column: cursor_point.column,
 846                            },
 847                            included_files,
 848                            referenced_declarations: vec![],
 849                            events,
 850                            can_collect_data,
 851                            diagnostic_groups,
 852                            diagnostic_groups_truncated,
 853                            debug_info: debug_tx.is_some(),
 854                            prompt_max_bytes: Some(options.max_prompt_bytes),
 855                            prompt_format: options.prompt_format,
 856                            // TODO [zeta2]
 857                            signatures: vec![],
 858                            excerpt_parent: None,
 859                            git_info: None,
 860                        }
 861                    }
 862                    ContextMode::Syntax(context_options) => {
 863                        let Some(context) = EditPredictionContext::gather_context(
 864                            cursor_point,
 865                            &active_snapshot,
 866                            parent_abs_path.as_deref(),
 867                            &context_options,
 868                            index_state.as_deref(),
 869                        ) else {
 870                            return Ok((None, None));
 871                        };
 872
 873                        make_syntax_context_cloud_request(
 874                            excerpt_path,
 875                            context,
 876                            events,
 877                            can_collect_data,
 878                            diagnostic_groups,
 879                            diagnostic_groups_truncated,
 880                            None,
 881                            debug_tx.is_some(),
 882                            &worktree_snapshots,
 883                            index_state.as_deref(),
 884                            Some(options.max_prompt_bytes),
 885                            options.prompt_format,
 886                        )
 887                    }
 888                };
 889
 890                let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request);
 891
 892                let retrieval_time = chrono::Utc::now() - before_retrieval;
 893
 894                let debug_response_tx = if let Some(debug_tx) = &debug_tx {
 895                    let (response_tx, response_rx) = oneshot::channel();
 896
 897                    debug_tx
 898                        .unbounded_send(ZetaDebugInfo::EditPredictionRequested(
 899                            ZetaEditPredictionDebugInfo {
 900                                request: cloud_request.clone(),
 901                                retrieval_time,
 902                                buffer: active_buffer.downgrade(),
 903                                local_prompt: match prompt_result.as_ref() {
 904                                    Ok((prompt, _)) => Ok(prompt.clone()),
 905                                    Err(err) => Err(err.to_string()),
 906                                },
 907                                position,
 908                                response_rx,
 909                            },
 910                        ))
 911                        .ok();
 912                    Some(response_tx)
 913                } else {
 914                    None
 915                };
 916
 917                if cfg!(debug_assertions) && env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
 918                    if let Some(debug_response_tx) = debug_response_tx {
 919                        debug_response_tx
 920                            .send((Err("Request skipped".to_string()), TimeDelta::zero()))
 921                            .ok();
 922                    }
 923                    anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
 924                }
 925
 926                let (prompt, _) = prompt_result?;
 927                let request = open_ai::Request {
 928                    model: MODEL_ID.clone(),
 929                    messages: vec![open_ai::RequestMessage::User {
 930                        content: open_ai::MessageContent::Plain(prompt),
 931                    }],
 932                    stream: false,
 933                    max_completion_tokens: None,
 934                    stop: Default::default(),
 935                    temperature: 0.7,
 936                    tool_choice: None,
 937                    parallel_tool_calls: None,
 938                    tools: vec![],
 939                    prompt_cache_key: None,
 940                    reasoning_effort: None,
 941                };
 942
 943                log::trace!("Sending edit prediction request");
 944
 945                let before_request = chrono::Utc::now();
 946                let response = Self::send_raw_llm_request(
 947                    request,
 948                    client,
 949                    llm_token,
 950                    app_version,
 951                    #[cfg(feature = "llm-response-cache")]
 952                    llm_response_cache,
 953                )
 954                .await;
 955                let request_time = chrono::Utc::now() - before_request;
 956
 957                log::trace!("Got edit prediction response");
 958
 959                if let Some(debug_response_tx) = debug_response_tx {
 960                    debug_response_tx
 961                        .send((
 962                            response
 963                                .as_ref()
 964                                .map_err(|err| err.to_string())
 965                                .map(|response| response.0.clone()),
 966                            request_time,
 967                        ))
 968                        .ok();
 969                }
 970
 971                let (res, usage) = response?;
 972                let request_id = EditPredictionId(res.id.clone().into());
 973                let Some(mut output_text) = text_from_response(res) else {
 974                    return Ok((None, usage));
 975                };
 976
 977                if output_text.contains(CURSOR_MARKER) {
 978                    log::trace!("Stripping out {CURSOR_MARKER} from response");
 979                    output_text = output_text.replace(CURSOR_MARKER, "");
 980                }
 981
 982                let get_buffer_from_context = |path: &Path| {
 983                    included_files
 984                        .iter()
 985                        .find_map(|(_, buffer, probe_path, ranges)| {
 986                            if probe_path.as_ref() == path {
 987                                Some((buffer, ranges.as_slice()))
 988                            } else {
 989                                None
 990                            }
 991                        })
 992                };
 993
 994                let (edited_buffer_snapshot, edits) = match options.prompt_format {
 995                    PromptFormat::NumLinesUniDiff => {
 996                        crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
 997                    }
 998                    PromptFormat::OldTextNewText => {
 999                        crate::xml_edits::parse_xml_edits(&output_text, get_buffer_from_context)
1000                            .await?
1001                    }
1002                    _ => {
1003                        bail!("unsupported prompt format {}", options.prompt_format)
1004                    }
1005                };
1006
1007                let edited_buffer = included_files
1008                    .iter()
1009                    .find_map(|(buffer, snapshot, _, _)| {
1010                        if snapshot.remote_id() == edited_buffer_snapshot.remote_id() {
1011                            Some(buffer.clone())
1012                        } else {
1013                            None
1014                        }
1015                    })
1016                    .context("Failed to find buffer in included_buffers")?;
1017
1018                anyhow::Ok((
1019                    Some((
1020                        request_id,
1021                        edited_buffer,
1022                        edited_buffer_snapshot.clone(),
1023                        edits,
1024                    )),
1025                    usage,
1026                ))
1027            }
1028        });
1029
1030        cx.spawn({
1031            async move |this, cx| {
1032                let Some((id, edited_buffer, edited_buffer_snapshot, edits)) =
1033                    Self::handle_api_response(&this, request_task.await, cx)?
1034                else {
1035                    return Ok(None);
1036                };
1037
1038                // TODO telemetry: duration, etc
1039                Ok(
1040                    EditPrediction::new(id, &edited_buffer, &edited_buffer_snapshot, edits, cx)
1041                        .await,
1042                )
1043            }
1044        })
1045    }
1046
1047    async fn send_raw_llm_request(
1048        request: open_ai::Request,
1049        client: Arc<Client>,
1050        llm_token: LlmApiToken,
1051        app_version: SemanticVersion,
1052        #[cfg(feature = "llm-response-cache")] llm_response_cache: Option<
1053            Arc<dyn LlmResponseCache>,
1054        >,
1055    ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
1056        let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
1057            http_client::Url::parse(&predict_edits_url)?
1058        } else {
1059            client
1060                .http_client()
1061                .build_zed_llm_url("/predict_edits/raw", &[])?
1062        };
1063
1064        #[cfg(feature = "llm-response-cache")]
1065        let cache_key = if let Some(cache) = llm_response_cache {
1066            let request_json = serde_json::to_string(&request)?;
1067            let key = cache.get_key(&url, &request_json);
1068
1069            if let Some(response_str) = cache.read_response(key) {
1070                return Ok((serde_json::from_str(&response_str)?, None));
1071            }
1072
1073            Some((cache, key))
1074        } else {
1075            None
1076        };
1077
1078        let (response, usage) = Self::send_api_request(
1079            |builder| {
1080                let req = builder
1081                    .uri(url.as_ref())
1082                    .body(serde_json::to_string(&request)?.into());
1083                Ok(req?)
1084            },
1085            client,
1086            llm_token,
1087            app_version,
1088        )
1089        .await?;
1090
1091        #[cfg(feature = "llm-response-cache")]
1092        if let Some((cache, key)) = cache_key {
1093            cache.write_response(key, &serde_json::to_string(&response)?);
1094        }
1095
1096        Ok((response, usage))
1097    }
1098
1099    fn handle_api_response<T>(
1100        this: &WeakEntity<Self>,
1101        response: Result<(T, Option<EditPredictionUsage>)>,
1102        cx: &mut gpui::AsyncApp,
1103    ) -> Result<T> {
1104        match response {
1105            Ok((data, usage)) => {
1106                if let Some(usage) = usage {
1107                    this.update(cx, |this, cx| {
1108                        this.user_store.update(cx, |user_store, cx| {
1109                            user_store.update_edit_prediction_usage(usage, cx);
1110                        });
1111                    })
1112                    .ok();
1113                }
1114                Ok(data)
1115            }
1116            Err(err) => {
1117                if err.is::<ZedUpdateRequiredError>() {
1118                    cx.update(|cx| {
1119                        this.update(cx, |this, _cx| {
1120                            this.update_required = true;
1121                        })
1122                        .ok();
1123
1124                        let error_message: SharedString = err.to_string().into();
1125                        show_app_notification(
1126                            NotificationId::unique::<ZedUpdateRequiredError>(),
1127                            cx,
1128                            move |cx| {
1129                                cx.new(|cx| {
1130                                    ErrorMessagePrompt::new(error_message.clone(), cx)
1131                                        .with_link_button("Update Zed", "https://zed.dev/releases")
1132                                })
1133                            },
1134                        );
1135                    })
1136                    .ok();
1137                }
1138                Err(err)
1139            }
1140        }
1141    }
1142
1143    async fn send_api_request<Res>(
1144        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
1145        client: Arc<Client>,
1146        llm_token: LlmApiToken,
1147        app_version: SemanticVersion,
1148    ) -> Result<(Res, Option<EditPredictionUsage>)>
1149    where
1150        Res: DeserializeOwned,
1151    {
1152        let http_client = client.http_client();
1153        let mut token = llm_token.acquire(&client).await?;
1154        let mut did_retry = false;
1155
1156        loop {
1157            let request_builder = http_client::Request::builder().method(Method::POST);
1158
1159            let request = build(
1160                request_builder
1161                    .header("Content-Type", "application/json")
1162                    .header("Authorization", format!("Bearer {}", token))
1163                    .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
1164            )?;
1165
1166            let mut response = http_client.send(request).await?;
1167
1168            if let Some(minimum_required_version) = response
1169                .headers()
1170                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
1171                .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
1172            {
1173                anyhow::ensure!(
1174                    app_version >= minimum_required_version,
1175                    ZedUpdateRequiredError {
1176                        minimum_version: minimum_required_version
1177                    }
1178                );
1179            }
1180
1181            if response.status().is_success() {
1182                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
1183
1184                let mut body = Vec::new();
1185                response.body_mut().read_to_end(&mut body).await?;
1186                return Ok((serde_json::from_slice(&body)?, usage));
1187            } else if !did_retry
1188                && response
1189                    .headers()
1190                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
1191                    .is_some()
1192            {
1193                did_retry = true;
1194                token = llm_token.refresh(&client).await?;
1195            } else {
1196                let mut body = String::new();
1197                response.body_mut().read_to_string(&mut body).await?;
1198                anyhow::bail!(
1199                    "Request failed with status: {:?}\nBody: {}",
1200                    response.status(),
1201                    body
1202                );
1203            }
1204        }
1205    }
1206
1207    pub const CONTEXT_RETRIEVAL_IDLE_DURATION: Duration = Duration::from_secs(10);
1208    pub const CONTEXT_RETRIEVAL_DEBOUNCE_DURATION: Duration = Duration::from_secs(3);
1209
1210    // Refresh the related excerpts when the user just beguns editing after
1211    // an idle period, and after they pause editing.
1212    fn refresh_context_if_needed(
1213        &mut self,
1214        project: &Entity<Project>,
1215        buffer: &Entity<language::Buffer>,
1216        cursor_position: language::Anchor,
1217        cx: &mut Context<Self>,
1218    ) {
1219        if !matches!(&self.options().context, ContextMode::Agentic { .. }) {
1220            return;
1221        }
1222
1223        let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1224            return;
1225        };
1226
1227        let now = Instant::now();
1228        let was_idle = zeta_project
1229            .refresh_context_timestamp
1230            .map_or(true, |timestamp| {
1231                now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION
1232            });
1233        zeta_project.refresh_context_timestamp = Some(now);
1234        zeta_project.refresh_context_debounce_task = Some(cx.spawn({
1235            let buffer = buffer.clone();
1236            let project = project.clone();
1237            async move |this, cx| {
1238                if was_idle {
1239                    log::debug!("refetching edit prediction context after idle");
1240                } else {
1241                    cx.background_executor()
1242                        .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION)
1243                        .await;
1244                    log::debug!("refetching edit prediction context after pause");
1245                }
1246                this.update(cx, |this, cx| {
1247                    let task = this.refresh_context(project.clone(), buffer, cursor_position, cx);
1248
1249                    if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
1250                        zeta_project.refresh_context_task = Some(task.log_err());
1251                    };
1252                })
1253                .ok()
1254            }
1255        }));
1256    }
1257
1258    // Refresh the related excerpts asynchronously. Ensure the task runs to completion,
1259    // and avoid spawning more than one concurrent task.
1260    pub fn refresh_context(
1261        &mut self,
1262        project: Entity<Project>,
1263        buffer: Entity<language::Buffer>,
1264        cursor_position: language::Anchor,
1265        cx: &mut Context<Self>,
1266    ) -> Task<Result<()>> {
1267        let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
1268            return Task::ready(anyhow::Ok(()));
1269        };
1270
1271        let ContextMode::Agentic(options) = &self.options().context else {
1272            return Task::ready(anyhow::Ok(()));
1273        };
1274
1275        let snapshot = buffer.read(cx).snapshot();
1276        let cursor_point = cursor_position.to_point(&snapshot);
1277        let Some(cursor_excerpt) = EditPredictionExcerpt::select_from_buffer(
1278            cursor_point,
1279            &snapshot,
1280            &options.excerpt,
1281            None,
1282        ) else {
1283            return Task::ready(Ok(()));
1284        };
1285
1286        let app_version = AppVersion::global(cx);
1287        let client = self.client.clone();
1288        let llm_token = self.llm_token.clone();
1289        let debug_tx = self.debug_tx.clone();
1290        let current_file_path: Arc<Path> = snapshot
1291            .file()
1292            .map(|f| f.full_path(cx).into())
1293            .unwrap_or_else(|| Path::new("untitled").into());
1294
1295        let prompt = match cloud_zeta2_prompt::retrieval_prompt::build_prompt(
1296            predict_edits_v3::PlanContextRetrievalRequest {
1297                excerpt: cursor_excerpt.text(&snapshot).body,
1298                excerpt_path: current_file_path,
1299                excerpt_line_range: cursor_excerpt.line_range,
1300                cursor_file_max_row: Line(snapshot.max_point().row),
1301                events: zeta_project
1302                    .events
1303                    .iter()
1304                    .filter_map(|ev| ev.to_request_event(cx))
1305                    .collect(),
1306            },
1307        ) {
1308            Ok(prompt) => prompt,
1309            Err(err) => {
1310                return Task::ready(Err(err));
1311            }
1312        };
1313
1314        if let Some(debug_tx) = &debug_tx {
1315            debug_tx
1316                .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted(
1317                    ZetaContextRetrievalStartedDebugInfo {
1318                        project: project.clone(),
1319                        timestamp: Instant::now(),
1320                        search_prompt: prompt.clone(),
1321                    },
1322                ))
1323                .ok();
1324        }
1325
1326        pub static TOOL_SCHEMA: LazyLock<(serde_json::Value, String)> = LazyLock::new(|| {
1327            let schema = language_model::tool_schema::root_schema_for::<SearchToolInput>(
1328                language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset,
1329            );
1330
1331            let description = schema
1332                .get("description")
1333                .and_then(|description| description.as_str())
1334                .unwrap()
1335                .to_string();
1336
1337            (schema.into(), description)
1338        });
1339
1340        let (tool_schema, tool_description) = TOOL_SCHEMA.clone();
1341
1342        let request = open_ai::Request {
1343            model: MODEL_ID.clone(),
1344            messages: vec![open_ai::RequestMessage::User {
1345                content: open_ai::MessageContent::Plain(prompt),
1346            }],
1347            stream: false,
1348            max_completion_tokens: None,
1349            stop: Default::default(),
1350            temperature: 0.7,
1351            tool_choice: None,
1352            parallel_tool_calls: None,
1353            tools: vec![open_ai::ToolDefinition::Function {
1354                function: FunctionDefinition {
1355                    name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME.to_string(),
1356                    description: Some(tool_description),
1357                    parameters: Some(tool_schema),
1358                },
1359            }],
1360            prompt_cache_key: None,
1361            reasoning_effort: None,
1362        };
1363
1364        #[cfg(feature = "llm-response-cache")]
1365        let llm_response_cache = self.llm_response_cache.clone();
1366
1367        cx.spawn(async move |this, cx| {
1368            log::trace!("Sending search planning request");
1369            let response = Self::send_raw_llm_request(
1370                request,
1371                client,
1372                llm_token,
1373                app_version,
1374                #[cfg(feature = "llm-response-cache")]
1375                llm_response_cache,
1376            )
1377            .await;
1378            let mut response = Self::handle_api_response(&this, response, cx)?;
1379            log::trace!("Got search planning response");
1380
1381            let choice = response
1382                .choices
1383                .pop()
1384                .context("No choices in retrieval response")?;
1385            let open_ai::RequestMessage::Assistant {
1386                content: _,
1387                tool_calls,
1388            } = choice.message
1389            else {
1390                anyhow::bail!("Retrieval response didn't include an assistant message");
1391            };
1392
1393            let mut queries: Vec<SearchToolQuery> = Vec::new();
1394            for tool_call in tool_calls {
1395                let open_ai::ToolCallContent::Function { function } = tool_call.content;
1396                if function.name != cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME {
1397                    log::warn!(
1398                        "Context retrieval response tried to call an unknown tool: {}",
1399                        function.name
1400                    );
1401
1402                    continue;
1403                }
1404
1405                let input: SearchToolInput = serde_json::from_str(&function.arguments)
1406                    .with_context(|| format!("invalid search json {}", &function.arguments))?;
1407                queries.extend(input.queries);
1408            }
1409
1410            if let Some(debug_tx) = &debug_tx {
1411                debug_tx
1412                    .unbounded_send(ZetaDebugInfo::SearchQueriesGenerated(
1413                        ZetaSearchQueryDebugInfo {
1414                            project: project.clone(),
1415                            timestamp: Instant::now(),
1416                            search_queries: queries.clone(),
1417                        },
1418                    ))
1419                    .ok();
1420            }
1421
1422            log::trace!("Running retrieval search: {queries:#?}");
1423
1424            let related_excerpts_result =
1425                retrieval_search::run_retrieval_searches(project.clone(), queries, cx).await;
1426
1427            log::trace!("Search queries executed");
1428
1429            if let Some(debug_tx) = &debug_tx {
1430                debug_tx
1431                    .unbounded_send(ZetaDebugInfo::SearchQueriesExecuted(
1432                        ZetaContextRetrievalDebugInfo {
1433                            project: project.clone(),
1434                            timestamp: Instant::now(),
1435                        },
1436                    ))
1437                    .ok();
1438            }
1439
1440            this.update(cx, |this, _cx| {
1441                let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
1442                    return Ok(());
1443                };
1444                zeta_project.refresh_context_task.take();
1445                if let Some(debug_tx) = &this.debug_tx {
1446                    debug_tx
1447                        .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished(
1448                            ZetaContextRetrievalDebugInfo {
1449                                project,
1450                                timestamp: Instant::now(),
1451                            },
1452                        ))
1453                        .ok();
1454                }
1455                match related_excerpts_result {
1456                    Ok(excerpts) => {
1457                        zeta_project.context = Some(excerpts);
1458                        Ok(())
1459                    }
1460                    Err(error) => Err(error),
1461                }
1462            })?
1463        })
1464    }
1465
1466    pub fn set_context(
1467        &mut self,
1468        project: Entity<Project>,
1469        context: HashMap<Entity<Buffer>, Vec<Range<Anchor>>>,
1470    ) {
1471        if let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) {
1472            zeta_project.context = Some(context);
1473        }
1474    }
1475
1476    fn gather_nearby_diagnostics(
1477        cursor_offset: usize,
1478        diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
1479        snapshot: &BufferSnapshot,
1480        max_diagnostics_bytes: usize,
1481    ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
1482        // TODO: Could make this more efficient
1483        let mut diagnostic_groups = Vec::new();
1484        for (language_server_id, diagnostics) in diagnostic_sets {
1485            let mut groups = Vec::new();
1486            diagnostics.groups(*language_server_id, &mut groups, &snapshot);
1487            diagnostic_groups.extend(
1488                groups
1489                    .into_iter()
1490                    .map(|(_, group)| group.resolve::<usize>(&snapshot)),
1491            );
1492        }
1493
1494        // sort by proximity to cursor
1495        diagnostic_groups.sort_by_key(|group| {
1496            let range = &group.entries[group.primary_ix].range;
1497            if range.start >= cursor_offset {
1498                range.start - cursor_offset
1499            } else if cursor_offset >= range.end {
1500                cursor_offset - range.end
1501            } else {
1502                (cursor_offset - range.start).min(range.end - cursor_offset)
1503            }
1504        });
1505
1506        let mut results = Vec::new();
1507        let mut diagnostic_groups_truncated = false;
1508        let mut diagnostics_byte_count = 0;
1509        for group in diagnostic_groups {
1510            let raw_value = serde_json::value::to_raw_value(&group).unwrap();
1511            diagnostics_byte_count += raw_value.get().len();
1512            if diagnostics_byte_count > max_diagnostics_bytes {
1513                diagnostic_groups_truncated = true;
1514                break;
1515            }
1516            results.push(predict_edits_v3::DiagnosticGroup(raw_value));
1517        }
1518
1519        (results, diagnostic_groups_truncated)
1520    }
1521
1522    // TODO: Dedupe with similar code in request_prediction?
1523    pub fn cloud_request_for_zeta_cli(
1524        &mut self,
1525        project: &Entity<Project>,
1526        buffer: &Entity<Buffer>,
1527        position: language::Anchor,
1528        cx: &mut Context<Self>,
1529    ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
1530        let project_state = self.projects.get(&project.entity_id());
1531
1532        let index_state = project_state.map(|state| {
1533            state
1534                .syntax_index
1535                .read_with(cx, |index, _cx| index.state().clone())
1536        });
1537        let options = self.options.clone();
1538        let snapshot = buffer.read(cx).snapshot();
1539        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
1540            return Task::ready(Err(anyhow!("No file path for excerpt")));
1541        };
1542        let worktree_snapshots = project
1543            .read(cx)
1544            .worktrees(cx)
1545            .map(|worktree| worktree.read(cx).snapshot())
1546            .collect::<Vec<_>>();
1547
1548        let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
1549            let mut path = f.worktree.read(cx).absolutize(&f.path);
1550            if path.pop() { Some(path) } else { None }
1551        });
1552
1553        cx.background_spawn(async move {
1554            let index_state = if let Some(index_state) = index_state {
1555                Some(index_state.lock_owned().await)
1556            } else {
1557                None
1558            };
1559
1560            let cursor_point = position.to_point(&snapshot);
1561
1562            let debug_info = true;
1563            EditPredictionContext::gather_context(
1564                cursor_point,
1565                &snapshot,
1566                parent_abs_path.as_deref(),
1567                match &options.context {
1568                    ContextMode::Agentic(_) => {
1569                        // TODO
1570                        panic!("Llm mode not supported in zeta cli yet");
1571                    }
1572                    ContextMode::Syntax(edit_prediction_context_options) => {
1573                        edit_prediction_context_options
1574                    }
1575                },
1576                index_state.as_deref(),
1577            )
1578            .context("Failed to select excerpt")
1579            .map(|context| {
1580                make_syntax_context_cloud_request(
1581                    excerpt_path.into(),
1582                    context,
1583                    // TODO pass everything
1584                    Vec::new(),
1585                    false,
1586                    Vec::new(),
1587                    false,
1588                    None,
1589                    debug_info,
1590                    &worktree_snapshots,
1591                    index_state.as_deref(),
1592                    Some(options.max_prompt_bytes),
1593                    options.prompt_format,
1594                )
1595            })
1596        })
1597    }
1598
1599    pub fn wait_for_initial_indexing(
1600        &mut self,
1601        project: &Entity<Project>,
1602        cx: &mut App,
1603    ) -> Task<Result<()>> {
1604        let zeta_project = self.get_or_init_zeta_project(project, cx);
1605        zeta_project
1606            .syntax_index
1607            .read(cx)
1608            .wait_for_initial_file_indexing(cx)
1609    }
1610}
1611
1612pub fn text_from_response(mut res: open_ai::Response) -> Option<String> {
1613    let choice = res.choices.pop()?;
1614    let output_text = match choice.message {
1615        open_ai::RequestMessage::Assistant {
1616            content: Some(open_ai::MessageContent::Plain(content)),
1617            ..
1618        } => content,
1619        open_ai::RequestMessage::Assistant {
1620            content: Some(open_ai::MessageContent::Multipart(mut content)),
1621            ..
1622        } => {
1623            if content.is_empty() {
1624                log::error!("No output from Baseten completion response");
1625                return None;
1626            }
1627
1628            match content.remove(0) {
1629                open_ai::MessagePart::Text { text } => text,
1630                open_ai::MessagePart::Image { .. } => {
1631                    log::error!("Expected text, got an image");
1632                    return None;
1633                }
1634            }
1635        }
1636        _ => {
1637            log::error!("Invalid response message: {:?}", choice.message);
1638            return None;
1639        }
1640    };
1641    Some(output_text)
1642}
1643
1644#[derive(Error, Debug)]
1645#[error(
1646    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1647)]
1648pub struct ZedUpdateRequiredError {
1649    minimum_version: SemanticVersion,
1650}
1651
1652fn make_syntax_context_cloud_request(
1653    excerpt_path: Arc<Path>,
1654    context: EditPredictionContext,
1655    events: Vec<predict_edits_v3::Event>,
1656    can_collect_data: bool,
1657    diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
1658    diagnostic_groups_truncated: bool,
1659    git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
1660    debug_info: bool,
1661    worktrees: &Vec<worktree::Snapshot>,
1662    index_state: Option<&SyntaxIndexState>,
1663    prompt_max_bytes: Option<usize>,
1664    prompt_format: PromptFormat,
1665) -> predict_edits_v3::PredictEditsRequest {
1666    let mut signatures = Vec::new();
1667    let mut declaration_to_signature_index = HashMap::default();
1668    let mut referenced_declarations = Vec::new();
1669
1670    for snippet in context.declarations {
1671        let project_entry_id = snippet.declaration.project_entry_id();
1672        let Some(path) = worktrees.iter().find_map(|worktree| {
1673            worktree.entry_for_id(project_entry_id).map(|entry| {
1674                let mut full_path = RelPathBuf::new();
1675                full_path.push(worktree.root_name());
1676                full_path.push(&entry.path);
1677                full_path
1678            })
1679        }) else {
1680            continue;
1681        };
1682
1683        let parent_index = index_state.and_then(|index_state| {
1684            snippet.declaration.parent().and_then(|parent| {
1685                add_signature(
1686                    parent,
1687                    &mut declaration_to_signature_index,
1688                    &mut signatures,
1689                    index_state,
1690                )
1691            })
1692        });
1693
1694        let (text, text_is_truncated) = snippet.declaration.item_text();
1695        referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
1696            path: path.as_std_path().into(),
1697            text: text.into(),
1698            range: snippet.declaration.item_line_range(),
1699            text_is_truncated,
1700            signature_range: snippet.declaration.signature_range_in_item_text(),
1701            parent_index,
1702            signature_score: snippet.score(DeclarationStyle::Signature),
1703            declaration_score: snippet.score(DeclarationStyle::Declaration),
1704            score_components: snippet.components,
1705        });
1706    }
1707
1708    let excerpt_parent = index_state.and_then(|index_state| {
1709        context
1710            .excerpt
1711            .parent_declarations
1712            .last()
1713            .and_then(|(parent, _)| {
1714                add_signature(
1715                    *parent,
1716                    &mut declaration_to_signature_index,
1717                    &mut signatures,
1718                    index_state,
1719                )
1720            })
1721    });
1722
1723    predict_edits_v3::PredictEditsRequest {
1724        excerpt_path,
1725        excerpt: context.excerpt_text.body,
1726        excerpt_line_range: context.excerpt.line_range,
1727        excerpt_range: context.excerpt.range,
1728        cursor_point: predict_edits_v3::Point {
1729            line: predict_edits_v3::Line(context.cursor_point.row),
1730            column: context.cursor_point.column,
1731        },
1732        referenced_declarations,
1733        included_files: vec![],
1734        signatures,
1735        excerpt_parent,
1736        events,
1737        can_collect_data,
1738        diagnostic_groups,
1739        diagnostic_groups_truncated,
1740        git_info,
1741        debug_info,
1742        prompt_max_bytes,
1743        prompt_format,
1744    }
1745}
1746
1747fn add_signature(
1748    declaration_id: DeclarationId,
1749    declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
1750    signatures: &mut Vec<Signature>,
1751    index: &SyntaxIndexState,
1752) -> Option<usize> {
1753    if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
1754        return Some(*signature_index);
1755    }
1756    let Some(parent_declaration) = index.declaration(declaration_id) else {
1757        log::error!("bug: missing parent declaration");
1758        return None;
1759    };
1760    let parent_index = parent_declaration.parent().and_then(|parent| {
1761        add_signature(parent, declaration_to_signature_index, signatures, index)
1762    });
1763    let (text, text_is_truncated) = parent_declaration.signature_text();
1764    let signature_index = signatures.len();
1765    signatures.push(Signature {
1766        text: text.into(),
1767        text_is_truncated,
1768        parent_index,
1769        range: parent_declaration.signature_line_range(),
1770    });
1771    declaration_to_signature_index.insert(declaration_id, signature_index);
1772    Some(signature_index)
1773}
1774
1775#[cfg(test)]
1776mod tests {
1777    use std::{path::Path, sync::Arc};
1778
1779    use client::UserStore;
1780    use clock::FakeSystemClock;
1781    use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
1782    use futures::{
1783        AsyncReadExt, StreamExt,
1784        channel::{mpsc, oneshot},
1785    };
1786    use gpui::{
1787        Entity, TestAppContext,
1788        http_client::{FakeHttpClient, Response},
1789        prelude::*,
1790    };
1791    use indoc::indoc;
1792    use language::OffsetRangeExt as _;
1793    use open_ai::Usage;
1794    use pretty_assertions::{assert_eq, assert_matches};
1795    use project::{FakeFs, Project};
1796    use serde_json::json;
1797    use settings::SettingsStore;
1798    use util::path;
1799    use uuid::Uuid;
1800
1801    use crate::{BufferEditPrediction, Zeta};
1802
1803    #[gpui::test]
1804    async fn test_current_state(cx: &mut TestAppContext) {
1805        let (zeta, mut req_rx) = init_test(cx);
1806        let fs = FakeFs::new(cx.executor());
1807        fs.insert_tree(
1808            "/root",
1809            json!({
1810                "1.txt": "Hello!\nHow\nBye\n",
1811                "2.txt": "Hola!\nComo\nAdios\n"
1812            }),
1813        )
1814        .await;
1815        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1816
1817        zeta.update(cx, |zeta, cx| {
1818            zeta.register_project(&project, cx);
1819        });
1820
1821        let buffer1 = project
1822            .update(cx, |project, cx| {
1823                let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
1824                project.open_buffer(path, cx)
1825            })
1826            .await
1827            .unwrap();
1828        let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
1829        let position = snapshot1.anchor_before(language::Point::new(1, 3));
1830
1831        // Prediction for current file
1832
1833        let prediction_task = zeta.update(cx, |zeta, cx| {
1834            zeta.refresh_prediction(&project, &buffer1, position, cx)
1835        });
1836        let (_request, respond_tx) = req_rx.next().await.unwrap();
1837
1838        respond_tx
1839            .send(model_response(indoc! {r"
1840                --- a/root/1.txt
1841                +++ b/root/1.txt
1842                @@ ... @@
1843                 Hello!
1844                -How
1845                +How are you?
1846                 Bye
1847            "}))
1848            .unwrap();
1849        prediction_task.await.unwrap();
1850
1851        zeta.read_with(cx, |zeta, cx| {
1852            let prediction = zeta
1853                .current_prediction_for_buffer(&buffer1, &project, cx)
1854                .unwrap();
1855            assert_matches!(prediction, BufferEditPrediction::Local { .. });
1856        });
1857
1858        // Context refresh
1859        let refresh_task = zeta.update(cx, |zeta, cx| {
1860            zeta.refresh_context(project.clone(), buffer1.clone(), position, cx)
1861        });
1862        let (_request, respond_tx) = req_rx.next().await.unwrap();
1863        respond_tx
1864            .send(open_ai::Response {
1865                id: Uuid::new_v4().to_string(),
1866                object: "response".into(),
1867                created: 0,
1868                model: "model".into(),
1869                choices: vec![open_ai::Choice {
1870                    index: 0,
1871                    message: open_ai::RequestMessage::Assistant {
1872                        content: None,
1873                        tool_calls: vec![open_ai::ToolCall {
1874                            id: "search".into(),
1875                            content: open_ai::ToolCallContent::Function {
1876                                function: open_ai::FunctionContent {
1877                                    name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME
1878                                        .to_string(),
1879                                    arguments: serde_json::to_string(&SearchToolInput {
1880                                        queries: Box::new([SearchToolQuery {
1881                                            glob: "root/2.txt".to_string(),
1882                                            syntax_node: vec![],
1883                                            content: Some(".".into()),
1884                                        }]),
1885                                    })
1886                                    .unwrap(),
1887                                },
1888                            },
1889                        }],
1890                    },
1891                    finish_reason: None,
1892                }],
1893                usage: Usage {
1894                    prompt_tokens: 0,
1895                    completion_tokens: 0,
1896                    total_tokens: 0,
1897                },
1898            })
1899            .unwrap();
1900        refresh_task.await.unwrap();
1901
1902        zeta.update(cx, |zeta, _cx| {
1903            zeta.discard_current_prediction(&project);
1904        });
1905
1906        // Prediction for another file
1907        let prediction_task = zeta.update(cx, |zeta, cx| {
1908            zeta.refresh_prediction(&project, &buffer1, position, cx)
1909        });
1910        let (_request, respond_tx) = req_rx.next().await.unwrap();
1911        respond_tx
1912            .send(model_response(indoc! {r#"
1913                --- a/root/2.txt
1914                +++ b/root/2.txt
1915                 Hola!
1916                -Como
1917                +Como estas?
1918                 Adios
1919            "#}))
1920            .unwrap();
1921        prediction_task.await.unwrap();
1922        zeta.read_with(cx, |zeta, cx| {
1923            let prediction = zeta
1924                .current_prediction_for_buffer(&buffer1, &project, cx)
1925                .unwrap();
1926            assert_matches!(
1927                prediction,
1928                BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
1929            );
1930        });
1931
1932        let buffer2 = project
1933            .update(cx, |project, cx| {
1934                let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
1935                project.open_buffer(path, cx)
1936            })
1937            .await
1938            .unwrap();
1939
1940        zeta.read_with(cx, |zeta, cx| {
1941            let prediction = zeta
1942                .current_prediction_for_buffer(&buffer2, &project, cx)
1943                .unwrap();
1944            assert_matches!(prediction, BufferEditPrediction::Local { .. });
1945        });
1946    }
1947
1948    #[gpui::test]
1949    async fn test_simple_request(cx: &mut TestAppContext) {
1950        let (zeta, mut req_rx) = init_test(cx);
1951        let fs = FakeFs::new(cx.executor());
1952        fs.insert_tree(
1953            "/root",
1954            json!({
1955                "foo.md":  "Hello!\nHow\nBye\n"
1956            }),
1957        )
1958        .await;
1959        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1960
1961        let buffer = project
1962            .update(cx, |project, cx| {
1963                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1964                project.open_buffer(path, cx)
1965            })
1966            .await
1967            .unwrap();
1968        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1969        let position = snapshot.anchor_before(language::Point::new(1, 3));
1970
1971        let prediction_task = zeta.update(cx, |zeta, cx| {
1972            zeta.request_prediction(&project, &buffer, position, cx)
1973        });
1974
1975        let (_, respond_tx) = req_rx.next().await.unwrap();
1976
1977        // TODO Put back when we have a structured request again
1978        // assert_eq!(
1979        //     request.excerpt_path.as_ref(),
1980        //     Path::new(path!("root/foo.md"))
1981        // );
1982        // assert_eq!(
1983        //     request.cursor_point,
1984        //     Point {
1985        //         line: Line(1),
1986        //         column: 3
1987        //     }
1988        // );
1989
1990        respond_tx
1991            .send(model_response(indoc! { r"
1992                --- a/root/foo.md
1993                +++ b/root/foo.md
1994                @@ ... @@
1995                 Hello!
1996                -How
1997                +How are you?
1998                 Bye
1999            "}))
2000            .unwrap();
2001
2002        let prediction = prediction_task.await.unwrap().unwrap();
2003
2004        assert_eq!(prediction.edits.len(), 1);
2005        assert_eq!(
2006            prediction.edits[0].0.to_point(&snapshot).start,
2007            language::Point::new(1, 3)
2008        );
2009        assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
2010    }
2011
2012    #[gpui::test]
2013    async fn test_request_events(cx: &mut TestAppContext) {
2014        let (zeta, mut req_rx) = init_test(cx);
2015        let fs = FakeFs::new(cx.executor());
2016        fs.insert_tree(
2017            "/root",
2018            json!({
2019                "foo.md": "Hello!\n\nBye\n"
2020            }),
2021        )
2022        .await;
2023        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2024
2025        let buffer = project
2026            .update(cx, |project, cx| {
2027                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2028                project.open_buffer(path, cx)
2029            })
2030            .await
2031            .unwrap();
2032
2033        zeta.update(cx, |zeta, cx| {
2034            zeta.register_buffer(&buffer, &project, cx);
2035        });
2036
2037        buffer.update(cx, |buffer, cx| {
2038            buffer.edit(vec![(7..7, "How")], None, cx);
2039        });
2040
2041        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2042        let position = snapshot.anchor_before(language::Point::new(1, 3));
2043
2044        let prediction_task = zeta.update(cx, |zeta, cx| {
2045            zeta.request_prediction(&project, &buffer, position, cx)
2046        });
2047
2048        let (request, respond_tx) = req_rx.next().await.unwrap();
2049
2050        let prompt = prompt_from_request(&request);
2051        assert!(
2052            prompt.contains(indoc! {"
2053            --- a/root/foo.md
2054            +++ b/root/foo.md
2055            @@ -1,3 +1,3 @@
2056             Hello!
2057            -
2058            +How
2059             Bye
2060        "}),
2061            "{prompt}"
2062        );
2063
2064        respond_tx
2065            .send(model_response(indoc! {r#"
2066                --- a/root/foo.md
2067                +++ b/root/foo.md
2068                @@ ... @@
2069                 Hello!
2070                -How
2071                +How are you?
2072                 Bye
2073            "#}))
2074            .unwrap();
2075
2076        let prediction = prediction_task.await.unwrap().unwrap();
2077
2078        assert_eq!(prediction.edits.len(), 1);
2079        assert_eq!(
2080            prediction.edits[0].0.to_point(&snapshot).start,
2081            language::Point::new(1, 3)
2082        );
2083        assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
2084    }
2085
2086    // Skipped until we start including diagnostics in prompt
2087    // #[gpui::test]
2088    // async fn test_request_diagnostics(cx: &mut TestAppContext) {
2089    //     let (zeta, mut req_rx) = init_test(cx);
2090    //     let fs = FakeFs::new(cx.executor());
2091    //     fs.insert_tree(
2092    //         "/root",
2093    //         json!({
2094    //             "foo.md": "Hello!\nBye"
2095    //         }),
2096    //     )
2097    //     .await;
2098    //     let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
2099
2100    //     let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
2101    //     let diagnostic = lsp::Diagnostic {
2102    //         range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
2103    //         severity: Some(lsp::DiagnosticSeverity::ERROR),
2104    //         message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
2105    //         ..Default::default()
2106    //     };
2107
2108    //     project.update(cx, |project, cx| {
2109    //         project.lsp_store().update(cx, |lsp_store, cx| {
2110    //             // Create some diagnostics
2111    //             lsp_store
2112    //                 .update_diagnostics(
2113    //                     LanguageServerId(0),
2114    //                     lsp::PublishDiagnosticsParams {
2115    //                         uri: path_to_buffer_uri.clone(),
2116    //                         diagnostics: vec![diagnostic],
2117    //                         version: None,
2118    //                     },
2119    //                     None,
2120    //                     language::DiagnosticSourceKind::Pushed,
2121    //                     &[],
2122    //                     cx,
2123    //                 )
2124    //                 .unwrap();
2125    //         });
2126    //     });
2127
2128    //     let buffer = project
2129    //         .update(cx, |project, cx| {
2130    //             let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
2131    //             project.open_buffer(path, cx)
2132    //         })
2133    //         .await
2134    //         .unwrap();
2135
2136    //     let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
2137    //     let position = snapshot.anchor_before(language::Point::new(0, 0));
2138
2139    //     let _prediction_task = zeta.update(cx, |zeta, cx| {
2140    //         zeta.request_prediction(&project, &buffer, position, cx)
2141    //     });
2142
2143    //     let (request, _respond_tx) = req_rx.next().await.unwrap();
2144
2145    //     assert_eq!(request.diagnostic_groups.len(), 1);
2146    //     let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
2147    //         .unwrap();
2148    //     // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
2149    //     assert_eq!(
2150    //         value,
2151    //         json!({
2152    //             "entries": [{
2153    //                 "range": {
2154    //                     "start": 8,
2155    //                     "end": 10
2156    //                 },
2157    //                 "diagnostic": {
2158    //                     "source": null,
2159    //                     "code": null,
2160    //                     "code_description": null,
2161    //                     "severity": 1,
2162    //                     "message": "\"Hello\" deprecated. Use \"Hi\" instead",
2163    //                     "markdown": null,
2164    //                     "group_id": 0,
2165    //                     "is_primary": true,
2166    //                     "is_disk_based": false,
2167    //                     "is_unnecessary": false,
2168    //                     "source_kind": "Pushed",
2169    //                     "data": null,
2170    //                     "underline": true
2171    //                 }
2172    //             }],
2173    //             "primary_ix": 0
2174    //         })
2175    //     );
2176    // }
2177
2178    fn model_response(text: &str) -> open_ai::Response {
2179        open_ai::Response {
2180            id: Uuid::new_v4().to_string(),
2181            object: "response".into(),
2182            created: 0,
2183            model: "model".into(),
2184            choices: vec![open_ai::Choice {
2185                index: 0,
2186                message: open_ai::RequestMessage::Assistant {
2187                    content: Some(open_ai::MessageContent::Plain(text.to_string())),
2188                    tool_calls: vec![],
2189                },
2190                finish_reason: None,
2191            }],
2192            usage: Usage {
2193                prompt_tokens: 0,
2194                completion_tokens: 0,
2195                total_tokens: 0,
2196            },
2197        }
2198    }
2199
2200    fn prompt_from_request(request: &open_ai::Request) -> &str {
2201        assert_eq!(request.messages.len(), 1);
2202        let open_ai::RequestMessage::User {
2203            content: open_ai::MessageContent::Plain(content),
2204            ..
2205        } = &request.messages[0]
2206        else {
2207            panic!(
2208                "Request does not have single user message of type Plain. {:#?}",
2209                request
2210            );
2211        };
2212        content
2213    }
2214
2215    fn init_test(
2216        cx: &mut TestAppContext,
2217    ) -> (
2218        Entity<Zeta>,
2219        mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender<open_ai::Response>)>,
2220    ) {
2221        cx.update(move |cx| {
2222            let settings_store = SettingsStore::test(cx);
2223            cx.set_global(settings_store);
2224            zlog::init_test();
2225
2226            let (req_tx, req_rx) = mpsc::unbounded();
2227
2228            let http_client = FakeHttpClient::create({
2229                move |req| {
2230                    let uri = req.uri().path().to_string();
2231                    let mut body = req.into_body();
2232                    let req_tx = req_tx.clone();
2233                    async move {
2234                        let resp = match uri.as_str() {
2235                            "/client/llm_tokens" => serde_json::to_string(&json!({
2236                                "token": "test"
2237                            }))
2238                            .unwrap(),
2239                            "/predict_edits/raw" => {
2240                                let mut buf = Vec::new();
2241                                body.read_to_end(&mut buf).await.ok();
2242                                let req = serde_json::from_slice(&buf).unwrap();
2243
2244                                let (res_tx, res_rx) = oneshot::channel();
2245                                req_tx.unbounded_send((req, res_tx)).unwrap();
2246                                serde_json::to_string(&res_rx.await?).unwrap()
2247                            }
2248                            _ => {
2249                                panic!("Unexpected path: {}", uri)
2250                            }
2251                        };
2252
2253                        Ok(Response::builder().body(resp.into()).unwrap())
2254                    }
2255                }
2256            });
2257
2258            let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
2259            client.cloud_client().set_credentials(1, "test".into());
2260
2261            language_model::init(client.clone(), cx);
2262
2263            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2264            let zeta = Zeta::global(&client, &user_store, cx);
2265
2266            (zeta, req_rx)
2267        })
2268    }
2269}