assistant.rs

   1use crate::{
   2    assistant_settings::{AssistantDockPosition, AssistantSettings},
   3    OpenAIRequest, OpenAIResponseStreamEvent, RequestMessage, Role,
   4};
   5use anyhow::{anyhow, Result};
   6use chrono::{DateTime, Local};
   7use collections::{HashMap, HashSet};
   8use editor::{
   9    display_map::{BlockDisposition, BlockId, BlockProperties, BlockStyle, ToDisplayPoint},
  10    scroll::{
  11        autoscroll::{Autoscroll, AutoscrollStrategy},
  12        ScrollAnchor,
  13    },
  14    Anchor, DisplayPoint, Editor, ToOffset as _,
  15};
  16use fs::Fs;
  17use futures::{io::BufReader, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
  18use gpui::{
  19    actions,
  20    elements::*,
  21    executor::Background,
  22    geometry::vector::vec2f,
  23    platform::{CursorStyle, MouseButton},
  24    Action, AppContext, AsyncAppContext, ClipboardItem, Entity, ModelContext, ModelHandle,
  25    Subscription, Task, View, ViewContext, ViewHandle, WeakViewHandle, WindowContext,
  26};
  27use isahc::{http::StatusCode, Request, RequestExt};
  28use language::{language_settings::SoftWrap, Buffer, LanguageRegistry, ToOffset as _};
  29use serde::Deserialize;
  30use settings::SettingsStore;
  31use std::{
  32    borrow::Cow, cell::RefCell, cmp, fmt::Write, io, iter, ops::Range, rc::Rc, sync::Arc,
  33    time::Duration,
  34};
  35use util::{post_inc, truncate_and_trailoff, ResultExt, TryFutureExt};
  36use workspace::{
  37    dock::{DockPosition, Panel},
  38    item::Item,
  39    pane, Pane, Workspace,
  40};
  41
  42const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
  43
  44actions!(
  45    assistant,
  46    [NewContext, Assist, QuoteSelection, ToggleFocus, ResetKey]
  47);
  48
  49pub fn init(cx: &mut AppContext) {
  50    settings::register::<AssistantSettings>(cx);
  51    cx.add_action(
  52        |workspace: &mut Workspace, _: &NewContext, cx: &mut ViewContext<Workspace>| {
  53            if let Some(this) = workspace.panel::<AssistantPanel>(cx) {
  54                this.update(cx, |this, cx| this.add_context(cx))
  55            }
  56
  57            workspace.focus_panel::<AssistantPanel>(cx);
  58        },
  59    );
  60    cx.add_action(AssistantEditor::assist);
  61    cx.capture_action(AssistantEditor::cancel_last_assist);
  62    cx.add_action(AssistantEditor::quote_selection);
  63    cx.capture_action(AssistantEditor::copy);
  64    cx.add_action(AssistantPanel::save_api_key);
  65    cx.add_action(AssistantPanel::reset_api_key);
  66}
  67
  68pub enum AssistantPanelEvent {
  69    ZoomIn,
  70    ZoomOut,
  71    Focus,
  72    Close,
  73    DockPositionChanged,
  74}
  75
  76pub struct AssistantPanel {
  77    width: Option<f32>,
  78    height: Option<f32>,
  79    pane: ViewHandle<Pane>,
  80    api_key: Rc<RefCell<Option<String>>>,
  81    api_key_editor: Option<ViewHandle<Editor>>,
  82    has_read_credentials: bool,
  83    languages: Arc<LanguageRegistry>,
  84    fs: Arc<dyn Fs>,
  85    subscriptions: Vec<Subscription>,
  86}
  87
  88impl AssistantPanel {
  89    pub fn load(
  90        workspace: WeakViewHandle<Workspace>,
  91        cx: AsyncAppContext,
  92    ) -> Task<Result<ViewHandle<Self>>> {
  93        cx.spawn(|mut cx| async move {
  94            // TODO: deserialize state.
  95            workspace.update(&mut cx, |workspace, cx| {
  96                cx.add_view::<Self, _>(|cx| {
  97                    let weak_self = cx.weak_handle();
  98                    let pane = cx.add_view(|cx| {
  99                        let mut pane = Pane::new(
 100                            workspace.weak_handle(),
 101                            workspace.project().clone(),
 102                            workspace.app_state().background_actions,
 103                            Default::default(),
 104                            cx,
 105                        );
 106                        pane.set_can_split(false, cx);
 107                        pane.set_can_navigate(false, cx);
 108                        pane.on_can_drop(move |_, _| false);
 109                        pane.set_render_tab_bar_buttons(cx, move |pane, cx| {
 110                            let weak_self = weak_self.clone();
 111                            Flex::row()
 112                                .with_child(Pane::render_tab_bar_button(
 113                                    0,
 114                                    "icons/plus_12.svg",
 115                                    false,
 116                                    Some(("New Context".into(), Some(Box::new(NewContext)))),
 117                                    cx,
 118                                    move |_, cx| {
 119                                        let weak_self = weak_self.clone();
 120                                        cx.window_context().defer(move |cx| {
 121                                            if let Some(this) = weak_self.upgrade(cx) {
 122                                                this.update(cx, |this, cx| this.add_context(cx));
 123                                            }
 124                                        })
 125                                    },
 126                                    None,
 127                                ))
 128                                .with_child(Pane::render_tab_bar_button(
 129                                    1,
 130                                    if pane.is_zoomed() {
 131                                        "icons/minimize_8.svg"
 132                                    } else {
 133                                        "icons/maximize_8.svg"
 134                                    },
 135                                    pane.is_zoomed(),
 136                                    Some((
 137                                        "Toggle Zoom".into(),
 138                                        Some(Box::new(workspace::ToggleZoom)),
 139                                    )),
 140                                    cx,
 141                                    move |pane, cx| pane.toggle_zoom(&Default::default(), cx),
 142                                    None,
 143                                ))
 144                                .into_any()
 145                        });
 146                        let buffer_search_bar = cx.add_view(search::BufferSearchBar::new);
 147                        pane.toolbar()
 148                            .update(cx, |toolbar, cx| toolbar.add_item(buffer_search_bar, cx));
 149                        pane
 150                    });
 151
 152                    let mut this = Self {
 153                        pane,
 154                        api_key: Rc::new(RefCell::new(None)),
 155                        api_key_editor: None,
 156                        has_read_credentials: false,
 157                        languages: workspace.app_state().languages.clone(),
 158                        fs: workspace.app_state().fs.clone(),
 159                        width: None,
 160                        height: None,
 161                        subscriptions: Default::default(),
 162                    };
 163
 164                    let mut old_dock_position = this.position(cx);
 165                    this.subscriptions = vec![
 166                        cx.observe(&this.pane, |_, _, cx| cx.notify()),
 167                        cx.subscribe(&this.pane, Self::handle_pane_event),
 168                        cx.observe_global::<SettingsStore, _>(move |this, cx| {
 169                            let new_dock_position = this.position(cx);
 170                            if new_dock_position != old_dock_position {
 171                                old_dock_position = new_dock_position;
 172                                cx.emit(AssistantPanelEvent::DockPositionChanged);
 173                            }
 174                        }),
 175                    ];
 176
 177                    this
 178                })
 179            })
 180        })
 181    }
 182
 183    fn handle_pane_event(
 184        &mut self,
 185        _pane: ViewHandle<Pane>,
 186        event: &pane::Event,
 187        cx: &mut ViewContext<Self>,
 188    ) {
 189        match event {
 190            pane::Event::ZoomIn => cx.emit(AssistantPanelEvent::ZoomIn),
 191            pane::Event::ZoomOut => cx.emit(AssistantPanelEvent::ZoomOut),
 192            pane::Event::Focus => cx.emit(AssistantPanelEvent::Focus),
 193            pane::Event::Remove => cx.emit(AssistantPanelEvent::Close),
 194            _ => {}
 195        }
 196    }
 197
 198    fn add_context(&mut self, cx: &mut ViewContext<Self>) {
 199        let focus = self.has_focus(cx);
 200        let editor = cx
 201            .add_view(|cx| AssistantEditor::new(self.api_key.clone(), self.languages.clone(), cx));
 202        self.subscriptions
 203            .push(cx.subscribe(&editor, Self::handle_assistant_editor_event));
 204        self.pane.update(cx, |pane, cx| {
 205            pane.add_item(Box::new(editor), true, focus, None, cx)
 206        });
 207    }
 208
 209    fn handle_assistant_editor_event(
 210        &mut self,
 211        _: ViewHandle<AssistantEditor>,
 212        event: &AssistantEditorEvent,
 213        cx: &mut ViewContext<Self>,
 214    ) {
 215        match event {
 216            AssistantEditorEvent::TabContentChanged => self.pane.update(cx, |_, cx| cx.notify()),
 217        }
 218    }
 219
 220    fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
 221        if let Some(api_key) = self
 222            .api_key_editor
 223            .as_ref()
 224            .map(|editor| editor.read(cx).text(cx))
 225        {
 226            if !api_key.is_empty() {
 227                cx.platform()
 228                    .write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
 229                    .log_err();
 230                *self.api_key.borrow_mut() = Some(api_key);
 231                self.api_key_editor.take();
 232                cx.focus_self();
 233                cx.notify();
 234            }
 235        } else {
 236            cx.propagate_action();
 237        }
 238    }
 239
 240    fn reset_api_key(&mut self, _: &ResetKey, cx: &mut ViewContext<Self>) {
 241        cx.platform().delete_credentials(OPENAI_API_URL).log_err();
 242        self.api_key.take();
 243        self.api_key_editor = Some(build_api_key_editor(cx));
 244        cx.focus_self();
 245        cx.notify();
 246    }
 247}
 248
 249fn build_api_key_editor(cx: &mut ViewContext<AssistantPanel>) -> ViewHandle<Editor> {
 250    cx.add_view(|cx| {
 251        let mut editor = Editor::single_line(
 252            Some(Arc::new(|theme| theme.assistant.api_key_editor.clone())),
 253            cx,
 254        );
 255        editor.set_placeholder_text("sk-000000000000000000000000000000000000000000000000", cx);
 256        editor
 257    })
 258}
 259
 260impl Entity for AssistantPanel {
 261    type Event = AssistantPanelEvent;
 262}
 263
 264impl View for AssistantPanel {
 265    fn ui_name() -> &'static str {
 266        "AssistantPanel"
 267    }
 268
 269    fn render(&mut self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
 270        let style = &theme::current(cx).assistant;
 271        if let Some(api_key_editor) = self.api_key_editor.as_ref() {
 272            Flex::column()
 273                .with_child(
 274                    Text::new(
 275                        "Paste your OpenAI API key and press Enter to use the assistant",
 276                        style.api_key_prompt.text.clone(),
 277                    )
 278                    .aligned(),
 279                )
 280                .with_child(
 281                    ChildView::new(api_key_editor, cx)
 282                        .contained()
 283                        .with_style(style.api_key_editor.container)
 284                        .aligned(),
 285                )
 286                .contained()
 287                .with_style(style.api_key_prompt.container)
 288                .aligned()
 289                .into_any()
 290        } else {
 291            ChildView::new(&self.pane, cx).into_any()
 292        }
 293    }
 294
 295    fn focus_in(&mut self, _: gpui::AnyViewHandle, cx: &mut ViewContext<Self>) {
 296        if cx.is_self_focused() {
 297            if let Some(api_key_editor) = self.api_key_editor.as_ref() {
 298                cx.focus(api_key_editor);
 299            } else {
 300                cx.focus(&self.pane);
 301            }
 302        }
 303    }
 304}
 305
 306impl Panel for AssistantPanel {
 307    fn position(&self, cx: &WindowContext) -> DockPosition {
 308        match settings::get::<AssistantSettings>(cx).dock {
 309            AssistantDockPosition::Left => DockPosition::Left,
 310            AssistantDockPosition::Bottom => DockPosition::Bottom,
 311            AssistantDockPosition::Right => DockPosition::Right,
 312        }
 313    }
 314
 315    fn position_is_valid(&self, _: DockPosition) -> bool {
 316        true
 317    }
 318
 319    fn set_position(&mut self, position: DockPosition, cx: &mut ViewContext<Self>) {
 320        settings::update_settings_file::<AssistantSettings>(self.fs.clone(), cx, move |settings| {
 321            let dock = match position {
 322                DockPosition::Left => AssistantDockPosition::Left,
 323                DockPosition::Bottom => AssistantDockPosition::Bottom,
 324                DockPosition::Right => AssistantDockPosition::Right,
 325            };
 326            settings.dock = Some(dock);
 327        });
 328    }
 329
 330    fn size(&self, cx: &WindowContext) -> f32 {
 331        let settings = settings::get::<AssistantSettings>(cx);
 332        match self.position(cx) {
 333            DockPosition::Left | DockPosition::Right => {
 334                self.width.unwrap_or_else(|| settings.default_width)
 335            }
 336            DockPosition::Bottom => self.height.unwrap_or_else(|| settings.default_height),
 337        }
 338    }
 339
 340    fn set_size(&mut self, size: f32, cx: &mut ViewContext<Self>) {
 341        match self.position(cx) {
 342            DockPosition::Left | DockPosition::Right => self.width = Some(size),
 343            DockPosition::Bottom => self.height = Some(size),
 344        }
 345        cx.notify();
 346    }
 347
 348    fn should_zoom_in_on_event(event: &AssistantPanelEvent) -> bool {
 349        matches!(event, AssistantPanelEvent::ZoomIn)
 350    }
 351
 352    fn should_zoom_out_on_event(event: &AssistantPanelEvent) -> bool {
 353        matches!(event, AssistantPanelEvent::ZoomOut)
 354    }
 355
 356    fn is_zoomed(&self, cx: &WindowContext) -> bool {
 357        self.pane.read(cx).is_zoomed()
 358    }
 359
 360    fn set_zoomed(&mut self, zoomed: bool, cx: &mut ViewContext<Self>) {
 361        self.pane.update(cx, |pane, cx| pane.set_zoomed(zoomed, cx));
 362    }
 363
 364    fn set_active(&mut self, active: bool, cx: &mut ViewContext<Self>) {
 365        if active {
 366            if self.api_key.borrow().is_none() && !self.has_read_credentials {
 367                self.has_read_credentials = true;
 368                let api_key = if let Some((_, api_key)) = cx
 369                    .platform()
 370                    .read_credentials(OPENAI_API_URL)
 371                    .log_err()
 372                    .flatten()
 373                {
 374                    String::from_utf8(api_key).log_err()
 375                } else {
 376                    None
 377                };
 378                if let Some(api_key) = api_key {
 379                    *self.api_key.borrow_mut() = Some(api_key);
 380                } else if self.api_key_editor.is_none() {
 381                    self.api_key_editor = Some(build_api_key_editor(cx));
 382                    cx.notify();
 383                }
 384            }
 385
 386            if self.pane.read(cx).items_len() == 0 {
 387                self.add_context(cx);
 388            }
 389        }
 390    }
 391
 392    fn icon_path(&self) -> &'static str {
 393        "icons/speech_bubble_12.svg"
 394    }
 395
 396    fn icon_tooltip(&self) -> (String, Option<Box<dyn Action>>) {
 397        ("Assistant Panel".into(), Some(Box::new(ToggleFocus)))
 398    }
 399
 400    fn should_change_position_on_event(event: &Self::Event) -> bool {
 401        matches!(event, AssistantPanelEvent::DockPositionChanged)
 402    }
 403
 404    fn should_activate_on_event(_: &Self::Event) -> bool {
 405        false
 406    }
 407
 408    fn should_close_on_event(event: &AssistantPanelEvent) -> bool {
 409        matches!(event, AssistantPanelEvent::Close)
 410    }
 411
 412    fn has_focus(&self, cx: &WindowContext) -> bool {
 413        self.pane.read(cx).has_focus()
 414            || self
 415                .api_key_editor
 416                .as_ref()
 417                .map_or(false, |editor| editor.is_focused(cx))
 418    }
 419
 420    fn is_focus_event(event: &Self::Event) -> bool {
 421        matches!(event, AssistantPanelEvent::Focus)
 422    }
 423}
 424
 425enum AssistantEvent {
 426    MessagesEdited,
 427    SummaryChanged,
 428    StreamedCompletion,
 429}
 430
 431struct Assistant {
 432    buffer: ModelHandle<Buffer>,
 433    messages: Vec<Message>,
 434    messages_metadata: HashMap<MessageId, MessageMetadata>,
 435    next_message_id: MessageId,
 436    summary: Option<String>,
 437    pending_summary: Task<Option<()>>,
 438    completion_count: usize,
 439    pending_completions: Vec<PendingCompletion>,
 440    model: String,
 441    token_count: Option<usize>,
 442    max_token_count: usize,
 443    pending_token_count: Task<Option<()>>,
 444    api_key: Rc<RefCell<Option<String>>>,
 445    _subscriptions: Vec<Subscription>,
 446}
 447
 448impl Entity for Assistant {
 449    type Event = AssistantEvent;
 450}
 451
 452impl Assistant {
 453    fn new(
 454        api_key: Rc<RefCell<Option<String>>>,
 455        language_registry: Arc<LanguageRegistry>,
 456        cx: &mut ModelContext<Self>,
 457    ) -> Self {
 458        let model = "gpt-3.5-turbo";
 459        let markdown = language_registry.language_for_name("Markdown");
 460        let buffer = cx.add_model(|cx| {
 461            let mut buffer = Buffer::new(0, "", cx);
 462            buffer.set_language_registry(language_registry);
 463            cx.spawn_weak(|buffer, mut cx| async move {
 464                let markdown = markdown.await?;
 465                let buffer = buffer
 466                    .upgrade(&cx)
 467                    .ok_or_else(|| anyhow!("buffer was dropped"))?;
 468                buffer.update(&mut cx, |buffer, cx| {
 469                    buffer.set_language(Some(markdown), cx)
 470                });
 471                anyhow::Ok(())
 472            })
 473            .detach_and_log_err(cx);
 474            buffer
 475        });
 476
 477        let mut this = Self {
 478            messages: Default::default(),
 479            messages_metadata: Default::default(),
 480            next_message_id: Default::default(),
 481            summary: None,
 482            pending_summary: Task::ready(None),
 483            completion_count: Default::default(),
 484            pending_completions: Default::default(),
 485            token_count: None,
 486            max_token_count: tiktoken_rs::model::get_context_size(model),
 487            pending_token_count: Task::ready(None),
 488            model: model.into(),
 489            _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
 490            api_key,
 491            buffer,
 492        };
 493        let message = Message {
 494            id: MessageId(post_inc(&mut this.next_message_id.0)),
 495            start: language::Anchor::MIN,
 496        };
 497        this.messages.push(message.clone());
 498        this.messages_metadata.insert(
 499            message.id,
 500            MessageMetadata {
 501                role: Role::User,
 502                sent_at: Local::now(),
 503                error: None,
 504            },
 505        );
 506
 507        this.count_remaining_tokens(cx);
 508        this
 509    }
 510
 511    fn handle_buffer_event(
 512        &mut self,
 513        _: ModelHandle<Buffer>,
 514        event: &language::Event,
 515        cx: &mut ModelContext<Self>,
 516    ) {
 517        match event {
 518            language::Event::Edited => {
 519                self.count_remaining_tokens(cx);
 520                cx.emit(AssistantEvent::MessagesEdited);
 521            }
 522            _ => {}
 523        }
 524    }
 525
 526    fn count_remaining_tokens(&mut self, cx: &mut ModelContext<Self>) {
 527        let messages = self
 528            .open_ai_request_messages(cx)
 529            .into_iter()
 530            .filter_map(|message| {
 531                Some(tiktoken_rs::ChatCompletionRequestMessage {
 532                    role: match message.role {
 533                        Role::User => "user".into(),
 534                        Role::Assistant => "assistant".into(),
 535                        Role::System => "system".into(),
 536                    },
 537                    content: message.content,
 538                    name: None,
 539                })
 540            })
 541            .collect::<Vec<_>>();
 542        let model = self.model.clone();
 543        self.pending_token_count = cx.spawn_weak(|this, mut cx| {
 544            async move {
 545                cx.background().timer(Duration::from_millis(200)).await;
 546                let token_count = cx
 547                    .background()
 548                    .spawn(async move { tiktoken_rs::num_tokens_from_messages(&model, &messages) })
 549                    .await?;
 550
 551                this.upgrade(&cx)
 552                    .ok_or_else(|| anyhow!("assistant was dropped"))?
 553                    .update(&mut cx, |this, cx| {
 554                        this.max_token_count = tiktoken_rs::model::get_context_size(&this.model);
 555                        this.token_count = Some(token_count);
 556                        cx.notify()
 557                    });
 558                anyhow::Ok(())
 559            }
 560            .log_err()
 561        });
 562    }
 563
 564    fn remaining_tokens(&self) -> Option<isize> {
 565        Some(self.max_token_count as isize - self.token_count? as isize)
 566    }
 567
 568    fn set_model(&mut self, model: String, cx: &mut ModelContext<Self>) {
 569        self.model = model;
 570        self.count_remaining_tokens(cx);
 571        cx.notify();
 572    }
 573
 574    fn assist(&mut self, cx: &mut ModelContext<Self>) -> Option<(Message, Message)> {
 575        let request = OpenAIRequest {
 576            model: self.model.clone(),
 577            messages: self.open_ai_request_messages(cx),
 578            stream: true,
 579        };
 580
 581        let api_key = self.api_key.borrow().clone()?;
 582        let stream = stream_completion(api_key, cx.background().clone(), request);
 583        let assistant_message =
 584            self.insert_message_after(self.messages.last()?.id, Role::Assistant, cx)?;
 585        let user_message = self.insert_message_after(assistant_message.id, Role::User, cx)?;
 586        let task = cx.spawn_weak({
 587            |this, mut cx| async move {
 588                let assistant_message_id = assistant_message.id;
 589                let stream_completion = async {
 590                    let mut messages = stream.await?;
 591
 592                    while let Some(message) = messages.next().await {
 593                        let mut message = message?;
 594                        if let Some(choice) = message.choices.pop() {
 595                            this.upgrade(&cx)
 596                                .ok_or_else(|| anyhow!("assistant was dropped"))?
 597                                .update(&mut cx, |this, cx| {
 598                                    let text: Arc<str> = choice.delta.content?.into();
 599                                    let message_ix = this
 600                                        .messages
 601                                        .iter()
 602                                        .position(|message| message.id == assistant_message_id)?;
 603                                    this.buffer.update(cx, |buffer, cx| {
 604                                        let offset = if message_ix + 1 == this.messages.len() {
 605                                            buffer.len()
 606                                        } else {
 607                                            this.messages[message_ix + 1]
 608                                                .start
 609                                                .to_offset(buffer)
 610                                                .saturating_sub(1)
 611                                        };
 612                                        buffer.edit([(offset..offset, text)], None, cx);
 613                                    });
 614
 615                                    Some(())
 616                                });
 617                        }
 618                    }
 619
 620                    this.upgrade(&cx)
 621                        .ok_or_else(|| anyhow!("assistant was dropped"))?
 622                        .update(&mut cx, |this, cx| {
 623                            this.pending_completions
 624                                .retain(|completion| completion.id != this.completion_count);
 625                            this.summarize(cx);
 626                        });
 627
 628                    anyhow::Ok(())
 629                };
 630
 631                let result = stream_completion.await;
 632                if let Some(this) = this.upgrade(&cx) {
 633                    this.update(&mut cx, |this, cx| {
 634                        if let Err(error) = result {
 635                            if let Some(metadata) =
 636                                this.messages_metadata.get_mut(&assistant_message.id)
 637                            {
 638                                metadata.error = Some(error.to_string().trim().into());
 639                                cx.notify();
 640                            }
 641                        }
 642                    });
 643                }
 644            }
 645        });
 646
 647        self.pending_completions.push(PendingCompletion {
 648            id: post_inc(&mut self.completion_count),
 649            _task: task,
 650        });
 651        Some((assistant_message, user_message))
 652    }
 653
 654    fn cancel_last_assist(&mut self) -> bool {
 655        self.pending_completions.pop().is_some()
 656    }
 657
 658    fn cycle_message_role(&mut self, id: MessageId, cx: &mut ModelContext<Self>) {
 659        if let Some(metadata) = self.messages_metadata.get_mut(&id) {
 660            metadata.role.cycle();
 661            cx.emit(AssistantEvent::MessagesEdited);
 662            cx.notify();
 663        }
 664    }
 665
 666    fn insert_message_after(
 667        &mut self,
 668        message_id: MessageId,
 669        role: Role,
 670        cx: &mut ModelContext<Self>,
 671    ) -> Option<Message> {
 672        if let Some(prev_message_ix) = self
 673            .messages
 674            .iter()
 675            .position(|message| message.id == message_id)
 676        {
 677            let start = self.buffer.update(cx, |buffer, cx| {
 678                let offset = self.messages[prev_message_ix + 1..]
 679                    .iter()
 680                    .find(|message| message.start.is_valid(buffer))
 681                    .map_or(buffer.len(), |message| message.start.to_offset(buffer) - 1);
 682                buffer.edit([(offset..offset, "\n")], None, cx);
 683                buffer.anchor_before(offset + 1)
 684            });
 685            let message = Message {
 686                id: MessageId(post_inc(&mut self.next_message_id.0)),
 687                start,
 688            };
 689            self.messages.insert(prev_message_ix + 1, message.clone());
 690            self.messages_metadata.insert(
 691                message.id,
 692                MessageMetadata {
 693                    role,
 694                    sent_at: Local::now(),
 695                    error: None,
 696                },
 697            );
 698            Some(message)
 699        } else {
 700            None
 701        }
 702    }
 703
 704    fn summarize(&mut self, cx: &mut ModelContext<Self>) {
 705        if self.messages.len() >= 2 && self.summary.is_none() {
 706            let api_key = self.api_key.borrow().clone();
 707            if let Some(api_key) = api_key {
 708                let mut messages = self.open_ai_request_messages(cx);
 709                messages.truncate(2);
 710                messages.push(RequestMessage {
 711                    role: Role::User,
 712                    content: "Summarize the conversation into a short title without punctuation"
 713                        .into(),
 714                });
 715                let request = OpenAIRequest {
 716                    model: self.model.clone(),
 717                    messages,
 718                    stream: true,
 719                };
 720
 721                let stream = stream_completion(api_key, cx.background().clone(), request);
 722                self.pending_summary = cx.spawn(|this, mut cx| {
 723                    async move {
 724                        let mut messages = stream.await?;
 725
 726                        while let Some(message) = messages.next().await {
 727                            let mut message = message?;
 728                            if let Some(choice) = message.choices.pop() {
 729                                let text = choice.delta.content.unwrap_or_default();
 730                                this.update(&mut cx, |this, cx| {
 731                                    this.summary.get_or_insert(String::new()).push_str(&text);
 732                                    cx.emit(AssistantEvent::SummaryChanged);
 733                                });
 734                            }
 735                        }
 736
 737                        anyhow::Ok(())
 738                    }
 739                    .log_err()
 740                });
 741            }
 742        }
 743    }
 744
 745    fn open_ai_request_messages(&self, cx: &AppContext) -> Vec<RequestMessage> {
 746        let buffer = self.buffer.read(cx);
 747        self.messages(cx)
 748            .map(|(message, metadata, range)| RequestMessage {
 749                role: metadata.role,
 750                content: buffer.text_for_range(range).collect(),
 751            })
 752            .collect()
 753    }
 754
 755    fn message_id_for_offset(&self, offset: usize, cx: &AppContext) -> Option<MessageId> {
 756        Some(
 757            self.messages(cx)
 758                .find(|(_, _, range)| range.contains(&offset))
 759                .map(|(message, _, _)| message)
 760                .or(self.messages.last())?
 761                .id,
 762        )
 763    }
 764
 765    fn messages<'a>(
 766        &'a self,
 767        cx: &'a AppContext,
 768    ) -> impl 'a + Iterator<Item = (&Message, &MessageMetadata, Range<usize>)> {
 769        let buffer = self.buffer.read(cx);
 770        let mut messages = self.messages.iter().peekable();
 771        iter::from_fn(move || {
 772            while let Some(message) = messages.next() {
 773                let metadata = self.messages_metadata.get(&message.id)?;
 774                let message_start = message.start.to_offset(buffer);
 775                let mut message_end = None;
 776                while let Some(next_message) = messages.peek() {
 777                    if next_message.start.is_valid(buffer) {
 778                        message_end = Some(next_message.start);
 779                        break;
 780                    } else {
 781                        messages.next();
 782                    }
 783                }
 784                let message_end = message_end
 785                    .unwrap_or(language::Anchor::MAX)
 786                    .to_offset(buffer);
 787                return Some((message, metadata, message_start..message_end));
 788            }
 789            None
 790        })
 791    }
 792}
 793
 794struct PendingCompletion {
 795    id: usize,
 796    _task: Task<()>,
 797}
 798
 799enum AssistantEditorEvent {
 800    TabContentChanged,
 801}
 802
 803struct AssistantEditor {
 804    assistant: ModelHandle<Assistant>,
 805    editor: ViewHandle<Editor>,
 806    blocks: HashSet<BlockId>,
 807    scroll_bottom: ScrollAnchor,
 808    _subscriptions: Vec<Subscription>,
 809}
 810
 811impl AssistantEditor {
 812    fn new(
 813        api_key: Rc<RefCell<Option<String>>>,
 814        language_registry: Arc<LanguageRegistry>,
 815        cx: &mut ViewContext<Self>,
 816    ) -> Self {
 817        let assistant = cx.add_model(|cx| Assistant::new(api_key, language_registry, cx));
 818        let editor = cx.add_view(|cx| {
 819            let mut editor = Editor::for_buffer(assistant.read(cx).buffer.clone(), None, cx);
 820            editor.set_soft_wrap_mode(SoftWrap::EditorWidth, cx);
 821            editor.set_show_gutter(false, cx);
 822            editor
 823        });
 824
 825        let _subscriptions = vec![
 826            cx.observe(&assistant, |_, _, cx| cx.notify()),
 827            cx.subscribe(&assistant, Self::handle_assistant_event),
 828            cx.subscribe(&editor, Self::handle_editor_event),
 829        ];
 830
 831        Self {
 832            assistant,
 833            editor,
 834            blocks: Default::default(),
 835            scroll_bottom: ScrollAnchor {
 836                offset: Default::default(),
 837                anchor: Anchor::max(),
 838            },
 839            _subscriptions,
 840        }
 841    }
 842
 843    fn assist(&mut self, _: &Assist, cx: &mut ViewContext<Self>) {
 844        let user_message = self.assistant.update(cx, |assistant, cx| {
 845            let editor = self.editor.read(cx);
 846            let newest_selection = editor
 847                .selections
 848                .newest_anchor()
 849                .head()
 850                .to_offset(&editor.buffer().read(cx).snapshot(cx));
 851            let message_id = assistant.message_id_for_offset(newest_selection, cx)?;
 852            let metadata = assistant.messages_metadata.get(&message_id)?;
 853            let user_message = if metadata.role == Role::User {
 854                let (_, user_message) = assistant.assist(cx)?;
 855                user_message
 856            } else {
 857                let user_message = assistant.insert_message_after(message_id, Role::User, cx)?;
 858                user_message
 859            };
 860            Some(user_message)
 861        });
 862
 863        if let Some(user_message) = user_message {
 864            let cursor = user_message
 865                .start
 866                .to_offset(&self.assistant.read(cx).buffer.read(cx));
 867            self.editor.update(cx, |editor, cx| {
 868                editor.change_selections(
 869                    Some(Autoscroll::Strategy(AutoscrollStrategy::Fit)),
 870                    cx,
 871                    |selections| selections.select_ranges([cursor..cursor]),
 872                );
 873            });
 874            self.update_scroll_bottom(cx);
 875        }
 876    }
 877
 878    fn cancel_last_assist(&mut self, _: &editor::Cancel, cx: &mut ViewContext<Self>) {
 879        if !self
 880            .assistant
 881            .update(cx, |assistant, _| assistant.cancel_last_assist())
 882        {
 883            cx.propagate_action();
 884        }
 885    }
 886
 887    fn handle_assistant_event(
 888        &mut self,
 889        _: ModelHandle<Assistant>,
 890        event: &AssistantEvent,
 891        cx: &mut ViewContext<Self>,
 892    ) {
 893        match event {
 894            AssistantEvent::MessagesEdited => {
 895                self.editor.update(cx, |editor, cx| {
 896                    let buffer = editor.buffer().read(cx).snapshot(cx);
 897                    let excerpt_id = *buffer.as_singleton().unwrap().0;
 898                    let old_blocks = std::mem::take(&mut self.blocks);
 899                    let new_blocks =
 900                        self.assistant
 901                            .read(cx)
 902                            .messages(cx)
 903                            .map(|(message, metadata, _)| BlockProperties {
 904                                position: buffer.anchor_in_excerpt(excerpt_id, message.start),
 905                                height: 2,
 906                                style: BlockStyle::Sticky,
 907                                render: Arc::new({
 908                                    let assistant = self.assistant.clone();
 909                                    let metadata = metadata.clone();
 910                                    let message = message.clone();
 911                                    move |cx| {
 912                                        enum Sender {}
 913                                        enum ErrorTooltip {}
 914
 915                                        let theme = theme::current(cx);
 916                                        let style = &theme.assistant;
 917                                        let message_id = message.id;
 918                                        let sender = MouseEventHandler::<Sender, _>::new(
 919                                            message_id.0,
 920                                            cx,
 921                                            |state, _| match metadata.role {
 922                                                Role::User => {
 923                                                    let style =
 924                                                        style.user_sender.style_for(state, false);
 925                                                    Label::new("You", style.text.clone())
 926                                                        .contained()
 927                                                        .with_style(style.container)
 928                                                }
 929                                                Role::Assistant => {
 930                                                    let style = style
 931                                                        .assistant_sender
 932                                                        .style_for(state, false);
 933                                                    Label::new("Assistant", style.text.clone())
 934                                                        .contained()
 935                                                        .with_style(style.container)
 936                                                }
 937                                                Role::System => {
 938                                                    let style =
 939                                                        style.system_sender.style_for(state, false);
 940                                                    Label::new("System", style.text.clone())
 941                                                        .contained()
 942                                                        .with_style(style.container)
 943                                                }
 944                                            },
 945                                        )
 946                                        .with_cursor_style(CursorStyle::PointingHand)
 947                                        .on_down(MouseButton::Left, {
 948                                            let assistant = assistant.clone();
 949                                            move |_, _, cx| {
 950                                                assistant.update(cx, |assistant, cx| {
 951                                                    assistant.cycle_message_role(message_id, cx)
 952                                                })
 953                                            }
 954                                        });
 955
 956                                        Flex::row()
 957                                            .with_child(sender.aligned())
 958                                            .with_child(
 959                                                Label::new(
 960                                                    metadata.sent_at.format("%I:%M%P").to_string(),
 961                                                    style.sent_at.text.clone(),
 962                                                )
 963                                                .contained()
 964                                                .with_style(style.sent_at.container)
 965                                                .aligned(),
 966                                            )
 967                                            .with_children(metadata.error.clone().map(|error| {
 968                                                Svg::new("icons/circle_x_mark_12.svg")
 969                                                    .with_color(style.error_icon.color)
 970                                                    .constrained()
 971                                                    .with_width(style.error_icon.width)
 972                                                    .contained()
 973                                                    .with_style(style.error_icon.container)
 974                                                    .with_tooltip::<ErrorTooltip>(
 975                                                        message_id.0,
 976                                                        error,
 977                                                        None,
 978                                                        theme.tooltip.clone(),
 979                                                        cx,
 980                                                    )
 981                                                    .aligned()
 982                                            }))
 983                                            .aligned()
 984                                            .left()
 985                                            .contained()
 986                                            .with_style(style.header)
 987                                            .into_any()
 988                                    }
 989                                }),
 990                                disposition: BlockDisposition::Above,
 991                            })
 992                            .collect::<Vec<_>>();
 993
 994                    editor.remove_blocks(old_blocks, cx);
 995                    let ids = editor.insert_blocks(new_blocks, cx);
 996                    self.blocks = HashSet::from_iter(ids);
 997                });
 998            }
 999
1000            AssistantEvent::SummaryChanged => {
1001                cx.emit(AssistantEditorEvent::TabContentChanged);
1002            }
1003
1004            AssistantEvent::StreamedCompletion => {
1005                self.editor.update(cx, |editor, cx| {
1006                    let snapshot = editor.snapshot(cx);
1007                    let scroll_bottom_row = self
1008                        .scroll_bottom
1009                        .anchor
1010                        .to_display_point(&snapshot.display_snapshot)
1011                        .row();
1012
1013                    let scroll_bottom = scroll_bottom_row as f32 + self.scroll_bottom.offset.y();
1014                    let visible_line_count = editor.visible_line_count().unwrap_or(0.);
1015                    let scroll_top = scroll_bottom - visible_line_count;
1016                    editor
1017                        .set_scroll_position(vec2f(self.scroll_bottom.offset.x(), scroll_top), cx);
1018                });
1019            }
1020        }
1021    }
1022
1023    fn handle_editor_event(
1024        &mut self,
1025        _: ViewHandle<Editor>,
1026        event: &editor::Event,
1027        cx: &mut ViewContext<Self>,
1028    ) {
1029        match event {
1030            editor::Event::ScrollPositionChanged { .. } => self.update_scroll_bottom(cx),
1031            _ => {}
1032        }
1033    }
1034
1035    fn update_scroll_bottom(&mut self, cx: &mut ViewContext<Self>) {
1036        self.editor.update(cx, |editor, cx| {
1037            let snapshot = editor.snapshot(cx);
1038            let scroll_position = editor
1039                .scroll_manager
1040                .anchor()
1041                .scroll_position(&snapshot.display_snapshot);
1042            let scroll_bottom = scroll_position.y() + editor.visible_line_count().unwrap_or(0.);
1043            let scroll_bottom_point = cmp::min(
1044                DisplayPoint::new(scroll_bottom.floor() as u32, 0),
1045                snapshot.display_snapshot.max_point(),
1046            );
1047            let scroll_bottom_anchor = snapshot
1048                .buffer_snapshot
1049                .anchor_after(scroll_bottom_point.to_point(&snapshot.display_snapshot));
1050            let scroll_bottom_offset = vec2f(
1051                scroll_position.x(),
1052                scroll_bottom - scroll_bottom_point.row() as f32,
1053            );
1054            self.scroll_bottom = ScrollAnchor {
1055                anchor: scroll_bottom_anchor,
1056                offset: scroll_bottom_offset,
1057            };
1058        });
1059    }
1060
1061    fn quote_selection(
1062        workspace: &mut Workspace,
1063        _: &QuoteSelection,
1064        cx: &mut ViewContext<Workspace>,
1065    ) {
1066        let Some(panel) = workspace.panel::<AssistantPanel>(cx) else {
1067            return;
1068        };
1069        let Some(editor) = workspace.active_item(cx).and_then(|item| item.downcast::<Editor>()) else {
1070            return;
1071        };
1072
1073        let text = editor.read_with(cx, |editor, cx| {
1074            let range = editor.selections.newest::<usize>(cx).range();
1075            let buffer = editor.buffer().read(cx).snapshot(cx);
1076            let start_language = buffer.language_at(range.start);
1077            let end_language = buffer.language_at(range.end);
1078            let language_name = if start_language == end_language {
1079                start_language.map(|language| language.name())
1080            } else {
1081                None
1082            };
1083            let language_name = language_name.as_deref().unwrap_or("").to_lowercase();
1084
1085            let selected_text = buffer.text_for_range(range).collect::<String>();
1086            if selected_text.is_empty() {
1087                None
1088            } else {
1089                Some(if language_name == "markdown" {
1090                    selected_text
1091                        .lines()
1092                        .map(|line| format!("> {}", line))
1093                        .collect::<Vec<_>>()
1094                        .join("\n")
1095                } else {
1096                    format!("```{language_name}\n{selected_text}\n```")
1097                })
1098            }
1099        });
1100
1101        // Activate the panel
1102        if !panel.read(cx).has_focus(cx) {
1103            workspace.toggle_panel_focus::<AssistantPanel>(cx);
1104        }
1105
1106        if let Some(text) = text {
1107            panel.update(cx, |panel, cx| {
1108                if let Some(assistant) = panel
1109                    .pane
1110                    .read(cx)
1111                    .active_item()
1112                    .and_then(|item| item.downcast::<AssistantEditor>())
1113                    .ok_or_else(|| anyhow!("no active context"))
1114                    .log_err()
1115                {
1116                    assistant.update(cx, |assistant, cx| {
1117                        assistant
1118                            .editor
1119                            .update(cx, |editor, cx| editor.insert(&text, cx))
1120                    });
1121                }
1122            });
1123        }
1124    }
1125
1126    fn copy(&mut self, _: &editor::Copy, cx: &mut ViewContext<Self>) {
1127        let editor = self.editor.read(cx);
1128        let assistant = self.assistant.read(cx);
1129        if editor.selections.count() == 1 {
1130            let selection = editor.selections.newest::<usize>(cx);
1131            let mut offset = 0;
1132            let mut copied_text = String::new();
1133            let mut spanned_messages = 0;
1134            for message in &assistant.messages {
1135                todo!();
1136                // let message_range = offset..offset + message.content.read(cx).len() + 1;
1137                let message_range = offset..offset + 1;
1138
1139                if message_range.start >= selection.range().end {
1140                    break;
1141                } else if message_range.end >= selection.range().start {
1142                    let range = cmp::max(message_range.start, selection.range().start)
1143                        ..cmp::min(message_range.end, selection.range().end);
1144                    if !range.is_empty() {
1145                        if let Some(metadata) = assistant.messages_metadata.get(&message.id) {
1146                            spanned_messages += 1;
1147                            write!(&mut copied_text, "## {}\n\n", metadata.role).unwrap();
1148                            for chunk in assistant.buffer.read(cx).text_for_range(range) {
1149                                copied_text.push_str(&chunk);
1150                            }
1151                            copied_text.push('\n');
1152                        }
1153                    }
1154                }
1155
1156                offset = message_range.end;
1157            }
1158
1159            if spanned_messages > 1 {
1160                cx.platform()
1161                    .write_to_clipboard(ClipboardItem::new(copied_text));
1162                return;
1163            }
1164        }
1165
1166        cx.propagate_action();
1167    }
1168
1169    fn cycle_model(&mut self, cx: &mut ViewContext<Self>) {
1170        self.assistant.update(cx, |assistant, cx| {
1171            let new_model = match assistant.model.as_str() {
1172                "gpt-4" => "gpt-3.5-turbo",
1173                _ => "gpt-4",
1174            };
1175            assistant.set_model(new_model.into(), cx);
1176        });
1177    }
1178
1179    fn title(&self, cx: &AppContext) -> String {
1180        self.assistant
1181            .read(cx)
1182            .summary
1183            .clone()
1184            .unwrap_or_else(|| "New Context".into())
1185    }
1186}
1187
1188impl Entity for AssistantEditor {
1189    type Event = AssistantEditorEvent;
1190}
1191
1192impl View for AssistantEditor {
1193    fn ui_name() -> &'static str {
1194        "AssistantEditor"
1195    }
1196
1197    fn render(&mut self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
1198        enum Model {}
1199        let theme = &theme::current(cx).assistant;
1200        let assistant = &self.assistant.read(cx);
1201        let model = assistant.model.clone();
1202        let remaining_tokens = assistant.remaining_tokens().map(|remaining_tokens| {
1203            let remaining_tokens_style = if remaining_tokens <= 0 {
1204                &theme.no_remaining_tokens
1205            } else {
1206                &theme.remaining_tokens
1207            };
1208            Label::new(
1209                remaining_tokens.to_string(),
1210                remaining_tokens_style.text.clone(),
1211            )
1212            .contained()
1213            .with_style(remaining_tokens_style.container)
1214        });
1215
1216        Stack::new()
1217            .with_child(
1218                ChildView::new(&self.editor, cx)
1219                    .contained()
1220                    .with_style(theme.container),
1221            )
1222            .with_child(
1223                Flex::row()
1224                    .with_child(
1225                        MouseEventHandler::<Model, _>::new(0, cx, |state, _| {
1226                            let style = theme.model.style_for(state, false);
1227                            Label::new(model, style.text.clone())
1228                                .contained()
1229                                .with_style(style.container)
1230                        })
1231                        .with_cursor_style(CursorStyle::PointingHand)
1232                        .on_click(MouseButton::Left, |_, this, cx| this.cycle_model(cx)),
1233                    )
1234                    .with_children(remaining_tokens)
1235                    .contained()
1236                    .with_style(theme.model_info_container)
1237                    .aligned()
1238                    .top()
1239                    .right(),
1240            )
1241            .into_any()
1242    }
1243
1244    fn focus_in(&mut self, _: gpui::AnyViewHandle, cx: &mut ViewContext<Self>) {
1245        if cx.is_self_focused() {
1246            cx.focus(&self.editor);
1247        }
1248    }
1249}
1250
1251impl Item for AssistantEditor {
1252    fn tab_content<V: View>(
1253        &self,
1254        _: Option<usize>,
1255        style: &theme::Tab,
1256        cx: &gpui::AppContext,
1257    ) -> AnyElement<V> {
1258        let title = truncate_and_trailoff(&self.title(cx), editor::MAX_TAB_TITLE_LEN);
1259        Label::new(title, style.label.clone()).into_any()
1260    }
1261
1262    fn tab_tooltip_text(&self, cx: &AppContext) -> Option<Cow<str>> {
1263        Some(self.title(cx).into())
1264    }
1265
1266    fn as_searchable(
1267        &self,
1268        _: &ViewHandle<Self>,
1269    ) -> Option<Box<dyn workspace::searchable::SearchableItemHandle>> {
1270        Some(Box::new(self.editor.clone()))
1271    }
1272}
1273
1274#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, Hash)]
1275struct MessageId(usize);
1276
1277#[derive(Clone, Debug)]
1278struct Message {
1279    id: MessageId,
1280    start: language::Anchor,
1281}
1282
1283#[derive(Clone, Debug)]
1284struct MessageMetadata {
1285    role: Role,
1286    sent_at: DateTime<Local>,
1287    error: Option<String>,
1288}
1289
1290async fn stream_completion(
1291    api_key: String,
1292    executor: Arc<Background>,
1293    mut request: OpenAIRequest,
1294) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
1295    request.stream = true;
1296
1297    let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
1298
1299    let json_data = serde_json::to_string(&request)?;
1300    let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
1301        .header("Content-Type", "application/json")
1302        .header("Authorization", format!("Bearer {}", api_key))
1303        .body(json_data)?
1304        .send_async()
1305        .await?;
1306
1307    let status = response.status();
1308    if status == StatusCode::OK {
1309        executor
1310            .spawn(async move {
1311                let mut lines = BufReader::new(response.body_mut()).lines();
1312
1313                fn parse_line(
1314                    line: Result<String, io::Error>,
1315                ) -> Result<Option<OpenAIResponseStreamEvent>> {
1316                    if let Some(data) = line?.strip_prefix("data: ") {
1317                        let event = serde_json::from_str(&data)?;
1318                        Ok(Some(event))
1319                    } else {
1320                        Ok(None)
1321                    }
1322                }
1323
1324                while let Some(line) = lines.next().await {
1325                    if let Some(event) = parse_line(line).transpose() {
1326                        let done = event.as_ref().map_or(false, |event| {
1327                            event
1328                                .choices
1329                                .last()
1330                                .map_or(false, |choice| choice.finish_reason.is_some())
1331                        });
1332                        if tx.unbounded_send(event).is_err() {
1333                            break;
1334                        }
1335
1336                        if done {
1337                            break;
1338                        }
1339                    }
1340                }
1341
1342                anyhow::Ok(())
1343            })
1344            .detach();
1345
1346        Ok(rx)
1347    } else {
1348        let mut body = String::new();
1349        response.body_mut().read_to_string(&mut body).await?;
1350
1351        #[derive(Deserialize)]
1352        struct OpenAIResponse {
1353            error: OpenAIError,
1354        }
1355
1356        #[derive(Deserialize)]
1357        struct OpenAIError {
1358            message: String,
1359        }
1360
1361        match serde_json::from_str::<OpenAIResponse>(&body) {
1362            Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
1363                "Failed to connect to OpenAI API: {}",
1364                response.error.message,
1365            )),
1366
1367            _ => Err(anyhow!(
1368                "Failed to connect to OpenAI API: {} {}",
1369                response.status(),
1370                body,
1371            )),
1372        }
1373    }
1374}
1375
1376#[cfg(test)]
1377mod tests {
1378    use super::*;
1379    use gpui::AppContext;
1380
1381    #[gpui::test]
1382    fn test_inserting_and_removing_messages(cx: &mut AppContext) {
1383        let registry = Arc::new(LanguageRegistry::test());
1384        let assistant = cx.add_model(|cx| Assistant::new(Default::default(), registry, cx));
1385        let buffer = assistant.read(cx).buffer.clone();
1386
1387        let message_1 = assistant.read(cx).messages[0].clone();
1388        assert_eq!(
1389            messages(&assistant, cx),
1390            vec![(message_1.id, Role::User, 0..0)]
1391        );
1392
1393        let message_2 = assistant.update(cx, |assistant, cx| {
1394            assistant
1395                .insert_message_after(message_1.id, Role::Assistant, cx)
1396                .unwrap()
1397        });
1398        assert_eq!(
1399            messages(&assistant, cx),
1400            vec![
1401                (message_1.id, Role::User, 0..1),
1402                (message_2.id, Role::Assistant, 1..1)
1403            ]
1404        );
1405
1406        buffer.update(cx, |buffer, cx| {
1407            buffer.edit([(0..0, "1"), (1..1, "2")], None, cx)
1408        });
1409        assert_eq!(
1410            messages(&assistant, cx),
1411            vec![
1412                (message_1.id, Role::User, 0..2),
1413                (message_2.id, Role::Assistant, 2..3)
1414            ]
1415        );
1416
1417        let message_3 = assistant.update(cx, |assistant, cx| {
1418            assistant
1419                .insert_message_after(message_2.id, Role::User, cx)
1420                .unwrap()
1421        });
1422        assert_eq!(
1423            messages(&assistant, cx),
1424            vec![
1425                (message_1.id, Role::User, 0..2),
1426                (message_2.id, Role::Assistant, 2..4),
1427                (message_3.id, Role::User, 4..4)
1428            ]
1429        );
1430
1431        let message_4 = assistant.update(cx, |assistant, cx| {
1432            assistant
1433                .insert_message_after(message_2.id, Role::User, cx)
1434                .unwrap()
1435        });
1436        assert_eq!(
1437            messages(&assistant, cx),
1438            vec![
1439                (message_1.id, Role::User, 0..2),
1440                (message_2.id, Role::Assistant, 2..4),
1441                (message_4.id, Role::User, 4..5),
1442                (message_3.id, Role::User, 5..5),
1443            ]
1444        );
1445
1446        buffer.update(cx, |buffer, cx| {
1447            buffer.edit([(4..4, "C"), (5..5, "D")], None, cx)
1448        });
1449        assert_eq!(
1450            messages(&assistant, cx),
1451            vec![
1452                (message_1.id, Role::User, 0..2),
1453                (message_2.id, Role::Assistant, 2..4),
1454                (message_4.id, Role::User, 4..6),
1455                (message_3.id, Role::User, 6..7),
1456            ]
1457        );
1458
1459        // Deleting across message boundaries merges the messages.
1460        buffer.update(cx, |buffer, cx| buffer.edit([(1..4, "")], None, cx));
1461        assert_eq!(
1462            messages(&assistant, cx),
1463            vec![
1464                (message_1.id, Role::User, 0..3),
1465                (message_3.id, Role::User, 3..4),
1466            ]
1467        );
1468
1469        // Undoing the deletion should also undo the merge.
1470        buffer.update(cx, |buffer, cx| buffer.undo(cx));
1471        assert_eq!(
1472            messages(&assistant, cx),
1473            vec![
1474                (message_1.id, Role::User, 0..2),
1475                (message_2.id, Role::Assistant, 2..4),
1476                (message_4.id, Role::User, 4..6),
1477                (message_3.id, Role::User, 6..7),
1478            ]
1479        );
1480
1481        // Redoing the deletion should also redo the merge.
1482        buffer.update(cx, |buffer, cx| buffer.redo(cx));
1483        assert_eq!(
1484            messages(&assistant, cx),
1485            vec![
1486                (message_1.id, Role::User, 0..3),
1487                (message_3.id, Role::User, 3..4),
1488            ]
1489        );
1490
1491        // Ensure we can still insert after a merged message.
1492        let message_5 = assistant.update(cx, |assistant, cx| {
1493            assistant
1494                .insert_message_after(message_1.id, Role::System, cx)
1495                .unwrap()
1496        });
1497        assert_eq!(
1498            messages(&assistant, cx),
1499            vec![
1500                (message_1.id, Role::User, 0..3),
1501                (message_5.id, Role::System, 3..4),
1502                (message_3.id, Role::User, 4..5)
1503            ]
1504        );
1505    }
1506
1507    fn messages(
1508        assistant: &ModelHandle<Assistant>,
1509        cx: &AppContext,
1510    ) -> Vec<(MessageId, Role, Range<usize>)> {
1511        assistant
1512            .read(cx)
1513            .messages(cx)
1514            .map(|(message, metadata, range)| (message.id, metadata.role, range))
1515            .collect()
1516    }
1517}