zeta2.rs

   1use anyhow::{Context as _, Result, anyhow};
   2use chrono::TimeDelta;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
   5use cloud_llm_client::{
   6    EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, ZED_VERSION_HEADER_NAME,
   7};
   8use cloud_zeta2_prompt::DEFAULT_MAX_PROMPT_BYTES;
   9use edit_prediction_context::{
  10    DeclarationId, EditPredictionContext, EditPredictionExcerptOptions, SyntaxIndex,
  11    SyntaxIndexState,
  12};
  13use futures::AsyncReadExt as _;
  14use futures::channel::mpsc;
  15use gpui::http_client::Method;
  16use gpui::{
  17    App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
  18    http_client, prelude::*,
  19};
  20use language::BufferSnapshot;
  21use language::{Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
  22use language_model::{LlmApiToken, RefreshLlmTokenListener};
  23use project::Project;
  24use release_channel::AppVersion;
  25use std::collections::{HashMap, VecDeque, hash_map};
  26use std::path::Path;
  27use std::str::FromStr as _;
  28use std::sync::Arc;
  29use std::time::{Duration, Instant};
  30use thiserror::Error;
  31use util::rel_path::RelPathBuf;
  32use util::some_or_debug_panic;
  33use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  34
  35mod prediction;
  36mod provider;
  37
  38use crate::prediction::{EditPrediction, edits_from_response, interpolate_edits};
  39pub use provider::ZetaEditPredictionProvider;
  40
  41const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
  42
  43/// Maximum number of events to track.
  44const MAX_EVENT_COUNT: usize = 16;
  45
  46pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
  47    max_bytes: 512,
  48    min_bytes: 128,
  49    target_before_cursor_over_total_bytes: 0.5,
  50};
  51
  52pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
  53    excerpt: DEFAULT_EXCERPT_OPTIONS,
  54    max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
  55    max_diagnostic_bytes: 2048,
  56    prompt_format: PromptFormat::MarkedExcerpt,
  57};
  58
  59#[derive(Clone)]
  60struct ZetaGlobal(Entity<Zeta>);
  61
  62impl Global for ZetaGlobal {}
  63
  64pub struct Zeta {
  65    client: Arc<Client>,
  66    user_store: Entity<UserStore>,
  67    llm_token: LlmApiToken,
  68    _llm_token_subscription: Subscription,
  69    projects: HashMap<EntityId, ZetaProject>,
  70    options: ZetaOptions,
  71    update_required: bool,
  72    debug_tx: Option<mpsc::UnboundedSender<Result<PredictionDebugInfo, String>>>,
  73}
  74
  75#[derive(Debug, Clone, PartialEq)]
  76pub struct ZetaOptions {
  77    pub excerpt: EditPredictionExcerptOptions,
  78    pub max_prompt_bytes: usize,
  79    pub max_diagnostic_bytes: usize,
  80    pub prompt_format: predict_edits_v3::PromptFormat,
  81}
  82
  83pub struct PredictionDebugInfo {
  84    pub context: EditPredictionContext,
  85    pub retrieval_time: TimeDelta,
  86    pub request: RequestDebugInfo,
  87    pub buffer: WeakEntity<Buffer>,
  88    pub position: language::Anchor,
  89}
  90
  91pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
  92
  93struct ZetaProject {
  94    syntax_index: Entity<SyntaxIndex>,
  95    events: VecDeque<Event>,
  96    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
  97}
  98
  99struct RegisteredBuffer {
 100    snapshot: BufferSnapshot,
 101    _subscriptions: [gpui::Subscription; 2],
 102}
 103
 104#[derive(Clone)]
 105pub enum Event {
 106    BufferChange {
 107        old_snapshot: BufferSnapshot,
 108        new_snapshot: BufferSnapshot,
 109        timestamp: Instant,
 110    },
 111}
 112
 113impl Zeta {
 114    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 115        cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
 116    }
 117
 118    pub fn global(
 119        client: &Arc<Client>,
 120        user_store: &Entity<UserStore>,
 121        cx: &mut App,
 122    ) -> Entity<Self> {
 123        cx.try_global::<ZetaGlobal>()
 124            .map(|global| global.0.clone())
 125            .unwrap_or_else(|| {
 126                let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 127                cx.set_global(ZetaGlobal(zeta.clone()));
 128                zeta
 129            })
 130    }
 131
 132    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 133        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 134
 135        Self {
 136            projects: HashMap::new(),
 137            client,
 138            user_store,
 139            options: DEFAULT_OPTIONS,
 140            llm_token: LlmApiToken::default(),
 141            _llm_token_subscription: cx.subscribe(
 142                &refresh_llm_token_listener,
 143                |this, _listener, _event, cx| {
 144                    let client = this.client.clone();
 145                    let llm_token = this.llm_token.clone();
 146                    cx.spawn(async move |_this, _cx| {
 147                        llm_token.refresh(&client).await?;
 148                        anyhow::Ok(())
 149                    })
 150                    .detach_and_log_err(cx);
 151                },
 152            ),
 153            update_required: false,
 154            debug_tx: None,
 155        }
 156    }
 157
 158    pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<Result<PredictionDebugInfo, String>> {
 159        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
 160        self.debug_tx = Some(debug_watch_tx);
 161        debug_watch_rx
 162    }
 163
 164    pub fn options(&self) -> &ZetaOptions {
 165        &self.options
 166    }
 167
 168    pub fn set_options(&mut self, options: ZetaOptions) {
 169        self.options = options;
 170    }
 171
 172    pub fn clear_history(&mut self) {
 173        for zeta_project in self.projects.values_mut() {
 174            zeta_project.events.clear();
 175        }
 176    }
 177
 178    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 179        self.user_store.read(cx).edit_prediction_usage()
 180    }
 181
 182    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
 183        self.get_or_init_zeta_project(project, cx);
 184    }
 185
 186    pub fn register_buffer(
 187        &mut self,
 188        buffer: &Entity<Buffer>,
 189        project: &Entity<Project>,
 190        cx: &mut Context<Self>,
 191    ) {
 192        let zeta_project = self.get_or_init_zeta_project(project, cx);
 193        Self::register_buffer_impl(zeta_project, buffer, project, cx);
 194    }
 195
 196    fn get_or_init_zeta_project(
 197        &mut self,
 198        project: &Entity<Project>,
 199        cx: &mut App,
 200    ) -> &mut ZetaProject {
 201        self.projects
 202            .entry(project.entity_id())
 203            .or_insert_with(|| ZetaProject {
 204                syntax_index: cx.new(|cx| SyntaxIndex::new(project, cx)),
 205                events: VecDeque::new(),
 206                registered_buffers: HashMap::new(),
 207            })
 208    }
 209
 210    fn register_buffer_impl<'a>(
 211        zeta_project: &'a mut ZetaProject,
 212        buffer: &Entity<Buffer>,
 213        project: &Entity<Project>,
 214        cx: &mut Context<Self>,
 215    ) -> &'a mut RegisteredBuffer {
 216        let buffer_id = buffer.entity_id();
 217        match zeta_project.registered_buffers.entry(buffer_id) {
 218            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 219            hash_map::Entry::Vacant(entry) => {
 220                let snapshot = buffer.read(cx).snapshot();
 221                let project_entity_id = project.entity_id();
 222                entry.insert(RegisteredBuffer {
 223                    snapshot,
 224                    _subscriptions: [
 225                        cx.subscribe(buffer, {
 226                            let project = project.downgrade();
 227                            move |this, buffer, event, cx| {
 228                                if let language::BufferEvent::Edited = event
 229                                    && let Some(project) = project.upgrade()
 230                                {
 231                                    this.report_changes_for_buffer(&buffer, &project, cx);
 232                                }
 233                            }
 234                        }),
 235                        cx.observe_release(buffer, move |this, _buffer, _cx| {
 236                            let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
 237                            else {
 238                                return;
 239                            };
 240                            zeta_project.registered_buffers.remove(&buffer_id);
 241                        }),
 242                    ],
 243                })
 244            }
 245        }
 246    }
 247
 248    fn report_changes_for_buffer(
 249        &mut self,
 250        buffer: &Entity<Buffer>,
 251        project: &Entity<Project>,
 252        cx: &mut Context<Self>,
 253    ) -> BufferSnapshot {
 254        let zeta_project = self.get_or_init_zeta_project(project, cx);
 255        let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
 256
 257        let new_snapshot = buffer.read(cx).snapshot();
 258        if new_snapshot.version != registered_buffer.snapshot.version {
 259            let old_snapshot =
 260                std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
 261            Self::push_event(
 262                zeta_project,
 263                Event::BufferChange {
 264                    old_snapshot,
 265                    new_snapshot: new_snapshot.clone(),
 266                    timestamp: Instant::now(),
 267                },
 268            );
 269        }
 270
 271        new_snapshot
 272    }
 273
 274    fn push_event(zeta_project: &mut ZetaProject, event: Event) {
 275        let events = &mut zeta_project.events;
 276
 277        if let Some(Event::BufferChange {
 278            new_snapshot: last_new_snapshot,
 279            timestamp: last_timestamp,
 280            ..
 281        }) = events.back_mut()
 282        {
 283            // Coalesce edits for the same buffer when they happen one after the other.
 284            let Event::BufferChange {
 285                old_snapshot,
 286                new_snapshot,
 287                timestamp,
 288            } = &event;
 289
 290            if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
 291                && old_snapshot.remote_id() == last_new_snapshot.remote_id()
 292                && old_snapshot.version == last_new_snapshot.version
 293            {
 294                *last_new_snapshot = new_snapshot.clone();
 295                *last_timestamp = *timestamp;
 296                return;
 297            }
 298        }
 299
 300        if events.len() >= MAX_EVENT_COUNT {
 301            // These are halved instead of popping to improve prompt caching.
 302            events.drain(..MAX_EVENT_COUNT / 2);
 303        }
 304
 305        events.push_back(event);
 306    }
 307
 308    pub fn request_prediction(
 309        &mut self,
 310        project: &Entity<Project>,
 311        buffer: &Entity<Buffer>,
 312        position: language::Anchor,
 313        cx: &mut Context<Self>,
 314    ) -> Task<Result<Option<EditPrediction>>> {
 315        let project_state = self.projects.get(&project.entity_id());
 316
 317        let index_state = project_state.map(|state| {
 318            state
 319                .syntax_index
 320                .read_with(cx, |index, _cx| index.state().clone())
 321        });
 322        let options = self.options.clone();
 323        let snapshot = buffer.read(cx).snapshot();
 324        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx).into()) else {
 325            return Task::ready(Err(anyhow!("No file path for excerpt")));
 326        };
 327        let client = self.client.clone();
 328        let llm_token = self.llm_token.clone();
 329        let app_version = AppVersion::global(cx);
 330        let worktree_snapshots = project
 331            .read(cx)
 332            .worktrees(cx)
 333            .map(|worktree| worktree.read(cx).snapshot())
 334            .collect::<Vec<_>>();
 335        let debug_tx = self.debug_tx.clone();
 336
 337        let events = project_state
 338            .map(|state| {
 339                state
 340                    .events
 341                    .iter()
 342                    .filter_map(|event| match event {
 343                        Event::BufferChange {
 344                            old_snapshot,
 345                            new_snapshot,
 346                            ..
 347                        } => {
 348                            let path = new_snapshot.file().map(|f| f.full_path(cx));
 349
 350                            let old_path = old_snapshot.file().and_then(|f| {
 351                                let old_path = f.full_path(cx);
 352                                if Some(&old_path) != path.as_ref() {
 353                                    Some(old_path)
 354                                } else {
 355                                    None
 356                                }
 357                            });
 358
 359                            // TODO [zeta2] move to bg?
 360                            let diff =
 361                                language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
 362
 363                            if path == old_path && diff.is_empty() {
 364                                None
 365                            } else {
 366                                Some(predict_edits_v3::Event::BufferChange {
 367                                    old_path,
 368                                    path,
 369                                    diff,
 370                                    //todo: Actually detect if this edit was predicted or not
 371                                    predicted: false,
 372                                })
 373                            }
 374                        }
 375                    })
 376                    .collect::<Vec<_>>()
 377            })
 378            .unwrap_or_default();
 379
 380        let diagnostics = snapshot.diagnostic_sets().clone();
 381
 382        let request_task = cx.background_spawn({
 383            let snapshot = snapshot.clone();
 384            let buffer = buffer.clone();
 385            async move {
 386                let index_state = if let Some(index_state) = index_state {
 387                    Some(index_state.lock_owned().await)
 388                } else {
 389                    None
 390                };
 391
 392                let cursor_offset = position.to_offset(&snapshot);
 393                let cursor_point = cursor_offset.to_point(&snapshot);
 394
 395                let before_retrieval = chrono::Utc::now();
 396
 397                let Some(context) = EditPredictionContext::gather_context(
 398                    cursor_point,
 399                    &snapshot,
 400                    &options.excerpt,
 401                    index_state.as_deref(),
 402                ) else {
 403                    return Ok(None);
 404                };
 405
 406                let debug_context = if let Some(debug_tx) = debug_tx {
 407                    Some((debug_tx, context.clone()))
 408                } else {
 409                    None
 410                };
 411
 412                let (diagnostic_groups, diagnostic_groups_truncated) =
 413                    Self::gather_nearby_diagnostics(
 414                        cursor_offset,
 415                        &diagnostics,
 416                        &snapshot,
 417                        options.max_diagnostic_bytes,
 418                    );
 419
 420                let request = make_cloud_request(
 421                    excerpt_path,
 422                    context,
 423                    events,
 424                    // TODO data collection
 425                    false,
 426                    diagnostic_groups,
 427                    diagnostic_groups_truncated,
 428                    None,
 429                    debug_context.is_some(),
 430                    &worktree_snapshots,
 431                    index_state.as_deref(),
 432                    Some(options.max_prompt_bytes),
 433                    options.prompt_format,
 434                );
 435
 436                let retrieval_time = chrono::Utc::now() - before_retrieval;
 437                let response = Self::perform_request(client, llm_token, app_version, request).await;
 438
 439                if let Some((debug_tx, context)) = debug_context {
 440                    debug_tx
 441                        .unbounded_send(response.as_ref().map_err(|err| err.to_string()).and_then(
 442                            |response| {
 443                                let Some(request) =
 444                                    some_or_debug_panic(response.0.debug_info.clone())
 445                                else {
 446                                    return Err("Missing debug info".to_string());
 447                                };
 448                                Ok(PredictionDebugInfo {
 449                                    context,
 450                                    request,
 451                                    retrieval_time,
 452                                    buffer: buffer.downgrade(),
 453                                    position,
 454                                })
 455                            },
 456                        ))
 457                        .ok();
 458                }
 459
 460                let (response, usage) = response?;
 461                let edits = edits_from_response(&response.edits, &snapshot);
 462
 463                anyhow::Ok(Some((response.request_id, edits, usage)))
 464            }
 465        });
 466
 467        let buffer = buffer.clone();
 468
 469        cx.spawn(async move |this, cx| {
 470            match request_task.await {
 471                Ok(Some((id, edits, usage))) => {
 472                    if let Some(usage) = usage {
 473                        this.update(cx, |this, cx| {
 474                            this.user_store.update(cx, |user_store, cx| {
 475                                user_store.update_edit_prediction_usage(usage, cx);
 476                            });
 477                        })
 478                        .ok();
 479                    }
 480
 481                    // TODO telemetry: duration, etc
 482                    let Some((edits, snapshot, edit_preview_task)) =
 483                        buffer.read_with(cx, |buffer, cx| {
 484                            let new_snapshot = buffer.snapshot();
 485                            let edits: Arc<[_]> =
 486                                interpolate_edits(&snapshot, &new_snapshot, edits)?.into();
 487                            Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
 488                        })?
 489                    else {
 490                        return Ok(None);
 491                    };
 492
 493                    Ok(Some(EditPrediction {
 494                        id: id.into(),
 495                        edits,
 496                        snapshot,
 497                        edit_preview: edit_preview_task.await,
 498                    }))
 499                }
 500                Ok(None) => Ok(None),
 501                Err(err) => {
 502                    if err.is::<ZedUpdateRequiredError>() {
 503                        cx.update(|cx| {
 504                            this.update(cx, |this, _cx| {
 505                                this.update_required = true;
 506                            })
 507                            .ok();
 508
 509                            let error_message: SharedString = err.to_string().into();
 510                            show_app_notification(
 511                                NotificationId::unique::<ZedUpdateRequiredError>(),
 512                                cx,
 513                                move |cx| {
 514                                    cx.new(|cx| {
 515                                        ErrorMessagePrompt::new(error_message.clone(), cx)
 516                                            .with_link_button(
 517                                                "Update Zed",
 518                                                "https://zed.dev/releases",
 519                                            )
 520                                    })
 521                                },
 522                            );
 523                        })
 524                        .ok();
 525                    }
 526
 527                    Err(err)
 528                }
 529            }
 530        })
 531    }
 532
 533    async fn perform_request(
 534        client: Arc<Client>,
 535        llm_token: LlmApiToken,
 536        app_version: SemanticVersion,
 537        request: predict_edits_v3::PredictEditsRequest,
 538    ) -> Result<(
 539        predict_edits_v3::PredictEditsResponse,
 540        Option<EditPredictionUsage>,
 541    )> {
 542        let http_client = client.http_client();
 543        let mut token = llm_token.acquire(&client).await?;
 544        let mut did_retry = false;
 545
 546        loop {
 547            let request_builder = http_client::Request::builder().method(Method::POST);
 548            let request_builder =
 549                if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
 550                    request_builder.uri(predict_edits_url)
 551                } else {
 552                    request_builder.uri(
 553                        http_client
 554                            .build_zed_llm_url("/predict_edits/v3", &[])?
 555                            .as_ref(),
 556                    )
 557                };
 558            let request = request_builder
 559                .header("Content-Type", "application/json")
 560                .header("Authorization", format!("Bearer {}", token))
 561                .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 562                .body(serde_json::to_string(&request)?.into())?;
 563
 564            let mut response = http_client.send(request).await?;
 565
 566            if let Some(minimum_required_version) = response
 567                .headers()
 568                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
 569                .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
 570            {
 571                anyhow::ensure!(
 572                    app_version >= minimum_required_version,
 573                    ZedUpdateRequiredError {
 574                        minimum_version: minimum_required_version
 575                    }
 576                );
 577            }
 578
 579            if response.status().is_success() {
 580                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
 581
 582                let mut body = Vec::new();
 583                response.body_mut().read_to_end(&mut body).await?;
 584                return Ok((serde_json::from_slice(&body)?, usage));
 585            } else if !did_retry
 586                && response
 587                    .headers()
 588                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
 589                    .is_some()
 590            {
 591                did_retry = true;
 592                token = llm_token.refresh(&client).await?;
 593            } else {
 594                let mut body = String::new();
 595                response.body_mut().read_to_string(&mut body).await?;
 596                anyhow::bail!(
 597                    "error predicting edits.\nStatus: {:?}\nBody: {}",
 598                    response.status(),
 599                    body
 600                );
 601            }
 602        }
 603    }
 604
 605    fn gather_nearby_diagnostics(
 606        cursor_offset: usize,
 607        diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
 608        snapshot: &BufferSnapshot,
 609        max_diagnostics_bytes: usize,
 610    ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
 611        // TODO: Could make this more efficient
 612        let mut diagnostic_groups = Vec::new();
 613        for (language_server_id, diagnostics) in diagnostic_sets {
 614            let mut groups = Vec::new();
 615            diagnostics.groups(*language_server_id, &mut groups, &snapshot);
 616            diagnostic_groups.extend(
 617                groups
 618                    .into_iter()
 619                    .map(|(_, group)| group.resolve::<usize>(&snapshot)),
 620            );
 621        }
 622
 623        // sort by proximity to cursor
 624        diagnostic_groups.sort_by_key(|group| {
 625            let range = &group.entries[group.primary_ix].range;
 626            if range.start >= cursor_offset {
 627                range.start - cursor_offset
 628            } else if cursor_offset >= range.end {
 629                cursor_offset - range.end
 630            } else {
 631                (cursor_offset - range.start).min(range.end - cursor_offset)
 632            }
 633        });
 634
 635        let mut results = Vec::new();
 636        let mut diagnostic_groups_truncated = false;
 637        let mut diagnostics_byte_count = 0;
 638        for group in diagnostic_groups {
 639            let raw_value = serde_json::value::to_raw_value(&group).unwrap();
 640            diagnostics_byte_count += raw_value.get().len();
 641            if diagnostics_byte_count > max_diagnostics_bytes {
 642                diagnostic_groups_truncated = true;
 643                break;
 644            }
 645            results.push(predict_edits_v3::DiagnosticGroup(raw_value));
 646        }
 647
 648        (results, diagnostic_groups_truncated)
 649    }
 650
 651    // TODO: Dedupe with similar code in request_prediction?
 652    pub fn cloud_request_for_zeta_cli(
 653        &mut self,
 654        project: &Entity<Project>,
 655        buffer: &Entity<Buffer>,
 656        position: language::Anchor,
 657        cx: &mut Context<Self>,
 658    ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
 659        let project_state = self.projects.get(&project.entity_id());
 660
 661        let index_state = project_state.map(|state| {
 662            state
 663                .syntax_index
 664                .read_with(cx, |index, _cx| index.state().clone())
 665        });
 666        let options = self.options.clone();
 667        let snapshot = buffer.read(cx).snapshot();
 668        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
 669            return Task::ready(Err(anyhow!("No file path for excerpt")));
 670        };
 671        let worktree_snapshots = project
 672            .read(cx)
 673            .worktrees(cx)
 674            .map(|worktree| worktree.read(cx).snapshot())
 675            .collect::<Vec<_>>();
 676
 677        cx.background_spawn(async move {
 678            let index_state = if let Some(index_state) = index_state {
 679                Some(index_state.lock_owned().await)
 680            } else {
 681                None
 682            };
 683
 684            let cursor_point = position.to_point(&snapshot);
 685
 686            let debug_info = true;
 687            EditPredictionContext::gather_context(
 688                cursor_point,
 689                &snapshot,
 690                &options.excerpt,
 691                index_state.as_deref(),
 692            )
 693            .context("Failed to select excerpt")
 694            .map(|context| {
 695                make_cloud_request(
 696                    excerpt_path.into(),
 697                    context,
 698                    // TODO pass everything
 699                    Vec::new(),
 700                    false,
 701                    Vec::new(),
 702                    false,
 703                    None,
 704                    debug_info,
 705                    &worktree_snapshots,
 706                    index_state.as_deref(),
 707                    Some(options.max_prompt_bytes),
 708                    options.prompt_format,
 709                )
 710            })
 711        })
 712    }
 713}
 714
 715#[derive(Error, Debug)]
 716#[error(
 717    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
 718)]
 719pub struct ZedUpdateRequiredError {
 720    minimum_version: SemanticVersion,
 721}
 722
 723fn make_cloud_request(
 724    excerpt_path: Arc<Path>,
 725    context: EditPredictionContext,
 726    events: Vec<predict_edits_v3::Event>,
 727    can_collect_data: bool,
 728    diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
 729    diagnostic_groups_truncated: bool,
 730    git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
 731    debug_info: bool,
 732    worktrees: &Vec<worktree::Snapshot>,
 733    index_state: Option<&SyntaxIndexState>,
 734    prompt_max_bytes: Option<usize>,
 735    prompt_format: PromptFormat,
 736) -> predict_edits_v3::PredictEditsRequest {
 737    let mut signatures = Vec::new();
 738    let mut declaration_to_signature_index = HashMap::default();
 739    let mut referenced_declarations = Vec::new();
 740
 741    for snippet in context.declarations {
 742        let project_entry_id = snippet.declaration.project_entry_id();
 743        let Some(path) = worktrees.iter().find_map(|worktree| {
 744            worktree.entry_for_id(project_entry_id).map(|entry| {
 745                let mut full_path = RelPathBuf::new();
 746                full_path.push(worktree.root_name());
 747                full_path.push(&entry.path);
 748                full_path
 749            })
 750        }) else {
 751            continue;
 752        };
 753
 754        let parent_index = index_state.and_then(|index_state| {
 755            snippet.declaration.parent().and_then(|parent| {
 756                add_signature(
 757                    parent,
 758                    &mut declaration_to_signature_index,
 759                    &mut signatures,
 760                    index_state,
 761                )
 762            })
 763        });
 764
 765        let (text, text_is_truncated) = snippet.declaration.item_text();
 766        referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
 767            path: path.as_std_path().into(),
 768            text: text.into(),
 769            range: snippet.declaration.item_range(),
 770            text_is_truncated,
 771            signature_range: snippet.declaration.signature_range_in_item_text(),
 772            parent_index,
 773            score_components: snippet.score_components,
 774            signature_score: snippet.scores.signature,
 775            declaration_score: snippet.scores.declaration,
 776        });
 777    }
 778
 779    let excerpt_parent = index_state.and_then(|index_state| {
 780        context
 781            .excerpt
 782            .parent_declarations
 783            .last()
 784            .and_then(|(parent, _)| {
 785                add_signature(
 786                    *parent,
 787                    &mut declaration_to_signature_index,
 788                    &mut signatures,
 789                    index_state,
 790                )
 791            })
 792    });
 793
 794    predict_edits_v3::PredictEditsRequest {
 795        excerpt_path,
 796        excerpt: context.excerpt_text.body,
 797        excerpt_range: context.excerpt.range,
 798        cursor_offset: context.cursor_offset_in_excerpt,
 799        referenced_declarations,
 800        signatures,
 801        excerpt_parent,
 802        events,
 803        can_collect_data,
 804        diagnostic_groups,
 805        diagnostic_groups_truncated,
 806        git_info,
 807        debug_info,
 808        prompt_max_bytes,
 809        prompt_format,
 810    }
 811}
 812
 813fn add_signature(
 814    declaration_id: DeclarationId,
 815    declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
 816    signatures: &mut Vec<Signature>,
 817    index: &SyntaxIndexState,
 818) -> Option<usize> {
 819    if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
 820        return Some(*signature_index);
 821    }
 822    let Some(parent_declaration) = index.declaration(declaration_id) else {
 823        log::error!("bug: missing parent declaration");
 824        return None;
 825    };
 826    let parent_index = parent_declaration.parent().and_then(|parent| {
 827        add_signature(parent, declaration_to_signature_index, signatures, index)
 828    });
 829    let (text, text_is_truncated) = parent_declaration.signature_text();
 830    let signature_index = signatures.len();
 831    signatures.push(Signature {
 832        text: text.into(),
 833        text_is_truncated,
 834        parent_index,
 835        range: parent_declaration.signature_range(),
 836    });
 837    declaration_to_signature_index.insert(declaration_id, signature_index);
 838    Some(signature_index)
 839}
 840
 841#[cfg(test)]
 842mod tests {
 843    use std::{
 844        path::{Path, PathBuf},
 845        sync::Arc,
 846    };
 847
 848    use client::UserStore;
 849    use clock::FakeSystemClock;
 850    use cloud_llm_client::predict_edits_v3;
 851    use futures::{
 852        AsyncReadExt, StreamExt,
 853        channel::{mpsc, oneshot},
 854    };
 855    use gpui::{
 856        Entity, TestAppContext,
 857        http_client::{FakeHttpClient, Response},
 858        prelude::*,
 859    };
 860    use indoc::indoc;
 861    use language::{LanguageServerId, OffsetRangeExt as _};
 862    use project::{FakeFs, Project};
 863    use serde_json::json;
 864    use settings::SettingsStore;
 865    use util::path;
 866    use uuid::Uuid;
 867
 868    use crate::Zeta;
 869
 870    #[gpui::test]
 871    async fn test_simple_request(cx: &mut TestAppContext) {
 872        let (zeta, mut req_rx) = init_test(cx);
 873        let fs = FakeFs::new(cx.executor());
 874        fs.insert_tree(
 875            "/root",
 876            json!({
 877                "foo.md":  "Hello!\nHow\nBye"
 878            }),
 879        )
 880        .await;
 881        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 882
 883        let buffer = project
 884            .update(cx, |project, cx| {
 885                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 886                project.open_buffer(path, cx)
 887            })
 888            .await
 889            .unwrap();
 890        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 891        let position = snapshot.anchor_before(language::Point::new(1, 3));
 892
 893        let prediction_task = zeta.update(cx, |zeta, cx| {
 894            zeta.request_prediction(&project, &buffer, position, cx)
 895        });
 896
 897        let (request, respond_tx) = req_rx.next().await.unwrap();
 898        assert_eq!(
 899            request.excerpt_path.as_ref(),
 900            Path::new(path!("root/foo.md"))
 901        );
 902        assert_eq!(request.cursor_offset, 10);
 903
 904        respond_tx
 905            .send(predict_edits_v3::PredictEditsResponse {
 906                request_id: Uuid::new_v4(),
 907                edits: vec![predict_edits_v3::Edit {
 908                    path: Path::new(path!("root/foo.md")).into(),
 909                    range: 0..snapshot.len(),
 910                    content: "Hello!\nHow are you?\nBye".into(),
 911                }],
 912                debug_info: None,
 913            })
 914            .unwrap();
 915
 916        let prediction = prediction_task.await.unwrap().unwrap();
 917
 918        assert_eq!(prediction.edits.len(), 1);
 919        assert_eq!(
 920            prediction.edits[0].0.to_point(&snapshot).start,
 921            language::Point::new(1, 3)
 922        );
 923        assert_eq!(prediction.edits[0].1, " are you?");
 924    }
 925
 926    #[gpui::test]
 927    async fn test_request_events(cx: &mut TestAppContext) {
 928        let (zeta, mut req_rx) = init_test(cx);
 929        let fs = FakeFs::new(cx.executor());
 930        fs.insert_tree(
 931            "/root",
 932            json!({
 933                "foo.md": "Hello!\n\nBye"
 934            }),
 935        )
 936        .await;
 937        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 938
 939        let buffer = project
 940            .update(cx, |project, cx| {
 941                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 942                project.open_buffer(path, cx)
 943            })
 944            .await
 945            .unwrap();
 946
 947        zeta.update(cx, |zeta, cx| {
 948            zeta.register_buffer(&buffer, &project, cx);
 949        });
 950
 951        buffer.update(cx, |buffer, cx| {
 952            buffer.edit(vec![(7..7, "How")], None, cx);
 953        });
 954
 955        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 956        let position = snapshot.anchor_before(language::Point::new(1, 3));
 957
 958        let prediction_task = zeta.update(cx, |zeta, cx| {
 959            zeta.request_prediction(&project, &buffer, position, cx)
 960        });
 961
 962        let (request, respond_tx) = req_rx.next().await.unwrap();
 963
 964        assert_eq!(request.events.len(), 1);
 965        assert_eq!(
 966            request.events[0],
 967            predict_edits_v3::Event::BufferChange {
 968                path: Some(PathBuf::from(path!("root/foo.md"))),
 969                old_path: None,
 970                diff: indoc! {"
 971                        @@ -1,3 +1,3 @@
 972                         Hello!
 973                        -
 974                        +How
 975                         Bye
 976                    "}
 977                .to_string(),
 978                predicted: false
 979            }
 980        );
 981
 982        respond_tx
 983            .send(predict_edits_v3::PredictEditsResponse {
 984                request_id: Uuid::new_v4(),
 985                edits: vec![predict_edits_v3::Edit {
 986                    path: Path::new(path!("root/foo.md")).into(),
 987                    range: 0..snapshot.len(),
 988                    content: "Hello!\nHow are you?\nBye".into(),
 989                }],
 990                debug_info: None,
 991            })
 992            .unwrap();
 993
 994        let prediction = prediction_task.await.unwrap().unwrap();
 995
 996        assert_eq!(prediction.edits.len(), 1);
 997        assert_eq!(
 998            prediction.edits[0].0.to_point(&snapshot).start,
 999            language::Point::new(1, 3)
1000        );
1001        assert_eq!(prediction.edits[0].1, " are you?");
1002    }
1003
1004    #[gpui::test]
1005    async fn test_request_diagnostics(cx: &mut TestAppContext) {
1006        let (zeta, mut req_rx) = init_test(cx);
1007        let fs = FakeFs::new(cx.executor());
1008        fs.insert_tree(
1009            "/root",
1010            json!({
1011                "foo.md": "Hello!\nBye"
1012            }),
1013        )
1014        .await;
1015        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1016
1017        let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1018        let diagnostic = lsp::Diagnostic {
1019            range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1020            severity: Some(lsp::DiagnosticSeverity::ERROR),
1021            message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1022            ..Default::default()
1023        };
1024
1025        project.update(cx, |project, cx| {
1026            project.lsp_store().update(cx, |lsp_store, cx| {
1027                // Create some diagnostics
1028                lsp_store
1029                    .update_diagnostics(
1030                        LanguageServerId(0),
1031                        lsp::PublishDiagnosticsParams {
1032                            uri: path_to_buffer_uri.clone(),
1033                            diagnostics: vec![diagnostic],
1034                            version: None,
1035                        },
1036                        None,
1037                        language::DiagnosticSourceKind::Pushed,
1038                        &[],
1039                        cx,
1040                    )
1041                    .unwrap();
1042            });
1043        });
1044
1045        let buffer = project
1046            .update(cx, |project, cx| {
1047                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1048                project.open_buffer(path, cx)
1049            })
1050            .await
1051            .unwrap();
1052
1053        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1054        let position = snapshot.anchor_before(language::Point::new(0, 0));
1055
1056        let _prediction_task = zeta.update(cx, |zeta, cx| {
1057            zeta.request_prediction(&project, &buffer, position, cx)
1058        });
1059
1060        let (request, _respond_tx) = req_rx.next().await.unwrap();
1061
1062        assert_eq!(request.diagnostic_groups.len(), 1);
1063        let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1064            .unwrap();
1065        // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1066        assert_eq!(
1067            value,
1068            json!({
1069                "entries": [{
1070                    "range": {
1071                        "start": 8,
1072                        "end": 10
1073                    },
1074                    "diagnostic": {
1075                        "source": null,
1076                        "code": null,
1077                        "code_description": null,
1078                        "severity": 1,
1079                        "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1080                        "markdown": null,
1081                        "group_id": 0,
1082                        "is_primary": true,
1083                        "is_disk_based": false,
1084                        "is_unnecessary": false,
1085                        "source_kind": "Pushed",
1086                        "data": null,
1087                        "underline": true
1088                    }
1089                }],
1090                "primary_ix": 0
1091            })
1092        );
1093    }
1094
1095    fn init_test(
1096        cx: &mut TestAppContext,
1097    ) -> (
1098        Entity<Zeta>,
1099        mpsc::UnboundedReceiver<(
1100            predict_edits_v3::PredictEditsRequest,
1101            oneshot::Sender<predict_edits_v3::PredictEditsResponse>,
1102        )>,
1103    ) {
1104        cx.update(move |cx| {
1105            let settings_store = SettingsStore::test(cx);
1106            cx.set_global(settings_store);
1107            language::init(cx);
1108            Project::init_settings(cx);
1109
1110            let (req_tx, req_rx) = mpsc::unbounded();
1111
1112            let http_client = FakeHttpClient::create({
1113                move |req| {
1114                    let uri = req.uri().path().to_string();
1115                    let mut body = req.into_body();
1116                    let req_tx = req_tx.clone();
1117                    async move {
1118                        let resp = match uri.as_str() {
1119                            "/client/llm_tokens" => serde_json::to_string(&json!({
1120                                "token": "test"
1121                            }))
1122                            .unwrap(),
1123                            "/predict_edits/v3" => {
1124                                let mut buf = Vec::new();
1125                                body.read_to_end(&mut buf).await.ok();
1126                                let req = serde_json::from_slice(&buf).unwrap();
1127
1128                                let (res_tx, res_rx) = oneshot::channel();
1129                                req_tx.unbounded_send((req, res_tx)).unwrap();
1130                                serde_json::to_string(&res_rx.await.unwrap()).unwrap()
1131                            }
1132                            _ => {
1133                                panic!("Unexpected path: {}", uri)
1134                            }
1135                        };
1136
1137                        Ok(Response::builder().body(resp.into()).unwrap())
1138                    }
1139                }
1140            });
1141
1142            let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1143            client.cloud_client().set_credentials(1, "test".into());
1144
1145            language_model::init(client.clone(), cx);
1146
1147            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1148            let zeta = Zeta::global(&client, &user_store, cx);
1149            (zeta, req_rx)
1150        })
1151    }
1152}