zeta2.rs

   1use anyhow::{Context as _, Result, anyhow};
   2use chrono::TimeDelta;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
   5use cloud_llm_client::{
   6    AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
   7    ZED_VERSION_HEADER_NAME,
   8};
   9use cloud_zeta2_prompt::{DEFAULT_MAX_PROMPT_BYTES, build_prompt};
  10use collections::HashMap;
  11use edit_prediction_context::{
  12    DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
  13    EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionScoreOptions, Line,
  14    SyntaxIndex, SyntaxIndexState,
  15};
  16use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
  17use futures::AsyncReadExt as _;
  18use futures::channel::{mpsc, oneshot};
  19use gpui::http_client::{AsyncBody, Method};
  20use gpui::{
  21    App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
  22    http_client, prelude::*,
  23};
  24use language::{Anchor, Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
  25use language::{BufferSnapshot, OffsetRangeExt};
  26use language_model::{LlmApiToken, RefreshLlmTokenListener};
  27use project::Project;
  28use release_channel::AppVersion;
  29use serde::de::DeserializeOwned;
  30use std::collections::{VecDeque, hash_map};
  31use std::ops::Range;
  32use std::path::Path;
  33use std::str::FromStr as _;
  34use std::sync::Arc;
  35use std::time::{Duration, Instant};
  36use thiserror::Error;
  37use util::ResultExt as _;
  38use util::rel_path::RelPathBuf;
  39use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  40
  41mod merge_excerpts;
  42mod prediction;
  43mod provider;
  44mod related_excerpts;
  45
  46use crate::merge_excerpts::merge_excerpts;
  47use crate::prediction::EditPrediction;
  48pub use crate::related_excerpts::LlmContextOptions;
  49use crate::related_excerpts::find_related_excerpts;
  50pub use provider::ZetaEditPredictionProvider;
  51
  52const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
  53
  54/// Maximum number of events to track.
  55const MAX_EVENT_COUNT: usize = 16;
  56
  57pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
  58    max_bytes: 512,
  59    min_bytes: 128,
  60    target_before_cursor_over_total_bytes: 0.5,
  61};
  62
  63pub const DEFAULT_CONTEXT_OPTIONS: ContextMode = ContextMode::Llm(DEFAULT_LLM_CONTEXT_OPTIONS);
  64
  65pub const DEFAULT_LLM_CONTEXT_OPTIONS: LlmContextOptions = LlmContextOptions {
  66    excerpt: DEFAULT_EXCERPT_OPTIONS,
  67};
  68
  69pub const DEFAULT_SYNTAX_CONTEXT_OPTIONS: EditPredictionContextOptions =
  70    EditPredictionContextOptions {
  71        use_imports: true,
  72        max_retrieved_declarations: 0,
  73        excerpt: DEFAULT_EXCERPT_OPTIONS,
  74        score: EditPredictionScoreOptions {
  75            omit_excerpt_overlaps: true,
  76        },
  77    };
  78
  79pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
  80    context: DEFAULT_CONTEXT_OPTIONS,
  81    max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
  82    max_diagnostic_bytes: 2048,
  83    prompt_format: PromptFormat::DEFAULT,
  84    file_indexing_parallelism: 1,
  85};
  86
  87pub struct Zeta2FeatureFlag;
  88
  89impl FeatureFlag for Zeta2FeatureFlag {
  90    const NAME: &'static str = "zeta2";
  91
  92    fn enabled_for_staff() -> bool {
  93        false
  94    }
  95}
  96
  97#[derive(Clone)]
  98struct ZetaGlobal(Entity<Zeta>);
  99
 100impl Global for ZetaGlobal {}
 101
 102pub struct Zeta {
 103    client: Arc<Client>,
 104    user_store: Entity<UserStore>,
 105    llm_token: LlmApiToken,
 106    _llm_token_subscription: Subscription,
 107    projects: HashMap<EntityId, ZetaProject>,
 108    options: ZetaOptions,
 109    update_required: bool,
 110    debug_tx: Option<mpsc::UnboundedSender<PredictionDebugInfo>>,
 111}
 112
 113#[derive(Debug, Clone, PartialEq)]
 114pub struct ZetaOptions {
 115    pub context: ContextMode,
 116    pub max_prompt_bytes: usize,
 117    pub max_diagnostic_bytes: usize,
 118    pub prompt_format: predict_edits_v3::PromptFormat,
 119    pub file_indexing_parallelism: usize,
 120}
 121
 122#[derive(Debug, Clone, PartialEq)]
 123pub enum ContextMode {
 124    Llm(LlmContextOptions),
 125    Syntax(EditPredictionContextOptions),
 126}
 127
 128impl ContextMode {
 129    pub fn excerpt(&self) -> &EditPredictionExcerptOptions {
 130        match self {
 131            ContextMode::Llm(options) => &options.excerpt,
 132            ContextMode::Syntax(options) => &options.excerpt,
 133        }
 134    }
 135}
 136
 137pub struct PredictionDebugInfo {
 138    pub request: predict_edits_v3::PredictEditsRequest,
 139    pub retrieval_time: TimeDelta,
 140    pub buffer: WeakEntity<Buffer>,
 141    pub position: language::Anchor,
 142    pub local_prompt: Result<String, String>,
 143    pub response_rx: oneshot::Receiver<Result<predict_edits_v3::PredictEditsResponse, String>>,
 144}
 145
 146pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
 147
 148struct ZetaProject {
 149    syntax_index: Entity<SyntaxIndex>,
 150    events: VecDeque<Event>,
 151    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 152    current_prediction: Option<CurrentEditPrediction>,
 153    context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
 154    refresh_context_task: Option<Task<Option<()>>>,
 155    refresh_context_debounce_task: Option<Task<Option<()>>>,
 156    refresh_context_timestamp: Option<Instant>,
 157}
 158
 159#[derive(Debug, Clone)]
 160struct CurrentEditPrediction {
 161    pub requested_by_buffer_id: EntityId,
 162    pub prediction: EditPrediction,
 163}
 164
 165impl CurrentEditPrediction {
 166    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 167        let Some(new_edits) = self
 168            .prediction
 169            .interpolate(&self.prediction.buffer.read(cx))
 170        else {
 171            return false;
 172        };
 173
 174        if self.prediction.buffer != old_prediction.prediction.buffer {
 175            return true;
 176        }
 177
 178        let Some(old_edits) = old_prediction
 179            .prediction
 180            .interpolate(&old_prediction.prediction.buffer.read(cx))
 181        else {
 182            return true;
 183        };
 184
 185        // This reduces the occurrence of UI thrash from replacing edits
 186        //
 187        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 188        if self.requested_by_buffer_id == self.prediction.buffer.entity_id()
 189            && self.requested_by_buffer_id == old_prediction.prediction.buffer.entity_id()
 190            && old_edits.len() == 1
 191            && new_edits.len() == 1
 192        {
 193            let (old_range, old_text) = &old_edits[0];
 194            let (new_range, new_text) = &new_edits[0];
 195            new_range == old_range && new_text.starts_with(old_text)
 196        } else {
 197            true
 198        }
 199    }
 200}
 201
 202/// A prediction from the perspective of a buffer.
 203#[derive(Debug)]
 204enum BufferEditPrediction<'a> {
 205    Local { prediction: &'a EditPrediction },
 206    Jump { prediction: &'a EditPrediction },
 207}
 208
 209struct RegisteredBuffer {
 210    snapshot: BufferSnapshot,
 211    _subscriptions: [gpui::Subscription; 2],
 212}
 213
 214#[derive(Clone)]
 215pub enum Event {
 216    BufferChange {
 217        old_snapshot: BufferSnapshot,
 218        new_snapshot: BufferSnapshot,
 219        timestamp: Instant,
 220    },
 221}
 222
 223impl Event {
 224    pub fn to_request_event(&self, cx: &App) -> Option<predict_edits_v3::Event> {
 225        match self {
 226            Event::BufferChange {
 227                old_snapshot,
 228                new_snapshot,
 229                ..
 230            } => {
 231                let path = new_snapshot.file().map(|f| f.full_path(cx));
 232
 233                let old_path = old_snapshot.file().and_then(|f| {
 234                    let old_path = f.full_path(cx);
 235                    if Some(&old_path) != path.as_ref() {
 236                        Some(old_path)
 237                    } else {
 238                        None
 239                    }
 240                });
 241
 242                // TODO [zeta2] move to bg?
 243                let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
 244
 245                if path == old_path && diff.is_empty() {
 246                    None
 247                } else {
 248                    Some(predict_edits_v3::Event::BufferChange {
 249                        old_path,
 250                        path,
 251                        diff,
 252                        //todo: Actually detect if this edit was predicted or not
 253                        predicted: false,
 254                    })
 255                }
 256            }
 257        }
 258    }
 259}
 260
 261impl Zeta {
 262    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 263        cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
 264    }
 265
 266    pub fn global(
 267        client: &Arc<Client>,
 268        user_store: &Entity<UserStore>,
 269        cx: &mut App,
 270    ) -> Entity<Self> {
 271        cx.try_global::<ZetaGlobal>()
 272            .map(|global| global.0.clone())
 273            .unwrap_or_else(|| {
 274                let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 275                cx.set_global(ZetaGlobal(zeta.clone()));
 276                zeta
 277            })
 278    }
 279
 280    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 281        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 282
 283        Self {
 284            projects: HashMap::default(),
 285            client,
 286            user_store,
 287            options: DEFAULT_OPTIONS,
 288            llm_token: LlmApiToken::default(),
 289            _llm_token_subscription: cx.subscribe(
 290                &refresh_llm_token_listener,
 291                |this, _listener, _event, cx| {
 292                    let client = this.client.clone();
 293                    let llm_token = this.llm_token.clone();
 294                    cx.spawn(async move |_this, _cx| {
 295                        llm_token.refresh(&client).await?;
 296                        anyhow::Ok(())
 297                    })
 298                    .detach_and_log_err(cx);
 299                },
 300            ),
 301            update_required: false,
 302            debug_tx: None,
 303        }
 304    }
 305
 306    pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<PredictionDebugInfo> {
 307        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
 308        self.debug_tx = Some(debug_watch_tx);
 309        debug_watch_rx
 310    }
 311
 312    pub fn options(&self) -> &ZetaOptions {
 313        &self.options
 314    }
 315
 316    pub fn set_options(&mut self, options: ZetaOptions) {
 317        self.options = options;
 318    }
 319
 320    pub fn clear_history(&mut self) {
 321        for zeta_project in self.projects.values_mut() {
 322            zeta_project.events.clear();
 323        }
 324    }
 325
 326    pub fn history_for_project(&self, project: &Entity<Project>) -> impl Iterator<Item = &Event> {
 327        static EMPTY_EVENTS: VecDeque<Event> = VecDeque::new();
 328        self.projects
 329            .get(&project.entity_id())
 330            .map_or(&EMPTY_EVENTS, |project| &project.events)
 331            .iter()
 332    }
 333
 334    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 335        self.user_store.read(cx).edit_prediction_usage()
 336    }
 337
 338    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
 339        self.get_or_init_zeta_project(project, cx);
 340    }
 341
 342    pub fn register_buffer(
 343        &mut self,
 344        buffer: &Entity<Buffer>,
 345        project: &Entity<Project>,
 346        cx: &mut Context<Self>,
 347    ) {
 348        let zeta_project = self.get_or_init_zeta_project(project, cx);
 349        Self::register_buffer_impl(zeta_project, buffer, project, cx);
 350    }
 351
 352    fn get_or_init_zeta_project(
 353        &mut self,
 354        project: &Entity<Project>,
 355        cx: &mut App,
 356    ) -> &mut ZetaProject {
 357        self.projects
 358            .entry(project.entity_id())
 359            .or_insert_with(|| ZetaProject {
 360                syntax_index: cx.new(|cx| {
 361                    SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
 362                }),
 363                events: VecDeque::new(),
 364                registered_buffers: HashMap::default(),
 365                current_prediction: None,
 366                context: None,
 367                refresh_context_task: None,
 368                refresh_context_debounce_task: None,
 369                refresh_context_timestamp: None,
 370            })
 371    }
 372
 373    fn register_buffer_impl<'a>(
 374        zeta_project: &'a mut ZetaProject,
 375        buffer: &Entity<Buffer>,
 376        project: &Entity<Project>,
 377        cx: &mut Context<Self>,
 378    ) -> &'a mut RegisteredBuffer {
 379        let buffer_id = buffer.entity_id();
 380        match zeta_project.registered_buffers.entry(buffer_id) {
 381            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 382            hash_map::Entry::Vacant(entry) => {
 383                let snapshot = buffer.read(cx).snapshot();
 384                let project_entity_id = project.entity_id();
 385                entry.insert(RegisteredBuffer {
 386                    snapshot,
 387                    _subscriptions: [
 388                        cx.subscribe(buffer, {
 389                            let project = project.downgrade();
 390                            move |this, buffer, event, cx| {
 391                                if let language::BufferEvent::Edited = event
 392                                    && let Some(project) = project.upgrade()
 393                                {
 394                                    this.report_changes_for_buffer(&buffer, &project, cx);
 395                                }
 396                            }
 397                        }),
 398                        cx.observe_release(buffer, move |this, _buffer, _cx| {
 399                            let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
 400                            else {
 401                                return;
 402                            };
 403                            zeta_project.registered_buffers.remove(&buffer_id);
 404                        }),
 405                    ],
 406                })
 407            }
 408        }
 409    }
 410
 411    fn report_changes_for_buffer(
 412        &mut self,
 413        buffer: &Entity<Buffer>,
 414        project: &Entity<Project>,
 415        cx: &mut Context<Self>,
 416    ) -> BufferSnapshot {
 417        let zeta_project = self.get_or_init_zeta_project(project, cx);
 418        let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
 419
 420        let new_snapshot = buffer.read(cx).snapshot();
 421        if new_snapshot.version != registered_buffer.snapshot.version {
 422            let old_snapshot =
 423                std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
 424            Self::push_event(
 425                zeta_project,
 426                Event::BufferChange {
 427                    old_snapshot,
 428                    new_snapshot: new_snapshot.clone(),
 429                    timestamp: Instant::now(),
 430                },
 431            );
 432        }
 433
 434        new_snapshot
 435    }
 436
 437    fn push_event(zeta_project: &mut ZetaProject, event: Event) {
 438        let events = &mut zeta_project.events;
 439
 440        if let Some(Event::BufferChange {
 441            new_snapshot: last_new_snapshot,
 442            timestamp: last_timestamp,
 443            ..
 444        }) = events.back_mut()
 445        {
 446            // Coalesce edits for the same buffer when they happen one after the other.
 447            let Event::BufferChange {
 448                old_snapshot,
 449                new_snapshot,
 450                timestamp,
 451            } = &event;
 452
 453            if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
 454                && old_snapshot.remote_id() == last_new_snapshot.remote_id()
 455                && old_snapshot.version == last_new_snapshot.version
 456            {
 457                *last_new_snapshot = new_snapshot.clone();
 458                *last_timestamp = *timestamp;
 459                return;
 460            }
 461        }
 462
 463        if events.len() >= MAX_EVENT_COUNT {
 464            // These are halved instead of popping to improve prompt caching.
 465            events.drain(..MAX_EVENT_COUNT / 2);
 466        }
 467
 468        events.push_back(event);
 469    }
 470
 471    fn current_prediction_for_buffer(
 472        &self,
 473        buffer: &Entity<Buffer>,
 474        project: &Entity<Project>,
 475        cx: &App,
 476    ) -> Option<BufferEditPrediction<'_>> {
 477        let project_state = self.projects.get(&project.entity_id())?;
 478
 479        let CurrentEditPrediction {
 480            requested_by_buffer_id,
 481            prediction,
 482        } = project_state.current_prediction.as_ref()?;
 483
 484        if prediction.targets_buffer(buffer.read(cx), cx) {
 485            Some(BufferEditPrediction::Local { prediction })
 486        } else if *requested_by_buffer_id == buffer.entity_id() {
 487            Some(BufferEditPrediction::Jump { prediction })
 488        } else {
 489            None
 490        }
 491    }
 492
 493    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 494        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
 495            return;
 496        };
 497
 498        let Some(prediction) = project_state.current_prediction.take() else {
 499            return;
 500        };
 501        let request_id = prediction.prediction.id.into();
 502
 503        let client = self.client.clone();
 504        let llm_token = self.llm_token.clone();
 505        let app_version = AppVersion::global(cx);
 506        cx.spawn(async move |this, cx| {
 507            let url = if let Ok(predict_edits_url) = std::env::var("ZED_ACCEPT_PREDICTION_URL") {
 508                http_client::Url::parse(&predict_edits_url)?
 509            } else {
 510                client
 511                    .http_client()
 512                    .build_zed_llm_url("/predict_edits/accept", &[])?
 513            };
 514
 515            let response = cx
 516                .background_spawn(Self::send_api_request::<()>(
 517                    move |builder| {
 518                        let req = builder.uri(url.as_ref()).body(
 519                            serde_json::to_string(&AcceptEditPredictionBody { request_id })?.into(),
 520                        );
 521                        Ok(req?)
 522                    },
 523                    client,
 524                    llm_token,
 525                    app_version,
 526                ))
 527                .await;
 528
 529            Self::handle_api_response(&this, response, cx)?;
 530            anyhow::Ok(())
 531        })
 532        .detach_and_log_err(cx);
 533    }
 534
 535    fn discard_current_prediction(&mut self, project: &Entity<Project>) {
 536        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 537            project_state.current_prediction.take();
 538        };
 539    }
 540
 541    pub fn refresh_prediction(
 542        &mut self,
 543        project: &Entity<Project>,
 544        buffer: &Entity<Buffer>,
 545        position: language::Anchor,
 546        cx: &mut Context<Self>,
 547    ) -> Task<Result<()>> {
 548        let request_task = self.request_prediction(project, buffer, position, cx);
 549        let buffer = buffer.clone();
 550        let project = project.clone();
 551
 552        cx.spawn(async move |this, cx| {
 553            if let Some(prediction) = request_task.await? {
 554                this.update(cx, |this, cx| {
 555                    let project_state = this
 556                        .projects
 557                        .get_mut(&project.entity_id())
 558                        .context("Project not found")?;
 559
 560                    let new_prediction = CurrentEditPrediction {
 561                        requested_by_buffer_id: buffer.entity_id(),
 562                        prediction: prediction,
 563                    };
 564
 565                    if project_state
 566                        .current_prediction
 567                        .as_ref()
 568                        .is_none_or(|old_prediction| {
 569                            new_prediction.should_replace_prediction(&old_prediction, cx)
 570                        })
 571                    {
 572                        project_state.current_prediction = Some(new_prediction);
 573                    }
 574                    anyhow::Ok(())
 575                })??;
 576            }
 577            Ok(())
 578        })
 579    }
 580
 581    fn request_prediction(
 582        &mut self,
 583        project: &Entity<Project>,
 584        buffer: &Entity<Buffer>,
 585        position: language::Anchor,
 586        cx: &mut Context<Self>,
 587    ) -> Task<Result<Option<EditPrediction>>> {
 588        let project_state = self.projects.get(&project.entity_id());
 589
 590        let index_state = project_state.map(|state| {
 591            state
 592                .syntax_index
 593                .read_with(cx, |index, _cx| index.state().clone())
 594        });
 595        let options = self.options.clone();
 596        let snapshot = buffer.read(cx).snapshot();
 597        let Some(excerpt_path) = snapshot
 598            .file()
 599            .map(|path| -> Arc<Path> { path.full_path(cx).into() })
 600        else {
 601            return Task::ready(Err(anyhow!("No file path for excerpt")));
 602        };
 603        let client = self.client.clone();
 604        let llm_token = self.llm_token.clone();
 605        let app_version = AppVersion::global(cx);
 606        let worktree_snapshots = project
 607            .read(cx)
 608            .worktrees(cx)
 609            .map(|worktree| worktree.read(cx).snapshot())
 610            .collect::<Vec<_>>();
 611        let debug_tx = self.debug_tx.clone();
 612
 613        let events = project_state
 614            .map(|state| {
 615                state
 616                    .events
 617                    .iter()
 618                    .filter_map(|event| event.to_request_event(cx))
 619                    .collect::<Vec<_>>()
 620            })
 621            .unwrap_or_default();
 622
 623        let diagnostics = snapshot.diagnostic_sets().clone();
 624
 625        let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
 626            let mut path = f.worktree.read(cx).absolutize(&f.path);
 627            if path.pop() { Some(path) } else { None }
 628        });
 629
 630        // TODO data collection
 631        let can_collect_data = cx.is_staff();
 632
 633        let mut included_files = project_state
 634            .and_then(|project_state| project_state.context.as_ref())
 635            .unwrap_or(&HashMap::default())
 636            .iter()
 637            .filter_map(|(buffer, ranges)| {
 638                let buffer = buffer.read(cx);
 639                Some((
 640                    buffer.snapshot(),
 641                    buffer.file()?.full_path(cx).into(),
 642                    ranges.clone(),
 643                ))
 644            })
 645            .collect::<Vec<_>>();
 646
 647        let request_task = cx.background_spawn({
 648            let snapshot = snapshot.clone();
 649            let buffer = buffer.clone();
 650            async move {
 651                let index_state = if let Some(index_state) = index_state {
 652                    Some(index_state.lock_owned().await)
 653                } else {
 654                    None
 655                };
 656
 657                let cursor_offset = position.to_offset(&snapshot);
 658                let cursor_point = cursor_offset.to_point(&snapshot);
 659
 660                let before_retrieval = chrono::Utc::now();
 661
 662                let (diagnostic_groups, diagnostic_groups_truncated) =
 663                    Self::gather_nearby_diagnostics(
 664                        cursor_offset,
 665                        &diagnostics,
 666                        &snapshot,
 667                        options.max_diagnostic_bytes,
 668                    );
 669
 670                let request = match options.context {
 671                    ContextMode::Llm(context_options) => {
 672                        let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
 673                            cursor_point,
 674                            &snapshot,
 675                            &context_options.excerpt,
 676                            index_state.as_deref(),
 677                        ) else {
 678                            return Ok((None, None));
 679                        };
 680
 681                        let excerpt_anchor_range = snapshot.anchor_after(excerpt.range.start)
 682                            ..snapshot.anchor_before(excerpt.range.end);
 683
 684                        if let Some(buffer_ix) = included_files
 685                            .iter()
 686                            .position(|(buffer, _, _)| buffer.remote_id() == snapshot.remote_id())
 687                        {
 688                            let (buffer, _, ranges) = &mut included_files[buffer_ix];
 689                            let range_ix = ranges
 690                                .binary_search_by(|probe| {
 691                                    probe
 692                                        .start
 693                                        .cmp(&excerpt_anchor_range.start, buffer)
 694                                        .then(excerpt_anchor_range.end.cmp(&probe.end, buffer))
 695                                })
 696                                .unwrap_or_else(|ix| ix);
 697
 698                            ranges.insert(range_ix, excerpt_anchor_range);
 699                            let last_ix = included_files.len() - 1;
 700                            included_files.swap(buffer_ix, last_ix);
 701                        } else {
 702                            included_files.push((
 703                                snapshot,
 704                                excerpt_path.clone(),
 705                                vec![excerpt_anchor_range],
 706                            ));
 707                        }
 708
 709                        let included_files = included_files
 710                            .into_iter()
 711                            .map(|(buffer, path, ranges)| {
 712                                let excerpts = merge_excerpts(
 713                                    &buffer,
 714                                    ranges.iter().map(|range| {
 715                                        let point_range = range.to_point(&buffer);
 716                                        Line(point_range.start.row)..Line(point_range.end.row)
 717                                    }),
 718                                );
 719                                predict_edits_v3::IncludedFile {
 720                                    path,
 721                                    max_row: Line(buffer.max_point().row),
 722                                    excerpts,
 723                                }
 724                            })
 725                            .collect::<Vec<_>>();
 726
 727                        predict_edits_v3::PredictEditsRequest {
 728                            excerpt_path,
 729                            excerpt: String::new(),
 730                            excerpt_line_range: Line(0)..Line(0),
 731                            excerpt_range: 0..0,
 732                            cursor_point: predict_edits_v3::Point {
 733                                line: predict_edits_v3::Line(cursor_point.row),
 734                                column: cursor_point.column,
 735                            },
 736                            included_files,
 737                            referenced_declarations: vec![],
 738                            events,
 739                            can_collect_data,
 740                            diagnostic_groups,
 741                            diagnostic_groups_truncated,
 742                            debug_info: debug_tx.is_some(),
 743                            prompt_max_bytes: Some(options.max_prompt_bytes),
 744                            prompt_format: options.prompt_format,
 745                            // TODO [zeta2]
 746                            signatures: vec![],
 747                            excerpt_parent: None,
 748                            git_info: None,
 749                        }
 750                    }
 751                    ContextMode::Syntax(context_options) => {
 752                        let Some(context) = EditPredictionContext::gather_context(
 753                            cursor_point,
 754                            &snapshot,
 755                            parent_abs_path.as_deref(),
 756                            &context_options,
 757                            index_state.as_deref(),
 758                        ) else {
 759                            return Ok((None, None));
 760                        };
 761
 762                        make_syntax_context_cloud_request(
 763                            excerpt_path,
 764                            context,
 765                            events,
 766                            can_collect_data,
 767                            diagnostic_groups,
 768                            diagnostic_groups_truncated,
 769                            None,
 770                            debug_tx.is_some(),
 771                            &worktree_snapshots,
 772                            index_state.as_deref(),
 773                            Some(options.max_prompt_bytes),
 774                            options.prompt_format,
 775                        )
 776                    }
 777                };
 778
 779                let retrieval_time = chrono::Utc::now() - before_retrieval;
 780
 781                let debug_response_tx = if let Some(debug_tx) = &debug_tx {
 782                    let (response_tx, response_rx) = oneshot::channel();
 783
 784                    if !request.referenced_declarations.is_empty() || !request.signatures.is_empty()
 785                    {
 786                    } else {
 787                    };
 788
 789                    let local_prompt = build_prompt(&request)
 790                        .map(|(prompt, _)| prompt)
 791                        .map_err(|err| err.to_string());
 792
 793                    debug_tx
 794                        .unbounded_send(PredictionDebugInfo {
 795                            request: request.clone(),
 796                            retrieval_time,
 797                            buffer: buffer.downgrade(),
 798                            local_prompt,
 799                            position,
 800                            response_rx,
 801                        })
 802                        .ok();
 803                    Some(response_tx)
 804                } else {
 805                    None
 806                };
 807
 808                if cfg!(debug_assertions) && std::env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
 809                    if let Some(debug_response_tx) = debug_response_tx {
 810                        debug_response_tx
 811                            .send(Err("Request skipped".to_string()))
 812                            .ok();
 813                    }
 814                    anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
 815                }
 816
 817                let response =
 818                    Self::send_prediction_request(client, llm_token, app_version, request).await;
 819
 820                if let Some(debug_response_tx) = debug_response_tx {
 821                    debug_response_tx
 822                        .send(
 823                            response
 824                                .as_ref()
 825                                .map_err(|err| err.to_string())
 826                                .map(|response| response.0.clone()),
 827                        )
 828                        .ok();
 829                }
 830
 831                response.map(|(res, usage)| (Some(res), usage))
 832            }
 833        });
 834
 835        let buffer = buffer.clone();
 836
 837        cx.spawn({
 838            let project = project.clone();
 839            async move |this, cx| {
 840                let Some(response) = Self::handle_api_response(&this, request_task.await, cx)?
 841                else {
 842                    return Ok(None);
 843                };
 844
 845                // TODO telemetry: duration, etc
 846                Ok(EditPrediction::from_response(response, &snapshot, &buffer, &project, cx).await)
 847            }
 848        })
 849    }
 850
 851    async fn send_prediction_request(
 852        client: Arc<Client>,
 853        llm_token: LlmApiToken,
 854        app_version: SemanticVersion,
 855        request: predict_edits_v3::PredictEditsRequest,
 856    ) -> Result<(
 857        predict_edits_v3::PredictEditsResponse,
 858        Option<EditPredictionUsage>,
 859    )> {
 860        let url = if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
 861            http_client::Url::parse(&predict_edits_url)?
 862        } else {
 863            client
 864                .http_client()
 865                .build_zed_llm_url("/predict_edits/v3", &[])?
 866        };
 867
 868        Self::send_api_request(
 869            |builder| {
 870                let req = builder
 871                    .uri(url.as_ref())
 872                    .body(serde_json::to_string(&request)?.into());
 873                Ok(req?)
 874            },
 875            client,
 876            llm_token,
 877            app_version,
 878        )
 879        .await
 880    }
 881
 882    fn handle_api_response<T>(
 883        this: &WeakEntity<Self>,
 884        response: Result<(T, Option<EditPredictionUsage>)>,
 885        cx: &mut gpui::AsyncApp,
 886    ) -> Result<T> {
 887        match response {
 888            Ok((data, usage)) => {
 889                if let Some(usage) = usage {
 890                    this.update(cx, |this, cx| {
 891                        this.user_store.update(cx, |user_store, cx| {
 892                            user_store.update_edit_prediction_usage(usage, cx);
 893                        });
 894                    })
 895                    .ok();
 896                }
 897                Ok(data)
 898            }
 899            Err(err) => {
 900                if err.is::<ZedUpdateRequiredError>() {
 901                    cx.update(|cx| {
 902                        this.update(cx, |this, _cx| {
 903                            this.update_required = true;
 904                        })
 905                        .ok();
 906
 907                        let error_message: SharedString = err.to_string().into();
 908                        show_app_notification(
 909                            NotificationId::unique::<ZedUpdateRequiredError>(),
 910                            cx,
 911                            move |cx| {
 912                                cx.new(|cx| {
 913                                    ErrorMessagePrompt::new(error_message.clone(), cx)
 914                                        .with_link_button("Update Zed", "https://zed.dev/releases")
 915                                })
 916                            },
 917                        );
 918                    })
 919                    .ok();
 920                }
 921                Err(err)
 922            }
 923        }
 924    }
 925
 926    async fn send_api_request<Res>(
 927        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
 928        client: Arc<Client>,
 929        llm_token: LlmApiToken,
 930        app_version: SemanticVersion,
 931    ) -> Result<(Res, Option<EditPredictionUsage>)>
 932    where
 933        Res: DeserializeOwned,
 934    {
 935        let http_client = client.http_client();
 936        let mut token = llm_token.acquire(&client).await?;
 937        let mut did_retry = false;
 938
 939        loop {
 940            let request_builder = http_client::Request::builder().method(Method::POST);
 941
 942            let request = build(
 943                request_builder
 944                    .header("Content-Type", "application/json")
 945                    .header("Authorization", format!("Bearer {}", token))
 946                    .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
 947            )?;
 948
 949            let mut response = http_client.send(request).await?;
 950
 951            if let Some(minimum_required_version) = response
 952                .headers()
 953                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
 954                .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
 955            {
 956                anyhow::ensure!(
 957                    app_version >= minimum_required_version,
 958                    ZedUpdateRequiredError {
 959                        minimum_version: minimum_required_version
 960                    }
 961                );
 962            }
 963
 964            if response.status().is_success() {
 965                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
 966
 967                let mut body = Vec::new();
 968                response.body_mut().read_to_end(&mut body).await?;
 969                return Ok((serde_json::from_slice(&body)?, usage));
 970            } else if !did_retry
 971                && response
 972                    .headers()
 973                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
 974                    .is_some()
 975            {
 976                did_retry = true;
 977                token = llm_token.refresh(&client).await?;
 978            } else {
 979                let mut body = String::new();
 980                response.body_mut().read_to_string(&mut body).await?;
 981                anyhow::bail!(
 982                    "Request failed with status: {:?}\nBody: {}",
 983                    response.status(),
 984                    body
 985                );
 986            }
 987        }
 988    }
 989
 990    pub const CONTEXT_RETRIEVAL_IDLE_DURATION: Duration = Duration::from_secs(10);
 991    pub const CONTEXT_RETRIEVAL_DEBOUNCE_DURATION: Duration = Duration::from_secs(3);
 992
 993    // Refresh the related excerpts when the user just beguns editing after
 994    // an idle period, and after they pause editing.
 995    fn refresh_context_if_needed(
 996        &mut self,
 997        project: &Entity<Project>,
 998        buffer: &Entity<language::Buffer>,
 999        cursor_position: language::Anchor,
1000        cx: &mut Context<Self>,
1001    ) {
1002        if !matches!(&self.options().context, ContextMode::Llm { .. }) {
1003            return;
1004        }
1005
1006        let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1007            return;
1008        };
1009
1010        let now = Instant::now();
1011        let was_idle = zeta_project
1012            .refresh_context_timestamp
1013            .map_or(true, |timestamp| {
1014                now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION
1015            });
1016        zeta_project.refresh_context_timestamp = Some(now);
1017        zeta_project.refresh_context_debounce_task = Some(cx.spawn({
1018            let buffer = buffer.clone();
1019            let project = project.clone();
1020            async move |this, cx| {
1021                if was_idle {
1022                    log::debug!("refetching edit prediction context after idle");
1023                } else {
1024                    cx.background_executor()
1025                        .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION)
1026                        .await;
1027                    log::debug!("refetching edit prediction context after pause");
1028                }
1029                this.update(cx, |this, cx| {
1030                    this.refresh_context(project, buffer, cursor_position, cx);
1031                })
1032                .ok()
1033            }
1034        }));
1035    }
1036
1037    // Refresh the related excerpts asynchronously. Ensure the task runs to completion,
1038    // and avoid spawning more than one concurrent task.
1039    fn refresh_context(
1040        &mut self,
1041        project: Entity<Project>,
1042        buffer: Entity<language::Buffer>,
1043        cursor_position: language::Anchor,
1044        cx: &mut Context<Self>,
1045    ) {
1046        let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
1047            return;
1048        };
1049
1050        zeta_project.refresh_context_task = Some(cx.spawn(async move |this, cx| {
1051            let related_excerpts = this
1052                .update(cx, |this, cx| {
1053                    let Some(zeta_project) = this.projects.get(&project.entity_id()) else {
1054                        return Task::ready(anyhow::Ok(HashMap::default()));
1055                    };
1056
1057                    let ContextMode::Llm(options) = &this.options().context else {
1058                        return Task::ready(anyhow::Ok(HashMap::default()));
1059                    };
1060
1061                    find_related_excerpts(
1062                        buffer.clone(),
1063                        cursor_position,
1064                        &project,
1065                        zeta_project.events.iter(),
1066                        options,
1067                        cx,
1068                    )
1069                })
1070                .ok()?
1071                .await
1072                .log_err()
1073                .unwrap_or_default();
1074            this.update(cx, |this, _cx| {
1075                let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
1076                    return;
1077                };
1078                zeta_project.context = Some(related_excerpts);
1079                zeta_project.refresh_context_task.take();
1080            })
1081            .ok()
1082        }));
1083    }
1084
1085    fn gather_nearby_diagnostics(
1086        cursor_offset: usize,
1087        diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
1088        snapshot: &BufferSnapshot,
1089        max_diagnostics_bytes: usize,
1090    ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
1091        // TODO: Could make this more efficient
1092        let mut diagnostic_groups = Vec::new();
1093        for (language_server_id, diagnostics) in diagnostic_sets {
1094            let mut groups = Vec::new();
1095            diagnostics.groups(*language_server_id, &mut groups, &snapshot);
1096            diagnostic_groups.extend(
1097                groups
1098                    .into_iter()
1099                    .map(|(_, group)| group.resolve::<usize>(&snapshot)),
1100            );
1101        }
1102
1103        // sort by proximity to cursor
1104        diagnostic_groups.sort_by_key(|group| {
1105            let range = &group.entries[group.primary_ix].range;
1106            if range.start >= cursor_offset {
1107                range.start - cursor_offset
1108            } else if cursor_offset >= range.end {
1109                cursor_offset - range.end
1110            } else {
1111                (cursor_offset - range.start).min(range.end - cursor_offset)
1112            }
1113        });
1114
1115        let mut results = Vec::new();
1116        let mut diagnostic_groups_truncated = false;
1117        let mut diagnostics_byte_count = 0;
1118        for group in diagnostic_groups {
1119            let raw_value = serde_json::value::to_raw_value(&group).unwrap();
1120            diagnostics_byte_count += raw_value.get().len();
1121            if diagnostics_byte_count > max_diagnostics_bytes {
1122                diagnostic_groups_truncated = true;
1123                break;
1124            }
1125            results.push(predict_edits_v3::DiagnosticGroup(raw_value));
1126        }
1127
1128        (results, diagnostic_groups_truncated)
1129    }
1130
1131    // TODO: Dedupe with similar code in request_prediction?
1132    pub fn cloud_request_for_zeta_cli(
1133        &mut self,
1134        project: &Entity<Project>,
1135        buffer: &Entity<Buffer>,
1136        position: language::Anchor,
1137        cx: &mut Context<Self>,
1138    ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
1139        let project_state = self.projects.get(&project.entity_id());
1140
1141        let index_state = project_state.map(|state| {
1142            state
1143                .syntax_index
1144                .read_with(cx, |index, _cx| index.state().clone())
1145        });
1146        let options = self.options.clone();
1147        let snapshot = buffer.read(cx).snapshot();
1148        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
1149            return Task::ready(Err(anyhow!("No file path for excerpt")));
1150        };
1151        let worktree_snapshots = project
1152            .read(cx)
1153            .worktrees(cx)
1154            .map(|worktree| worktree.read(cx).snapshot())
1155            .collect::<Vec<_>>();
1156
1157        let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
1158            let mut path = f.worktree.read(cx).absolutize(&f.path);
1159            if path.pop() { Some(path) } else { None }
1160        });
1161
1162        cx.background_spawn(async move {
1163            let index_state = if let Some(index_state) = index_state {
1164                Some(index_state.lock_owned().await)
1165            } else {
1166                None
1167            };
1168
1169            let cursor_point = position.to_point(&snapshot);
1170
1171            let debug_info = true;
1172            EditPredictionContext::gather_context(
1173                cursor_point,
1174                &snapshot,
1175                parent_abs_path.as_deref(),
1176                match &options.context {
1177                    ContextMode::Llm(_) => {
1178                        // TODO
1179                        panic!("Llm mode not supported in zeta cli yet");
1180                    }
1181                    ContextMode::Syntax(edit_prediction_context_options) => {
1182                        edit_prediction_context_options
1183                    }
1184                },
1185                index_state.as_deref(),
1186            )
1187            .context("Failed to select excerpt")
1188            .map(|context| {
1189                make_syntax_context_cloud_request(
1190                    excerpt_path.into(),
1191                    context,
1192                    // TODO pass everything
1193                    Vec::new(),
1194                    false,
1195                    Vec::new(),
1196                    false,
1197                    None,
1198                    debug_info,
1199                    &worktree_snapshots,
1200                    index_state.as_deref(),
1201                    Some(options.max_prompt_bytes),
1202                    options.prompt_format,
1203                )
1204            })
1205        })
1206    }
1207
1208    pub fn wait_for_initial_indexing(
1209        &mut self,
1210        project: &Entity<Project>,
1211        cx: &mut App,
1212    ) -> Task<Result<()>> {
1213        let zeta_project = self.get_or_init_zeta_project(project, cx);
1214        zeta_project
1215            .syntax_index
1216            .read(cx)
1217            .wait_for_initial_file_indexing(cx)
1218    }
1219}
1220
1221#[derive(Error, Debug)]
1222#[error(
1223    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
1224)]
1225pub struct ZedUpdateRequiredError {
1226    minimum_version: SemanticVersion,
1227}
1228
1229fn make_syntax_context_cloud_request(
1230    excerpt_path: Arc<Path>,
1231    context: EditPredictionContext,
1232    events: Vec<predict_edits_v3::Event>,
1233    can_collect_data: bool,
1234    diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
1235    diagnostic_groups_truncated: bool,
1236    git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
1237    debug_info: bool,
1238    worktrees: &Vec<worktree::Snapshot>,
1239    index_state: Option<&SyntaxIndexState>,
1240    prompt_max_bytes: Option<usize>,
1241    prompt_format: PromptFormat,
1242) -> predict_edits_v3::PredictEditsRequest {
1243    let mut signatures = Vec::new();
1244    let mut declaration_to_signature_index = HashMap::default();
1245    let mut referenced_declarations = Vec::new();
1246
1247    for snippet in context.declarations {
1248        let project_entry_id = snippet.declaration.project_entry_id();
1249        let Some(path) = worktrees.iter().find_map(|worktree| {
1250            worktree.entry_for_id(project_entry_id).map(|entry| {
1251                let mut full_path = RelPathBuf::new();
1252                full_path.push(worktree.root_name());
1253                full_path.push(&entry.path);
1254                full_path
1255            })
1256        }) else {
1257            continue;
1258        };
1259
1260        let parent_index = index_state.and_then(|index_state| {
1261            snippet.declaration.parent().and_then(|parent| {
1262                add_signature(
1263                    parent,
1264                    &mut declaration_to_signature_index,
1265                    &mut signatures,
1266                    index_state,
1267                )
1268            })
1269        });
1270
1271        let (text, text_is_truncated) = snippet.declaration.item_text();
1272        referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
1273            path: path.as_std_path().into(),
1274            text: text.into(),
1275            range: snippet.declaration.item_line_range(),
1276            text_is_truncated,
1277            signature_range: snippet.declaration.signature_range_in_item_text(),
1278            parent_index,
1279            signature_score: snippet.score(DeclarationStyle::Signature),
1280            declaration_score: snippet.score(DeclarationStyle::Declaration),
1281            score_components: snippet.components,
1282        });
1283    }
1284
1285    let excerpt_parent = index_state.and_then(|index_state| {
1286        context
1287            .excerpt
1288            .parent_declarations
1289            .last()
1290            .and_then(|(parent, _)| {
1291                add_signature(
1292                    *parent,
1293                    &mut declaration_to_signature_index,
1294                    &mut signatures,
1295                    index_state,
1296                )
1297            })
1298    });
1299
1300    predict_edits_v3::PredictEditsRequest {
1301        excerpt_path,
1302        excerpt: context.excerpt_text.body,
1303        excerpt_line_range: context.excerpt.line_range,
1304        excerpt_range: context.excerpt.range,
1305        cursor_point: predict_edits_v3::Point {
1306            line: predict_edits_v3::Line(context.cursor_point.row),
1307            column: context.cursor_point.column,
1308        },
1309        referenced_declarations,
1310        included_files: vec![],
1311        signatures,
1312        excerpt_parent,
1313        events,
1314        can_collect_data,
1315        diagnostic_groups,
1316        diagnostic_groups_truncated,
1317        git_info,
1318        debug_info,
1319        prompt_max_bytes,
1320        prompt_format,
1321    }
1322}
1323
1324fn add_signature(
1325    declaration_id: DeclarationId,
1326    declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
1327    signatures: &mut Vec<Signature>,
1328    index: &SyntaxIndexState,
1329) -> Option<usize> {
1330    if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
1331        return Some(*signature_index);
1332    }
1333    let Some(parent_declaration) = index.declaration(declaration_id) else {
1334        log::error!("bug: missing parent declaration");
1335        return None;
1336    };
1337    let parent_index = parent_declaration.parent().and_then(|parent| {
1338        add_signature(parent, declaration_to_signature_index, signatures, index)
1339    });
1340    let (text, text_is_truncated) = parent_declaration.signature_text();
1341    let signature_index = signatures.len();
1342    signatures.push(Signature {
1343        text: text.into(),
1344        text_is_truncated,
1345        parent_index,
1346        range: parent_declaration.signature_line_range(),
1347    });
1348    declaration_to_signature_index.insert(declaration_id, signature_index);
1349    Some(signature_index)
1350}
1351
1352#[cfg(test)]
1353mod tests {
1354    use std::{
1355        path::{Path, PathBuf},
1356        sync::Arc,
1357    };
1358
1359    use client::UserStore;
1360    use clock::FakeSystemClock;
1361    use cloud_llm_client::predict_edits_v3::{self, Point};
1362    use edit_prediction_context::Line;
1363    use futures::{
1364        AsyncReadExt, StreamExt,
1365        channel::{mpsc, oneshot},
1366    };
1367    use gpui::{
1368        Entity, TestAppContext,
1369        http_client::{FakeHttpClient, Response},
1370        prelude::*,
1371    };
1372    use indoc::indoc;
1373    use language::{LanguageServerId, OffsetRangeExt as _};
1374    use pretty_assertions::{assert_eq, assert_matches};
1375    use project::{FakeFs, Project};
1376    use serde_json::json;
1377    use settings::SettingsStore;
1378    use util::path;
1379    use uuid::Uuid;
1380
1381    use crate::{BufferEditPrediction, Zeta};
1382
1383    #[gpui::test]
1384    async fn test_current_state(cx: &mut TestAppContext) {
1385        let (zeta, mut req_rx) = init_test(cx);
1386        let fs = FakeFs::new(cx.executor());
1387        fs.insert_tree(
1388            "/root",
1389            json!({
1390                "1.txt": "Hello!\nHow\nBye",
1391                "2.txt": "Hola!\nComo\nAdios"
1392            }),
1393        )
1394        .await;
1395        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1396
1397        zeta.update(cx, |zeta, cx| {
1398            zeta.register_project(&project, cx);
1399        });
1400
1401        let buffer1 = project
1402            .update(cx, |project, cx| {
1403                let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
1404                project.open_buffer(path, cx)
1405            })
1406            .await
1407            .unwrap();
1408        let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
1409        let position = snapshot1.anchor_before(language::Point::new(1, 3));
1410
1411        // Prediction for current file
1412
1413        let prediction_task = zeta.update(cx, |zeta, cx| {
1414            zeta.refresh_prediction(&project, &buffer1, position, cx)
1415        });
1416        let (_request, respond_tx) = req_rx.next().await.unwrap();
1417        respond_tx
1418            .send(predict_edits_v3::PredictEditsResponse {
1419                request_id: Uuid::new_v4(),
1420                edits: vec![predict_edits_v3::Edit {
1421                    path: Path::new(path!("root/1.txt")).into(),
1422                    range: Line(0)..Line(snapshot1.max_point().row + 1),
1423                    content: "Hello!\nHow are you?\nBye".into(),
1424                }],
1425                debug_info: None,
1426            })
1427            .unwrap();
1428        prediction_task.await.unwrap();
1429
1430        zeta.read_with(cx, |zeta, cx| {
1431            let prediction = zeta
1432                .current_prediction_for_buffer(&buffer1, &project, cx)
1433                .unwrap();
1434            assert_matches!(prediction, BufferEditPrediction::Local { .. });
1435        });
1436
1437        // Prediction for another file
1438        let prediction_task = zeta.update(cx, |zeta, cx| {
1439            zeta.refresh_prediction(&project, &buffer1, position, cx)
1440        });
1441        let (_request, respond_tx) = req_rx.next().await.unwrap();
1442        respond_tx
1443            .send(predict_edits_v3::PredictEditsResponse {
1444                request_id: Uuid::new_v4(),
1445                edits: vec![predict_edits_v3::Edit {
1446                    path: Path::new(path!("root/2.txt")).into(),
1447                    range: Line(0)..Line(snapshot1.max_point().row + 1),
1448                    content: "Hola!\nComo estas?\nAdios".into(),
1449                }],
1450                debug_info: None,
1451            })
1452            .unwrap();
1453        prediction_task.await.unwrap();
1454        zeta.read_with(cx, |zeta, cx| {
1455            let prediction = zeta
1456                .current_prediction_for_buffer(&buffer1, &project, cx)
1457                .unwrap();
1458            assert_matches!(
1459                prediction,
1460                BufferEditPrediction::Jump { prediction } if prediction.path.as_ref() == Path::new(path!("root/2.txt"))
1461            );
1462        });
1463
1464        let buffer2 = project
1465            .update(cx, |project, cx| {
1466                let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
1467                project.open_buffer(path, cx)
1468            })
1469            .await
1470            .unwrap();
1471
1472        zeta.read_with(cx, |zeta, cx| {
1473            let prediction = zeta
1474                .current_prediction_for_buffer(&buffer2, &project, cx)
1475                .unwrap();
1476            assert_matches!(prediction, BufferEditPrediction::Local { .. });
1477        });
1478    }
1479
1480    #[gpui::test]
1481    async fn test_simple_request(cx: &mut TestAppContext) {
1482        let (zeta, mut req_rx) = init_test(cx);
1483        let fs = FakeFs::new(cx.executor());
1484        fs.insert_tree(
1485            "/root",
1486            json!({
1487                "foo.md":  "Hello!\nHow\nBye"
1488            }),
1489        )
1490        .await;
1491        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1492
1493        let buffer = project
1494            .update(cx, |project, cx| {
1495                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1496                project.open_buffer(path, cx)
1497            })
1498            .await
1499            .unwrap();
1500        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1501        let position = snapshot.anchor_before(language::Point::new(1, 3));
1502
1503        let prediction_task = zeta.update(cx, |zeta, cx| {
1504            zeta.request_prediction(&project, &buffer, position, cx)
1505        });
1506
1507        let (request, respond_tx) = req_rx.next().await.unwrap();
1508        assert_eq!(
1509            request.excerpt_path.as_ref(),
1510            Path::new(path!("root/foo.md"))
1511        );
1512        assert_eq!(
1513            request.cursor_point,
1514            Point {
1515                line: Line(1),
1516                column: 3
1517            }
1518        );
1519
1520        respond_tx
1521            .send(predict_edits_v3::PredictEditsResponse {
1522                request_id: Uuid::new_v4(),
1523                edits: vec![predict_edits_v3::Edit {
1524                    path: Path::new(path!("root/foo.md")).into(),
1525                    range: Line(0)..Line(snapshot.max_point().row + 1),
1526                    content: "Hello!\nHow are you?\nBye".into(),
1527                }],
1528                debug_info: None,
1529            })
1530            .unwrap();
1531
1532        let prediction = prediction_task.await.unwrap().unwrap();
1533
1534        assert_eq!(prediction.edits.len(), 1);
1535        assert_eq!(
1536            prediction.edits[0].0.to_point(&snapshot).start,
1537            language::Point::new(1, 3)
1538        );
1539        assert_eq!(prediction.edits[0].1, " are you?");
1540    }
1541
1542    #[gpui::test]
1543    async fn test_request_events(cx: &mut TestAppContext) {
1544        let (zeta, mut req_rx) = init_test(cx);
1545        let fs = FakeFs::new(cx.executor());
1546        fs.insert_tree(
1547            "/root",
1548            json!({
1549                "foo.md": "Hello!\n\nBye"
1550            }),
1551        )
1552        .await;
1553        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1554
1555        let buffer = project
1556            .update(cx, |project, cx| {
1557                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1558                project.open_buffer(path, cx)
1559            })
1560            .await
1561            .unwrap();
1562
1563        zeta.update(cx, |zeta, cx| {
1564            zeta.register_buffer(&buffer, &project, cx);
1565        });
1566
1567        buffer.update(cx, |buffer, cx| {
1568            buffer.edit(vec![(7..7, "How")], None, cx);
1569        });
1570
1571        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1572        let position = snapshot.anchor_before(language::Point::new(1, 3));
1573
1574        let prediction_task = zeta.update(cx, |zeta, cx| {
1575            zeta.request_prediction(&project, &buffer, position, cx)
1576        });
1577
1578        let (request, respond_tx) = req_rx.next().await.unwrap();
1579
1580        assert_eq!(request.events.len(), 1);
1581        assert_eq!(
1582            request.events[0],
1583            predict_edits_v3::Event::BufferChange {
1584                path: Some(PathBuf::from(path!("root/foo.md"))),
1585                old_path: None,
1586                diff: indoc! {"
1587                        @@ -1,3 +1,3 @@
1588                         Hello!
1589                        -
1590                        +How
1591                         Bye
1592                    "}
1593                .to_string(),
1594                predicted: false
1595            }
1596        );
1597
1598        respond_tx
1599            .send(predict_edits_v3::PredictEditsResponse {
1600                request_id: Uuid::new_v4(),
1601                edits: vec![predict_edits_v3::Edit {
1602                    path: Path::new(path!("root/foo.md")).into(),
1603                    range: Line(0)..Line(snapshot.max_point().row + 1),
1604                    content: "Hello!\nHow are you?\nBye".into(),
1605                }],
1606                debug_info: None,
1607            })
1608            .unwrap();
1609
1610        let prediction = prediction_task.await.unwrap().unwrap();
1611
1612        assert_eq!(prediction.edits.len(), 1);
1613        assert_eq!(
1614            prediction.edits[0].0.to_point(&snapshot).start,
1615            language::Point::new(1, 3)
1616        );
1617        assert_eq!(prediction.edits[0].1, " are you?");
1618    }
1619
1620    #[gpui::test]
1621    async fn test_request_diagnostics(cx: &mut TestAppContext) {
1622        let (zeta, mut req_rx) = init_test(cx);
1623        let fs = FakeFs::new(cx.executor());
1624        fs.insert_tree(
1625            "/root",
1626            json!({
1627                "foo.md": "Hello!\nBye"
1628            }),
1629        )
1630        .await;
1631        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1632
1633        let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1634        let diagnostic = lsp::Diagnostic {
1635            range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1636            severity: Some(lsp::DiagnosticSeverity::ERROR),
1637            message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1638            ..Default::default()
1639        };
1640
1641        project.update(cx, |project, cx| {
1642            project.lsp_store().update(cx, |lsp_store, cx| {
1643                // Create some diagnostics
1644                lsp_store
1645                    .update_diagnostics(
1646                        LanguageServerId(0),
1647                        lsp::PublishDiagnosticsParams {
1648                            uri: path_to_buffer_uri.clone(),
1649                            diagnostics: vec![diagnostic],
1650                            version: None,
1651                        },
1652                        None,
1653                        language::DiagnosticSourceKind::Pushed,
1654                        &[],
1655                        cx,
1656                    )
1657                    .unwrap();
1658            });
1659        });
1660
1661        let buffer = project
1662            .update(cx, |project, cx| {
1663                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1664                project.open_buffer(path, cx)
1665            })
1666            .await
1667            .unwrap();
1668
1669        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1670        let position = snapshot.anchor_before(language::Point::new(0, 0));
1671
1672        let _prediction_task = zeta.update(cx, |zeta, cx| {
1673            zeta.request_prediction(&project, &buffer, position, cx)
1674        });
1675
1676        let (request, _respond_tx) = req_rx.next().await.unwrap();
1677
1678        assert_eq!(request.diagnostic_groups.len(), 1);
1679        let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1680            .unwrap();
1681        // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1682        assert_eq!(
1683            value,
1684            json!({
1685                "entries": [{
1686                    "range": {
1687                        "start": 8,
1688                        "end": 10
1689                    },
1690                    "diagnostic": {
1691                        "source": null,
1692                        "code": null,
1693                        "code_description": null,
1694                        "severity": 1,
1695                        "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1696                        "markdown": null,
1697                        "group_id": 0,
1698                        "is_primary": true,
1699                        "is_disk_based": false,
1700                        "is_unnecessary": false,
1701                        "source_kind": "Pushed",
1702                        "data": null,
1703                        "underline": true
1704                    }
1705                }],
1706                "primary_ix": 0
1707            })
1708        );
1709    }
1710
1711    fn init_test(
1712        cx: &mut TestAppContext,
1713    ) -> (
1714        Entity<Zeta>,
1715        mpsc::UnboundedReceiver<(
1716            predict_edits_v3::PredictEditsRequest,
1717            oneshot::Sender<predict_edits_v3::PredictEditsResponse>,
1718        )>,
1719    ) {
1720        cx.update(move |cx| {
1721            let settings_store = SettingsStore::test(cx);
1722            cx.set_global(settings_store);
1723            language::init(cx);
1724            Project::init_settings(cx);
1725
1726            let (req_tx, req_rx) = mpsc::unbounded();
1727
1728            let http_client = FakeHttpClient::create({
1729                move |req| {
1730                    let uri = req.uri().path().to_string();
1731                    let mut body = req.into_body();
1732                    let req_tx = req_tx.clone();
1733                    async move {
1734                        let resp = match uri.as_str() {
1735                            "/client/llm_tokens" => serde_json::to_string(&json!({
1736                                "token": "test"
1737                            }))
1738                            .unwrap(),
1739                            "/predict_edits/v3" => {
1740                                let mut buf = Vec::new();
1741                                body.read_to_end(&mut buf).await.ok();
1742                                let req = serde_json::from_slice(&buf).unwrap();
1743
1744                                let (res_tx, res_rx) = oneshot::channel();
1745                                req_tx.unbounded_send((req, res_tx)).unwrap();
1746                                serde_json::to_string(&res_rx.await?).unwrap()
1747                            }
1748                            _ => {
1749                                panic!("Unexpected path: {}", uri)
1750                            }
1751                        };
1752
1753                        Ok(Response::builder().body(resp.into()).unwrap())
1754                    }
1755                }
1756            });
1757
1758            let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1759            client.cloud_client().set_credentials(1, "test".into());
1760
1761            language_model::init(client.clone(), cx);
1762
1763            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1764            let zeta = Zeta::global(&client, &user_store, cx);
1765
1766            (zeta, req_rx)
1767        })
1768    }
1769}