zeta2.rs

   1use anyhow::{Context as _, Result, anyhow};
   2use chrono::TimeDelta;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
   5use cloud_llm_client::{
   6    AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
   7    ZED_VERSION_HEADER_NAME,
   8};
   9use cloud_zeta2_prompt::{DEFAULT_MAX_PROMPT_BYTES, build_prompt};
  10use collections::HashMap;
  11use edit_prediction_context::{
  12    DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
  13    EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionScoreOptions, Line,
  14    SyntaxIndex, SyntaxIndexState,
  15};
  16use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
  17use futures::AsyncReadExt as _;
  18use futures::channel::{mpsc, oneshot};
  19use gpui::http_client::{AsyncBody, Method};
  20use gpui::{
  21    App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
  22    http_client, prelude::*,
  23};
  24use language::{Anchor, Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
  25use language::{BufferSnapshot, OffsetRangeExt};
  26use language_model::{LlmApiToken, RefreshLlmTokenListener};
  27use project::Project;
  28use release_channel::AppVersion;
  29use serde::de::DeserializeOwned;
  30use std::collections::{VecDeque, hash_map};
  31use std::ops::Range;
  32use std::path::Path;
  33use std::str::FromStr as _;
  34use std::sync::Arc;
  35use std::time::{Duration, Instant};
  36use thiserror::Error;
  37use util::ResultExt as _;
  38use util::rel_path::RelPathBuf;
  39use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  40
  41mod merge_excerpts;
  42mod prediction;
  43mod provider;
  44mod related_excerpts;
  45
  46use crate::merge_excerpts::merge_excerpts;
  47use crate::prediction::EditPrediction;
  48pub use crate::related_excerpts::LlmContextOptions;
  49use crate::related_excerpts::find_related_excerpts;
  50pub use provider::ZetaEditPredictionProvider;
  51
  52const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
  53
  54/// Maximum number of events to track.
  55const MAX_EVENT_COUNT: usize = 16;
  56
  57pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
  58    max_bytes: 512,
  59    min_bytes: 128,
  60    target_before_cursor_over_total_bytes: 0.5,
  61};
  62
  63pub const DEFAULT_CONTEXT_OPTIONS: ContextMode = ContextMode::Llm(DEFAULT_LLM_CONTEXT_OPTIONS);
  64
  65pub const DEFAULT_LLM_CONTEXT_OPTIONS: LlmContextOptions = LlmContextOptions {
  66    excerpt: DEFAULT_EXCERPT_OPTIONS,
  67};
  68
  69pub const DEFAULT_SYNTAX_CONTEXT_OPTIONS: EditPredictionContextOptions =
  70    EditPredictionContextOptions {
  71        use_imports: true,
  72        max_retrieved_declarations: 0,
  73        excerpt: DEFAULT_EXCERPT_OPTIONS,
  74        score: EditPredictionScoreOptions {
  75            omit_excerpt_overlaps: true,
  76        },
  77    };
  78
  79pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
  80    context: DEFAULT_CONTEXT_OPTIONS,
  81    max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
  82    max_diagnostic_bytes: 2048,
  83    prompt_format: PromptFormat::DEFAULT,
  84    file_indexing_parallelism: 1,
  85};
  86
  87pub struct Zeta2FeatureFlag;
  88
  89impl FeatureFlag for Zeta2FeatureFlag {
  90    const NAME: &'static str = "zeta2";
  91
  92    fn enabled_for_staff() -> bool {
  93        false
  94    }
  95}
  96
  97#[derive(Clone)]
  98struct ZetaGlobal(Entity<Zeta>);
  99
 100impl Global for ZetaGlobal {}
 101
 102pub struct Zeta {
 103    client: Arc<Client>,
 104    user_store: Entity<UserStore>,
 105    llm_token: LlmApiToken,
 106    _llm_token_subscription: Subscription,
 107    projects: HashMap<EntityId, ZetaProject>,
 108    options: ZetaOptions,
 109    update_required: bool,
 110    debug_tx: Option<mpsc::UnboundedSender<PredictionDebugInfo>>,
 111}
 112
 113#[derive(Debug, Clone, PartialEq)]
 114pub struct ZetaOptions {
 115    pub context: ContextMode,
 116    pub max_prompt_bytes: usize,
 117    pub max_diagnostic_bytes: usize,
 118    pub prompt_format: predict_edits_v3::PromptFormat,
 119    pub file_indexing_parallelism: usize,
 120}
 121
 122#[derive(Debug, Clone, PartialEq)]
 123pub enum ContextMode {
 124    Llm(LlmContextOptions),
 125    Syntax(EditPredictionContextOptions),
 126}
 127
 128impl ContextMode {
 129    pub fn excerpt(&self) -> &EditPredictionExcerptOptions {
 130        match self {
 131            ContextMode::Llm(options) => &options.excerpt,
 132            ContextMode::Syntax(options) => &options.excerpt,
 133        }
 134    }
 135}
 136
 137pub struct PredictionDebugInfo {
 138    pub request: predict_edits_v3::PredictEditsRequest,
 139    pub retrieval_time: TimeDelta,
 140    pub buffer: WeakEntity<Buffer>,
 141    pub position: language::Anchor,
 142    pub local_prompt: Result<String, String>,
 143    pub response_rx: oneshot::Receiver<Result<predict_edits_v3::PredictEditsResponse, String>>,
 144}
 145
 146pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
 147
 148struct ZetaProject {
 149    syntax_index: Entity<SyntaxIndex>,
 150    events: VecDeque<Event>,
 151    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 152    current_prediction: Option<CurrentEditPrediction>,
 153    context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
 154    refresh_context_task: Option<Task<Option<()>>>,
 155    refresh_context_debounce_task: Option<Task<Option<()>>>,
 156    refresh_context_timestamp: Option<Instant>,
 157}
 158
 159#[derive(Debug, Clone)]
 160struct CurrentEditPrediction {
 161    pub requested_by_buffer_id: EntityId,
 162    pub prediction: EditPrediction,
 163}
 164
 165impl CurrentEditPrediction {
 166    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 167        let Some(new_edits) = self
 168            .prediction
 169            .interpolate(&self.prediction.buffer.read(cx))
 170        else {
 171            return false;
 172        };
 173
 174        if self.prediction.buffer != old_prediction.prediction.buffer {
 175            return true;
 176        }
 177
 178        let Some(old_edits) = old_prediction
 179            .prediction
 180            .interpolate(&old_prediction.prediction.buffer.read(cx))
 181        else {
 182            return true;
 183        };
 184
 185        // This reduces the occurrence of UI thrash from replacing edits
 186        //
 187        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 188        if self.requested_by_buffer_id == self.prediction.buffer.entity_id()
 189            && self.requested_by_buffer_id == old_prediction.prediction.buffer.entity_id()
 190            && old_edits.len() == 1
 191            && new_edits.len() == 1
 192        {
 193            let (old_range, old_text) = &old_edits[0];
 194            let (new_range, new_text) = &new_edits[0];
 195            new_range == old_range && new_text.starts_with(old_text)
 196        } else {
 197            true
 198        }
 199    }
 200}
 201
 202/// A prediction from the perspective of a buffer.
 203#[derive(Debug)]
 204enum BufferEditPrediction<'a> {
 205    Local { prediction: &'a EditPrediction },
 206    Jump { prediction: &'a EditPrediction },
 207}
 208
 209struct RegisteredBuffer {
 210    snapshot: BufferSnapshot,
 211    _subscriptions: [gpui::Subscription; 2],
 212}
 213
 214#[derive(Clone)]
 215pub enum Event {
 216    BufferChange {
 217        old_snapshot: BufferSnapshot,
 218        new_snapshot: BufferSnapshot,
 219        timestamp: Instant,
 220    },
 221}
 222
 223impl Event {
 224    pub fn to_request_event(&self, cx: &App) -> Option<predict_edits_v3::Event> {
 225        match self {
 226            Event::BufferChange {
 227                old_snapshot,
 228                new_snapshot,
 229                ..
 230            } => {
 231                let path = new_snapshot.file().map(|f| f.full_path(cx));
 232
 233                let old_path = old_snapshot.file().and_then(|f| {
 234                    let old_path = f.full_path(cx);
 235                    if Some(&old_path) != path.as_ref() {
 236                        Some(old_path)
 237                    } else {
 238                        None
 239                    }
 240                });
 241
 242                // TODO [zeta2] move to bg?
 243                let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
 244
 245                if path == old_path && diff.is_empty() {
 246                    None
 247                } else {
 248                    Some(predict_edits_v3::Event::BufferChange {
 249                        old_path,
 250                        path,
 251                        diff,
 252                        //todo: Actually detect if this edit was predicted or not
 253                        predicted: false,
 254                    })
 255                }
 256            }
 257        }
 258    }
 259}
 260
 261impl Zeta {
 262    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 263        cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
 264    }
 265
 266    pub fn global(
 267        client: &Arc<Client>,
 268        user_store: &Entity<UserStore>,
 269        cx: &mut App,
 270    ) -> Entity<Self> {
 271        cx.try_global::<ZetaGlobal>()
 272            .map(|global| global.0.clone())
 273            .unwrap_or_else(|| {
 274                let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 275                cx.set_global(ZetaGlobal(zeta.clone()));
 276                zeta
 277            })
 278    }
 279
 280    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 281        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 282
 283        Self {
 284            projects: HashMap::default(),
 285            client,
 286            user_store,
 287            options: DEFAULT_OPTIONS,
 288            llm_token: LlmApiToken::default(),
 289            _llm_token_subscription: cx.subscribe(
 290                &refresh_llm_token_listener,
 291                |this, _listener, _event, cx| {
 292                    let client = this.client.clone();
 293                    let llm_token = this.llm_token.clone();
 294                    cx.spawn(async move |_this, _cx| {
 295                        llm_token.refresh(&client).await?;
 296                        anyhow::Ok(())
 297                    })
 298                    .detach_and_log_err(cx);
 299                },
 300            ),
 301            update_required: false,
 302            debug_tx: None,
 303        }
 304    }
 305
 306    pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<PredictionDebugInfo> {
 307        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
 308        self.debug_tx = Some(debug_watch_tx);
 309        debug_watch_rx
 310    }
 311
 312    pub fn options(&self) -> &ZetaOptions {
 313        &self.options
 314    }
 315
 316    pub fn set_options(&mut self, options: ZetaOptions) {
 317        self.options = options;
 318    }
 319
 320    pub fn clear_history(&mut self) {
 321        for zeta_project in self.projects.values_mut() {
 322            zeta_project.events.clear();
 323        }
 324    }
 325
 326    pub fn history_for_project(&self, project: &Entity<Project>) -> impl Iterator<Item = &Event> {
 327        static EMPTY_EVENTS: VecDeque<Event> = VecDeque::new();
 328        self.projects
 329            .get(&project.entity_id())
 330            .map_or(&EMPTY_EVENTS, |project| &project.events)
 331            .iter()
 332    }
 333
 334    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 335        self.user_store.read(cx).edit_prediction_usage()
 336    }
 337
 338    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
 339        self.get_or_init_zeta_project(project, cx);
 340    }
 341
 342    pub fn register_buffer(
 343        &mut self,
 344        buffer: &Entity<Buffer>,
 345        project: &Entity<Project>,
 346        cx: &mut Context<Self>,
 347    ) {
 348        let zeta_project = self.get_or_init_zeta_project(project, cx);
 349        Self::register_buffer_impl(zeta_project, buffer, project, cx);
 350    }
 351
 352    fn get_or_init_zeta_project(
 353        &mut self,
 354        project: &Entity<Project>,
 355        cx: &mut App,
 356    ) -> &mut ZetaProject {
 357        self.projects
 358            .entry(project.entity_id())
 359            .or_insert_with(|| ZetaProject {
 360                syntax_index: cx.new(|cx| {
 361                    SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
 362                }),
 363                events: VecDeque::new(),
 364                registered_buffers: HashMap::default(),
 365                current_prediction: None,
 366                context: None,
 367                refresh_context_task: None,
 368                refresh_context_debounce_task: None,
 369                refresh_context_timestamp: None,
 370            })
 371    }
 372
 373    fn register_buffer_impl<'a>(
 374        zeta_project: &'a mut ZetaProject,
 375        buffer: &Entity<Buffer>,
 376        project: &Entity<Project>,
 377        cx: &mut Context<Self>,
 378    ) -> &'a mut RegisteredBuffer {
 379        let buffer_id = buffer.entity_id();
 380        match zeta_project.registered_buffers.entry(buffer_id) {
 381            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 382            hash_map::Entry::Vacant(entry) => {
 383                let snapshot = buffer.read(cx).snapshot();
 384                let project_entity_id = project.entity_id();
 385                entry.insert(RegisteredBuffer {
 386                    snapshot,
 387                    _subscriptions: [
 388                        cx.subscribe(buffer, {
 389                            let project = project.downgrade();
 390                            move |this, buffer, event, cx| {
 391                                if let language::BufferEvent::Edited = event
 392                                    && let Some(project) = project.upgrade()
 393                                {
 394                                    this.report_changes_for_buffer(&buffer, &project, cx);
 395                                }
 396                            }
 397                        }),
 398                        cx.observe_release(buffer, move |this, _buffer, _cx| {
 399                            let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
 400                            else {
 401                                return;
 402                            };
 403                            zeta_project.registered_buffers.remove(&buffer_id);
 404                        }),
 405                    ],
 406                })
 407            }
 408        }
 409    }
 410
 411    fn report_changes_for_buffer(
 412        &mut self,
 413        buffer: &Entity<Buffer>,
 414        project: &Entity<Project>,
 415        cx: &mut Context<Self>,
 416    ) -> BufferSnapshot {
 417        let zeta_project = self.get_or_init_zeta_project(project, cx);
 418        let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
 419
 420        let new_snapshot = buffer.read(cx).snapshot();
 421        if new_snapshot.version != registered_buffer.snapshot.version {
 422            let old_snapshot =
 423                std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
 424            Self::push_event(
 425                zeta_project,
 426                Event::BufferChange {
 427                    old_snapshot,
 428                    new_snapshot: new_snapshot.clone(),
 429                    timestamp: Instant::now(),
 430                },
 431            );
 432        }
 433
 434        new_snapshot
 435    }
 436
 437    fn push_event(zeta_project: &mut ZetaProject, event: Event) {
 438        let events = &mut zeta_project.events;
 439
 440        if let Some(Event::BufferChange {
 441            new_snapshot: last_new_snapshot,
 442            timestamp: last_timestamp,
 443            ..
 444        }) = events.back_mut()
 445        {
 446            // Coalesce edits for the same buffer when they happen one after the other.
 447            let Event::BufferChange {
 448                old_snapshot,
 449                new_snapshot,
 450                timestamp,
 451            } = &event;
 452
 453            if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
 454                && old_snapshot.remote_id() == last_new_snapshot.remote_id()
 455                && old_snapshot.version == last_new_snapshot.version
 456            {
 457                *last_new_snapshot = new_snapshot.clone();
 458                *last_timestamp = *timestamp;
 459                return;
 460            }
 461        }
 462
 463        if events.len() >= MAX_EVENT_COUNT {
 464            // These are halved instead of popping to improve prompt caching.
 465            events.drain(..MAX_EVENT_COUNT / 2);
 466        }
 467
 468        events.push_back(event);
 469    }
 470
 471    fn current_prediction_for_buffer(
 472        &self,
 473        buffer: &Entity<Buffer>,
 474        project: &Entity<Project>,
 475        cx: &App,
 476    ) -> Option<BufferEditPrediction<'_>> {
 477        let project_state = self.projects.get(&project.entity_id())?;
 478
 479        let CurrentEditPrediction {
 480            requested_by_buffer_id,
 481            prediction,
 482        } = project_state.current_prediction.as_ref()?;
 483
 484        if prediction.targets_buffer(buffer.read(cx), cx) {
 485            Some(BufferEditPrediction::Local { prediction })
 486        } else if *requested_by_buffer_id == buffer.entity_id() {
 487            Some(BufferEditPrediction::Jump { prediction })
 488        } else {
 489            None
 490        }
 491    }
 492
 493    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 494        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
 495            return;
 496        };
 497
 498        let Some(prediction) = project_state.current_prediction.take() else {
 499            return;
 500        };
 501        let request_id = prediction.prediction.id.into();
 502
 503        let client = self.client.clone();
 504        let llm_token = self.llm_token.clone();
 505        let app_version = AppVersion::global(cx);
 506        cx.spawn(async move |this, cx| {
 507            let url = if let Ok(predict_edits_url) = std::env::var("ZED_ACCEPT_PREDICTION_URL") {
 508                http_client::Url::parse(&predict_edits_url)?
 509            } else {
 510                client
 511                    .http_client()
 512                    .build_zed_llm_url("/predict_edits/accept", &[])?
 513            };
 514
 515            let response = cx
 516                .background_spawn(Self::send_api_request::<()>(
 517                    move |builder| {
 518                        let req = builder.uri(url.as_ref()).body(
 519                            serde_json::to_string(&AcceptEditPredictionBody { request_id })?.into(),
 520                        );
 521                        Ok(req?)
 522                    },
 523                    client,
 524                    llm_token,
 525                    app_version,
 526                ))
 527                .await;
 528
 529            Self::handle_api_response(&this, response, cx)?;
 530            anyhow::Ok(())
 531        })
 532        .detach_and_log_err(cx);
 533    }
 534
 535    fn discard_current_prediction(&mut self, project: &Entity<Project>) {
 536        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 537            project_state.current_prediction.take();
 538        };
 539    }
 540
 541    pub fn refresh_prediction(
 542        &mut self,
 543        project: &Entity<Project>,
 544        buffer: &Entity<Buffer>,
 545        position: language::Anchor,
 546        cx: &mut Context<Self>,
 547    ) -> Task<Result<()>> {
 548        let request_task = self.request_prediction(project, buffer, position, cx);
 549        let buffer = buffer.clone();
 550        let project = project.clone();
 551
 552        cx.spawn(async move |this, cx| {
 553            if let Some(prediction) = request_task.await? {
 554                this.update(cx, |this, cx| {
 555                    let project_state = this
 556                        .projects
 557                        .get_mut(&project.entity_id())
 558                        .context("Project not found")?;
 559
 560                    let new_prediction = CurrentEditPrediction {
 561                        requested_by_buffer_id: buffer.entity_id(),
 562                        prediction: prediction,
 563                    };
 564
 565                    if project_state
 566                        .current_prediction
 567                        .as_ref()
 568                        .is_none_or(|old_prediction| {
 569                            new_prediction.should_replace_prediction(&old_prediction, cx)
 570                        })
 571                    {
 572                        project_state.current_prediction = Some(new_prediction);
 573                    }
 574                    anyhow::Ok(())
 575                })??;
 576            }
 577            Ok(())
 578        })
 579    }
 580
 581    fn request_prediction(
 582        &mut self,
 583        project: &Entity<Project>,
 584        buffer: &Entity<Buffer>,
 585        position: language::Anchor,
 586        cx: &mut Context<Self>,
 587    ) -> Task<Result<Option<EditPrediction>>> {
 588        let project_state = self.projects.get(&project.entity_id());
 589
 590        let index_state = project_state.map(|state| {
 591            state
 592                .syntax_index
 593                .read_with(cx, |index, _cx| index.state().clone())
 594        });
 595        let options = self.options.clone();
 596        let snapshot = buffer.read(cx).snapshot();
 597        let Some(excerpt_path) = snapshot
 598            .file()
 599            .map(|path| -> Arc<Path> { path.full_path(cx).into() })
 600        else {
 601            return Task::ready(Err(anyhow!("No file path for excerpt")));
 602        };
 603        let client = self.client.clone();
 604        let llm_token = self.llm_token.clone();
 605        let app_version = AppVersion::global(cx);
 606        let worktree_snapshots = project
 607            .read(cx)
 608            .worktrees(cx)
 609            .map(|worktree| worktree.read(cx).snapshot())
 610            .collect::<Vec<_>>();
 611        let debug_tx = self.debug_tx.clone();
 612
 613        let events = project_state
 614            .map(|state| {
 615                state
 616                    .events
 617                    .iter()
 618                    .filter_map(|event| event.to_request_event(cx))
 619                    .collect::<Vec<_>>()
 620            })
 621            .unwrap_or_default();
 622
 623        let diagnostics = snapshot.diagnostic_sets().clone();
 624
 625        let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
 626            let mut path = f.worktree.read(cx).absolutize(&f.path);
 627            if path.pop() { Some(path) } else { None }
 628        });
 629
 630        // TODO data collection
 631        let can_collect_data = cx.is_staff();
 632
 633        let mut included_files = project_state
 634            .and_then(|project_state| project_state.context.as_ref())
 635            .unwrap_or(&HashMap::default())
 636            .iter()
 637            .filter_map(|(buffer, ranges)| {
 638                let buffer = buffer.read(cx);
 639                Some((
 640                    buffer.snapshot(),
 641                    buffer.file()?.full_path(cx).into(),
 642                    ranges.clone(),
 643                ))
 644            })
 645            .collect::<Vec<_>>();
 646
 647        let request_task = cx.background_spawn({
 648            let snapshot = snapshot.clone();
 649            let buffer = buffer.clone();
 650            async move {
 651                let index_state = if let Some(index_state) = index_state {
 652                    Some(index_state.lock_owned().await)
 653                } else {
 654                    None
 655                };
 656
 657                let cursor_offset = position.to_offset(&snapshot);
 658                let cursor_point = cursor_offset.to_point(&snapshot);
 659
 660                let before_retrieval = chrono::Utc::now();
 661
 662                let (diagnostic_groups, diagnostic_groups_truncated) =
 663                    Self::gather_nearby_diagnostics(
 664                        cursor_offset,
 665                        &diagnostics,
 666                        &snapshot,
 667                        options.max_diagnostic_bytes,
 668                    );
 669
 670                let request = match options.context {
 671                    ContextMode::Llm(context_options) => {
 672                        let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
 673                            cursor_point,
 674                            &snapshot,
 675                            &context_options.excerpt,
 676                            index_state.as_deref(),
 677                        ) else {
 678                            return Ok((None, None));
 679                        };
 680
 681                        let excerpt_anchor_range = snapshot.anchor_after(excerpt.range.start)
 682                            ..snapshot.anchor_before(excerpt.range.end);
 683
 684                        if let Some(buffer_ix) = included_files
 685                            .iter()
 686                            .position(|(buffer, _, _)| buffer.remote_id() == snapshot.remote_id())
 687                        {
 688                            let (buffer, _, ranges) = &mut included_files[buffer_ix];
 689                            let range_ix = ranges
 690                                .binary_search_by(|probe| {
 691                                    probe
 692                                        .start
 693                                        .cmp(&excerpt_anchor_range.start, buffer)
 694                                        .then(excerpt_anchor_range.end.cmp(&probe.end, buffer))
 695                                })
 696                                .unwrap_or_else(|ix| ix);
 697
 698                            ranges.insert(range_ix, excerpt_anchor_range);
 699                            let last_ix = included_files.len() - 1;
 700                            included_files.swap(buffer_ix, last_ix);
 701                        } else {
 702                            included_files.push((
 703                                snapshot,
 704                                excerpt_path.clone(),
 705                                vec![excerpt_anchor_range],
 706                            ));
 707                        }
 708
 709                        let included_files = included_files
 710                            .into_iter()
 711                            .map(|(buffer, path, ranges)| {
 712                                let excerpts = merge_excerpts(
 713                                    &buffer,
 714                                    ranges.iter().map(|range| {
 715                                        let point_range = range.to_point(&buffer);
 716                                        Line(point_range.start.row)..Line(point_range.end.row)
 717                                    }),
 718                                );
 719                                predict_edits_v3::IncludedFile {
 720                                    path,
 721                                    max_row: Line(buffer.max_point().row),
 722                                    excerpts,
 723                                }
 724                            })
 725                            .collect::<Vec<_>>();
 726
 727                        predict_edits_v3::PredictEditsRequest {
 728                            excerpt_path,
 729                            excerpt: String::new(),
 730                            excerpt_line_range: Line(0)..Line(0),
 731                            excerpt_range: 0..0,
 732                            cursor_point: predict_edits_v3::Point {
 733                                line: predict_edits_v3::Line(cursor_point.row),
 734                                column: cursor_point.column,
 735                            },
 736                            included_files,
 737                            referenced_declarations: vec![],
 738                            events,
 739                            can_collect_data,
 740                            diagnostic_groups,
 741                            diagnostic_groups_truncated,
 742                            debug_info: debug_tx.is_some(),
 743                            prompt_max_bytes: Some(options.max_prompt_bytes),
 744                            prompt_format: options.prompt_format,
 745                            // TODO [zeta2]
 746                            signatures: vec![],
 747                            excerpt_parent: None,
 748                            git_info: None,
 749                        }
 750                    }
 751                    ContextMode::Syntax(context_options) => {
 752                        let Some(context) = EditPredictionContext::gather_context(
 753                            cursor_point,
 754                            &snapshot,
 755                            parent_abs_path.as_deref(),
 756                            &context_options,
 757                            index_state.as_deref(),
 758                        ) else {
 759                            return Ok((None, None));
 760                        };
 761
 762                        make_syntax_context_cloud_request(
 763                            excerpt_path,
 764                            context,
 765                            events,
 766                            can_collect_data,
 767                            diagnostic_groups,
 768                            diagnostic_groups_truncated,
 769                            None,
 770                            debug_tx.is_some(),
 771                            &worktree_snapshots,
 772                            index_state.as_deref(),
 773                            Some(options.max_prompt_bytes),
 774                            options.prompt_format,
 775                        )
 776                    }
 777                };
 778
 779                let retrieval_time = chrono::Utc::now() - before_retrieval;
 780
 781                let debug_response_tx = if let Some(debug_tx) = &debug_tx {
 782                    let (response_tx, response_rx) = oneshot::channel();
 783
 784                    if !request.referenced_declarations.is_empty() || !request.signatures.is_empty()
 785                    {
 786                    } else {
 787                    };
 788
 789                    let local_prompt = build_prompt(&request)
 790                        .map(|(prompt, _)| prompt)
 791                        .map_err(|err| err.to_string());
 792
 793                    debug_tx
 794                        .unbounded_send(PredictionDebugInfo {
 795                            request: request.clone(),
 796                            retrieval_time,
 797                            buffer: buffer.downgrade(),
 798                            local_prompt,
 799                            position,
 800                            response_rx,
 801                        })
 802                        .ok();
 803                    Some(response_tx)
 804                } else {
 805                    None
 806                };
 807
 808                if cfg!(debug_assertions) && std::env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
 809                    if let Some(debug_response_tx) = debug_response_tx {
 810                        debug_response_tx
 811                            .send(Err("Request skipped".to_string()))
 812                            .ok();
 813                    }
 814                    anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
 815                }
 816
 817                let response =
 818                    Self::send_prediction_request(client, llm_token, app_version, request).await;
 819
 820                if let Some(debug_response_tx) = debug_response_tx {
 821                    debug_response_tx
 822                        .send(
 823                            response
 824                                .as_ref()
 825                                .map_err(|err| err.to_string())
 826                                .map(|response| response.0.clone()),
 827                        )
 828                        .ok();
 829                }
 830
 831                response.map(|(res, usage)| (Some(res), usage))
 832            }
 833        });
 834
 835        let buffer = buffer.clone();
 836
 837        cx.spawn({
 838            let project = project.clone();
 839            async move |this, cx| {
 840                let Some(response) = Self::handle_api_response(&this, request_task.await, cx)?
 841                else {
 842                    return Ok(None);
 843                };
 844
 845                // TODO telemetry: duration, etc
 846                Ok(EditPrediction::from_response(response, &snapshot, &buffer, &project, cx).await)
 847            }
 848        })
 849    }
 850
 851    async fn send_prediction_request(
 852        client: Arc<Client>,
 853        llm_token: LlmApiToken,
 854        app_version: SemanticVersion,
 855        request: predict_edits_v3::PredictEditsRequest,
 856    ) -> Result<(
 857        predict_edits_v3::PredictEditsResponse,
 858        Option<EditPredictionUsage>,
 859    )> {
 860        let url = if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
 861            http_client::Url::parse(&predict_edits_url)?
 862        } else {
 863            client
 864                .http_client()
 865                .build_zed_llm_url("/predict_edits/v3", &[])?
 866        };
 867
 868        Self::send_api_request(
 869            |builder| {
 870                let req = builder
 871                    .uri(url.as_ref())
 872                    .body(serde_json::to_string(&request)?.into());
 873                Ok(req?)
 874            },
 875            client,
 876            llm_token,
 877            app_version,
 878        )
 879        .await
 880    }
 881
 882    fn handle_api_response<T>(
 883        this: &WeakEntity<Self>,
 884        response: Result<(T, Option<EditPredictionUsage>)>,
 885        cx: &mut gpui::AsyncApp,
 886    ) -> Result<T> {
 887        match response {
 888            Ok((data, usage)) => {
 889                if let Some(usage) = usage {
 890                    this.update(cx, |this, cx| {
 891                        this.user_store.update(cx, |user_store, cx| {
 892                            user_store.update_edit_prediction_usage(usage, cx);
 893                        });
 894                    })
 895                    .ok();
 896                }
 897                Ok(data)
 898            }
 899            Err(err) => {
 900                if err.is::<ZedUpdateRequiredError>() {
 901                    cx.update(|cx| {
 902                        this.update(cx, |this, _cx| {
 903                            this.update_required = true;
 904                        })
 905                        .ok();
 906
 907                        let error_message: SharedString = err.to_string().into();
 908                        show_app_notification(
 909                            NotificationId::unique::<ZedUpdateRequiredError>(),
 910                            cx,
 911                            move |cx| {
 912                                cx.new(|cx| {
 913                                    ErrorMessagePrompt::new(error_message.clone(), cx)
 914                                        .with_link_button("Update Zed", "https://zed.dev/releases")
 915                                })
 916                            },
 917                        );
 918                    })
 919                    .ok();
 920                }
 921                Err(err)
 922            }
 923        }
 924    }
 925
 926    async fn send_api_request<Res>(
 927        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
 928        client: Arc<Client>,
 929        llm_token: LlmApiToken,
 930        app_version: SemanticVersion,
 931    ) -> Result<(Res, Option<EditPredictionUsage>)>
 932    where
 933        Res: DeserializeOwned,
 934    {
 935        let http_client = client.http_client();
 936        let mut token = llm_token.acquire(&client).await?;
 937        let mut did_retry = false;
 938
 939        loop {
 940            let request_builder = http_client::Request::builder().method(Method::POST);
 941
 942            let request = build(
 943                request_builder
 944                    .header("Content-Type", "application/json")
 945                    .header("Authorization", format!("Bearer {}", token))
 946                    .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
 947            )?;
 948
 949            let mut response = http_client.send(request).await?;
 950
 951            if let Some(minimum_required_version) = response
 952                .headers()
 953                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
 954                .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
 955            {
 956                anyhow::ensure!(
 957                    app_version >= minimum_required_version,
 958                    ZedUpdateRequiredError {
 959                        minimum_version: minimum_required_version
 960                    }
 961                );
 962            }
 963
 964            if response.status().is_success() {
 965                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
 966
 967                let mut body = Vec::new();
 968                response.body_mut().read_to_end(&mut body).await?;
 969                return Ok((serde_json::from_slice(&body)?, usage));
 970            } else if !did_retry
 971                && response
 972                    .headers()
 973                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
 974                    .is_some()
 975            {
 976                did_retry = true;
 977                token = llm_token.refresh(&client).await?;
 978            } else {
 979                let mut body = String::new();
 980                response.body_mut().read_to_string(&mut body).await?;
 981                anyhow::bail!(
 982                    "Request failed with status: {:?}\nBody: {}",
 983                    response.status(),
 984                    body
 985                );
 986            }
 987        }
 988    }
 989
 990    pub const CONTEXT_RETRIEVAL_IDLE_DURATION: Duration = Duration::from_secs(10);
 991    pub const CONTEXT_RETRIEVAL_DEBOUNCE_DURATION: Duration = Duration::from_secs(3);
 992
 993    // Refresh the related excerpts when the user just beguns editing after
 994    // an idle period, and after they pause editing.
 995    fn refresh_context_if_needed(
 996        &mut self,
 997        project: &Entity<Project>,
 998        buffer: &Entity<language::Buffer>,
 999        cursor_position: language::Anchor,
1000        cx: &mut Context<Self>,
1001    ) {
1002        if !matches!(&self.options().context, ContextMode::Llm { .. }) {
1003            return;
1004        }
1005
1006        let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1007            return;
1008        };
1009
1010        let now = Instant::now();
1011        let was_idle = zeta_project
1012            .refresh_context_timestamp
1013            .map_or(true, |timestamp| {
1014                now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION
1015            });
1016        zeta_project.refresh_context_timestamp = Some(now);
1017        zeta_project.refresh_context_debounce_task = Some(cx.spawn({
1018            let buffer = buffer.clone();
1019            let project = project.clone();
1020            async move |this, cx| {
1021                if was_idle {
1022                    log::debug!("refetching edit prediction context after idle");
1023                } else {
1024                    cx.background_executor()
1025                        .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION)
1026                        .await;
1027                    log::debug!("refetching edit prediction context after pause");
1028                }
1029                this.update(cx, |this, cx| {
1030                    this.refresh_context(project, buffer, cursor_position, cx);
1031                })
1032                .ok()
1033            }
1034        }));
1035    }
1036
1037    // Refresh the related excerpts asynchronously. Ensure the task runs to completion,
1038    // and avoid spawning more than one concurrent task.
1039    fn refresh_context(
1040        &mut self,
1041        project: Entity<Project>,
1042        buffer: Entity<language::Buffer>,
1043        cursor_position: language::Anchor,
1044        cx: &mut Context<Self>,
1045    ) {
1046        let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1047            return;
1048        };
1049
1050        zeta_project
1051            .refresh_context_task
1052            .get_or_insert(cx.spawn(async move |this, cx| {
1053                let related_excerpts = this
1054                    .update(cx, |this, cx| {
1055                        let Some(zeta_project) = this.projects.get(&project.entity_id()) else {
1056                            return Task::ready(anyhow::Ok(HashMap::default()));
1057                        };
1058
1059                        let ContextMode::Llm(options) = &this.options().context else {
1060                            return Task::ready(anyhow::Ok(HashMap::default()));
1061                        };
1062
1063                        find_related_excerpts(
1064                            buffer.clone(),
1065                            cursor_position,
1066                            &project,
1067                            zeta_project.events.iter(),
1068                            options,
1069                            cx,
1070                        )
1071                    })
1072                    .ok()?
1073                    .await
1074                    .log_err()
1075                    .unwrap_or_default();
1076                this.update(cx, |this, _cx| {
1077                    let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
1078                        return;
1079                    };
1080                    zeta_project.context = Some(related_excerpts);
1081                    zeta_project.refresh_context_task.take();
1082                })
1083                .ok()
1084            }));
1085    }
1086
1087    fn gather_nearby_diagnostics(
1088        cursor_offset: usize,
1089        diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
1090        snapshot: &BufferSnapshot,
1091        max_diagnostics_bytes: usize,
1092    ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
1093        // TODO: Could make this more efficient
1094        let mut diagnostic_groups = Vec::new();
1095        for (language_server_id, diagnostics) in diagnostic_sets {
1096            let mut groups = Vec::new();
1097            diagnostics.groups(*language_server_id, &mut groups, &snapshot);
1098            diagnostic_groups.extend(
1099                groups
1100                    .into_iter()
1101                    .map(|(_, group)| group.resolve::<usize>(&snapshot)),
1102            );
1103        }
1104
1105        // sort by proximity to cursor
1106        diagnostic_groups.sort_by_key(|group| {
1107            let range = &group.entries[group.primary_ix].range;
1108            if range.start >= cursor_offset {
1109                range.start - cursor_offset
1110            } else if cursor_offset >= range.end {
1111                cursor_offset - range.end
1112            } else {
1113                (cursor_offset - range.start).min(range.end - cursor_offset)
1114            }
1115        });
1116
1117        let mut results = Vec::new();
1118        let mut diagnostic_groups_truncated = false;
1119        let mut diagnostics_byte_count = 0;
1120        for group in diagnostic_groups {
1121            let raw_value = serde_json::value::to_raw_value(&group).unwrap();
1122            diagnostics_byte_count += raw_value.get().len();
1123            if diagnostics_byte_count > max_diagnostics_bytes {
1124                diagnostic_groups_truncated = true;
1125                break;
1126            }
1127            results.push(predict_edits_v3::DiagnosticGroup(raw_value));
1128        }
1129
1130        (results, diagnostic_groups_truncated)
1131    }
1132
1133    // TODO: Dedupe with similar code in request_prediction?
1134    pub fn cloud_request_for_zeta_cli(
1135        &mut self,
1136        project: &Entity<Project>,
1137        buffer: &Entity<Buffer>,
1138        position: language::Anchor,
1139        cx: &mut Context<Self>,
1140    ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
1141        let project_state = self.projects.get(&project.entity_id());
1142
1143        let index_state = project_state.map(|state| {
1144            state
1145                .syntax_index
1146                .read_with(cx, |index, _cx| index.state().clone())
1147        });
1148        let options = self.options.clone();
1149        let snapshot = buffer.read(cx).snapshot();
1150        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
1151            return Task::ready(Err(anyhow!("No file path for excerpt")));
1152        };
1153        let worktree_snapshots = project
1154            .read(cx)
1155            .worktrees(cx)
1156            .map(|worktree| worktree.read(cx).snapshot())
1157            .collect::<Vec<_>>();
1158
1159        let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
1160            let mut path = f.worktree.read(cx).absolutize(&f.path);
1161            if path.pop() { Some(path) } else { None }
1162        });
1163
1164        cx.background_spawn(async move {
1165            let index_state = if let Some(index_state) = index_state {
1166                Some(index_state.lock_owned().await)
1167            } else {
1168                None
1169            };
1170
1171            let cursor_point = position.to_point(&snapshot);
1172
1173            let debug_info = true;
1174            EditPredictionContext::gather_context(
1175                cursor_point,
1176                &snapshot,
1177                parent_abs_path.as_deref(),
1178                match &options.context {
1179                    ContextMode::Llm(_) => {
1180                        // TODO
1181                        panic!("Llm mode not supported in zeta cli yet");
1182                    }
1183                    ContextMode::Syntax(edit_prediction_context_options) => {
1184                        edit_prediction_context_options
1185                    }
1186                },
1187                index_state.as_deref(),
1188            )
1189            .context("Failed to select excerpt")
1190            .map(|context| {
1191                make_syntax_context_cloud_request(
1192                    excerpt_path.into(),
1193                    context,
1194                    // TODO pass everything
1195                    Vec::new(),
1196                    false,
1197                    Vec::new(),
1198                    false,
1199                    None,
1200                    debug_info,
1201                    &worktree_snapshots,
1202                    index_state.as_deref(),
1203                    Some(options.max_prompt_bytes),
1204                    options.prompt_format,
1205                )
1206            })
1207        })
1208    }
1209
1210    pub fn wait_for_initial_indexing(
1211        &mut self,
1212        project: &Entity<Project>,
1213        cx: &mut App,
1214    ) -> Task<Result<()>> {
1215        let zeta_project = self.get_or_init_zeta_project(project, cx);
1216        zeta_project
1217            .syntax_index
1218            .read(cx)
1219            .wait_for_initial_file_indexing(cx)
1220    }
1221}
1222
1223#[derive(Error, Debug)]
1224#[error(
1225    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1226)]
1227pub struct ZedUpdateRequiredError {
1228    minimum_version: SemanticVersion,
1229}
1230
1231fn make_syntax_context_cloud_request(
1232    excerpt_path: Arc<Path>,
1233    context: EditPredictionContext,
1234    events: Vec<predict_edits_v3::Event>,
1235    can_collect_data: bool,
1236    diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
1237    diagnostic_groups_truncated: bool,
1238    git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
1239    debug_info: bool,
1240    worktrees: &Vec<worktree::Snapshot>,
1241    index_state: Option<&SyntaxIndexState>,
1242    prompt_max_bytes: Option<usize>,
1243    prompt_format: PromptFormat,
1244) -> predict_edits_v3::PredictEditsRequest {
1245    let mut signatures = Vec::new();
1246    let mut declaration_to_signature_index = HashMap::default();
1247    let mut referenced_declarations = Vec::new();
1248
1249    for snippet in context.declarations {
1250        let project_entry_id = snippet.declaration.project_entry_id();
1251        let Some(path) = worktrees.iter().find_map(|worktree| {
1252            worktree.entry_for_id(project_entry_id).map(|entry| {
1253                let mut full_path = RelPathBuf::new();
1254                full_path.push(worktree.root_name());
1255                full_path.push(&entry.path);
1256                full_path
1257            })
1258        }) else {
1259            continue;
1260        };
1261
1262        let parent_index = index_state.and_then(|index_state| {
1263            snippet.declaration.parent().and_then(|parent| {
1264                add_signature(
1265                    parent,
1266                    &mut declaration_to_signature_index,
1267                    &mut signatures,
1268                    index_state,
1269                )
1270            })
1271        });
1272
1273        let (text, text_is_truncated) = snippet.declaration.item_text();
1274        referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
1275            path: path.as_std_path().into(),
1276            text: text.into(),
1277            range: snippet.declaration.item_line_range(),
1278            text_is_truncated,
1279            signature_range: snippet.declaration.signature_range_in_item_text(),
1280            parent_index,
1281            signature_score: snippet.score(DeclarationStyle::Signature),
1282            declaration_score: snippet.score(DeclarationStyle::Declaration),
1283            score_components: snippet.components,
1284        });
1285    }
1286
1287    let excerpt_parent = index_state.and_then(|index_state| {
1288        context
1289            .excerpt
1290            .parent_declarations
1291            .last()
1292            .and_then(|(parent, _)| {
1293                add_signature(
1294                    *parent,
1295                    &mut declaration_to_signature_index,
1296                    &mut signatures,
1297                    index_state,
1298                )
1299            })
1300    });
1301
1302    predict_edits_v3::PredictEditsRequest {
1303        excerpt_path,
1304        excerpt: context.excerpt_text.body,
1305        excerpt_line_range: context.excerpt.line_range,
1306        excerpt_range: context.excerpt.range,
1307        cursor_point: predict_edits_v3::Point {
1308            line: predict_edits_v3::Line(context.cursor_point.row),
1309            column: context.cursor_point.column,
1310        },
1311        referenced_declarations,
1312        included_files: vec![],
1313        signatures,
1314        excerpt_parent,
1315        events,
1316        can_collect_data,
1317        diagnostic_groups,
1318        diagnostic_groups_truncated,
1319        git_info,
1320        debug_info,
1321        prompt_max_bytes,
1322        prompt_format,
1323    }
1324}
1325
1326fn add_signature(
1327    declaration_id: DeclarationId,
1328    declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
1329    signatures: &mut Vec<Signature>,
1330    index: &SyntaxIndexState,
1331) -> Option<usize> {
1332    if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
1333        return Some(*signature_index);
1334    }
1335    let Some(parent_declaration) = index.declaration(declaration_id) else {
1336        log::error!("bug: missing parent declaration");
1337        return None;
1338    };
1339    let parent_index = parent_declaration.parent().and_then(|parent| {
1340        add_signature(parent, declaration_to_signature_index, signatures, index)
1341    });
1342    let (text, text_is_truncated) = parent_declaration.signature_text();
1343    let signature_index = signatures.len();
1344    signatures.push(Signature {
1345        text: text.into(),
1346        text_is_truncated,
1347        parent_index,
1348        range: parent_declaration.signature_line_range(),
1349    });
1350    declaration_to_signature_index.insert(declaration_id, signature_index);
1351    Some(signature_index)
1352}
1353
1354#[cfg(test)]
1355mod tests {
1356    use std::{
1357        path::{Path, PathBuf},
1358        sync::Arc,
1359    };
1360
1361    use client::UserStore;
1362    use clock::FakeSystemClock;
1363    use cloud_llm_client::predict_edits_v3::{self, Point};
1364    use edit_prediction_context::Line;
1365    use futures::{
1366        AsyncReadExt, StreamExt,
1367        channel::{mpsc, oneshot},
1368    };
1369    use gpui::{
1370        Entity, TestAppContext,
1371        http_client::{FakeHttpClient, Response},
1372        prelude::*,
1373    };
1374    use indoc::indoc;
1375    use language::{LanguageServerId, OffsetRangeExt as _};
1376    use pretty_assertions::{assert_eq, assert_matches};
1377    use project::{FakeFs, Project};
1378    use serde_json::json;
1379    use settings::SettingsStore;
1380    use util::path;
1381    use uuid::Uuid;
1382
1383    use crate::{BufferEditPrediction, Zeta};
1384
1385    #[gpui::test]
1386    async fn test_current_state(cx: &mut TestAppContext) {
1387        let (zeta, mut req_rx) = init_test(cx);
1388        let fs = FakeFs::new(cx.executor());
1389        fs.insert_tree(
1390            "/root",
1391            json!({
1392                "1.txt": "Hello!\nHow\nBye",
1393                "2.txt": "Hola!\nComo\nAdios"
1394            }),
1395        )
1396        .await;
1397        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1398
1399        zeta.update(cx, |zeta, cx| {
1400            zeta.register_project(&project, cx);
1401        });
1402
1403        let buffer1 = project
1404            .update(cx, |project, cx| {
1405                let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
1406                project.open_buffer(path, cx)
1407            })
1408            .await
1409            .unwrap();
1410        let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
1411        let position = snapshot1.anchor_before(language::Point::new(1, 3));
1412
1413        // Prediction for current file
1414
1415        let prediction_task = zeta.update(cx, |zeta, cx| {
1416            zeta.refresh_prediction(&project, &buffer1, position, cx)
1417        });
1418        let (_request, respond_tx) = req_rx.next().await.unwrap();
1419        respond_tx
1420            .send(predict_edits_v3::PredictEditsResponse {
1421                request_id: Uuid::new_v4(),
1422                edits: vec![predict_edits_v3::Edit {
1423                    path: Path::new(path!("root/1.txt")).into(),
1424                    range: Line(0)..Line(snapshot1.max_point().row + 1),
1425                    content: "Hello!\nHow are you?\nBye".into(),
1426                }],
1427                debug_info: None,
1428            })
1429            .unwrap();
1430        prediction_task.await.unwrap();
1431
1432        zeta.read_with(cx, |zeta, cx| {
1433            let prediction = zeta
1434                .current_prediction_for_buffer(&buffer1, &project, cx)
1435                .unwrap();
1436            assert_matches!(prediction, BufferEditPrediction::Local { .. });
1437        });
1438
1439        // Prediction for another file
1440        let prediction_task = zeta.update(cx, |zeta, cx| {
1441            zeta.refresh_prediction(&project, &buffer1, position, cx)
1442        });
1443        let (_request, respond_tx) = req_rx.next().await.unwrap();
1444        respond_tx
1445            .send(predict_edits_v3::PredictEditsResponse {
1446                request_id: Uuid::new_v4(),
1447                edits: vec![predict_edits_v3::Edit {
1448                    path: Path::new(path!("root/2.txt")).into(),
1449                    range: Line(0)..Line(snapshot1.max_point().row + 1),
1450                    content: "Hola!\nComo estas?\nAdios".into(),
1451                }],
1452                debug_info: None,
1453            })
1454            .unwrap();
1455        prediction_task.await.unwrap();
1456        zeta.read_with(cx, |zeta, cx| {
1457            let prediction = zeta
1458                .current_prediction_for_buffer(&buffer1, &project, cx)
1459                .unwrap();
1460            assert_matches!(
1461                prediction,
1462                BufferEditPrediction::Jump { prediction } if prediction.path.as_ref() == Path::new(path!("root/2.txt"))
1463            );
1464        });
1465
1466        let buffer2 = project
1467            .update(cx, |project, cx| {
1468                let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
1469                project.open_buffer(path, cx)
1470            })
1471            .await
1472            .unwrap();
1473
1474        zeta.read_with(cx, |zeta, cx| {
1475            let prediction = zeta
1476                .current_prediction_for_buffer(&buffer2, &project, cx)
1477                .unwrap();
1478            assert_matches!(prediction, BufferEditPrediction::Local { .. });
1479        });
1480    }
1481
1482    #[gpui::test]
1483    async fn test_simple_request(cx: &mut TestAppContext) {
1484        let (zeta, mut req_rx) = init_test(cx);
1485        let fs = FakeFs::new(cx.executor());
1486        fs.insert_tree(
1487            "/root",
1488            json!({
1489                "foo.md":  "Hello!\nHow\nBye"
1490            }),
1491        )
1492        .await;
1493        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1494
1495        let buffer = project
1496            .update(cx, |project, cx| {
1497                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1498                project.open_buffer(path, cx)
1499            })
1500            .await
1501            .unwrap();
1502        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1503        let position = snapshot.anchor_before(language::Point::new(1, 3));
1504
1505        let prediction_task = zeta.update(cx, |zeta, cx| {
1506            zeta.request_prediction(&project, &buffer, position, cx)
1507        });
1508
1509        let (request, respond_tx) = req_rx.next().await.unwrap();
1510        assert_eq!(
1511            request.excerpt_path.as_ref(),
1512            Path::new(path!("root/foo.md"))
1513        );
1514        assert_eq!(
1515            request.cursor_point,
1516            Point {
1517                line: Line(1),
1518                column: 3
1519            }
1520        );
1521
1522        respond_tx
1523            .send(predict_edits_v3::PredictEditsResponse {
1524                request_id: Uuid::new_v4(),
1525                edits: vec![predict_edits_v3::Edit {
1526                    path: Path::new(path!("root/foo.md")).into(),
1527                    range: Line(0)..Line(snapshot.max_point().row + 1),
1528                    content: "Hello!\nHow are you?\nBye".into(),
1529                }],
1530                debug_info: None,
1531            })
1532            .unwrap();
1533
1534        let prediction = prediction_task.await.unwrap().unwrap();
1535
1536        assert_eq!(prediction.edits.len(), 1);
1537        assert_eq!(
1538            prediction.edits[0].0.to_point(&snapshot).start,
1539            language::Point::new(1, 3)
1540        );
1541        assert_eq!(prediction.edits[0].1, " are you?");
1542    }
1543
1544    #[gpui::test]
1545    async fn test_request_events(cx: &mut TestAppContext) {
1546        let (zeta, mut req_rx) = init_test(cx);
1547        let fs = FakeFs::new(cx.executor());
1548        fs.insert_tree(
1549            "/root",
1550            json!({
1551                "foo.md": "Hello!\n\nBye"
1552            }),
1553        )
1554        .await;
1555        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1556
1557        let buffer = project
1558            .update(cx, |project, cx| {
1559                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1560                project.open_buffer(path, cx)
1561            })
1562            .await
1563            .unwrap();
1564
1565        zeta.update(cx, |zeta, cx| {
1566            zeta.register_buffer(&buffer, &project, cx);
1567        });
1568
1569        buffer.update(cx, |buffer, cx| {
1570            buffer.edit(vec![(7..7, "How")], None, cx);
1571        });
1572
1573        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1574        let position = snapshot.anchor_before(language::Point::new(1, 3));
1575
1576        let prediction_task = zeta.update(cx, |zeta, cx| {
1577            zeta.request_prediction(&project, &buffer, position, cx)
1578        });
1579
1580        let (request, respond_tx) = req_rx.next().await.unwrap();
1581
1582        assert_eq!(request.events.len(), 1);
1583        assert_eq!(
1584            request.events[0],
1585            predict_edits_v3::Event::BufferChange {
1586                path: Some(PathBuf::from(path!("root/foo.md"))),
1587                old_path: None,
1588                diff: indoc! {"
1589                        @@ -1,3 +1,3 @@
1590                         Hello!
1591                        -
1592                        +How
1593                         Bye
1594                    "}
1595                .to_string(),
1596                predicted: false
1597            }
1598        );
1599
1600        respond_tx
1601            .send(predict_edits_v3::PredictEditsResponse {
1602                request_id: Uuid::new_v4(),
1603                edits: vec![predict_edits_v3::Edit {
1604                    path: Path::new(path!("root/foo.md")).into(),
1605                    range: Line(0)..Line(snapshot.max_point().row + 1),
1606                    content: "Hello!\nHow are you?\nBye".into(),
1607                }],
1608                debug_info: None,
1609            })
1610            .unwrap();
1611
1612        let prediction = prediction_task.await.unwrap().unwrap();
1613
1614        assert_eq!(prediction.edits.len(), 1);
1615        assert_eq!(
1616            prediction.edits[0].0.to_point(&snapshot).start,
1617            language::Point::new(1, 3)
1618        );
1619        assert_eq!(prediction.edits[0].1, " are you?");
1620    }
1621
1622    #[gpui::test]
1623    async fn test_request_diagnostics(cx: &mut TestAppContext) {
1624        let (zeta, mut req_rx) = init_test(cx);
1625        let fs = FakeFs::new(cx.executor());
1626        fs.insert_tree(
1627            "/root",
1628            json!({
1629                "foo.md": "Hello!\nBye"
1630            }),
1631        )
1632        .await;
1633        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1634
1635        let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1636        let diagnostic = lsp::Diagnostic {
1637            range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1638            severity: Some(lsp::DiagnosticSeverity::ERROR),
1639            message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1640            ..Default::default()
1641        };
1642
1643        project.update(cx, |project, cx| {
1644            project.lsp_store().update(cx, |lsp_store, cx| {
1645                // Create some diagnostics
1646                lsp_store
1647                    .update_diagnostics(
1648                        LanguageServerId(0),
1649                        lsp::PublishDiagnosticsParams {
1650                            uri: path_to_buffer_uri.clone(),
1651                            diagnostics: vec![diagnostic],
1652                            version: None,
1653                        },
1654                        None,
1655                        language::DiagnosticSourceKind::Pushed,
1656                        &[],
1657                        cx,
1658                    )
1659                    .unwrap();
1660            });
1661        });
1662
1663        let buffer = project
1664            .update(cx, |project, cx| {
1665                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1666                project.open_buffer(path, cx)
1667            })
1668            .await
1669            .unwrap();
1670
1671        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1672        let position = snapshot.anchor_before(language::Point::new(0, 0));
1673
1674        let _prediction_task = zeta.update(cx, |zeta, cx| {
1675            zeta.request_prediction(&project, &buffer, position, cx)
1676        });
1677
1678        let (request, _respond_tx) = req_rx.next().await.unwrap();
1679
1680        assert_eq!(request.diagnostic_groups.len(), 1);
1681        let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1682            .unwrap();
1683        // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1684        assert_eq!(
1685            value,
1686            json!({
1687                "entries": [{
1688                    "range": {
1689                        "start": 8,
1690                        "end": 10
1691                    },
1692                    "diagnostic": {
1693                        "source": null,
1694                        "code": null,
1695                        "code_description": null,
1696                        "severity": 1,
1697                        "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1698                        "markdown": null,
1699                        "group_id": 0,
1700                        "is_primary": true,
1701                        "is_disk_based": false,
1702                        "is_unnecessary": false,
1703                        "source_kind": "Pushed",
1704                        "data": null,
1705                        "underline": true
1706                    }
1707                }],
1708                "primary_ix": 0
1709            })
1710        );
1711    }
1712
1713    fn init_test(
1714        cx: &mut TestAppContext,
1715    ) -> (
1716        Entity<Zeta>,
1717        mpsc::UnboundedReceiver<(
1718            predict_edits_v3::PredictEditsRequest,
1719            oneshot::Sender<predict_edits_v3::PredictEditsResponse>,
1720        )>,
1721    ) {
1722        cx.update(move |cx| {
1723            let settings_store = SettingsStore::test(cx);
1724            cx.set_global(settings_store);
1725            language::init(cx);
1726            Project::init_settings(cx);
1727
1728            let (req_tx, req_rx) = mpsc::unbounded();
1729
1730            let http_client = FakeHttpClient::create({
1731                move |req| {
1732                    let uri = req.uri().path().to_string();
1733                    let mut body = req.into_body();
1734                    let req_tx = req_tx.clone();
1735                    async move {
1736                        let resp = match uri.as_str() {
1737                            "/client/llm_tokens" => serde_json::to_string(&json!({
1738                                "token": "test"
1739                            }))
1740                            .unwrap(),
1741                            "/predict_edits/v3" => {
1742                                let mut buf = Vec::new();
1743                                body.read_to_end(&mut buf).await.ok();
1744                                let req = serde_json::from_slice(&buf).unwrap();
1745
1746                                let (res_tx, res_rx) = oneshot::channel();
1747                                req_tx.unbounded_send((req, res_tx)).unwrap();
1748                                serde_json::to_string(&res_rx.await?).unwrap()
1749                            }
1750                            _ => {
1751                                panic!("Unexpected path: {}", uri)
1752                            }
1753                        };
1754
1755                        Ok(Response::builder().body(resp.into()).unwrap())
1756                    }
1757                }
1758            });
1759
1760            let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1761            client.cloud_client().set_credentials(1, "test".into());
1762
1763            language_model::init(client.clone(), cx);
1764
1765            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1766            let zeta = Zeta::global(&client, &user_store, cx);
1767
1768            (zeta, req_rx)
1769        })
1770    }
1771}