zeta.rs

   1mod rate_completion_modal;
   2
   3pub use rate_completion_modal::*;
   4
   5use anyhow::{anyhow, Context as _, Result};
   6use arrayvec::ArrayVec;
   7use client::Client;
   8use collections::{HashMap, HashSet, VecDeque};
   9use futures::AsyncReadExt;
  10use gpui::{
  11    actions, AppContext, AsyncAppContext, Context, EntityId, Global, Model, ModelContext,
  12    Subscription, Task,
  13};
  14use http_client::{HttpClient, Method};
  15use language::{
  16    language_settings::all_language_settings, Anchor, Buffer, BufferSnapshot, OffsetRangeExt,
  17    Point, ToOffset, ToPoint,
  18};
  19use language_models::LlmApiToken;
  20use rpc::{PredictEditsParams, PredictEditsResponse, EXPIRED_LLM_TOKEN_HEADER_NAME};
  21use std::{
  22    borrow::Cow,
  23    cmp,
  24    fmt::Write,
  25    future::Future,
  26    mem,
  27    ops::Range,
  28    path::Path,
  29    sync::Arc,
  30    time::{Duration, Instant},
  31};
  32use telemetry_events::InlineCompletionRating;
  33use uuid::Uuid;
  34
  35const CURSOR_MARKER: &'static str = "<|user_cursor_is_here|>";
  36const START_OF_FILE_MARKER: &'static str = "<|start_of_file|>";
  37const EDITABLE_REGION_START_MARKER: &'static str = "<|editable_region_start|>";
  38const EDITABLE_REGION_END_MARKER: &'static str = "<|editable_region_end|>";
  39const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
  40
  41actions!(zeta, [ClearHistory]);
  42
  43#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
  44pub struct InlineCompletionId(Uuid);
  45
  46impl From<InlineCompletionId> for gpui::ElementId {
  47    fn from(value: InlineCompletionId) -> Self {
  48        gpui::ElementId::Uuid(value.0)
  49    }
  50}
  51
  52impl std::fmt::Display for InlineCompletionId {
  53    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  54        write!(f, "{}", self.0)
  55    }
  56}
  57
  58impl InlineCompletionId {
  59    fn new() -> Self {
  60        Self(Uuid::new_v4())
  61    }
  62}
  63
  64#[derive(Clone)]
  65struct ZetaGlobal(Model<Zeta>);
  66
  67impl Global for ZetaGlobal {}
  68
  69#[derive(Clone)]
  70pub struct InlineCompletion {
  71    id: InlineCompletionId,
  72    path: Arc<Path>,
  73    excerpt_range: Range<usize>,
  74    edits: Arc<[(Range<Anchor>, String)]>,
  75    snapshot: BufferSnapshot,
  76    input_outline: Arc<str>,
  77    input_events: Arc<str>,
  78    input_excerpt: Arc<str>,
  79    output_excerpt: Arc<str>,
  80    request_sent_at: Instant,
  81    response_received_at: Instant,
  82}
  83
  84impl InlineCompletion {
  85    fn latency(&self) -> Duration {
  86        self.response_received_at
  87            .duration_since(self.request_sent_at)
  88    }
  89
  90    fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
  91        let mut edits = Vec::new();
  92
  93        let mut user_edits = new_snapshot
  94            .edits_since::<usize>(&self.snapshot.version)
  95            .peekable();
  96        for (model_old_range, model_new_text) in self.edits.iter() {
  97            let model_offset_range = model_old_range.to_offset(&self.snapshot);
  98            while let Some(next_user_edit) = user_edits.peek() {
  99                if next_user_edit.old.end < model_offset_range.start {
 100                    user_edits.next();
 101                } else {
 102                    break;
 103                }
 104            }
 105
 106            if let Some(user_edit) = user_edits.peek() {
 107                if user_edit.old.start > model_offset_range.end {
 108                    edits.push((model_old_range.clone(), model_new_text.clone()));
 109                } else if user_edit.old == model_offset_range {
 110                    let user_new_text = new_snapshot
 111                        .text_for_range(user_edit.new.clone())
 112                        .collect::<String>();
 113
 114                    if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
 115                        if !model_suffix.is_empty() {
 116                            edits.push((
 117                                new_snapshot.anchor_after(user_edit.new.end)
 118                                    ..new_snapshot.anchor_before(user_edit.new.end),
 119                                model_suffix.into(),
 120                            ));
 121                        }
 122
 123                        user_edits.next();
 124                    } else {
 125                        return None;
 126                    }
 127                } else {
 128                    return None;
 129                }
 130            } else {
 131                edits.push((model_old_range.clone(), model_new_text.clone()));
 132            }
 133        }
 134
 135        if edits.is_empty() {
 136            None
 137        } else {
 138            Some(edits)
 139        }
 140    }
 141}
 142
 143impl std::fmt::Debug for InlineCompletion {
 144    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 145        f.debug_struct("InlineCompletion")
 146            .field("id", &self.id)
 147            .field("path", &self.path)
 148            .field("edits", &self.edits)
 149            .finish_non_exhaustive()
 150    }
 151}
 152
 153pub struct Zeta {
 154    client: Arc<Client>,
 155    events: VecDeque<Event>,
 156    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 157    recent_completions: VecDeque<InlineCompletion>,
 158    rated_completions: HashSet<InlineCompletionId>,
 159    shown_completions: HashSet<InlineCompletionId>,
 160    llm_token: LlmApiToken,
 161    _llm_token_subscription: Subscription,
 162}
 163
 164impl Zeta {
 165    pub fn global(cx: &mut AppContext) -> Option<Model<Self>> {
 166        cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
 167    }
 168
 169    pub fn register(client: Arc<Client>, cx: &mut AppContext) -> Model<Self> {
 170        Self::global(cx).unwrap_or_else(|| {
 171            let model = cx.new_model(|cx| Self::new(client, cx));
 172            cx.set_global(ZetaGlobal(model.clone()));
 173            model
 174        })
 175    }
 176
 177    pub fn clear_history(&mut self) {
 178        self.events.clear();
 179    }
 180
 181    fn new(client: Arc<Client>, cx: &mut ModelContext<Self>) -> Self {
 182        let refresh_llm_token_listener = language_models::RefreshLlmTokenListener::global(cx);
 183
 184        Self {
 185            client,
 186            events: VecDeque::new(),
 187            recent_completions: VecDeque::new(),
 188            rated_completions: HashSet::default(),
 189            shown_completions: HashSet::default(),
 190            registered_buffers: HashMap::default(),
 191            llm_token: LlmApiToken::default(),
 192            _llm_token_subscription: cx.subscribe(
 193                &refresh_llm_token_listener,
 194                |this, _listener, _event, cx| {
 195                    let client = this.client.clone();
 196                    let llm_token = this.llm_token.clone();
 197                    cx.spawn(|_this, _cx| async move {
 198                        llm_token.refresh(&client).await?;
 199                        anyhow::Ok(())
 200                    })
 201                    .detach_and_log_err(cx);
 202                },
 203            ),
 204        }
 205    }
 206
 207    fn push_event(&mut self, event: Event) {
 208        const MAX_EVENT_COUNT: usize = 20;
 209
 210        if let Some(Event::BufferChange {
 211            new_snapshot: last_new_snapshot,
 212            timestamp: last_timestamp,
 213            ..
 214        }) = self.events.back_mut()
 215        {
 216            // Coalesce edits for the same buffer when they happen one after the other.
 217            let Event::BufferChange {
 218                old_snapshot,
 219                new_snapshot,
 220                timestamp,
 221            } = &event;
 222
 223            if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
 224                && old_snapshot.remote_id() == last_new_snapshot.remote_id()
 225                && old_snapshot.version == last_new_snapshot.version
 226            {
 227                *last_new_snapshot = new_snapshot.clone();
 228                *last_timestamp = *timestamp;
 229                return;
 230            }
 231        }
 232
 233        self.events.push_back(event);
 234        if self.events.len() > MAX_EVENT_COUNT {
 235            self.events.pop_front();
 236        }
 237    }
 238
 239    pub fn register_buffer(&mut self, buffer: &Model<Buffer>, cx: &mut ModelContext<Self>) {
 240        let buffer_id = buffer.entity_id();
 241        let weak_buffer = buffer.downgrade();
 242
 243        if let std::collections::hash_map::Entry::Vacant(entry) =
 244            self.registered_buffers.entry(buffer_id)
 245        {
 246            let snapshot = buffer.read(cx).snapshot();
 247
 248            entry.insert(RegisteredBuffer {
 249                snapshot,
 250                _subscriptions: [
 251                    cx.subscribe(buffer, move |this, buffer, event, cx| {
 252                        this.handle_buffer_event(buffer, event, cx);
 253                    }),
 254                    cx.observe_release(buffer, move |this, _buffer, _cx| {
 255                        this.registered_buffers.remove(&weak_buffer.entity_id());
 256                    }),
 257                ],
 258            });
 259        };
 260    }
 261
 262    fn handle_buffer_event(
 263        &mut self,
 264        buffer: Model<Buffer>,
 265        event: &language::BufferEvent,
 266        cx: &mut ModelContext<Self>,
 267    ) {
 268        match event {
 269            language::BufferEvent::Edited => {
 270                self.report_changes_for_buffer(&buffer, cx);
 271            }
 272            _ => {}
 273        }
 274    }
 275
 276    pub fn request_completion_impl<F, R>(
 277        &mut self,
 278        buffer: &Model<Buffer>,
 279        position: language::Anchor,
 280        cx: &mut ModelContext<Self>,
 281        perform_predict_edits: F,
 282    ) -> Task<Result<InlineCompletion>>
 283    where
 284        F: FnOnce(Arc<Client>, LlmApiToken, PredictEditsParams) -> R + 'static,
 285        R: Future<Output = Result<PredictEditsResponse>> + Send + 'static,
 286    {
 287        let snapshot = self.report_changes_for_buffer(buffer, cx);
 288        let point = position.to_point(&snapshot);
 289        let offset = point.to_offset(&snapshot);
 290        let excerpt_range = excerpt_range_for_position(point, &snapshot);
 291        let events = self.events.clone();
 292        let path = snapshot
 293            .file()
 294            .map(|f| f.path().clone())
 295            .unwrap_or_else(|| Arc::from(Path::new("untitled")));
 296
 297        let client = self.client.clone();
 298        let llm_token = self.llm_token.clone();
 299
 300        cx.spawn(|this, mut cx| async move {
 301            let request_sent_at = Instant::now();
 302
 303            let input_events = cx
 304                .background_executor()
 305                .spawn(async move {
 306                    let mut input_events = String::new();
 307                    for event in events {
 308                        if !input_events.is_empty() {
 309                            input_events.push('\n');
 310                            input_events.push('\n');
 311                        }
 312                        input_events.push_str(&event.to_prompt());
 313                    }
 314                    input_events
 315                })
 316                .await;
 317
 318            let input_excerpt = prompt_for_excerpt(&snapshot, &excerpt_range, offset);
 319            let input_outline = prompt_for_outline(&snapshot);
 320
 321            log::debug!("Events:\n{}\nExcerpt:\n{}", input_events, input_excerpt);
 322
 323            let body = PredictEditsParams {
 324                input_events: input_events.clone(),
 325                input_excerpt: input_excerpt.clone(),
 326            };
 327
 328            let response = perform_predict_edits(client, llm_token, body).await?;
 329
 330            let output_excerpt = response.output_excerpt;
 331            log::debug!("completion response: {}", output_excerpt);
 332
 333            let inline_completion = Self::process_completion_response(
 334                output_excerpt,
 335                &snapshot,
 336                excerpt_range,
 337                path,
 338                input_outline,
 339                input_events,
 340                input_excerpt,
 341                request_sent_at,
 342                &cx,
 343            )
 344            .await?;
 345
 346            this.update(&mut cx, |this, cx| {
 347                this.recent_completions
 348                    .push_front(inline_completion.clone());
 349                if this.recent_completions.len() > 50 {
 350                    let completion = this.recent_completions.pop_back().unwrap();
 351                    this.shown_completions.remove(&completion.id);
 352                    this.rated_completions.remove(&completion.id);
 353                }
 354                cx.notify();
 355            })?;
 356
 357            Ok(inline_completion)
 358        })
 359    }
 360
 361    // Generates several example completions of various states to fill the Zeta completion modal
 362    #[cfg(any(test, feature = "test-support"))]
 363    pub fn fill_with_fake_completions(&mut self, cx: &mut ModelContext<Self>) -> Task<()> {
 364        let test_buffer_text = indoc::indoc! {r#"a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 365            And maybe a short line
 366
 367            Then a few lines
 368
 369            and then another
 370            "#};
 371
 372        let buffer = cx.new_model(|cx| Buffer::local(test_buffer_text, cx));
 373        let position = buffer.read(cx).anchor_before(Point::new(1, 0));
 374
 375        let completion_tasks = vec![
 376            self.fake_completion(
 377                &buffer,
 378                position,
 379                PredictEditsResponse {
 380                    output_excerpt: format!("{EDITABLE_REGION_START_MARKER}
 381a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 382[here's an edit]
 383And maybe a short line
 384Then a few lines
 385and then another
 386{EDITABLE_REGION_END_MARKER}
 387                        ", ),
 388                },
 389                cx,
 390            ),
 391            self.fake_completion(
 392                &buffer,
 393                position,
 394                PredictEditsResponse {
 395                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 396a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 397And maybe a short line
 398[and another edit]
 399Then a few lines
 400and then another
 401{EDITABLE_REGION_END_MARKER}
 402                        "#),
 403                },
 404                cx,
 405            ),
 406            self.fake_completion(
 407                &buffer,
 408                position,
 409                PredictEditsResponse {
 410                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 411a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 412And maybe a short line
 413
 414Then a few lines
 415
 416and then another
 417{EDITABLE_REGION_END_MARKER}
 418                        "#),
 419                },
 420                cx,
 421            ),
 422            self.fake_completion(
 423                &buffer,
 424                position,
 425                PredictEditsResponse {
 426                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 427a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 428And maybe a short line
 429
 430Then a few lines
 431
 432and then another
 433{EDITABLE_REGION_END_MARKER}
 434                        "#),
 435                },
 436                cx,
 437            ),
 438            self.fake_completion(
 439                &buffer,
 440                position,
 441                PredictEditsResponse {
 442                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 443a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 444And maybe a short line
 445Then a few lines
 446[a third completion]
 447and then another
 448{EDITABLE_REGION_END_MARKER}
 449                        "#),
 450                },
 451                cx,
 452            ),
 453            self.fake_completion(
 454                &buffer,
 455                position,
 456                PredictEditsResponse {
 457                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 458a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 459And maybe a short line
 460and then another
 461[fourth completion example]
 462{EDITABLE_REGION_END_MARKER}
 463                        "#),
 464                },
 465                cx,
 466            ),
 467            self.fake_completion(
 468                &buffer,
 469                position,
 470                PredictEditsResponse {
 471                    output_excerpt: format!(r#"{EDITABLE_REGION_START_MARKER}
 472a longggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg line
 473And maybe a short line
 474Then a few lines
 475and then another
 476[fifth and final completion]
 477{EDITABLE_REGION_END_MARKER}
 478                        "#),
 479                },
 480                cx,
 481            ),
 482        ];
 483
 484        cx.spawn(|zeta, mut cx| async move {
 485            for task in completion_tasks {
 486                task.await.unwrap();
 487            }
 488
 489            zeta.update(&mut cx, |zeta, _cx| {
 490                zeta.recent_completions.get_mut(2).unwrap().edits = Arc::new([]);
 491                zeta.recent_completions.get_mut(3).unwrap().edits = Arc::new([]);
 492            })
 493            .ok();
 494        })
 495    }
 496
 497    #[cfg(any(test, feature = "test-support"))]
 498    pub fn fake_completion(
 499        &mut self,
 500        buffer: &Model<Buffer>,
 501        position: language::Anchor,
 502        response: PredictEditsResponse,
 503        cx: &mut ModelContext<Self>,
 504    ) -> Task<Result<InlineCompletion>> {
 505        use std::future::ready;
 506
 507        self.request_completion_impl(buffer, position, cx, |_, _, _| ready(Ok(response)))
 508    }
 509
 510    pub fn request_completion(
 511        &mut self,
 512        buffer: &Model<Buffer>,
 513        position: language::Anchor,
 514        cx: &mut ModelContext<Self>,
 515    ) -> Task<Result<InlineCompletion>> {
 516        self.request_completion_impl(buffer, position, cx, Self::perform_predict_edits)
 517    }
 518
 519    fn perform_predict_edits(
 520        client: Arc<Client>,
 521        llm_token: LlmApiToken,
 522        body: PredictEditsParams,
 523    ) -> impl Future<Output = Result<PredictEditsResponse>> {
 524        async move {
 525            let http_client = client.http_client();
 526            let mut token = llm_token.acquire(&client).await?;
 527            let mut did_retry = false;
 528
 529            loop {
 530                let request_builder = http_client::Request::builder();
 531                let request = request_builder
 532                    .method(Method::POST)
 533                    .uri(
 534                        http_client
 535                            .build_zed_llm_url("/predict_edits", &[])?
 536                            .as_ref(),
 537                    )
 538                    .header("Content-Type", "application/json")
 539                    .header("Authorization", format!("Bearer {}", token))
 540                    .body(serde_json::to_string(&body)?.into())?;
 541
 542                let mut response = http_client.send(request).await?;
 543
 544                if response.status().is_success() {
 545                    let mut body = String::new();
 546                    response.body_mut().read_to_string(&mut body).await?;
 547                    return Ok(serde_json::from_str(&body)?);
 548                } else if !did_retry
 549                    && response
 550                        .headers()
 551                        .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
 552                        .is_some()
 553                {
 554                    did_retry = true;
 555                    token = llm_token.refresh(&client).await?;
 556                } else {
 557                    let mut body = String::new();
 558                    response.body_mut().read_to_string(&mut body).await?;
 559                    return Err(anyhow!(
 560                        "error predicting edits.\nStatus: {:?}\nBody: {}",
 561                        response.status(),
 562                        body
 563                    ));
 564                }
 565            }
 566        }
 567    }
 568
 569    #[allow(clippy::too_many_arguments)]
 570    fn process_completion_response(
 571        output_excerpt: String,
 572        snapshot: &BufferSnapshot,
 573        excerpt_range: Range<usize>,
 574        path: Arc<Path>,
 575        input_outline: String,
 576        input_events: String,
 577        input_excerpt: String,
 578        request_sent_at: Instant,
 579        cx: &AsyncAppContext,
 580    ) -> Task<Result<InlineCompletion>> {
 581        let snapshot = snapshot.clone();
 582        cx.background_executor().spawn(async move {
 583            let content = output_excerpt.replace(CURSOR_MARKER, "");
 584
 585            let codefence_start = content
 586                .find(EDITABLE_REGION_START_MARKER)
 587                .context("could not find start marker")?;
 588            let content = &content[codefence_start..];
 589
 590            let newline_ix = content.find('\n').context("could not find newline")?;
 591            let content = &content[newline_ix + 1..];
 592
 593            let codefence_end = content
 594                .rfind(&format!("\n{EDITABLE_REGION_END_MARKER}"))
 595                .context("could not find end marker")?;
 596            let new_text = &content[..codefence_end];
 597
 598            let old_text = snapshot
 599                .text_for_range(excerpt_range.clone())
 600                .collect::<String>();
 601
 602            let edits = Self::compute_edits(old_text, new_text, excerpt_range.start, &snapshot);
 603
 604            Ok(InlineCompletion {
 605                id: InlineCompletionId::new(),
 606                path,
 607                excerpt_range,
 608                edits: edits.into(),
 609                snapshot: snapshot.clone(),
 610                input_outline: input_outline.into(),
 611                input_events: input_events.into(),
 612                input_excerpt: input_excerpt.into(),
 613                output_excerpt: output_excerpt.into(),
 614                request_sent_at,
 615                response_received_at: Instant::now(),
 616            })
 617        })
 618    }
 619
 620    pub fn compute_edits(
 621        old_text: String,
 622        new_text: &str,
 623        offset: usize,
 624        snapshot: &BufferSnapshot,
 625    ) -> Vec<(Range<Anchor>, String)> {
 626        let diff = similar::TextDiff::from_words(old_text.as_str(), new_text);
 627
 628        let mut edits: Vec<(Range<usize>, String)> = Vec::new();
 629        let mut old_start = offset;
 630        for change in diff.iter_all_changes() {
 631            let value = change.value();
 632            match change.tag() {
 633                similar::ChangeTag::Equal => {
 634                    old_start += value.len();
 635                }
 636                similar::ChangeTag::Delete => {
 637                    let old_end = old_start + value.len();
 638                    if let Some((last_old_range, _)) = edits.last_mut() {
 639                        if last_old_range.end == old_start {
 640                            last_old_range.end = old_end;
 641                        } else {
 642                            edits.push((old_start..old_end, String::new()));
 643                        }
 644                    } else {
 645                        edits.push((old_start..old_end, String::new()));
 646                    }
 647                    old_start = old_end;
 648                }
 649                similar::ChangeTag::Insert => {
 650                    if let Some((last_old_range, last_new_text)) = edits.last_mut() {
 651                        if last_old_range.end == old_start {
 652                            last_new_text.push_str(value);
 653                        } else {
 654                            edits.push((old_start..old_start, value.into()));
 655                        }
 656                    } else {
 657                        edits.push((old_start..old_start, value.into()));
 658                    }
 659                }
 660            }
 661        }
 662
 663        edits
 664            .into_iter()
 665            .map(|(mut old_range, new_text)| {
 666                let prefix_len = common_prefix(
 667                    snapshot.chars_for_range(old_range.clone()),
 668                    new_text.chars(),
 669                );
 670                old_range.start += prefix_len;
 671                let suffix_len = common_prefix(
 672                    snapshot.reversed_chars_for_range(old_range.clone()),
 673                    new_text[prefix_len..].chars().rev(),
 674                );
 675                old_range.end = old_range.end.saturating_sub(suffix_len);
 676
 677                let new_text = new_text[prefix_len..new_text.len() - suffix_len].to_string();
 678                (
 679                    snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end),
 680                    new_text,
 681                )
 682            })
 683            .collect()
 684    }
 685
 686    pub fn is_completion_rated(&self, completion_id: InlineCompletionId) -> bool {
 687        self.rated_completions.contains(&completion_id)
 688    }
 689
 690    pub fn was_completion_shown(&self, completion_id: InlineCompletionId) -> bool {
 691        self.shown_completions.contains(&completion_id)
 692    }
 693
 694    pub fn completion_shown(&mut self, completion_id: InlineCompletionId) {
 695        self.shown_completions.insert(completion_id);
 696    }
 697
 698    pub fn rate_completion(
 699        &mut self,
 700        completion: &InlineCompletion,
 701        rating: InlineCompletionRating,
 702        feedback: String,
 703        cx: &mut ModelContext<Self>,
 704    ) {
 705        telemetry::event!(
 706            "Inline Completion Rated",
 707            rating,
 708            input_events = completion.input_events,
 709            input_excerpt = completion.input_excerpt,
 710            input_outline = completion.input_outline,
 711            output_excerpt = completion.output_excerpt,
 712            feedback
 713        );
 714        self.client.telemetry().flush_events();
 715        cx.notify();
 716    }
 717
 718    pub fn recent_completions(&self) -> impl DoubleEndedIterator<Item = &InlineCompletion> {
 719        self.recent_completions.iter()
 720    }
 721
 722    pub fn recent_completions_len(&self) -> usize {
 723        self.recent_completions.len()
 724    }
 725
 726    fn report_changes_for_buffer(
 727        &mut self,
 728        buffer: &Model<Buffer>,
 729        cx: &mut ModelContext<Self>,
 730    ) -> BufferSnapshot {
 731        self.register_buffer(buffer, cx);
 732
 733        let registered_buffer = self
 734            .registered_buffers
 735            .get_mut(&buffer.entity_id())
 736            .unwrap();
 737        let new_snapshot = buffer.read(cx).snapshot();
 738
 739        if new_snapshot.version != registered_buffer.snapshot.version {
 740            let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
 741            self.push_event(Event::BufferChange {
 742                old_snapshot,
 743                new_snapshot: new_snapshot.clone(),
 744                timestamp: Instant::now(),
 745            });
 746        }
 747
 748        new_snapshot
 749    }
 750}
 751
 752fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
 753    a.zip(b)
 754        .take_while(|(a, b)| a == b)
 755        .map(|(a, _)| a.len_utf8())
 756        .sum()
 757}
 758
 759fn prompt_for_outline(snapshot: &BufferSnapshot) -> String {
 760    let mut input_outline = String::new();
 761
 762    writeln!(
 763        input_outline,
 764        "```{}",
 765        snapshot
 766            .file()
 767            .map_or(Cow::Borrowed("untitled"), |file| file
 768                .path()
 769                .to_string_lossy())
 770    )
 771    .unwrap();
 772
 773    if let Some(outline) = snapshot.outline(None) {
 774        let guess_size = outline.items.len() * 15;
 775        input_outline.reserve(guess_size);
 776        for item in outline.items.iter() {
 777            let spacing = " ".repeat(item.depth);
 778            writeln!(input_outline, "{}{}", spacing, item.text).unwrap();
 779        }
 780    }
 781
 782    writeln!(input_outline, "```").unwrap();
 783
 784    input_outline
 785}
 786
 787fn prompt_for_excerpt(
 788    snapshot: &BufferSnapshot,
 789    excerpt_range: &Range<usize>,
 790    offset: usize,
 791) -> String {
 792    let mut prompt_excerpt = String::new();
 793    writeln!(
 794        prompt_excerpt,
 795        "```{}",
 796        snapshot
 797            .file()
 798            .map_or(Cow::Borrowed("untitled"), |file| file
 799                .path()
 800                .to_string_lossy())
 801    )
 802    .unwrap();
 803
 804    if excerpt_range.start == 0 {
 805        writeln!(prompt_excerpt, "{START_OF_FILE_MARKER}").unwrap();
 806    }
 807
 808    let point_range = excerpt_range.to_point(snapshot);
 809    if point_range.start.row > 0 && !snapshot.is_line_blank(point_range.start.row - 1) {
 810        let extra_context_line_range = Point::new(point_range.start.row - 1, 0)..point_range.start;
 811        for chunk in snapshot.text_for_range(extra_context_line_range) {
 812            prompt_excerpt.push_str(chunk);
 813        }
 814    }
 815    writeln!(prompt_excerpt, "{EDITABLE_REGION_START_MARKER}").unwrap();
 816    for chunk in snapshot.text_for_range(excerpt_range.start..offset) {
 817        prompt_excerpt.push_str(chunk);
 818    }
 819    prompt_excerpt.push_str(CURSOR_MARKER);
 820    for chunk in snapshot.text_for_range(offset..excerpt_range.end) {
 821        prompt_excerpt.push_str(chunk);
 822    }
 823    write!(prompt_excerpt, "\n{EDITABLE_REGION_END_MARKER}").unwrap();
 824
 825    if point_range.end.row < snapshot.max_point().row
 826        && !snapshot.is_line_blank(point_range.end.row + 1)
 827    {
 828        let extra_context_line_range = point_range.end
 829            ..Point::new(
 830                point_range.end.row + 1,
 831                snapshot.line_len(point_range.end.row + 1),
 832            );
 833        for chunk in snapshot.text_for_range(extra_context_line_range) {
 834            prompt_excerpt.push_str(chunk);
 835        }
 836    }
 837
 838    write!(prompt_excerpt, "\n```").unwrap();
 839    prompt_excerpt
 840}
 841
 842fn excerpt_range_for_position(point: Point, snapshot: &BufferSnapshot) -> Range<usize> {
 843    const CONTEXT_LINES: u32 = 32;
 844
 845    let mut context_lines_before = CONTEXT_LINES;
 846    let mut context_lines_after = CONTEXT_LINES;
 847    if point.row < CONTEXT_LINES {
 848        context_lines_after += CONTEXT_LINES - point.row;
 849    } else if point.row + CONTEXT_LINES > snapshot.max_point().row {
 850        context_lines_before += (point.row + CONTEXT_LINES) - snapshot.max_point().row;
 851    }
 852
 853    let excerpt_start_row = point.row.saturating_sub(context_lines_before);
 854    let excerpt_start = Point::new(excerpt_start_row, 0);
 855    let excerpt_end_row = cmp::min(point.row + context_lines_after, snapshot.max_point().row);
 856    let excerpt_end = Point::new(excerpt_end_row, snapshot.line_len(excerpt_end_row));
 857    excerpt_start.to_offset(snapshot)..excerpt_end.to_offset(snapshot)
 858}
 859
 860struct RegisteredBuffer {
 861    snapshot: BufferSnapshot,
 862    _subscriptions: [gpui::Subscription; 2],
 863}
 864
 865#[derive(Clone)]
 866enum Event {
 867    BufferChange {
 868        old_snapshot: BufferSnapshot,
 869        new_snapshot: BufferSnapshot,
 870        timestamp: Instant,
 871    },
 872}
 873
 874impl Event {
 875    fn to_prompt(&self) -> String {
 876        match self {
 877            Event::BufferChange {
 878                old_snapshot,
 879                new_snapshot,
 880                ..
 881            } => {
 882                let mut prompt = String::new();
 883
 884                let old_path = old_snapshot
 885                    .file()
 886                    .map(|f| f.path().as_ref())
 887                    .unwrap_or(Path::new("untitled"));
 888                let new_path = new_snapshot
 889                    .file()
 890                    .map(|f| f.path().as_ref())
 891                    .unwrap_or(Path::new("untitled"));
 892                if old_path != new_path {
 893                    writeln!(prompt, "User renamed {:?} to {:?}\n", old_path, new_path).unwrap();
 894                }
 895
 896                let diff =
 897                    similar::TextDiff::from_lines(&old_snapshot.text(), &new_snapshot.text())
 898                        .unified_diff()
 899                        .to_string();
 900                if !diff.is_empty() {
 901                    write!(
 902                        prompt,
 903                        "User edited {:?}:\n```diff\n{}\n```",
 904                        new_path, diff
 905                    )
 906                    .unwrap();
 907                }
 908
 909                prompt
 910            }
 911        }
 912    }
 913}
 914
 915#[derive(Debug, Clone)]
 916struct CurrentInlineCompletion {
 917    buffer_id: EntityId,
 918    completion: InlineCompletion,
 919}
 920
 921impl CurrentInlineCompletion {
 922    fn should_replace_completion(&self, old_completion: &Self, snapshot: &BufferSnapshot) -> bool {
 923        if self.buffer_id != old_completion.buffer_id {
 924            return true;
 925        }
 926
 927        let Some(old_edits) = old_completion.completion.interpolate(&snapshot) else {
 928            return true;
 929        };
 930        let Some(new_edits) = self.completion.interpolate(&snapshot) else {
 931            return false;
 932        };
 933
 934        if old_edits.len() == 1 && new_edits.len() == 1 {
 935            let (old_range, old_text) = &old_edits[0];
 936            let (new_range, new_text) = &new_edits[0];
 937            new_range == old_range && new_text.starts_with(old_text)
 938        } else {
 939            true
 940        }
 941    }
 942}
 943
 944struct PendingCompletion {
 945    id: usize,
 946    _task: Task<Result<()>>,
 947}
 948
 949pub struct ZetaInlineCompletionProvider {
 950    zeta: Model<Zeta>,
 951    pending_completions: ArrayVec<PendingCompletion, 2>,
 952    next_pending_completion_id: usize,
 953    current_completion: Option<CurrentInlineCompletion>,
 954}
 955
 956impl ZetaInlineCompletionProvider {
 957    pub const DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(8);
 958
 959    pub fn new(zeta: Model<Zeta>) -> Self {
 960        Self {
 961            zeta,
 962            pending_completions: ArrayVec::new(),
 963            next_pending_completion_id: 0,
 964            current_completion: None,
 965        }
 966    }
 967}
 968
 969impl inline_completion::InlineCompletionProvider for ZetaInlineCompletionProvider {
 970    fn name() -> &'static str {
 971        "zeta"
 972    }
 973
 974    fn display_name() -> &'static str {
 975        "Zeta"
 976    }
 977
 978    fn show_completions_in_menu() -> bool {
 979        true
 980    }
 981
 982    fn show_completions_in_normal_mode() -> bool {
 983        true
 984    }
 985
 986    fn is_enabled(
 987        &self,
 988        buffer: &Model<Buffer>,
 989        cursor_position: language::Anchor,
 990        cx: &AppContext,
 991    ) -> bool {
 992        let buffer = buffer.read(cx);
 993        let file = buffer.file();
 994        let language = buffer.language_at(cursor_position);
 995        let settings = all_language_settings(file, cx);
 996        settings.inline_completions_enabled(language.as_ref(), file.map(|f| f.path().as_ref()), cx)
 997    }
 998
 999    fn refresh(
1000        &mut self,
1001        buffer: Model<Buffer>,
1002        position: language::Anchor,
1003        debounce: bool,
1004        cx: &mut ModelContext<Self>,
1005    ) {
1006        let pending_completion_id = self.next_pending_completion_id;
1007        self.next_pending_completion_id += 1;
1008
1009        let task = cx.spawn(|this, mut cx| async move {
1010            if debounce {
1011                cx.background_executor().timer(Self::DEBOUNCE_TIMEOUT).await;
1012            }
1013
1014            let completion_request = this.update(&mut cx, |this, cx| {
1015                this.zeta.update(cx, |zeta, cx| {
1016                    zeta.request_completion(&buffer, position, cx)
1017                })
1018            });
1019
1020            let mut completion = None;
1021            if let Ok(completion_request) = completion_request {
1022                completion = Some(CurrentInlineCompletion {
1023                    buffer_id: buffer.entity_id(),
1024                    completion: completion_request.await?,
1025                });
1026            }
1027
1028            this.update(&mut cx, |this, cx| {
1029                if this.pending_completions[0].id == pending_completion_id {
1030                    this.pending_completions.remove(0);
1031                } else {
1032                    this.pending_completions.clear();
1033                }
1034
1035                if let Some(new_completion) = completion {
1036                    if let Some(old_completion) = this.current_completion.as_ref() {
1037                        let snapshot = buffer.read(cx).snapshot();
1038                        if new_completion.should_replace_completion(&old_completion, &snapshot) {
1039                            this.zeta.update(cx, |zeta, _cx| {
1040                                zeta.completion_shown(new_completion.completion.id)
1041                            });
1042                            this.current_completion = Some(new_completion);
1043                        }
1044                    } else {
1045                        this.zeta.update(cx, |zeta, _cx| {
1046                            zeta.completion_shown(new_completion.completion.id)
1047                        });
1048                        this.current_completion = Some(new_completion);
1049                    }
1050                } else {
1051                    this.current_completion = None;
1052                }
1053
1054                cx.notify();
1055            })
1056        });
1057
1058        // We always maintain at most two pending completions. When we already
1059        // have two, we replace the newest one.
1060        if self.pending_completions.len() <= 1 {
1061            self.pending_completions.push(PendingCompletion {
1062                id: pending_completion_id,
1063                _task: task,
1064            });
1065        } else if self.pending_completions.len() == 2 {
1066            self.pending_completions.pop();
1067            self.pending_completions.push(PendingCompletion {
1068                id: pending_completion_id,
1069                _task: task,
1070            });
1071        }
1072    }
1073
1074    fn cycle(
1075        &mut self,
1076        _buffer: Model<Buffer>,
1077        _cursor_position: language::Anchor,
1078        _direction: inline_completion::Direction,
1079        _cx: &mut ModelContext<Self>,
1080    ) {
1081        // Right now we don't support cycling.
1082    }
1083
1084    fn accept(&mut self, _cx: &mut ModelContext<Self>) {
1085        self.pending_completions.clear();
1086    }
1087
1088    fn discard(&mut self, _cx: &mut ModelContext<Self>) {
1089        self.pending_completions.clear();
1090        self.current_completion.take();
1091    }
1092
1093    fn suggest(
1094        &mut self,
1095        buffer: &Model<Buffer>,
1096        cursor_position: language::Anchor,
1097        cx: &mut ModelContext<Self>,
1098    ) -> Option<inline_completion::InlineCompletion> {
1099        let CurrentInlineCompletion {
1100            buffer_id,
1101            completion,
1102            ..
1103        } = self.current_completion.as_mut()?;
1104
1105        // Invalidate previous completion if it was generated for a different buffer.
1106        if *buffer_id != buffer.entity_id() {
1107            self.current_completion.take();
1108            return None;
1109        }
1110
1111        let buffer = buffer.read(cx);
1112        let Some(edits) = completion.interpolate(&buffer.snapshot()) else {
1113            self.current_completion.take();
1114            return None;
1115        };
1116
1117        let cursor_row = cursor_position.to_point(buffer).row;
1118        let (closest_edit_ix, (closest_edit_range, _)) =
1119            edits.iter().enumerate().min_by_key(|(_, (range, _))| {
1120                let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
1121                let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
1122                cmp::min(distance_from_start, distance_from_end)
1123            })?;
1124
1125        let mut edit_start_ix = closest_edit_ix;
1126        for (range, _) in edits[..edit_start_ix].iter().rev() {
1127            let distance_from_closest_edit =
1128                closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
1129            if distance_from_closest_edit <= 1 {
1130                edit_start_ix -= 1;
1131            } else {
1132                break;
1133            }
1134        }
1135
1136        let mut edit_end_ix = closest_edit_ix + 1;
1137        for (range, _) in &edits[edit_end_ix..] {
1138            let distance_from_closest_edit =
1139                range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
1140            if distance_from_closest_edit <= 1 {
1141                edit_end_ix += 1;
1142            } else {
1143                break;
1144            }
1145        }
1146
1147        Some(inline_completion::InlineCompletion {
1148            edits: edits[edit_start_ix..edit_end_ix].to_vec(),
1149        })
1150    }
1151}
1152
1153#[cfg(test)]
1154mod tests {
1155    use client::test::FakeServer;
1156    use clock::FakeSystemClock;
1157    use gpui::TestAppContext;
1158    use http_client::FakeHttpClient;
1159    use indoc::indoc;
1160    use language_models::RefreshLlmTokenListener;
1161    use rpc::proto;
1162    use settings::SettingsStore;
1163
1164    use super::*;
1165
1166    #[gpui::test]
1167    fn test_inline_completion_basic_interpolation(cx: &mut AppContext) {
1168        let buffer = cx.new_model(|cx| Buffer::local("Lorem ipsum dolor", cx));
1169        let completion = InlineCompletion {
1170            edits: to_completion_edits(
1171                [(2..5, "REM".to_string()), (9..11, "".to_string())],
1172                &buffer,
1173                cx,
1174            )
1175            .into(),
1176            path: Path::new("").into(),
1177            snapshot: buffer.read(cx).snapshot(),
1178            id: InlineCompletionId::new(),
1179            excerpt_range: 0..0,
1180            input_outline: "".into(),
1181            input_events: "".into(),
1182            input_excerpt: "".into(),
1183            output_excerpt: "".into(),
1184            request_sent_at: Instant::now(),
1185            response_received_at: Instant::now(),
1186        };
1187
1188        assert_eq!(
1189            from_completion_edits(
1190                &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1191                &buffer,
1192                cx
1193            ),
1194            vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1195        );
1196
1197        buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1198        assert_eq!(
1199            from_completion_edits(
1200                &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1201                &buffer,
1202                cx
1203            ),
1204            vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
1205        );
1206
1207        buffer.update(cx, |buffer, cx| buffer.undo(cx));
1208        assert_eq!(
1209            from_completion_edits(
1210                &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1211                &buffer,
1212                cx
1213            ),
1214            vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1215        );
1216
1217        buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1218        assert_eq!(
1219            from_completion_edits(
1220                &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1221                &buffer,
1222                cx
1223            ),
1224            vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
1225        );
1226
1227        buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1228        assert_eq!(
1229            from_completion_edits(
1230                &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1231                &buffer,
1232                cx
1233            ),
1234            vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1235        );
1236
1237        buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1238        assert_eq!(
1239            from_completion_edits(
1240                &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1241                &buffer,
1242                cx
1243            ),
1244            vec![(9..11, "".to_string())]
1245        );
1246
1247        buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1248        assert_eq!(
1249            from_completion_edits(
1250                &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1251                &buffer,
1252                cx
1253            ),
1254            vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1255        );
1256
1257        buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1258        assert_eq!(
1259            from_completion_edits(
1260                &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1261                &buffer,
1262                cx
1263            ),
1264            vec![(4..4, "M".to_string())]
1265        );
1266
1267        buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1268        assert_eq!(completion.interpolate(&buffer.read(cx).snapshot()), None);
1269    }
1270
1271    #[gpui::test]
1272    async fn test_inline_completion_end_of_buffer(cx: &mut TestAppContext) {
1273        cx.update(|cx| {
1274            let settings_store = SettingsStore::test(cx);
1275            cx.set_global(settings_store);
1276            client::init_settings(cx);
1277        });
1278
1279        let buffer_content = "lorem\n";
1280        let completion_response = indoc! {"
1281            ```animals.js
1282            <|start_of_file|>
1283            <|editable_region_start|>
1284            lorem
1285            ipsum
1286            <|editable_region_end|>
1287            ```"};
1288
1289        let http_client = FakeHttpClient::create(move |_| async move {
1290            Ok(http_client::Response::builder()
1291                .status(200)
1292                .body(
1293                    serde_json::to_string(&PredictEditsResponse {
1294                        output_excerpt: completion_response.to_string(),
1295                    })
1296                    .unwrap()
1297                    .into(),
1298                )
1299                .unwrap())
1300        });
1301
1302        let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
1303        cx.update(|cx| {
1304            RefreshLlmTokenListener::register(client.clone(), cx);
1305        });
1306        let server = FakeServer::for_client(42, &client, cx).await;
1307
1308        let zeta = cx.new_model(|cx| Zeta::new(client, cx));
1309        let buffer = cx.new_model(|cx| Buffer::local(buffer_content, cx));
1310        let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
1311        let completion_task =
1312            zeta.update(cx, |zeta, cx| zeta.request_completion(&buffer, cursor, cx));
1313
1314        let token_request = server.receive::<proto::GetLlmToken>().await.unwrap();
1315        server.respond(
1316            token_request.receipt(),
1317            proto::GetLlmTokenResponse { token: "".into() },
1318        );
1319
1320        let completion = completion_task.await.unwrap();
1321        buffer.update(cx, |buffer, cx| {
1322            buffer.edit(completion.edits.iter().cloned(), None, cx)
1323        });
1324        assert_eq!(
1325            buffer.read_with(cx, |buffer, _| buffer.text()),
1326            "lorem\nipsum"
1327        );
1328    }
1329
1330    fn to_completion_edits(
1331        iterator: impl IntoIterator<Item = (Range<usize>, String)>,
1332        buffer: &Model<Buffer>,
1333        cx: &AppContext,
1334    ) -> Vec<(Range<Anchor>, String)> {
1335        let buffer = buffer.read(cx);
1336        iterator
1337            .into_iter()
1338            .map(|(range, text)| {
1339                (
1340                    buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
1341                    text,
1342                )
1343            })
1344            .collect()
1345    }
1346
1347    fn from_completion_edits(
1348        editor_edits: &[(Range<Anchor>, String)],
1349        buffer: &Model<Buffer>,
1350        cx: &AppContext,
1351    ) -> Vec<(Range<usize>, String)> {
1352        let buffer = buffer.read(cx);
1353        editor_edits
1354            .iter()
1355            .map(|(range, text)| {
1356                (
1357                    range.start.to_offset(buffer)..range.end.to_offset(buffer),
1358                    text.clone(),
1359                )
1360            })
1361            .collect()
1362    }
1363
1364    #[ctor::ctor]
1365    fn init_logger() {
1366        if std::env::var("RUST_LOG").is_ok() {
1367            env_logger::init();
1368        }
1369    }
1370}