zeta2.rs

   1use anyhow::{Context as _, Result, anyhow};
   2use chrono::TimeDelta;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
   5use cloud_llm_client::{
   6    EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, ZED_VERSION_HEADER_NAME,
   7};
   8use cloud_zeta2_prompt::DEFAULT_MAX_PROMPT_BYTES;
   9use edit_prediction_context::{
  10    DeclarationId, EditPredictionContext, EditPredictionExcerptOptions, SyntaxIndex,
  11    SyntaxIndexState,
  12};
  13use futures::AsyncReadExt as _;
  14use futures::channel::mpsc;
  15use gpui::http_client::Method;
  16use gpui::{
  17    App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
  18    http_client, prelude::*,
  19};
  20use language::BufferSnapshot;
  21use language::{Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
  22use language_model::{LlmApiToken, RefreshLlmTokenListener};
  23use project::Project;
  24use release_channel::AppVersion;
  25use std::collections::{HashMap, VecDeque, hash_map};
  26use std::path::Path;
  27use std::str::FromStr as _;
  28use std::sync::Arc;
  29use std::time::{Duration, Instant};
  30use thiserror::Error;
  31use util::rel_path::RelPathBuf;
  32use util::some_or_debug_panic;
  33use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  34
  35mod prediction;
  36mod provider;
  37
  38use crate::prediction::{EditPrediction, edits_from_response, interpolate_edits};
  39pub use provider::ZetaEditPredictionProvider;
  40
  41const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
  42
  43/// Maximum number of events to track.
  44const MAX_EVENT_COUNT: usize = 16;
  45
  46pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
  47    max_bytes: 512,
  48    min_bytes: 128,
  49    target_before_cursor_over_total_bytes: 0.5,
  50};
  51
  52pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
  53    excerpt: DEFAULT_EXCERPT_OPTIONS,
  54    max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
  55    max_diagnostic_bytes: 2048,
  56    prompt_format: PromptFormat::MarkedExcerpt,
  57};
  58
  59#[derive(Clone)]
  60struct ZetaGlobal(Entity<Zeta>);
  61
  62impl Global for ZetaGlobal {}
  63
  64pub struct Zeta {
  65    client: Arc<Client>,
  66    user_store: Entity<UserStore>,
  67    llm_token: LlmApiToken,
  68    _llm_token_subscription: Subscription,
  69    projects: HashMap<EntityId, ZetaProject>,
  70    options: ZetaOptions,
  71    update_required: bool,
  72    debug_tx: Option<mpsc::UnboundedSender<Result<PredictionDebugInfo, String>>>,
  73}
  74
  75#[derive(Debug, Clone, PartialEq)]
  76pub struct ZetaOptions {
  77    pub excerpt: EditPredictionExcerptOptions,
  78    pub max_prompt_bytes: usize,
  79    pub max_diagnostic_bytes: usize,
  80    pub prompt_format: predict_edits_v3::PromptFormat,
  81}
  82
  83pub struct PredictionDebugInfo {
  84    pub context: EditPredictionContext,
  85    pub retrieval_time: TimeDelta,
  86    pub request: RequestDebugInfo,
  87    pub buffer: WeakEntity<Buffer>,
  88    pub position: language::Anchor,
  89}
  90
  91pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
  92
  93struct ZetaProject {
  94    syntax_index: Entity<SyntaxIndex>,
  95    events: VecDeque<Event>,
  96    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
  97}
  98
  99struct RegisteredBuffer {
 100    snapshot: BufferSnapshot,
 101    _subscriptions: [gpui::Subscription; 2],
 102}
 103
 104#[derive(Clone)]
 105pub enum Event {
 106    BufferChange {
 107        old_snapshot: BufferSnapshot,
 108        new_snapshot: BufferSnapshot,
 109        timestamp: Instant,
 110    },
 111}
 112
 113impl Zeta {
 114    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 115        cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
 116    }
 117
 118    pub fn global(
 119        client: &Arc<Client>,
 120        user_store: &Entity<UserStore>,
 121        cx: &mut App,
 122    ) -> Entity<Self> {
 123        cx.try_global::<ZetaGlobal>()
 124            .map(|global| global.0.clone())
 125            .unwrap_or_else(|| {
 126                let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 127                cx.set_global(ZetaGlobal(zeta.clone()));
 128                zeta
 129            })
 130    }
 131
 132    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 133        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 134
 135        Self {
 136            projects: HashMap::new(),
 137            client,
 138            user_store,
 139            options: DEFAULT_OPTIONS,
 140            llm_token: LlmApiToken::default(),
 141            _llm_token_subscription: cx.subscribe(
 142                &refresh_llm_token_listener,
 143                |this, _listener, _event, cx| {
 144                    let client = this.client.clone();
 145                    let llm_token = this.llm_token.clone();
 146                    cx.spawn(async move |_this, _cx| {
 147                        llm_token.refresh(&client).await?;
 148                        anyhow::Ok(())
 149                    })
 150                    .detach_and_log_err(cx);
 151                },
 152            ),
 153            update_required: false,
 154            debug_tx: None,
 155        }
 156    }
 157
 158    pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<Result<PredictionDebugInfo, String>> {
 159        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
 160        self.debug_tx = Some(debug_watch_tx);
 161        debug_watch_rx
 162    }
 163
 164    pub fn options(&self) -> &ZetaOptions {
 165        &self.options
 166    }
 167
 168    pub fn set_options(&mut self, options: ZetaOptions) {
 169        self.options = options;
 170    }
 171
 172    pub fn clear_history(&mut self) {
 173        for zeta_project in self.projects.values_mut() {
 174            zeta_project.events.clear();
 175        }
 176    }
 177
 178    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 179        self.user_store.read(cx).edit_prediction_usage()
 180    }
 181
 182    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
 183        self.get_or_init_zeta_project(project, cx);
 184    }
 185
 186    pub fn register_buffer(
 187        &mut self,
 188        buffer: &Entity<Buffer>,
 189        project: &Entity<Project>,
 190        cx: &mut Context<Self>,
 191    ) {
 192        let zeta_project = self.get_or_init_zeta_project(project, cx);
 193        Self::register_buffer_impl(zeta_project, buffer, project, cx);
 194    }
 195
 196    fn get_or_init_zeta_project(
 197        &mut self,
 198        project: &Entity<Project>,
 199        cx: &mut App,
 200    ) -> &mut ZetaProject {
 201        self.projects
 202            .entry(project.entity_id())
 203            .or_insert_with(|| ZetaProject {
 204                syntax_index: cx.new(|cx| SyntaxIndex::new(project, cx)),
 205                events: VecDeque::new(),
 206                registered_buffers: HashMap::new(),
 207            })
 208    }
 209
 210    fn register_buffer_impl<'a>(
 211        zeta_project: &'a mut ZetaProject,
 212        buffer: &Entity<Buffer>,
 213        project: &Entity<Project>,
 214        cx: &mut Context<Self>,
 215    ) -> &'a mut RegisteredBuffer {
 216        let buffer_id = buffer.entity_id();
 217        match zeta_project.registered_buffers.entry(buffer_id) {
 218            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 219            hash_map::Entry::Vacant(entry) => {
 220                let snapshot = buffer.read(cx).snapshot();
 221                let project_entity_id = project.entity_id();
 222                entry.insert(RegisteredBuffer {
 223                    snapshot,
 224                    _subscriptions: [
 225                        cx.subscribe(buffer, {
 226                            let project = project.downgrade();
 227                            move |this, buffer, event, cx| {
 228                                if let language::BufferEvent::Edited = event
 229                                    && let Some(project) = project.upgrade()
 230                                {
 231                                    this.report_changes_for_buffer(&buffer, &project, cx);
 232                                }
 233                            }
 234                        }),
 235                        cx.observe_release(buffer, move |this, _buffer, _cx| {
 236                            let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
 237                            else {
 238                                return;
 239                            };
 240                            zeta_project.registered_buffers.remove(&buffer_id);
 241                        }),
 242                    ],
 243                })
 244            }
 245        }
 246    }
 247
 248    fn report_changes_for_buffer(
 249        &mut self,
 250        buffer: &Entity<Buffer>,
 251        project: &Entity<Project>,
 252        cx: &mut Context<Self>,
 253    ) -> BufferSnapshot {
 254        let zeta_project = self.get_or_init_zeta_project(project, cx);
 255        let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
 256
 257        let new_snapshot = buffer.read(cx).snapshot();
 258        if new_snapshot.version != registered_buffer.snapshot.version {
 259            let old_snapshot =
 260                std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
 261            Self::push_event(
 262                zeta_project,
 263                Event::BufferChange {
 264                    old_snapshot,
 265                    new_snapshot: new_snapshot.clone(),
 266                    timestamp: Instant::now(),
 267                },
 268            );
 269        }
 270
 271        new_snapshot
 272    }
 273
 274    fn push_event(zeta_project: &mut ZetaProject, event: Event) {
 275        let events = &mut zeta_project.events;
 276
 277        if let Some(Event::BufferChange {
 278            new_snapshot: last_new_snapshot,
 279            timestamp: last_timestamp,
 280            ..
 281        }) = events.back_mut()
 282        {
 283            // Coalesce edits for the same buffer when they happen one after the other.
 284            let Event::BufferChange {
 285                old_snapshot,
 286                new_snapshot,
 287                timestamp,
 288            } = &event;
 289
 290            if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
 291                && old_snapshot.remote_id() == last_new_snapshot.remote_id()
 292                && old_snapshot.version == last_new_snapshot.version
 293            {
 294                *last_new_snapshot = new_snapshot.clone();
 295                *last_timestamp = *timestamp;
 296                return;
 297            }
 298        }
 299
 300        if events.len() >= MAX_EVENT_COUNT {
 301            // These are halved instead of popping to improve prompt caching.
 302            events.drain(..MAX_EVENT_COUNT / 2);
 303        }
 304
 305        events.push_back(event);
 306    }
 307
 308    pub fn request_prediction(
 309        &mut self,
 310        project: &Entity<Project>,
 311        buffer: &Entity<Buffer>,
 312        position: language::Anchor,
 313        cx: &mut Context<Self>,
 314    ) -> Task<Result<Option<EditPrediction>>> {
 315        let project_state = self.projects.get(&project.entity_id());
 316
 317        let index_state = project_state.map(|state| {
 318            state
 319                .syntax_index
 320                .read_with(cx, |index, _cx| index.state().clone())
 321        });
 322        let options = self.options.clone();
 323        let snapshot = buffer.read(cx).snapshot();
 324        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx).into()) else {
 325            return Task::ready(Err(anyhow!("No file path for excerpt")));
 326        };
 327        let client = self.client.clone();
 328        let llm_token = self.llm_token.clone();
 329        let app_version = AppVersion::global(cx);
 330        let worktree_snapshots = project
 331            .read(cx)
 332            .worktrees(cx)
 333            .map(|worktree| worktree.read(cx).snapshot())
 334            .collect::<Vec<_>>();
 335        let debug_tx = self.debug_tx.clone();
 336
 337        let events = project_state
 338            .map(|state| {
 339                state
 340                    .events
 341                    .iter()
 342                    .map(|event| match event {
 343                        Event::BufferChange {
 344                            old_snapshot,
 345                            new_snapshot,
 346                            ..
 347                        } => {
 348                            let path = new_snapshot.file().map(|f| f.full_path(cx));
 349
 350                            let old_path = old_snapshot.file().and_then(|f| {
 351                                let old_path = f.full_path(cx);
 352                                if Some(&old_path) != path.as_ref() {
 353                                    Some(old_path)
 354                                } else {
 355                                    None
 356                                }
 357                            });
 358
 359                            predict_edits_v3::Event::BufferChange {
 360                                old_path,
 361                                path,
 362                                diff: language::unified_diff(
 363                                    &old_snapshot.text(),
 364                                    &new_snapshot.text(),
 365                                ),
 366                                //todo: Actually detect if this edit was predicted or not
 367                                predicted: false,
 368                            }
 369                        }
 370                    })
 371                    .collect::<Vec<_>>()
 372            })
 373            .unwrap_or_default();
 374
 375        let diagnostics = snapshot.diagnostic_sets().clone();
 376
 377        let request_task = cx.background_spawn({
 378            let snapshot = snapshot.clone();
 379            let buffer = buffer.clone();
 380            async move {
 381                let index_state = if let Some(index_state) = index_state {
 382                    Some(index_state.lock_owned().await)
 383                } else {
 384                    None
 385                };
 386
 387                let cursor_offset = position.to_offset(&snapshot);
 388                let cursor_point = cursor_offset.to_point(&snapshot);
 389
 390                let before_retrieval = chrono::Utc::now();
 391
 392                let Some(context) = EditPredictionContext::gather_context(
 393                    cursor_point,
 394                    &snapshot,
 395                    &options.excerpt,
 396                    index_state.as_deref(),
 397                ) else {
 398                    return Ok(None);
 399                };
 400
 401                let debug_context = if let Some(debug_tx) = debug_tx {
 402                    Some((debug_tx, context.clone()))
 403                } else {
 404                    None
 405                };
 406
 407                let (diagnostic_groups, diagnostic_groups_truncated) =
 408                    Self::gather_nearby_diagnostics(
 409                        cursor_offset,
 410                        &diagnostics,
 411                        &snapshot,
 412                        options.max_diagnostic_bytes,
 413                    );
 414
 415                let request = make_cloud_request(
 416                    excerpt_path,
 417                    context,
 418                    events,
 419                    // TODO data collection
 420                    false,
 421                    diagnostic_groups,
 422                    diagnostic_groups_truncated,
 423                    None,
 424                    debug_context.is_some(),
 425                    &worktree_snapshots,
 426                    index_state.as_deref(),
 427                    Some(options.max_prompt_bytes),
 428                    options.prompt_format,
 429                );
 430
 431                let retrieval_time = chrono::Utc::now() - before_retrieval;
 432                let response = Self::perform_request(client, llm_token, app_version, request).await;
 433
 434                if let Some((debug_tx, context)) = debug_context {
 435                    debug_tx
 436                        .unbounded_send(response.as_ref().map_err(|err| err.to_string()).and_then(
 437                            |response| {
 438                                let Some(request) =
 439                                    some_or_debug_panic(response.0.debug_info.clone())
 440                                else {
 441                                    return Err("Missing debug info".to_string());
 442                                };
 443                                Ok(PredictionDebugInfo {
 444                                    context,
 445                                    request,
 446                                    retrieval_time,
 447                                    buffer: buffer.downgrade(),
 448                                    position,
 449                                })
 450                            },
 451                        ))
 452                        .ok();
 453                }
 454
 455                let (response, usage) = response?;
 456                let edits = edits_from_response(&response.edits, &snapshot);
 457
 458                anyhow::Ok(Some((response.request_id, edits, usage)))
 459            }
 460        });
 461
 462        let buffer = buffer.clone();
 463
 464        cx.spawn(async move |this, cx| {
 465            match request_task.await {
 466                Ok(Some((id, edits, usage))) => {
 467                    if let Some(usage) = usage {
 468                        this.update(cx, |this, cx| {
 469                            this.user_store.update(cx, |user_store, cx| {
 470                                user_store.update_edit_prediction_usage(usage, cx);
 471                            });
 472                        })
 473                        .ok();
 474                    }
 475
 476                    // TODO telemetry: duration, etc
 477                    let Some((edits, snapshot, edit_preview_task)) =
 478                        buffer.read_with(cx, |buffer, cx| {
 479                            let new_snapshot = buffer.snapshot();
 480                            let edits: Arc<[_]> =
 481                                interpolate_edits(&snapshot, &new_snapshot, edits)?.into();
 482                            Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
 483                        })?
 484                    else {
 485                        return Ok(None);
 486                    };
 487
 488                    Ok(Some(EditPrediction {
 489                        id: id.into(),
 490                        edits,
 491                        snapshot,
 492                        edit_preview: edit_preview_task.await,
 493                    }))
 494                }
 495                Ok(None) => Ok(None),
 496                Err(err) => {
 497                    if err.is::<ZedUpdateRequiredError>() {
 498                        cx.update(|cx| {
 499                            this.update(cx, |this, _cx| {
 500                                this.update_required = true;
 501                            })
 502                            .ok();
 503
 504                            let error_message: SharedString = err.to_string().into();
 505                            show_app_notification(
 506                                NotificationId::unique::<ZedUpdateRequiredError>(),
 507                                cx,
 508                                move |cx| {
 509                                    cx.new(|cx| {
 510                                        ErrorMessagePrompt::new(error_message.clone(), cx)
 511                                            .with_link_button(
 512                                                "Update Zed",
 513                                                "https://zed.dev/releases",
 514                                            )
 515                                    })
 516                                },
 517                            );
 518                        })
 519                        .ok();
 520                    }
 521
 522                    Err(err)
 523                }
 524            }
 525        })
 526    }
 527
 528    async fn perform_request(
 529        client: Arc<Client>,
 530        llm_token: LlmApiToken,
 531        app_version: SemanticVersion,
 532        request: predict_edits_v3::PredictEditsRequest,
 533    ) -> Result<(
 534        predict_edits_v3::PredictEditsResponse,
 535        Option<EditPredictionUsage>,
 536    )> {
 537        let http_client = client.http_client();
 538        let mut token = llm_token.acquire(&client).await?;
 539        let mut did_retry = false;
 540
 541        loop {
 542            let request_builder = http_client::Request::builder().method(Method::POST);
 543            let request_builder =
 544                if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
 545                    request_builder.uri(predict_edits_url)
 546                } else {
 547                    request_builder.uri(
 548                        http_client
 549                            .build_zed_llm_url("/predict_edits/v3", &[])?
 550                            .as_ref(),
 551                    )
 552                };
 553            let request = request_builder
 554                .header("Content-Type", "application/json")
 555                .header("Authorization", format!("Bearer {}", token))
 556                .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 557                .body(serde_json::to_string(&request)?.into())?;
 558
 559            let mut response = http_client.send(request).await?;
 560
 561            if let Some(minimum_required_version) = response
 562                .headers()
 563                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
 564                .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
 565            {
 566                anyhow::ensure!(
 567                    app_version >= minimum_required_version,
 568                    ZedUpdateRequiredError {
 569                        minimum_version: minimum_required_version
 570                    }
 571                );
 572            }
 573
 574            if response.status().is_success() {
 575                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
 576
 577                let mut body = Vec::new();
 578                response.body_mut().read_to_end(&mut body).await?;
 579                return Ok((serde_json::from_slice(&body)?, usage));
 580            } else if !did_retry
 581                && response
 582                    .headers()
 583                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
 584                    .is_some()
 585            {
 586                did_retry = true;
 587                token = llm_token.refresh(&client).await?;
 588            } else {
 589                let mut body = String::new();
 590                response.body_mut().read_to_string(&mut body).await?;
 591                anyhow::bail!(
 592                    "error predicting edits.\nStatus: {:?}\nBody: {}",
 593                    response.status(),
 594                    body
 595                );
 596            }
 597        }
 598    }
 599
 600    fn gather_nearby_diagnostics(
 601        cursor_offset: usize,
 602        diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
 603        snapshot: &BufferSnapshot,
 604        max_diagnostics_bytes: usize,
 605    ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
 606        // TODO: Could make this more efficient
 607        let mut diagnostic_groups = Vec::new();
 608        for (language_server_id, diagnostics) in diagnostic_sets {
 609            let mut groups = Vec::new();
 610            diagnostics.groups(*language_server_id, &mut groups, &snapshot);
 611            diagnostic_groups.extend(
 612                groups
 613                    .into_iter()
 614                    .map(|(_, group)| group.resolve::<usize>(&snapshot)),
 615            );
 616        }
 617
 618        // sort by proximity to cursor
 619        diagnostic_groups.sort_by_key(|group| {
 620            let range = &group.entries[group.primary_ix].range;
 621            if range.start >= cursor_offset {
 622                range.start - cursor_offset
 623            } else if cursor_offset >= range.end {
 624                cursor_offset - range.end
 625            } else {
 626                (cursor_offset - range.start).min(range.end - cursor_offset)
 627            }
 628        });
 629
 630        let mut results = Vec::new();
 631        let mut diagnostic_groups_truncated = false;
 632        let mut diagnostics_byte_count = 0;
 633        for group in diagnostic_groups {
 634            let raw_value = serde_json::value::to_raw_value(&group).unwrap();
 635            diagnostics_byte_count += raw_value.get().len();
 636            if diagnostics_byte_count > max_diagnostics_bytes {
 637                diagnostic_groups_truncated = true;
 638                break;
 639            }
 640            results.push(predict_edits_v3::DiagnosticGroup(raw_value));
 641        }
 642
 643        (results, diagnostic_groups_truncated)
 644    }
 645
 646    // TODO: Dedupe with similar code in request_prediction?
 647    pub fn cloud_request_for_zeta_cli(
 648        &mut self,
 649        project: &Entity<Project>,
 650        buffer: &Entity<Buffer>,
 651        position: language::Anchor,
 652        cx: &mut Context<Self>,
 653    ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
 654        let project_state = self.projects.get(&project.entity_id());
 655
 656        let index_state = project_state.map(|state| {
 657            state
 658                .syntax_index
 659                .read_with(cx, |index, _cx| index.state().clone())
 660        });
 661        let options = self.options.clone();
 662        let snapshot = buffer.read(cx).snapshot();
 663        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
 664            return Task::ready(Err(anyhow!("No file path for excerpt")));
 665        };
 666        let worktree_snapshots = project
 667            .read(cx)
 668            .worktrees(cx)
 669            .map(|worktree| worktree.read(cx).snapshot())
 670            .collect::<Vec<_>>();
 671
 672        cx.background_spawn(async move {
 673            let index_state = if let Some(index_state) = index_state {
 674                Some(index_state.lock_owned().await)
 675            } else {
 676                None
 677            };
 678
 679            let cursor_point = position.to_point(&snapshot);
 680
 681            let debug_info = true;
 682            EditPredictionContext::gather_context(
 683                cursor_point,
 684                &snapshot,
 685                &options.excerpt,
 686                index_state.as_deref(),
 687            )
 688            .context("Failed to select excerpt")
 689            .map(|context| {
 690                make_cloud_request(
 691                    excerpt_path.into(),
 692                    context,
 693                    // TODO pass everything
 694                    Vec::new(),
 695                    false,
 696                    Vec::new(),
 697                    false,
 698                    None,
 699                    debug_info,
 700                    &worktree_snapshots,
 701                    index_state.as_deref(),
 702                    Some(options.max_prompt_bytes),
 703                    options.prompt_format,
 704                )
 705            })
 706        })
 707    }
 708}
 709
 710#[derive(Error, Debug)]
 711#[error(
 712    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
 713)]
 714pub struct ZedUpdateRequiredError {
 715    minimum_version: SemanticVersion,
 716}
 717
 718fn make_cloud_request(
 719    excerpt_path: Arc<Path>,
 720    context: EditPredictionContext,
 721    events: Vec<predict_edits_v3::Event>,
 722    can_collect_data: bool,
 723    diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
 724    diagnostic_groups_truncated: bool,
 725    git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
 726    debug_info: bool,
 727    worktrees: &Vec<worktree::Snapshot>,
 728    index_state: Option<&SyntaxIndexState>,
 729    prompt_max_bytes: Option<usize>,
 730    prompt_format: PromptFormat,
 731) -> predict_edits_v3::PredictEditsRequest {
 732    let mut signatures = Vec::new();
 733    let mut declaration_to_signature_index = HashMap::default();
 734    let mut referenced_declarations = Vec::new();
 735
 736    for snippet in context.snippets {
 737        let project_entry_id = snippet.declaration.project_entry_id();
 738        let Some(path) = worktrees.iter().find_map(|worktree| {
 739            worktree.entry_for_id(project_entry_id).map(|entry| {
 740                let mut full_path = RelPathBuf::new();
 741                full_path.push(worktree.root_name());
 742                full_path.push(&entry.path);
 743                full_path
 744            })
 745        }) else {
 746            continue;
 747        };
 748
 749        let parent_index = index_state.and_then(|index_state| {
 750            snippet.declaration.parent().and_then(|parent| {
 751                add_signature(
 752                    parent,
 753                    &mut declaration_to_signature_index,
 754                    &mut signatures,
 755                    index_state,
 756                )
 757            })
 758        });
 759
 760        let (text, text_is_truncated) = snippet.declaration.item_text();
 761        referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
 762            path: path.as_std_path().into(),
 763            text: text.into(),
 764            range: snippet.declaration.item_range(),
 765            text_is_truncated,
 766            signature_range: snippet.declaration.signature_range_in_item_text(),
 767            parent_index,
 768            score_components: snippet.score_components,
 769            signature_score: snippet.scores.signature,
 770            declaration_score: snippet.scores.declaration,
 771        });
 772    }
 773
 774    let excerpt_parent = index_state.and_then(|index_state| {
 775        context
 776            .excerpt
 777            .parent_declarations
 778            .last()
 779            .and_then(|(parent, _)| {
 780                add_signature(
 781                    *parent,
 782                    &mut declaration_to_signature_index,
 783                    &mut signatures,
 784                    index_state,
 785                )
 786            })
 787    });
 788
 789    predict_edits_v3::PredictEditsRequest {
 790        excerpt_path,
 791        excerpt: context.excerpt_text.body,
 792        excerpt_range: context.excerpt.range,
 793        cursor_offset: context.cursor_offset_in_excerpt,
 794        referenced_declarations,
 795        signatures,
 796        excerpt_parent,
 797        events,
 798        can_collect_data,
 799        diagnostic_groups,
 800        diagnostic_groups_truncated,
 801        git_info,
 802        debug_info,
 803        prompt_max_bytes,
 804        prompt_format,
 805    }
 806}
 807
 808fn add_signature(
 809    declaration_id: DeclarationId,
 810    declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
 811    signatures: &mut Vec<Signature>,
 812    index: &SyntaxIndexState,
 813) -> Option<usize> {
 814    if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
 815        return Some(*signature_index);
 816    }
 817    let Some(parent_declaration) = index.declaration(declaration_id) else {
 818        log::error!("bug: missing parent declaration");
 819        return None;
 820    };
 821    let parent_index = parent_declaration.parent().and_then(|parent| {
 822        add_signature(parent, declaration_to_signature_index, signatures, index)
 823    });
 824    let (text, text_is_truncated) = parent_declaration.signature_text();
 825    let signature_index = signatures.len();
 826    signatures.push(Signature {
 827        text: text.into(),
 828        text_is_truncated,
 829        parent_index,
 830        range: parent_declaration.signature_range(),
 831    });
 832    declaration_to_signature_index.insert(declaration_id, signature_index);
 833    Some(signature_index)
 834}
 835
 836#[cfg(test)]
 837mod tests {
 838    use std::{
 839        path::{Path, PathBuf},
 840        sync::Arc,
 841    };
 842
 843    use client::UserStore;
 844    use clock::FakeSystemClock;
 845    use cloud_llm_client::predict_edits_v3;
 846    use futures::{
 847        AsyncReadExt, StreamExt,
 848        channel::{mpsc, oneshot},
 849    };
 850    use gpui::{
 851        Entity, TestAppContext,
 852        http_client::{FakeHttpClient, Response},
 853        prelude::*,
 854    };
 855    use indoc::indoc;
 856    use language::{LanguageServerId, OffsetRangeExt as _};
 857    use project::{FakeFs, Project};
 858    use serde_json::json;
 859    use settings::SettingsStore;
 860    use util::path;
 861    use uuid::Uuid;
 862
 863    use crate::Zeta;
 864
 865    #[gpui::test]
 866    async fn test_simple_request(cx: &mut TestAppContext) {
 867        let (zeta, mut req_rx) = init_test(cx);
 868        let fs = FakeFs::new(cx.executor());
 869        fs.insert_tree(
 870            "/root",
 871            json!({
 872                "foo.md":  "Hello!\nHow\nBye"
 873            }),
 874        )
 875        .await;
 876        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 877
 878        let buffer = project
 879            .update(cx, |project, cx| {
 880                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 881                project.open_buffer(path, cx)
 882            })
 883            .await
 884            .unwrap();
 885        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 886        let position = snapshot.anchor_before(language::Point::new(1, 3));
 887
 888        let prediction_task = zeta.update(cx, |zeta, cx| {
 889            zeta.request_prediction(&project, &buffer, position, cx)
 890        });
 891
 892        let (request, respond_tx) = req_rx.next().await.unwrap();
 893        assert_eq!(
 894            request.excerpt_path.as_ref(),
 895            Path::new(path!("root/foo.md"))
 896        );
 897        assert_eq!(request.cursor_offset, 10);
 898
 899        respond_tx
 900            .send(predict_edits_v3::PredictEditsResponse {
 901                request_id: Uuid::new_v4(),
 902                edits: vec![predict_edits_v3::Edit {
 903                    path: Path::new(path!("root/foo.md")).into(),
 904                    range: 0..snapshot.len(),
 905                    content: "Hello!\nHow are you?\nBye".into(),
 906                }],
 907                debug_info: None,
 908            })
 909            .unwrap();
 910
 911        let prediction = prediction_task.await.unwrap().unwrap();
 912
 913        assert_eq!(prediction.edits.len(), 1);
 914        assert_eq!(
 915            prediction.edits[0].0.to_point(&snapshot).start,
 916            language::Point::new(1, 3)
 917        );
 918        assert_eq!(prediction.edits[0].1, " are you?");
 919    }
 920
 921    #[gpui::test]
 922    async fn test_request_events(cx: &mut TestAppContext) {
 923        let (zeta, mut req_rx) = init_test(cx);
 924        let fs = FakeFs::new(cx.executor());
 925        fs.insert_tree(
 926            "/root",
 927            json!({
 928                "foo.md": "Hello!\n\nBye"
 929            }),
 930        )
 931        .await;
 932        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 933
 934        let buffer = project
 935            .update(cx, |project, cx| {
 936                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
 937                project.open_buffer(path, cx)
 938            })
 939            .await
 940            .unwrap();
 941
 942        zeta.update(cx, |zeta, cx| {
 943            zeta.register_buffer(&buffer, &project, cx);
 944        });
 945
 946        buffer.update(cx, |buffer, cx| {
 947            buffer.edit(vec![(7..7, "How")], None, cx);
 948        });
 949
 950        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 951        let position = snapshot.anchor_before(language::Point::new(1, 3));
 952
 953        let prediction_task = zeta.update(cx, |zeta, cx| {
 954            zeta.request_prediction(&project, &buffer, position, cx)
 955        });
 956
 957        let (request, respond_tx) = req_rx.next().await.unwrap();
 958
 959        assert_eq!(request.events.len(), 1);
 960        assert_eq!(
 961            request.events[0],
 962            predict_edits_v3::Event::BufferChange {
 963                path: Some(PathBuf::from(path!("root/foo.md"))),
 964                old_path: None,
 965                diff: indoc! {"
 966                        @@ -1,3 +1,3 @@
 967                         Hello!
 968                        -
 969                        +How
 970                         Bye
 971                    "}
 972                .to_string(),
 973                predicted: false
 974            }
 975        );
 976
 977        respond_tx
 978            .send(predict_edits_v3::PredictEditsResponse {
 979                request_id: Uuid::new_v4(),
 980                edits: vec![predict_edits_v3::Edit {
 981                    path: Path::new(path!("root/foo.md")).into(),
 982                    range: 0..snapshot.len(),
 983                    content: "Hello!\nHow are you?\nBye".into(),
 984                }],
 985                debug_info: None,
 986            })
 987            .unwrap();
 988
 989        let prediction = prediction_task.await.unwrap().unwrap();
 990
 991        assert_eq!(prediction.edits.len(), 1);
 992        assert_eq!(
 993            prediction.edits[0].0.to_point(&snapshot).start,
 994            language::Point::new(1, 3)
 995        );
 996        assert_eq!(prediction.edits[0].1, " are you?");
 997    }
 998
 999    #[gpui::test]
1000    async fn test_request_diagnostics(cx: &mut TestAppContext) {
1001        let (zeta, mut req_rx) = init_test(cx);
1002        let fs = FakeFs::new(cx.executor());
1003        fs.insert_tree(
1004            "/root",
1005            json!({
1006                "foo.md": "Hello!\nBye"
1007            }),
1008        )
1009        .await;
1010        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1011
1012        let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1013        let diagnostic = lsp::Diagnostic {
1014            range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1015            severity: Some(lsp::DiagnosticSeverity::ERROR),
1016            message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1017            ..Default::default()
1018        };
1019
1020        project.update(cx, |project, cx| {
1021            project.lsp_store().update(cx, |lsp_store, cx| {
1022                // Create some diagnostics
1023                lsp_store
1024                    .update_diagnostics(
1025                        LanguageServerId(0),
1026                        lsp::PublishDiagnosticsParams {
1027                            uri: path_to_buffer_uri.clone(),
1028                            diagnostics: vec![diagnostic],
1029                            version: None,
1030                        },
1031                        None,
1032                        language::DiagnosticSourceKind::Pushed,
1033                        &[],
1034                        cx,
1035                    )
1036                    .unwrap();
1037            });
1038        });
1039
1040        let buffer = project
1041            .update(cx, |project, cx| {
1042                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1043                project.open_buffer(path, cx)
1044            })
1045            .await
1046            .unwrap();
1047
1048        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1049        let position = snapshot.anchor_before(language::Point::new(0, 0));
1050
1051        let _prediction_task = zeta.update(cx, |zeta, cx| {
1052            zeta.request_prediction(&project, &buffer, position, cx)
1053        });
1054
1055        let (request, _respond_tx) = req_rx.next().await.unwrap();
1056
1057        assert_eq!(request.diagnostic_groups.len(), 1);
1058        let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1059            .unwrap();
1060        // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1061        assert_eq!(
1062            value,
1063            json!({
1064                "entries": [{
1065                    "range": {
1066                        "start": 8,
1067                        "end": 10
1068                    },
1069                    "diagnostic": {
1070                        "source": null,
1071                        "code": null,
1072                        "code_description": null,
1073                        "severity": 1,
1074                        "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1075                        "markdown": null,
1076                        "group_id": 0,
1077                        "is_primary": true,
1078                        "is_disk_based": false,
1079                        "is_unnecessary": false,
1080                        "source_kind": "Pushed",
1081                        "data": null,
1082                        "underline": true
1083                    }
1084                }],
1085                "primary_ix": 0
1086            })
1087        );
1088    }
1089
1090    fn init_test(
1091        cx: &mut TestAppContext,
1092    ) -> (
1093        Entity<Zeta>,
1094        mpsc::UnboundedReceiver<(
1095            predict_edits_v3::PredictEditsRequest,
1096            oneshot::Sender<predict_edits_v3::PredictEditsResponse>,
1097        )>,
1098    ) {
1099        cx.update(move |cx| {
1100            let settings_store = SettingsStore::test(cx);
1101            cx.set_global(settings_store);
1102            language::init(cx);
1103            Project::init_settings(cx);
1104
1105            let (req_tx, req_rx) = mpsc::unbounded();
1106
1107            let http_client = FakeHttpClient::create({
1108                move |req| {
1109                    let uri = req.uri().path().to_string();
1110                    let mut body = req.into_body();
1111                    let req_tx = req_tx.clone();
1112                    async move {
1113                        let resp = match uri.as_str() {
1114                            "/client/llm_tokens" => serde_json::to_string(&json!({
1115                                "token": "test"
1116                            }))
1117                            .unwrap(),
1118                            "/predict_edits/v3" => {
1119                                let mut buf = Vec::new();
1120                                body.read_to_end(&mut buf).await.ok();
1121                                let req = serde_json::from_slice(&buf).unwrap();
1122
1123                                let (res_tx, res_rx) = oneshot::channel();
1124                                req_tx.unbounded_send((req, res_tx)).unwrap();
1125                                serde_json::to_string(&res_rx.await.unwrap()).unwrap()
1126                            }
1127                            _ => {
1128                                panic!("Unexpected path: {}", uri)
1129                            }
1130                        };
1131
1132                        Ok(Response::builder().body(resp.into()).unwrap())
1133                    }
1134                }
1135            });
1136
1137            let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1138            client.cloud_client().set_credentials(1, "test".into());
1139
1140            language_model::init(client.clone(), cx);
1141
1142            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1143            let zeta = Zeta::global(&client, &user_store, cx);
1144            (zeta, req_rx)
1145        })
1146    }
1147}