zeta.rs

   1mod rate_completion_modal;
   2
   3pub use rate_completion_modal::*;
   4
   5use anyhow::{anyhow, Context as _, Result};
   6use arrayvec::ArrayVec;
   7use client::Client;
   8use collections::{HashMap, HashSet, VecDeque};
   9use futures::AsyncReadExt;
  10use gpui::{
  11    actions, AppContext, AsyncAppContext, Context, EntityId, Global, Model, ModelContext,
  12    Subscription, Task,
  13};
  14use http_client::{HttpClient, Method};
  15use language::{
  16    language_settings::all_language_settings, Anchor, Buffer, BufferSnapshot, OffsetRangeExt,
  17    Point, ToOffset, ToPoint,
  18};
  19use language_models::LlmApiToken;
  20use rpc::{PredictEditsParams, PredictEditsResponse, EXPIRED_LLM_TOKEN_HEADER_NAME};
  21use std::{
  22    borrow::Cow,
  23    cmp,
  24    fmt::Write,
  25    future::Future,
  26    mem,
  27    ops::Range,
  28    path::Path,
  29    sync::Arc,
  30    time::{Duration, Instant},
  31};
  32use telemetry_events::InlineCompletionRating;
  33use uuid::Uuid;
  34
  35const CURSOR_MARKER: &'static str = "<|user_cursor_is_here|>";
  36const START_OF_FILE_MARKER: &'static str = "<|start_of_file|>";
  37const EDITABLE_REGION_START_MARKER: &'static str = "<|editable_region_start|>";
  38const EDITABLE_REGION_END_MARKER: &'static str = "<|editable_region_end|>";
  39const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
  40
  41actions!(zeta, [ClearHistory]);
  42
  43#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
  44pub struct InlineCompletionId(Uuid);
  45
  46impl From<InlineCompletionId> for gpui::ElementId {
  47    fn from(value: InlineCompletionId) -> Self {
  48        gpui::ElementId::Uuid(value.0)
  49    }
  50}
  51
  52impl std::fmt::Display for InlineCompletionId {
  53    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  54        write!(f, "{}", self.0)
  55    }
  56}
  57
  58impl InlineCompletionId {
  59    fn new() -> Self {
  60        Self(Uuid::new_v4())
  61    }
  62}
  63
  64#[derive(Clone)]
  65struct ZetaGlobal(Model<Zeta>);
  66
  67impl Global for ZetaGlobal {}
  68
  69#[derive(Clone)]
  70pub struct InlineCompletion {
  71    id: InlineCompletionId,
  72    path: Arc<Path>,
  73    excerpt_range: Range<usize>,
  74    edits: Arc<[(Range<Anchor>, String)]>,
  75    snapshot: BufferSnapshot,
  76    input_outline: Arc<str>,
  77    input_events: Arc<str>,
  78    input_excerpt: Arc<str>,
  79    output_excerpt: Arc<str>,
  80    request_sent_at: Instant,
  81    response_received_at: Instant,
  82}
  83
  84impl InlineCompletion {
  85    fn latency(&self) -> Duration {
  86        self.response_received_at
  87            .duration_since(self.request_sent_at)
  88    }
  89
  90    fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
  91        let mut edits = Vec::new();
  92
  93        let mut user_edits = new_snapshot
  94            .edits_since::<usize>(&self.snapshot.version)
  95            .peekable();
  96        for (model_old_range, model_new_text) in self.edits.iter() {
  97            let model_offset_range = model_old_range.to_offset(&self.snapshot);
  98            while let Some(next_user_edit) = user_edits.peek() {
  99                if next_user_edit.old.end < model_offset_range.start {
 100                    user_edits.next();
 101                } else {
 102                    break;
 103                }
 104            }
 105
 106            if let Some(user_edit) = user_edits.peek() {
 107                if user_edit.old.start > model_offset_range.end {
 108                    edits.push((model_old_range.clone(), model_new_text.clone()));
 109                } else if user_edit.old == model_offset_range {
 110                    let user_new_text = new_snapshot
 111                        .text_for_range(user_edit.new.clone())
 112                        .collect::<String>();
 113
 114                    if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
 115                        if !model_suffix.is_empty() {
 116                            edits.push((
 117                                new_snapshot.anchor_after(user_edit.new.end)
 118                                    ..new_snapshot.anchor_before(user_edit.new.end),
 119                                model_suffix.into(),
 120                            ));
 121                        }
 122
 123                        user_edits.next();
 124                    } else {
 125                        return None;
 126                    }
 127                } else {
 128                    return None;
 129                }
 130            } else {
 131                edits.push((model_old_range.clone(), model_new_text.clone()));
 132            }
 133        }
 134
 135        if edits.is_empty() {
 136            None
 137        } else {
 138            Some(edits)
 139        }
 140    }
 141}
 142
 143impl std::fmt::Debug for InlineCompletion {
 144    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 145        f.debug_struct("InlineCompletion")
 146            .field("id", &self.id)
 147            .field("path", &self.path)
 148            .field("edits", &self.edits)
 149            .finish_non_exhaustive()
 150    }
 151}
 152
 153pub struct Zeta {
 154    client: Arc<Client>,
 155    events: VecDeque<Event>,
 156    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 157    recent_completions: VecDeque<InlineCompletion>,
 158    rated_completions: HashSet<InlineCompletionId>,
 159    shown_completions: HashSet<InlineCompletionId>,
 160    llm_token: LlmApiToken,
 161    _llm_token_subscription: Subscription,
 162}
 163
 164impl Zeta {
 165    pub fn global(cx: &mut AppContext) -> Option<Model<Self>> {
 166        cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
 167    }
 168
 169    pub fn register(client: Arc<Client>, cx: &mut AppContext) -> Model<Self> {
 170        Self::global(cx).unwrap_or_else(|| {
 171            let model = cx.new_model(|cx| Self::new(client, cx));
 172            cx.set_global(ZetaGlobal(model.clone()));
 173            model
 174        })
 175    }
 176
 177    pub fn clear_history(&mut self) {
 178        self.events.clear();
 179    }
 180
 181    fn new(client: Arc<Client>, cx: &mut ModelContext<Self>) -> Self {
 182        let refresh_llm_token_listener = language_models::RefreshLlmTokenListener::global(cx);
 183
 184        Self {
 185            client,
 186            events: VecDeque::new(),
 187            recent_completions: VecDeque::new(),
 188            rated_completions: HashSet::default(),
 189            shown_completions: HashSet::default(),
 190            registered_buffers: HashMap::default(),
 191            llm_token: LlmApiToken::default(),
 192            _llm_token_subscription: cx.subscribe(
 193                &refresh_llm_token_listener,
 194                |this, _listener, _event, cx| {
 195                    let client = this.client.clone();
 196                    let llm_token = this.llm_token.clone();
 197                    cx.spawn(|_this, _cx| async move {
 198                        llm_token.refresh(&client).await?;
 199                        anyhow::Ok(())
 200                    })
 201                    .detach_and_log_err(cx);
 202                },
 203            ),
 204        }
 205    }
 206
 207    fn push_event(&mut self, event: Event) {
 208        const MAX_EVENT_COUNT: usize = 20;
 209
 210        if let Some(Event::BufferChange {
 211            new_snapshot: last_new_snapshot,
 212            timestamp: last_timestamp,
 213            ..
 214        }) = self.events.back_mut()
 215        {
 216            // Coalesce edits for the same buffer when they happen one after the other.
 217            let Event::BufferChange {
 218                old_snapshot,
 219                new_snapshot,
 220                timestamp,
 221            } = &event;
 222
 223            if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
 224                && old_snapshot.remote_id() == last_new_snapshot.remote_id()
 225                && old_snapshot.version == last_new_snapshot.version
 226            {
 227                *last_new_snapshot = new_snapshot.clone();
 228                *last_timestamp = *timestamp;
 229                return;
 230            }
 231        }
 232
 233        self.events.push_back(event);
 234        if self.events.len() > MAX_EVENT_COUNT {
 235            self.events.pop_front();
 236        }
 237    }
 238
 239    pub fn register_buffer(&mut self, buffer: &Model<Buffer>, cx: &mut ModelContext<Self>) {
 240        let buffer_id = buffer.entity_id();
 241        let weak_buffer = buffer.downgrade();
 242
 243        if let std::collections::hash_map::Entry::Vacant(entry) =
 244            self.registered_buffers.entry(buffer_id)
 245        {
 246            let snapshot = buffer.read(cx).snapshot();
 247
 248            entry.insert(RegisteredBuffer {
 249                snapshot,
 250                _subscriptions: [
 251                    cx.subscribe(buffer, move |this, buffer, event, cx| {
 252                        this.handle_buffer_event(buffer, event, cx);
 253                    }),
 254                    cx.observe_release(buffer, move |this, _buffer, _cx| {
 255                        this.registered_buffers.remove(&weak_buffer.entity_id());
 256                    }),
 257                ],
 258            });
 259        };
 260    }
 261
 262    fn handle_buffer_event(
 263        &mut self,
 264        buffer: Model<Buffer>,
 265        event: &language::BufferEvent,
 266        cx: &mut ModelContext<Self>,
 267    ) {
 268        match event {
 269            language::BufferEvent::Edited => {
 270                self.report_changes_for_buffer(&buffer, cx);
 271            }
 272            _ => {}
 273        }
 274    }
 275
 276    pub fn request_completion_impl<F, R>(
 277        &mut self,
 278        buffer: &Model<Buffer>,
 279        position: language::Anchor,
 280        cx: &mut ModelContext<Self>,
 281        perform_predict_edits: F,
 282    ) -> Task<Result<InlineCompletion>>
 283    where
 284        F: FnOnce(Arc<Client>, LlmApiToken, PredictEditsParams) -> R + 'static,
 285        R: Future<Output = Result<PredictEditsResponse>> + Send + 'static,
 286    {
 287        let snapshot = self.report_changes_for_buffer(buffer, cx);
 288        let point = position.to_point(&snapshot);
 289        let offset = point.to_offset(&snapshot);
 290        let excerpt_range = excerpt_range_for_position(point, &snapshot);
 291        let events = self.events.clone();
 292        let path = snapshot
 293            .file()
 294            .map(|f| f.path().clone())
 295            .unwrap_or_else(|| Arc::from(Path::new("untitled")));
 296
 297        let client = self.client.clone();
 298        let llm_token = self.llm_token.clone();
 299
 300        cx.spawn(|this, mut cx| async move {
 301            let request_sent_at = Instant::now();
 302
 303            let input_events = cx
 304                .background_executor()
 305                .spawn(async move {
 306                    let mut input_events = String::new();
 307                    for event in events {
 308                        if !input_events.is_empty() {
 309                            input_events.push('\n');
 310                            input_events.push('\n');
 311                        }
 312                        input_events.push_str(&event.to_prompt());
 313                    }
 314                    input_events
 315                })
 316                .await;
 317
 318            let input_excerpt = prompt_for_excerpt(&snapshot, &excerpt_range, offset);
 319            let input_outline = prompt_for_outline(&snapshot);
 320
 321            log::debug!("Events:\n{}\nExcerpt:\n{}", input_events, input_excerpt);
 322
 323            let body = PredictEditsParams {
 324                input_events: input_events.clone(),
 325                input_excerpt: input_excerpt.clone(),
 326            };
 327
 328            let response = perform_predict_edits(client, llm_token, body).await?;
 329
 330            let output_excerpt = response.output_excerpt;
 331            log::debug!("completion response: {}", output_excerpt);
 332
 333            let inline_completion = Self::process_completion_response(
 334                output_excerpt,
 335                &snapshot,
 336                excerpt_range,
 337                path,
 338                input_outline,
 339                input_events,
 340                input_excerpt,
 341                request_sent_at,
 342                &cx,
 343            )
 344            .await?;
 345
 346            this.update(&mut cx, |this, cx| {
 347                this.recent_completions
 348                    .push_front(inline_completion.clone());
 349                if this.recent_completions.len() > 50 {
 350                    let completion = this.recent_completions.pop_back().unwrap();
 351                    this.shown_completions.remove(&completion.id);
 352                    this.rated_completions.remove(&completion.id);
 353                }
 354                cx.notify();
 355            })?;
 356
 357            Ok(inline_completion)
 358        })
 359    }
 360
 361    // Generates several example completions of various states to fill the Zeta completion modal
 362    #[cfg(any(test, feature = "test-support"))]
 363    pub fn fill_with_fake_completions(&mut self, cx: &mut ModelContext<Self>) -> Task<()> {
 364        let test_buffer_text = indoc::indoc! {r#"a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 365            And maybe a short line
 366
 367            Then a few lines
 368
 369            and then another
 370            "#};
 371
 372        let buffer = cx.new_model(|cx| Buffer::local(test_buffer_text, cx));
 373        let position = buffer.read(cx).anchor_before(Point::new(1, 0));
 374
 375        let completion_tasks = vec![
 376            self.fake_completion(
 377                &buffer,
 378                position,
 379                PredictEditsResponse {
 380                    output_excerpt: format!("{EDITABLE_REGION_START_MARKER}
 381a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 382[here's an edit]
 383And maybe a short line
 384Then a few lines
 385and then another
 386{EDITABLE_REGION_END_MARKER}
 387                        ", ),
 388                },
 389                cx,
 390            ),
 391            self.fake_completion(
 392                &buffer,
 393                position,
 394                PredictEditsResponse {
 395                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 396a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 397And maybe a short line
 398[and another edit]
 399Then a few lines
 400and then another
 401{EDITABLE_REGION_END_MARKER}
 402                        "#),
 403                },
 404                cx,
 405            ),
 406            self.fake_completion(
 407                &buffer,
 408                position,
 409                PredictEditsResponse {
 410                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 411a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 412And maybe a short line
 413
 414Then a few lines
 415
 416and then another
 417{EDITABLE_REGION_END_MARKER}
 418                        "#),
 419                },
 420                cx,
 421            ),
 422            self.fake_completion(
 423                &buffer,
 424                position,
 425                PredictEditsResponse {
 426                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 427a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 428And maybe a short line
 429
 430Then a few lines
 431
 432and then another
 433{EDITABLE_REGION_END_MARKER}
 434                        "#),
 435                },
 436                cx,
 437            ),
 438            self.fake_completion(
 439                &buffer,
 440                position,
 441                PredictEditsResponse {
 442                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 443a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 444And maybe a short line
 445Then a few lines
 446[a third completion]
 447and then another
 448{EDITABLE_REGION_END_MARKER}
 449                        "#),
 450                },
 451                cx,
 452            ),
 453            self.fake_completion(
 454                &buffer,
 455                position,
 456                PredictEditsResponse {
 457                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 458a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 459And maybe a short line
 460and then another
 461[fourth completion example]
 462{EDITABLE_REGION_END_MARKER}
 463                        "#),
 464                },
 465                cx,
 466            ),
 467            self.fake_completion(
 468                &buffer,
 469                position,
 470                PredictEditsResponse {
 471                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 472a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 473And maybe a short line
 474Then a few lines
 475and then another
 476[fifth and final completion]
 477{EDITABLE_REGION_END_MARKER}
 478                        "#),
 479                },
 480                cx,
 481            ),
 482        ];
 483
 484        cx.spawn(|zeta, mut cx| async move {
 485            for task in completion_tasks {
 486                task.await.unwrap();
 487            }
 488
 489            zeta.update(&mut cx, |zeta, _cx| {
 490                zeta.recent_completions.get_mut(2).unwrap().edits = Arc::new([]);
 491                zeta.recent_completions.get_mut(3).unwrap().edits = Arc::new([]);
 492            })
 493            .ok();
 494        })
 495    }
 496
 497    #[cfg(any(test, feature = "test-support"))]
 498    pub fn fake_completion(
 499        &mut self,
 500        buffer: &Model<Buffer>,
 501        position: language::Anchor,
 502        response: PredictEditsResponse,
 503        cx: &mut ModelContext<Self>,
 504    ) -> Task<Result<InlineCompletion>> {
 505        use std::future::ready;
 506
 507        self.request_completion_impl(buffer, position, cx, |_, _, _| ready(Ok(response)))
 508    }
 509
 510    pub fn request_completion(
 511        &mut self,
 512        buffer: &Model<Buffer>,
 513        position: language::Anchor,
 514        cx: &mut ModelContext<Self>,
 515    ) -> Task<Result<InlineCompletion>> {
 516        self.request_completion_impl(buffer, position, cx, Self::perform_predict_edits)
 517    }
 518
 519    fn perform_predict_edits(
 520        client: Arc<Client>,
 521        llm_token: LlmApiToken,
 522        body: PredictEditsParams,
 523    ) -> impl Future<Output = Result<PredictEditsResponse>> {
 524        async move {
 525            let http_client = client.http_client();
 526            let mut token = llm_token.acquire(&client).await?;
 527            let mut did_retry = false;
 528
 529            loop {
 530                let request_builder = http_client::Request::builder();
 531                let request = request_builder
 532                    .method(Method::POST)
 533                    .uri(
 534                        http_client
 535                            .build_zed_llm_url("/predict_edits", &[])?
 536                            .as_ref(),
 537                    )
 538                    .header("Content-Type", "application/json")
 539                    .header("Authorization", format!("Bearer {}", token))
 540                    .body(serde_json::to_string(&body)?.into())?;
 541
 542                let mut response = http_client.send(request).await?;
 543
 544                if response.status().is_success() {
 545                    let mut body = String::new();
 546                    response.body_mut().read_to_string(&mut body).await?;
 547                    return Ok(serde_json::from_str(&body)?);
 548                } else if !did_retry
 549                    && response
 550                        .headers()
 551                        .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
 552                        .is_some()
 553                {
 554                    did_retry = true;
 555                    token = llm_token.refresh(&client).await?;
 556                } else {
 557                    let mut body = String::new();
 558                    response.body_mut().read_to_string(&mut body).await?;
 559                    return Err(anyhow!(
 560                        "error predicting edits.\nStatus: {:?}\nBody: {}",
 561                        response.status(),
 562                        body
 563                    ));
 564                }
 565            }
 566        }
 567    }
 568
 569    #[allow(clippy::too_many_arguments)]
 570    fn process_completion_response(
 571        output_excerpt: String,
 572        snapshot: &BufferSnapshot,
 573        excerpt_range: Range<usize>,
 574        path: Arc<Path>,
 575        input_outline: String,
 576        input_events: String,
 577        input_excerpt: String,
 578        request_sent_at: Instant,
 579        cx: &AsyncAppContext,
 580    ) -> Task<Result<InlineCompletion>> {
 581        let snapshot = snapshot.clone();
 582        cx.background_executor().spawn(async move {
 583            let content = output_excerpt.replace(CURSOR_MARKER, "");
 584
 585            let start_markers = content
 586                .match_indices(EDITABLE_REGION_START_MARKER)
 587                .collect::<Vec<_>>();
 588            anyhow::ensure!(
 589                start_markers.len() == 1,
 590                "expected exactly one start marker, found {}",
 591                start_markers.len()
 592            );
 593
 594            let end_markers = content
 595                .match_indices(EDITABLE_REGION_END_MARKER)
 596                .collect::<Vec<_>>();
 597            anyhow::ensure!(
 598                end_markers.len() == 1,
 599                "expected exactly one end marker, found {}",
 600                end_markers.len()
 601            );
 602
 603            let sof_markers = content
 604                .match_indices(START_OF_FILE_MARKER)
 605                .collect::<Vec<_>>();
 606            anyhow::ensure!(
 607                sof_markers.len() <= 1,
 608                "expected at most one start-of-file marker, found {}",
 609                sof_markers.len()
 610            );
 611
 612            let codefence_start = start_markers[0].0;
 613            let content = &content[codefence_start..];
 614
 615            let newline_ix = content.find('\n').context("could not find newline")?;
 616            let content = &content[newline_ix + 1..];
 617
 618            let codefence_end = content
 619                .rfind(&format!("\n{EDITABLE_REGION_END_MARKER}"))
 620                .context("could not find end marker")?;
 621            let new_text = &content[..codefence_end];
 622
 623            let old_text = snapshot
 624                .text_for_range(excerpt_range.clone())
 625                .collect::<String>();
 626
 627            let edits = Self::compute_edits(old_text, new_text, excerpt_range.start, &snapshot);
 628
 629            Ok(InlineCompletion {
 630                id: InlineCompletionId::new(),
 631                path,
 632                excerpt_range,
 633                edits: edits.into(),
 634                snapshot: snapshot.clone(),
 635                input_outline: input_outline.into(),
 636                input_events: input_events.into(),
 637                input_excerpt: input_excerpt.into(),
 638                output_excerpt: output_excerpt.into(),
 639                request_sent_at,
 640                response_received_at: Instant::now(),
 641            })
 642        })
 643    }
 644
 645    pub fn compute_edits(
 646        old_text: String,
 647        new_text: &str,
 648        offset: usize,
 649        snapshot: &BufferSnapshot,
 650    ) -> Vec<(Range<Anchor>, String)> {
 651        let diff = similar::TextDiff::from_words(old_text.as_str(), new_text);
 652
 653        let mut edits: Vec<(Range<usize>, String)> = Vec::new();
 654        let mut old_start = offset;
 655        for change in diff.iter_all_changes() {
 656            let value = change.value();
 657            match change.tag() {
 658                similar::ChangeTag::Equal => {
 659                    old_start += value.len();
 660                }
 661                similar::ChangeTag::Delete => {
 662                    let old_end = old_start + value.len();
 663                    if let Some((last_old_range, _)) = edits.last_mut() {
 664                        if last_old_range.end == old_start {
 665                            last_old_range.end = old_end;
 666                        } else {
 667                            edits.push((old_start..old_end, String::new()));
 668                        }
 669                    } else {
 670                        edits.push((old_start..old_end, String::new()));
 671                    }
 672                    old_start = old_end;
 673                }
 674                similar::ChangeTag::Insert => {
 675                    if let Some((last_old_range, last_new_text)) = edits.last_mut() {
 676                        if last_old_range.end == old_start {
 677                            last_new_text.push_str(value);
 678                        } else {
 679                            edits.push((old_start..old_start, value.into()));
 680                        }
 681                    } else {
 682                        edits.push((old_start..old_start, value.into()));
 683                    }
 684                }
 685            }
 686        }
 687
 688        edits
 689            .into_iter()
 690            .map(|(mut old_range, new_text)| {
 691                let prefix_len = common_prefix(
 692                    snapshot.chars_for_range(old_range.clone()),
 693                    new_text.chars(),
 694                );
 695                old_range.start += prefix_len;
 696                let suffix_len = common_prefix(
 697                    snapshot.reversed_chars_for_range(old_range.clone()),
 698                    new_text[prefix_len..].chars().rev(),
 699                );
 700                old_range.end = old_range.end.saturating_sub(suffix_len);
 701
 702                let new_text = new_text[prefix_len..new_text.len() - suffix_len].to_string();
 703                (
 704                    snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end),
 705                    new_text,
 706                )
 707            })
 708            .collect()
 709    }
 710
 711    pub fn is_completion_rated(&self, completion_id: InlineCompletionId) -> bool {
 712        self.rated_completions.contains(&completion_id)
 713    }
 714
 715    pub fn was_completion_shown(&self, completion_id: InlineCompletionId) -> bool {
 716        self.shown_completions.contains(&completion_id)
 717    }
 718
 719    pub fn completion_shown(&mut self, completion_id: InlineCompletionId) {
 720        self.shown_completions.insert(completion_id);
 721    }
 722
 723    pub fn rate_completion(
 724        &mut self,
 725        completion: &InlineCompletion,
 726        rating: InlineCompletionRating,
 727        feedback: String,
 728        cx: &mut ModelContext<Self>,
 729    ) {
 730        telemetry::event!(
 731            "Inline Completion Rated",
 732            rating,
 733            input_events = completion.input_events,
 734            input_excerpt = completion.input_excerpt,
 735            input_outline = completion.input_outline,
 736            output_excerpt = completion.output_excerpt,
 737            feedback
 738        );
 739        self.client.telemetry().flush_events();
 740        cx.notify();
 741    }
 742
 743    pub fn recent_completions(&self) -> impl DoubleEndedIterator<Item = &InlineCompletion> {
 744        self.recent_completions.iter()
 745    }
 746
 747    pub fn recent_completions_len(&self) -> usize {
 748        self.recent_completions.len()
 749    }
 750
 751    fn report_changes_for_buffer(
 752        &mut self,
 753        buffer: &Model<Buffer>,
 754        cx: &mut ModelContext<Self>,
 755    ) -> BufferSnapshot {
 756        self.register_buffer(buffer, cx);
 757
 758        let registered_buffer = self
 759            .registered_buffers
 760            .get_mut(&buffer.entity_id())
 761            .unwrap();
 762        let new_snapshot = buffer.read(cx).snapshot();
 763
 764        if new_snapshot.version != registered_buffer.snapshot.version {
 765            let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
 766            self.push_event(Event::BufferChange {
 767                old_snapshot,
 768                new_snapshot: new_snapshot.clone(),
 769                timestamp: Instant::now(),
 770            });
 771        }
 772
 773        new_snapshot
 774    }
 775}
 776
 777fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
 778    a.zip(b)
 779        .take_while(|(a, b)| a == b)
 780        .map(|(a, _)| a.len_utf8())
 781        .sum()
 782}
 783
 784fn prompt_for_outline(snapshot: &BufferSnapshot) -> String {
 785    let mut input_outline = String::new();
 786
 787    writeln!(
 788        input_outline,
 789        "```{}",
 790        snapshot
 791            .file()
 792            .map_or(Cow::Borrowed("untitled"), |file| file
 793                .path()
 794                .to_string_lossy())
 795    )
 796    .unwrap();
 797
 798    if let Some(outline) = snapshot.outline(None) {
 799        let guess_size = outline.items.len() * 15;
 800        input_outline.reserve(guess_size);
 801        for item in outline.items.iter() {
 802            let spacing = " ".repeat(item.depth);
 803            writeln!(input_outline, "{}{}", spacing, item.text).unwrap();
 804        }
 805    }
 806
 807    writeln!(input_outline, "```").unwrap();
 808
 809    input_outline
 810}
 811
 812fn prompt_for_excerpt(
 813    snapshot: &BufferSnapshot,
 814    excerpt_range: &Range<usize>,
 815    offset: usize,
 816) -> String {
 817    let mut prompt_excerpt = String::new();
 818    writeln!(
 819        prompt_excerpt,
 820        "```{}",
 821        snapshot
 822            .file()
 823            .map_or(Cow::Borrowed("untitled"), |file| file
 824                .path()
 825                .to_string_lossy())
 826    )
 827    .unwrap();
 828
 829    if excerpt_range.start == 0 {
 830        writeln!(prompt_excerpt, "{START_OF_FILE_MARKER}").unwrap();
 831    }
 832
 833    let point_range = excerpt_range.to_point(snapshot);
 834    if point_range.start.row > 0 && !snapshot.is_line_blank(point_range.start.row - 1) {
 835        let extra_context_line_range = Point::new(point_range.start.row - 1, 0)..point_range.start;
 836        for chunk in snapshot.text_for_range(extra_context_line_range) {
 837            prompt_excerpt.push_str(chunk);
 838        }
 839    }
 840    writeln!(prompt_excerpt, "{EDITABLE_REGION_START_MARKER}").unwrap();
 841    for chunk in snapshot.text_for_range(excerpt_range.start..offset) {
 842        prompt_excerpt.push_str(chunk);
 843    }
 844    prompt_excerpt.push_str(CURSOR_MARKER);
 845    for chunk in snapshot.text_for_range(offset..excerpt_range.end) {
 846        prompt_excerpt.push_str(chunk);
 847    }
 848    write!(prompt_excerpt, "\n{EDITABLE_REGION_END_MARKER}").unwrap();
 849
 850    if point_range.end.row < snapshot.max_point().row
 851        && !snapshot.is_line_blank(point_range.end.row + 1)
 852    {
 853        let extra_context_line_range = point_range.end
 854            ..Point::new(
 855                point_range.end.row + 1,
 856                snapshot.line_len(point_range.end.row + 1),
 857            );
 858        for chunk in snapshot.text_for_range(extra_context_line_range) {
 859            prompt_excerpt.push_str(chunk);
 860        }
 861    }
 862
 863    write!(prompt_excerpt, "\n```").unwrap();
 864    prompt_excerpt
 865}
 866
 867fn excerpt_range_for_position(point: Point, snapshot: &BufferSnapshot) -> Range<usize> {
 868    const CONTEXT_LINES: u32 = 32;
 869
 870    let mut context_lines_before = CONTEXT_LINES;
 871    let mut context_lines_after = CONTEXT_LINES;
 872    if point.row < CONTEXT_LINES {
 873        context_lines_after += CONTEXT_LINES - point.row;
 874    } else if point.row + CONTEXT_LINES > snapshot.max_point().row {
 875        context_lines_before += (point.row + CONTEXT_LINES) - snapshot.max_point().row;
 876    }
 877
 878    let excerpt_start_row = point.row.saturating_sub(context_lines_before);
 879    let excerpt_start = Point::new(excerpt_start_row, 0);
 880    let excerpt_end_row = cmp::min(point.row + context_lines_after, snapshot.max_point().row);
 881    let excerpt_end = Point::new(excerpt_end_row, snapshot.line_len(excerpt_end_row));
 882    excerpt_start.to_offset(snapshot)..excerpt_end.to_offset(snapshot)
 883}
 884
 885struct RegisteredBuffer {
 886    snapshot: BufferSnapshot,
 887    _subscriptions: [gpui::Subscription; 2],
 888}
 889
 890#[derive(Clone)]
 891enum Event {
 892    BufferChange {
 893        old_snapshot: BufferSnapshot,
 894        new_snapshot: BufferSnapshot,
 895        timestamp: Instant,
 896    },
 897}
 898
 899impl Event {
 900    fn to_prompt(&self) -> String {
 901        match self {
 902            Event::BufferChange {
 903                old_snapshot,
 904                new_snapshot,
 905                ..
 906            } => {
 907                let mut prompt = String::new();
 908
 909                let old_path = old_snapshot
 910                    .file()
 911                    .map(|f| f.path().as_ref())
 912                    .unwrap_or(Path::new("untitled"));
 913                let new_path = new_snapshot
 914                    .file()
 915                    .map(|f| f.path().as_ref())
 916                    .unwrap_or(Path::new("untitled"));
 917                if old_path != new_path {
 918                    writeln!(prompt, "User renamed {:?} to {:?}\n", old_path, new_path).unwrap();
 919                }
 920
 921                let diff =
 922                    similar::TextDiff::from_lines(&old_snapshot.text(), &new_snapshot.text())
 923                        .unified_diff()
 924                        .to_string();
 925                if !diff.is_empty() {
 926                    write!(
 927                        prompt,
 928                        "User edited {:?}:\n```diff\n{}\n```",
 929                        new_path, diff
 930                    )
 931                    .unwrap();
 932                }
 933
 934                prompt
 935            }
 936        }
 937    }
 938}
 939
 940#[derive(Debug, Clone)]
 941struct CurrentInlineCompletion {
 942    buffer_id: EntityId,
 943    completion: InlineCompletion,
 944}
 945
 946impl CurrentInlineCompletion {
 947    fn should_replace_completion(&self, old_completion: &Self, snapshot: &BufferSnapshot) -> bool {
 948        if self.buffer_id != old_completion.buffer_id {
 949            return true;
 950        }
 951
 952        let Some(old_edits) = old_completion.completion.interpolate(&snapshot) else {
 953            return true;
 954        };
 955        let Some(new_edits) = self.completion.interpolate(&snapshot) else {
 956            return false;
 957        };
 958
 959        if old_edits.len() == 1 && new_edits.len() == 1 {
 960            let (old_range, old_text) = &old_edits[0];
 961            let (new_range, new_text) = &new_edits[0];
 962            new_range == old_range && new_text.starts_with(old_text)
 963        } else {
 964            true
 965        }
 966    }
 967}
 968
 969struct PendingCompletion {
 970    id: usize,
 971    _task: Task<Result<()>>,
 972}
 973
 974pub struct ZetaInlineCompletionProvider {
 975    zeta: Model<Zeta>,
 976    pending_completions: ArrayVec<PendingCompletion, 2>,
 977    next_pending_completion_id: usize,
 978    current_completion: Option<CurrentInlineCompletion>,
 979}
 980
 981impl ZetaInlineCompletionProvider {
 982    pub const DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(8);
 983
 984    pub fn new(zeta: Model<Zeta>) -> Self {
 985        Self {
 986            zeta,
 987            pending_completions: ArrayVec::new(),
 988            next_pending_completion_id: 0,
 989            current_completion: None,
 990        }
 991    }
 992}
 993
 994impl inline_completion::InlineCompletionProvider for ZetaInlineCompletionProvider {
 995    fn name() -> &'static str {
 996        "zeta"
 997    }
 998
 999    fn display_name() -> &'static str {
1000        "Zeta"
1001    }
1002
1003    fn show_completions_in_menu() -> bool {
1004        true
1005    }
1006
1007    fn show_completions_in_normal_mode() -> bool {
1008        true
1009    }
1010
1011    fn is_enabled(
1012        &self,
1013        buffer: &Model<Buffer>,
1014        cursor_position: language::Anchor,
1015        cx: &AppContext,
1016    ) -> bool {
1017        let buffer = buffer.read(cx);
1018        let file = buffer.file();
1019        let language = buffer.language_at(cursor_position);
1020        let settings = all_language_settings(file, cx);
1021        settings.inline_completions_enabled(language.as_ref(), file.map(|f| f.path().as_ref()), cx)
1022    }
1023
1024    fn refresh(
1025        &mut self,
1026        buffer: Model<Buffer>,
1027        position: language::Anchor,
1028        debounce: bool,
1029        cx: &mut ModelContext<Self>,
1030    ) {
1031        let pending_completion_id = self.next_pending_completion_id;
1032        self.next_pending_completion_id += 1;
1033
1034        let task = cx.spawn(|this, mut cx| async move {
1035            if debounce {
1036                cx.background_executor().timer(Self::DEBOUNCE_TIMEOUT).await;
1037            }
1038
1039            let completion_request = this.update(&mut cx, |this, cx| {
1040                this.zeta.update(cx, |zeta, cx| {
1041                    zeta.request_completion(&buffer, position, cx)
1042                })
1043            });
1044
1045            let mut completion = None;
1046            if let Ok(completion_request) = completion_request {
1047                completion = Some(CurrentInlineCompletion {
1048                    buffer_id: buffer.entity_id(),
1049                    completion: completion_request.await?,
1050                });
1051            }
1052
1053            this.update(&mut cx, |this, cx| {
1054                if this.pending_completions[0].id == pending_completion_id {
1055                    this.pending_completions.remove(0);
1056                } else {
1057                    this.pending_completions.clear();
1058                }
1059
1060                if let Some(new_completion) = completion {
1061                    if let Some(old_completion) = this.current_completion.as_ref() {
1062                        let snapshot = buffer.read(cx).snapshot();
1063                        if new_completion.should_replace_completion(&old_completion, &snapshot) {
1064                            this.zeta.update(cx, |zeta, _cx| {
1065                                zeta.completion_shown(new_completion.completion.id)
1066                            });
1067                            this.current_completion = Some(new_completion);
1068                        }
1069                    } else {
1070                        this.zeta.update(cx, |zeta, _cx| {
1071                            zeta.completion_shown(new_completion.completion.id)
1072                        });
1073                        this.current_completion = Some(new_completion);
1074                    }
1075                } else {
1076                    this.current_completion = None;
1077                }
1078
1079                cx.notify();
1080            })
1081        });
1082
1083        // We always maintain at most two pending completions. When we already
1084        // have two, we replace the newest one.
1085        if self.pending_completions.len() <= 1 {
1086            self.pending_completions.push(PendingCompletion {
1087                id: pending_completion_id,
1088                _task: task,
1089            });
1090        } else if self.pending_completions.len() == 2 {
1091            self.pending_completions.pop();
1092            self.pending_completions.push(PendingCompletion {
1093                id: pending_completion_id,
1094                _task: task,
1095            });
1096        }
1097    }
1098
1099    fn cycle(
1100        &mut self,
1101        _buffer: Model<Buffer>,
1102        _cursor_position: language::Anchor,
1103        _direction: inline_completion::Direction,
1104        _cx: &mut ModelContext<Self>,
1105    ) {
1106        // Right now we don't support cycling.
1107    }
1108
1109    fn accept(&mut self, _cx: &mut ModelContext<Self>) {
1110        self.pending_completions.clear();
1111    }
1112
1113    fn discard(&mut self, _cx: &mut ModelContext<Self>) {
1114        self.pending_completions.clear();
1115        self.current_completion.take();
1116    }
1117
1118    fn suggest(
1119        &mut self,
1120        buffer: &Model<Buffer>,
1121        cursor_position: language::Anchor,
1122        cx: &mut ModelContext<Self>,
1123    ) -> Option<inline_completion::InlineCompletion> {
1124        let CurrentInlineCompletion {
1125            buffer_id,
1126            completion,
1127            ..
1128        } = self.current_completion.as_mut()?;
1129
1130        // Invalidate previous completion if it was generated for a different buffer.
1131        if *buffer_id != buffer.entity_id() {
1132            self.current_completion.take();
1133            return None;
1134        }
1135
1136        let buffer = buffer.read(cx);
1137        let Some(edits) = completion.interpolate(&buffer.snapshot()) else {
1138            self.current_completion.take();
1139            return None;
1140        };
1141
1142        let cursor_row = cursor_position.to_point(buffer).row;
1143        let (closest_edit_ix, (closest_edit_range, _)) =
1144            edits.iter().enumerate().min_by_key(|(_, (range, _))| {
1145                let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
1146                let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
1147                cmp::min(distance_from_start, distance_from_end)
1148            })?;
1149
1150        let mut edit_start_ix = closest_edit_ix;
1151        for (range, _) in edits[..edit_start_ix].iter().rev() {
1152            let distance_from_closest_edit =
1153                closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
1154            if distance_from_closest_edit <= 1 {
1155                edit_start_ix -= 1;
1156            } else {
1157                break;
1158            }
1159        }
1160
1161        let mut edit_end_ix = closest_edit_ix + 1;
1162        for (range, _) in &edits[edit_end_ix..] {
1163            let distance_from_closest_edit =
1164                range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
1165            if distance_from_closest_edit <= 1 {
1166                edit_end_ix += 1;
1167            } else {
1168                break;
1169            }
1170        }
1171
1172        Some(inline_completion::InlineCompletion {
1173            edits: edits[edit_start_ix..edit_end_ix].to_vec(),
1174        })
1175    }
1176}
1177
1178#[cfg(test)]
1179mod tests {
1180    use client::test::FakeServer;
1181    use clock::FakeSystemClock;
1182    use gpui::TestAppContext;
1183    use http_client::FakeHttpClient;
1184    use indoc::indoc;
1185    use language_models::RefreshLlmTokenListener;
1186    use rpc::proto;
1187    use settings::SettingsStore;
1188
1189    use super::*;
1190
1191    #[gpui::test]
1192    fn test_inline_completion_basic_interpolation(cx: &mut AppContext) {
1193        let buffer = cx.new_model(|cx| Buffer::local("Lorem ipsum dolor", cx));
1194        let completion = InlineCompletion {
1195            edits: to_completion_edits(
1196                [(2..5, "REM".to_string()), (9..11, "".to_string())],
1197                &buffer,
1198                cx,
1199            )
1200            .into(),
1201            path: Path::new("").into(),
1202            snapshot: buffer.read(cx).snapshot(),
1203            id: InlineCompletionId::new(),
1204            excerpt_range: 0..0,
1205            input_outline: "".into(),
1206            input_events: "".into(),
1207            input_excerpt: "".into(),
1208            output_excerpt: "".into(),
1209            request_sent_at: Instant::now(),
1210            response_received_at: Instant::now(),
1211        };
1212
1213        assert_eq!(
1214            from_completion_edits(
1215                &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1216                &buffer,
1217                cx
1218            ),
1219            vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1220        );
1221
1222        buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1223        assert_eq!(
1224            from_completion_edits(
1225                &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1226                &buffer,
1227                cx
1228            ),
1229            vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
1230        );
1231
1232        buffer.update(cx, |buffer, cx| buffer.undo(cx));
1233        assert_eq!(
1234            from_completion_edits(
1235                &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1236                &buffer,
1237                cx
1238            ),
1239            vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1240        );
1241
1242        buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1243        assert_eq!(
1244            from_completion_edits(
1245                &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1246                &buffer,
1247                cx
1248            ),
1249            vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
1250        );
1251
1252        buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1253        assert_eq!(
1254            from_completion_edits(
1255                &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1256                &buffer,
1257                cx
1258            ),
1259            vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1260        );
1261
1262        buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1263        assert_eq!(
1264            from_completion_edits(
1265                &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1266                &buffer,
1267                cx
1268            ),
1269            vec![(9..11, "".to_string())]
1270        );
1271
1272        buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1273        assert_eq!(
1274            from_completion_edits(
1275                &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1276                &buffer,
1277                cx
1278            ),
1279            vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1280        );
1281
1282        buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1283        assert_eq!(
1284            from_completion_edits(
1285                &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1286                &buffer,
1287                cx
1288            ),
1289            vec![(4..4, "M".to_string())]
1290        );
1291
1292        buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1293        assert_eq!(completion.interpolate(&buffer.read(cx).snapshot()), None);
1294    }
1295
1296    #[gpui::test]
1297    async fn test_inline_completion_end_of_buffer(cx: &mut TestAppContext) {
1298        cx.update(|cx| {
1299            let settings_store = SettingsStore::test(cx);
1300            cx.set_global(settings_store);
1301            client::init_settings(cx);
1302        });
1303
1304        let buffer_content = "lorem\n";
1305        let completion_response = indoc! {"
1306            ```animals.js
1307            <|start_of_file|>
1308            <|editable_region_start|>
1309            lorem
1310            ipsum
1311            <|editable_region_end|>
1312            ```"};
1313
1314        let http_client = FakeHttpClient::create(move |_| async move {
1315            Ok(http_client::Response::builder()
1316                .status(200)
1317                .body(
1318                    serde_json::to_string(&PredictEditsResponse {
1319                        output_excerpt: completion_response.to_string(),
1320                    })
1321                    .unwrap()
1322                    .into(),
1323                )
1324                .unwrap())
1325        });
1326
1327        let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
1328        cx.update(|cx| {
1329            RefreshLlmTokenListener::register(client.clone(), cx);
1330        });
1331        let server = FakeServer::for_client(42, &client, cx).await;
1332
1333        let zeta = cx.new_model(|cx| Zeta::new(client, cx));
1334        let buffer = cx.new_model(|cx| Buffer::local(buffer_content, cx));
1335        let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
1336        let completion_task =
1337            zeta.update(cx, |zeta, cx| zeta.request_completion(&buffer, cursor, cx));
1338
1339        let token_request = server.receive::<proto::GetLlmToken>().await.unwrap();
1340        server.respond(
1341            token_request.receipt(),
1342            proto::GetLlmTokenResponse { token: "".into() },
1343        );
1344
1345        let completion = completion_task.await.unwrap();
1346        buffer.update(cx, |buffer, cx| {
1347            buffer.edit(completion.edits.iter().cloned(), None, cx)
1348        });
1349        assert_eq!(
1350            buffer.read_with(cx, |buffer, _| buffer.text()),
1351            "lorem\nipsum"
1352        );
1353    }
1354
1355    fn to_completion_edits(
1356        iterator: impl IntoIterator<Item = (Range<usize>, String)>,
1357        buffer: &Model<Buffer>,
1358        cx: &AppContext,
1359    ) -> Vec<(Range<Anchor>, String)> {
1360        let buffer = buffer.read(cx);
1361        iterator
1362            .into_iter()
1363            .map(|(range, text)| {
1364                (
1365                    buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
1366                    text,
1367                )
1368            })
1369            .collect()
1370    }
1371
1372    fn from_completion_edits(
1373        editor_edits: &[(Range<Anchor>, String)],
1374        buffer: &Model<Buffer>,
1375        cx: &AppContext,
1376    ) -> Vec<(Range<usize>, String)> {
1377        let buffer = buffer.read(cx);
1378        editor_edits
1379            .iter()
1380            .map(|(range, text)| {
1381                (
1382                    range.start.to_offset(buffer)..range.end.to_offset(buffer),
1383                    text.clone(),
1384                )
1385            })
1386            .collect()
1387    }
1388
1389    #[ctor::ctor]
1390    fn init_logger() {
1391        if std::env::var("RUST_LOG").is_ok() {
1392            env_logger::init();
1393        }
1394    }
1395}