assistant.rs

   1use crate::{
   2    assistant_settings::{AssistantDockPosition, AssistantSettings},
   3    OpenAIRequest, OpenAIResponseStreamEvent, RequestMessage, Role,
   4};
   5use anyhow::{anyhow, Result};
   6use chrono::{DateTime, Local};
   7use collections::{HashMap, HashSet};
   8use editor::{
   9    display_map::{BlockDisposition, BlockId, BlockProperties, BlockStyle, ToDisplayPoint},
  10    scroll::autoscroll::{Autoscroll, AutoscrollStrategy},
  11    Anchor, Editor, ToOffset,
  12};
  13use fs::Fs;
  14use futures::{io::BufReader, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
  15use gpui::{
  16    actions,
  17    elements::*,
  18    executor::Background,
  19    geometry::vector::{vec2f, Vector2F},
  20    platform::{CursorStyle, MouseButton},
  21    Action, AppContext, AsyncAppContext, ClipboardItem, Entity, ModelContext, ModelHandle,
  22    Subscription, Task, View, ViewContext, ViewHandle, WeakViewHandle, WindowContext,
  23};
  24use isahc::{http::StatusCode, Request, RequestExt};
  25use language::{language_settings::SoftWrap, Buffer, LanguageRegistry, ToOffset as _};
  26use serde::Deserialize;
  27use settings::SettingsStore;
  28use std::{
  29    borrow::Cow, cell::RefCell, cmp, fmt::Write, io, iter, ops::Range, rc::Rc, sync::Arc,
  30    time::Duration,
  31};
  32use util::{channel::ReleaseChannel, post_inc, truncate_and_trailoff, ResultExt, TryFutureExt};
  33use workspace::{
  34    dock::{DockPosition, Panel},
  35    item::Item,
  36    pane, Pane, Workspace,
  37};
  38
  39const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
  40
  41actions!(
  42    assistant,
  43    [
  44        NewContext,
  45        Assist,
  46        Split,
  47        CycleMessageRole,
  48        QuoteSelection,
  49        ToggleFocus,
  50        ResetKey
  51    ]
  52);
  53
  54pub fn init(cx: &mut AppContext) {
  55    if *util::channel::RELEASE_CHANNEL == ReleaseChannel::Stable {
  56        cx.update_default_global::<collections::CommandPaletteFilter, _, _>(move |filter, _cx| {
  57            filter.filtered_namespaces.insert("assistant");
  58        });
  59    }
  60
  61    settings::register::<AssistantSettings>(cx);
  62    cx.add_action(
  63        |workspace: &mut Workspace, _: &NewContext, cx: &mut ViewContext<Workspace>| {
  64            if let Some(this) = workspace.panel::<AssistantPanel>(cx) {
  65                this.update(cx, |this, cx| this.add_context(cx))
  66            }
  67
  68            workspace.focus_panel::<AssistantPanel>(cx);
  69        },
  70    );
  71    cx.add_action(AssistantEditor::assist);
  72    cx.capture_action(AssistantEditor::cancel_last_assist);
  73    cx.add_action(AssistantEditor::quote_selection);
  74    cx.capture_action(AssistantEditor::copy);
  75    cx.capture_action(AssistantEditor::split);
  76    cx.capture_action(AssistantEditor::cycle_message_role);
  77    cx.add_action(AssistantPanel::save_api_key);
  78    cx.add_action(AssistantPanel::reset_api_key);
  79    cx.add_action(
  80        |workspace: &mut Workspace, _: &ToggleFocus, cx: &mut ViewContext<Workspace>| {
  81            workspace.toggle_panel_focus::<AssistantPanel>(cx);
  82        },
  83    );
  84}
  85
  86pub enum AssistantPanelEvent {
  87    ZoomIn,
  88    ZoomOut,
  89    Focus,
  90    Close,
  91    DockPositionChanged,
  92}
  93
  94pub struct AssistantPanel {
  95    width: Option<f32>,
  96    height: Option<f32>,
  97    pane: ViewHandle<Pane>,
  98    api_key: Rc<RefCell<Option<String>>>,
  99    api_key_editor: Option<ViewHandle<Editor>>,
 100    has_read_credentials: bool,
 101    languages: Arc<LanguageRegistry>,
 102    fs: Arc<dyn Fs>,
 103    subscriptions: Vec<Subscription>,
 104}
 105
 106impl AssistantPanel {
 107    pub fn load(
 108        workspace: WeakViewHandle<Workspace>,
 109        cx: AsyncAppContext,
 110    ) -> Task<Result<ViewHandle<Self>>> {
 111        cx.spawn(|mut cx| async move {
 112            // TODO: deserialize state.
 113            workspace.update(&mut cx, |workspace, cx| {
 114                cx.add_view::<Self, _>(|cx| {
 115                    let weak_self = cx.weak_handle();
 116                    let pane = cx.add_view(|cx| {
 117                        let mut pane = Pane::new(
 118                            workspace.weak_handle(),
 119                            workspace.project().clone(),
 120                            workspace.app_state().background_actions,
 121                            Default::default(),
 122                            cx,
 123                        );
 124                        pane.set_can_split(false, cx);
 125                        pane.set_can_navigate(false, cx);
 126                        pane.on_can_drop(move |_, _| false);
 127                        pane.set_render_tab_bar_buttons(cx, move |pane, cx| {
 128                            let weak_self = weak_self.clone();
 129                            Flex::row()
 130                                .with_child(Pane::render_tab_bar_button(
 131                                    0,
 132                                    "icons/plus_12.svg",
 133                                    false,
 134                                    Some(("New Context".into(), Some(Box::new(NewContext)))),
 135                                    cx,
 136                                    move |_, cx| {
 137                                        let weak_self = weak_self.clone();
 138                                        cx.window_context().defer(move |cx| {
 139                                            if let Some(this) = weak_self.upgrade(cx) {
 140                                                this.update(cx, |this, cx| this.add_context(cx));
 141                                            }
 142                                        })
 143                                    },
 144                                    None,
 145                                ))
 146                                .with_child(Pane::render_tab_bar_button(
 147                                    1,
 148                                    if pane.is_zoomed() {
 149                                        "icons/minimize_8.svg"
 150                                    } else {
 151                                        "icons/maximize_8.svg"
 152                                    },
 153                                    pane.is_zoomed(),
 154                                    Some((
 155                                        "Toggle Zoom".into(),
 156                                        Some(Box::new(workspace::ToggleZoom)),
 157                                    )),
 158                                    cx,
 159                                    move |pane, cx| pane.toggle_zoom(&Default::default(), cx),
 160                                    None,
 161                                ))
 162                                .into_any()
 163                        });
 164                        let buffer_search_bar = cx.add_view(search::BufferSearchBar::new);
 165                        pane.toolbar()
 166                            .update(cx, |toolbar, cx| toolbar.add_item(buffer_search_bar, cx));
 167                        pane
 168                    });
 169
 170                    let mut this = Self {
 171                        pane,
 172                        api_key: Rc::new(RefCell::new(None)),
 173                        api_key_editor: None,
 174                        has_read_credentials: false,
 175                        languages: workspace.app_state().languages.clone(),
 176                        fs: workspace.app_state().fs.clone(),
 177                        width: None,
 178                        height: None,
 179                        subscriptions: Default::default(),
 180                    };
 181
 182                    let mut old_dock_position = this.position(cx);
 183                    this.subscriptions = vec![
 184                        cx.observe(&this.pane, |_, _, cx| cx.notify()),
 185                        cx.subscribe(&this.pane, Self::handle_pane_event),
 186                        cx.observe_global::<SettingsStore, _>(move |this, cx| {
 187                            let new_dock_position = this.position(cx);
 188                            if new_dock_position != old_dock_position {
 189                                old_dock_position = new_dock_position;
 190                                cx.emit(AssistantPanelEvent::DockPositionChanged);
 191                            }
 192                        }),
 193                    ];
 194
 195                    this
 196                })
 197            })
 198        })
 199    }
 200
 201    fn handle_pane_event(
 202        &mut self,
 203        _pane: ViewHandle<Pane>,
 204        event: &pane::Event,
 205        cx: &mut ViewContext<Self>,
 206    ) {
 207        match event {
 208            pane::Event::ZoomIn => cx.emit(AssistantPanelEvent::ZoomIn),
 209            pane::Event::ZoomOut => cx.emit(AssistantPanelEvent::ZoomOut),
 210            pane::Event::Focus => cx.emit(AssistantPanelEvent::Focus),
 211            pane::Event::Remove => cx.emit(AssistantPanelEvent::Close),
 212            _ => {}
 213        }
 214    }
 215
 216    fn add_context(&mut self, cx: &mut ViewContext<Self>) {
 217        let focus = self.has_focus(cx);
 218        let editor = cx
 219            .add_view(|cx| AssistantEditor::new(self.api_key.clone(), self.languages.clone(), cx));
 220        self.subscriptions
 221            .push(cx.subscribe(&editor, Self::handle_assistant_editor_event));
 222        self.pane.update(cx, |pane, cx| {
 223            pane.add_item(Box::new(editor), true, focus, None, cx)
 224        });
 225    }
 226
 227    fn handle_assistant_editor_event(
 228        &mut self,
 229        _: ViewHandle<AssistantEditor>,
 230        event: &AssistantEditorEvent,
 231        cx: &mut ViewContext<Self>,
 232    ) {
 233        match event {
 234            AssistantEditorEvent::TabContentChanged => self.pane.update(cx, |_, cx| cx.notify()),
 235        }
 236    }
 237
 238    fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
 239        if let Some(api_key) = self
 240            .api_key_editor
 241            .as_ref()
 242            .map(|editor| editor.read(cx).text(cx))
 243        {
 244            if !api_key.is_empty() {
 245                cx.platform()
 246                    .write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
 247                    .log_err();
 248                *self.api_key.borrow_mut() = Some(api_key);
 249                self.api_key_editor.take();
 250                cx.focus_self();
 251                cx.notify();
 252            }
 253        } else {
 254            cx.propagate_action();
 255        }
 256    }
 257
 258    fn reset_api_key(&mut self, _: &ResetKey, cx: &mut ViewContext<Self>) {
 259        cx.platform().delete_credentials(OPENAI_API_URL).log_err();
 260        self.api_key.take();
 261        self.api_key_editor = Some(build_api_key_editor(cx));
 262        cx.focus_self();
 263        cx.notify();
 264    }
 265}
 266
 267fn build_api_key_editor(cx: &mut ViewContext<AssistantPanel>) -> ViewHandle<Editor> {
 268    cx.add_view(|cx| {
 269        let mut editor = Editor::single_line(
 270            Some(Arc::new(|theme| theme.assistant.api_key_editor.clone())),
 271            cx,
 272        );
 273        editor.set_placeholder_text("sk-000000000000000000000000000000000000000000000000", cx);
 274        editor
 275    })
 276}
 277
 278impl Entity for AssistantPanel {
 279    type Event = AssistantPanelEvent;
 280}
 281
 282impl View for AssistantPanel {
 283    fn ui_name() -> &'static str {
 284        "AssistantPanel"
 285    }
 286
 287    fn render(&mut self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
 288        let style = &theme::current(cx).assistant;
 289        if let Some(api_key_editor) = self.api_key_editor.as_ref() {
 290            Flex::column()
 291                .with_child(
 292                    Text::new(
 293                        "Paste your OpenAI API key and press Enter to use the assistant",
 294                        style.api_key_prompt.text.clone(),
 295                    )
 296                    .aligned(),
 297                )
 298                .with_child(
 299                    ChildView::new(api_key_editor, cx)
 300                        .contained()
 301                        .with_style(style.api_key_editor.container)
 302                        .aligned(),
 303                )
 304                .contained()
 305                .with_style(style.api_key_prompt.container)
 306                .aligned()
 307                .into_any()
 308        } else {
 309            ChildView::new(&self.pane, cx).into_any()
 310        }
 311    }
 312
 313    fn focus_in(&mut self, _: gpui::AnyViewHandle, cx: &mut ViewContext<Self>) {
 314        if cx.is_self_focused() {
 315            if let Some(api_key_editor) = self.api_key_editor.as_ref() {
 316                cx.focus(api_key_editor);
 317            } else {
 318                cx.focus(&self.pane);
 319            }
 320        }
 321    }
 322}
 323
 324impl Panel for AssistantPanel {
 325    fn position(&self, cx: &WindowContext) -> DockPosition {
 326        match settings::get::<AssistantSettings>(cx).dock {
 327            AssistantDockPosition::Left => DockPosition::Left,
 328            AssistantDockPosition::Bottom => DockPosition::Bottom,
 329            AssistantDockPosition::Right => DockPosition::Right,
 330        }
 331    }
 332
 333    fn position_is_valid(&self, _: DockPosition) -> bool {
 334        true
 335    }
 336
 337    fn set_position(&mut self, position: DockPosition, cx: &mut ViewContext<Self>) {
 338        settings::update_settings_file::<AssistantSettings>(self.fs.clone(), cx, move |settings| {
 339            let dock = match position {
 340                DockPosition::Left => AssistantDockPosition::Left,
 341                DockPosition::Bottom => AssistantDockPosition::Bottom,
 342                DockPosition::Right => AssistantDockPosition::Right,
 343            };
 344            settings.dock = Some(dock);
 345        });
 346    }
 347
 348    fn size(&self, cx: &WindowContext) -> f32 {
 349        let settings = settings::get::<AssistantSettings>(cx);
 350        match self.position(cx) {
 351            DockPosition::Left | DockPosition::Right => {
 352                self.width.unwrap_or_else(|| settings.default_width)
 353            }
 354            DockPosition::Bottom => self.height.unwrap_or_else(|| settings.default_height),
 355        }
 356    }
 357
 358    fn set_size(&mut self, size: f32, cx: &mut ViewContext<Self>) {
 359        match self.position(cx) {
 360            DockPosition::Left | DockPosition::Right => self.width = Some(size),
 361            DockPosition::Bottom => self.height = Some(size),
 362        }
 363        cx.notify();
 364    }
 365
 366    fn should_zoom_in_on_event(event: &AssistantPanelEvent) -> bool {
 367        matches!(event, AssistantPanelEvent::ZoomIn)
 368    }
 369
 370    fn should_zoom_out_on_event(event: &AssistantPanelEvent) -> bool {
 371        matches!(event, AssistantPanelEvent::ZoomOut)
 372    }
 373
 374    fn is_zoomed(&self, cx: &WindowContext) -> bool {
 375        self.pane.read(cx).is_zoomed()
 376    }
 377
 378    fn set_zoomed(&mut self, zoomed: bool, cx: &mut ViewContext<Self>) {
 379        self.pane.update(cx, |pane, cx| pane.set_zoomed(zoomed, cx));
 380    }
 381
 382    fn set_active(&mut self, active: bool, cx: &mut ViewContext<Self>) {
 383        if active {
 384            if self.api_key.borrow().is_none() && !self.has_read_credentials {
 385                self.has_read_credentials = true;
 386                let api_key = if let Some((_, api_key)) = cx
 387                    .platform()
 388                    .read_credentials(OPENAI_API_URL)
 389                    .log_err()
 390                    .flatten()
 391                {
 392                    String::from_utf8(api_key).log_err()
 393                } else {
 394                    None
 395                };
 396                if let Some(api_key) = api_key {
 397                    *self.api_key.borrow_mut() = Some(api_key);
 398                } else if self.api_key_editor.is_none() {
 399                    self.api_key_editor = Some(build_api_key_editor(cx));
 400                    cx.notify();
 401                }
 402            }
 403
 404            if self.pane.read(cx).items_len() == 0 {
 405                self.add_context(cx);
 406            }
 407        }
 408    }
 409
 410    fn icon_path(&self) -> &'static str {
 411        "icons/robot_14.svg"
 412    }
 413
 414    fn icon_tooltip(&self) -> (String, Option<Box<dyn Action>>) {
 415        ("Assistant Panel".into(), Some(Box::new(ToggleFocus)))
 416    }
 417
 418    fn should_change_position_on_event(event: &Self::Event) -> bool {
 419        matches!(event, AssistantPanelEvent::DockPositionChanged)
 420    }
 421
 422    fn should_activate_on_event(_: &Self::Event) -> bool {
 423        false
 424    }
 425
 426    fn should_close_on_event(event: &AssistantPanelEvent) -> bool {
 427        matches!(event, AssistantPanelEvent::Close)
 428    }
 429
 430    fn has_focus(&self, cx: &WindowContext) -> bool {
 431        self.pane.read(cx).has_focus()
 432            || self
 433                .api_key_editor
 434                .as_ref()
 435                .map_or(false, |editor| editor.is_focused(cx))
 436    }
 437
 438    fn is_focus_event(event: &Self::Event) -> bool {
 439        matches!(event, AssistantPanelEvent::Focus)
 440    }
 441}
 442
 443enum AssistantEvent {
 444    MessagesEdited,
 445    SummaryChanged,
 446    StreamedCompletion,
 447}
 448
 449struct Assistant {
 450    buffer: ModelHandle<Buffer>,
 451    message_anchors: Vec<MessageAnchor>,
 452    messages_metadata: HashMap<MessageId, MessageMetadata>,
 453    next_message_id: MessageId,
 454    summary: Option<String>,
 455    pending_summary: Task<Option<()>>,
 456    completion_count: usize,
 457    pending_completions: Vec<PendingCompletion>,
 458    model: String,
 459    token_count: Option<usize>,
 460    max_token_count: usize,
 461    pending_token_count: Task<Option<()>>,
 462    api_key: Rc<RefCell<Option<String>>>,
 463    _subscriptions: Vec<Subscription>,
 464}
 465
 466impl Entity for Assistant {
 467    type Event = AssistantEvent;
 468}
 469
 470impl Assistant {
 471    fn new(
 472        api_key: Rc<RefCell<Option<String>>>,
 473        language_registry: Arc<LanguageRegistry>,
 474        cx: &mut ModelContext<Self>,
 475    ) -> Self {
 476        let model = "gpt-3.5-turbo-0613";
 477        let markdown = language_registry.language_for_name("Markdown");
 478        let buffer = cx.add_model(|cx| {
 479            let mut buffer = Buffer::new(0, "", cx);
 480            buffer.set_language_registry(language_registry);
 481            cx.spawn_weak(|buffer, mut cx| async move {
 482                let markdown = markdown.await?;
 483                let buffer = buffer
 484                    .upgrade(&cx)
 485                    .ok_or_else(|| anyhow!("buffer was dropped"))?;
 486                buffer.update(&mut cx, |buffer, cx| {
 487                    buffer.set_language(Some(markdown), cx)
 488                });
 489                anyhow::Ok(())
 490            })
 491            .detach_and_log_err(cx);
 492            buffer
 493        });
 494
 495        let mut this = Self {
 496            message_anchors: Default::default(),
 497            messages_metadata: Default::default(),
 498            next_message_id: Default::default(),
 499            summary: None,
 500            pending_summary: Task::ready(None),
 501            completion_count: Default::default(),
 502            pending_completions: Default::default(),
 503            token_count: None,
 504            max_token_count: tiktoken_rs::model::get_context_size(model),
 505            pending_token_count: Task::ready(None),
 506            model: model.into(),
 507            _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
 508            api_key,
 509            buffer,
 510        };
 511        let message = MessageAnchor {
 512            id: MessageId(post_inc(&mut this.next_message_id.0)),
 513            start: language::Anchor::MIN,
 514        };
 515        this.message_anchors.push(message.clone());
 516        this.messages_metadata.insert(
 517            message.id,
 518            MessageMetadata {
 519                role: Role::User,
 520                sent_at: Local::now(),
 521                status: MessageStatus::Done,
 522            },
 523        );
 524
 525        this.count_remaining_tokens(cx);
 526        this
 527    }
 528
 529    fn handle_buffer_event(
 530        &mut self,
 531        _: ModelHandle<Buffer>,
 532        event: &language::Event,
 533        cx: &mut ModelContext<Self>,
 534    ) {
 535        match event {
 536            language::Event::Edited => {
 537                self.count_remaining_tokens(cx);
 538                cx.emit(AssistantEvent::MessagesEdited);
 539            }
 540            _ => {}
 541        }
 542    }
 543
 544    fn count_remaining_tokens(&mut self, cx: &mut ModelContext<Self>) {
 545        let messages = self
 546            .messages(cx)
 547            .into_iter()
 548            .filter_map(|message| {
 549                Some(tiktoken_rs::ChatCompletionRequestMessage {
 550                    role: match message.role {
 551                        Role::User => "user".into(),
 552                        Role::Assistant => "assistant".into(),
 553                        Role::System => "system".into(),
 554                    },
 555                    content: self.buffer.read(cx).text_for_range(message.range).collect(),
 556                    name: None,
 557                })
 558            })
 559            .collect::<Vec<_>>();
 560        let model = self.model.clone();
 561        self.pending_token_count = cx.spawn_weak(|this, mut cx| {
 562            async move {
 563                cx.background().timer(Duration::from_millis(200)).await;
 564                let token_count = cx
 565                    .background()
 566                    .spawn(async move { tiktoken_rs::num_tokens_from_messages(&model, &messages) })
 567                    .await?;
 568
 569                this.upgrade(&cx)
 570                    .ok_or_else(|| anyhow!("assistant was dropped"))?
 571                    .update(&mut cx, |this, cx| {
 572                        this.max_token_count = tiktoken_rs::model::get_context_size(&this.model);
 573                        this.token_count = Some(token_count);
 574                        cx.notify()
 575                    });
 576                anyhow::Ok(())
 577            }
 578            .log_err()
 579        });
 580    }
 581
 582    fn remaining_tokens(&self) -> Option<isize> {
 583        Some(self.max_token_count as isize - self.token_count? as isize)
 584    }
 585
 586    fn set_model(&mut self, model: String, cx: &mut ModelContext<Self>) {
 587        self.model = model;
 588        self.count_remaining_tokens(cx);
 589        cx.notify();
 590    }
 591
 592    fn assist(
 593        &mut self,
 594        selected_messages: HashSet<MessageId>,
 595        cx: &mut ModelContext<Self>,
 596    ) -> Vec<MessageAnchor> {
 597        let mut user_messages = Vec::new();
 598        let mut tasks = Vec::new();
 599        for selected_message_id in selected_messages {
 600            let selected_message_role =
 601                if let Some(metadata) = self.messages_metadata.get(&selected_message_id) {
 602                    metadata.role
 603                } else {
 604                    continue;
 605                };
 606
 607            if selected_message_role == Role::Assistant {
 608                if let Some(user_message) = self.insert_message_after(
 609                    selected_message_id,
 610                    Role::User,
 611                    MessageStatus::Done,
 612                    cx,
 613                ) {
 614                    user_messages.push(user_message);
 615                } else {
 616                    continue;
 617                }
 618            } else {
 619                let request = OpenAIRequest {
 620                    model: self.model.clone(),
 621                    messages: self
 622                        .messages(cx)
 623                        .filter(|message| matches!(message.status, MessageStatus::Done))
 624                        .flat_map(|message| {
 625                            let mut system_message = None;
 626                            if message.id == selected_message_id {
 627                                system_message = Some(RequestMessage {
 628                                    role: Role::System,
 629                                    content: concat!(
 630                                        "Treat the following messages as additional knowledge you have learned about, ",
 631                                        "but act as if they were not part of this conversation. That is, treat them ",
 632                                        "as if the user didn't see them and couldn't possibly inquire about them."
 633                                    ).into()
 634                                });
 635                            }
 636
 637                            Some(message.to_open_ai_message(self.buffer.read(cx))).into_iter().chain(system_message)
 638                        })
 639                        .chain(Some(RequestMessage {
 640                            role: Role::System,
 641                            content: format!(
 642                                "Direct your reply to message with id {}. Do not include a [Message X] header.",
 643                                selected_message_id.0
 644                            ),
 645                        }))
 646                        .collect(),
 647                    stream: true,
 648                };
 649
 650                let Some(api_key) = self.api_key.borrow().clone() else { continue };
 651                let stream = stream_completion(api_key, cx.background().clone(), request);
 652                let assistant_message = self
 653                    .insert_message_after(
 654                        selected_message_id,
 655                        Role::Assistant,
 656                        MessageStatus::Pending,
 657                        cx,
 658                    )
 659                    .unwrap();
 660
 661                tasks.push(cx.spawn_weak({
 662                    |this, mut cx| async move {
 663                        let assistant_message_id = assistant_message.id;
 664                        let stream_completion = async {
 665                            let mut messages = stream.await?;
 666
 667                            while let Some(message) = messages.next().await {
 668                                let mut message = message?;
 669                                if let Some(choice) = message.choices.pop() {
 670                                    this.upgrade(&cx)
 671                                        .ok_or_else(|| anyhow!("assistant was dropped"))?
 672                                        .update(&mut cx, |this, cx| {
 673                                            let text: Arc<str> = choice.delta.content?.into();
 674                                            let message_ix = this.message_anchors.iter().position(
 675                                                |message| message.id == assistant_message_id,
 676                                            )?;
 677                                            this.buffer.update(cx, |buffer, cx| {
 678                                                let offset = this.message_anchors[message_ix + 1..]
 679                                                    .iter()
 680                                                    .find(|message| message.start.is_valid(buffer))
 681                                                    .map_or(buffer.len(), |message| {
 682                                                        message
 683                                                            .start
 684                                                            .to_offset(buffer)
 685                                                            .saturating_sub(1)
 686                                                    });
 687                                                buffer.edit([(offset..offset, text)], None, cx);
 688                                            });
 689                                            cx.emit(AssistantEvent::StreamedCompletion);
 690
 691                                            Some(())
 692                                        });
 693                                }
 694                                smol::future::yield_now().await;
 695                            }
 696
 697                            this.upgrade(&cx)
 698                                .ok_or_else(|| anyhow!("assistant was dropped"))?
 699                                .update(&mut cx, |this, cx| {
 700                                    this.pending_completions.retain(|completion| {
 701                                        completion.id != this.completion_count
 702                                    });
 703                                    this.summarize(cx);
 704                                });
 705
 706                            anyhow::Ok(())
 707                        };
 708
 709                        let result = stream_completion.await;
 710                        if let Some(this) = this.upgrade(&cx) {
 711                            this.update(&mut cx, |this, cx| {
 712                                if let Some(metadata) =
 713                                    this.messages_metadata.get_mut(&assistant_message.id)
 714                                {
 715                                    match result {
 716                                        Ok(_) => {
 717                                            metadata.status = MessageStatus::Done;
 718                                        }
 719                                        Err(error) => {
 720                                            metadata.status = MessageStatus::Error(
 721                                                error.to_string().trim().into(),
 722                                            );
 723                                        }
 724                                    }
 725                                    cx.notify();
 726                                }
 727                            });
 728                        }
 729                    }
 730                }));
 731            }
 732        }
 733
 734        if !tasks.is_empty() {
 735            self.pending_completions.push(PendingCompletion {
 736                id: post_inc(&mut self.completion_count),
 737                _tasks: tasks,
 738            });
 739        }
 740
 741        user_messages
 742    }
 743
 744    fn cancel_last_assist(&mut self) -> bool {
 745        self.pending_completions.pop().is_some()
 746    }
 747
 748    fn cycle_message_roles(&mut self, ids: HashSet<MessageId>, cx: &mut ModelContext<Self>) {
 749        for id in ids {
 750            if let Some(metadata) = self.messages_metadata.get_mut(&id) {
 751                metadata.role.cycle();
 752                cx.emit(AssistantEvent::MessagesEdited);
 753                cx.notify();
 754            }
 755        }
 756    }
 757
 758    fn insert_message_after(
 759        &mut self,
 760        message_id: MessageId,
 761        role: Role,
 762        status: MessageStatus,
 763        cx: &mut ModelContext<Self>,
 764    ) -> Option<MessageAnchor> {
 765        if let Some(prev_message_ix) = self
 766            .message_anchors
 767            .iter()
 768            .position(|message| message.id == message_id)
 769        {
 770            let start = self.buffer.update(cx, |buffer, cx| {
 771                let offset = self.message_anchors[prev_message_ix + 1..]
 772                    .iter()
 773                    .find(|message| message.start.is_valid(buffer))
 774                    .map_or(buffer.len(), |message| message.start.to_offset(buffer) - 1);
 775                buffer.edit([(offset..offset, "\n")], None, cx);
 776                buffer.anchor_before(offset + 1)
 777            });
 778            let message = MessageAnchor {
 779                id: MessageId(post_inc(&mut self.next_message_id.0)),
 780                start,
 781            };
 782            self.message_anchors
 783                .insert(prev_message_ix + 1, message.clone());
 784            self.messages_metadata.insert(
 785                message.id,
 786                MessageMetadata {
 787                    role,
 788                    sent_at: Local::now(),
 789                    status,
 790                },
 791            );
 792            cx.emit(AssistantEvent::MessagesEdited);
 793            Some(message)
 794        } else {
 795            None
 796        }
 797    }
 798
 799    fn split_message(
 800        &mut self,
 801        range: Range<usize>,
 802        cx: &mut ModelContext<Self>,
 803    ) -> (Option<MessageAnchor>, Option<MessageAnchor>) {
 804        let start_message = self.message_for_offset(range.start, cx);
 805        let end_message = self.message_for_offset(range.end, cx);
 806        if let Some((start_message, end_message)) = start_message.zip(end_message) {
 807            // Prevent splitting when range spans multiple messages.
 808            if start_message.index != end_message.index {
 809                return (None, None);
 810            }
 811
 812            let message = start_message;
 813            let role = message.role;
 814            let mut edited_buffer = false;
 815
 816            let mut suffix_start = None;
 817            if range.start > message.range.start && range.end < message.range.end - 1 {
 818                if self.buffer.read(cx).chars_at(range.end).next() == Some('\n') {
 819                    suffix_start = Some(range.end + 1);
 820                } else if self.buffer.read(cx).reversed_chars_at(range.end).next() == Some('\n') {
 821                    suffix_start = Some(range.end);
 822                }
 823            }
 824
 825            let suffix = if let Some(suffix_start) = suffix_start {
 826                MessageAnchor {
 827                    id: MessageId(post_inc(&mut self.next_message_id.0)),
 828                    start: self.buffer.read(cx).anchor_before(suffix_start),
 829                }
 830            } else {
 831                self.buffer.update(cx, |buffer, cx| {
 832                    buffer.edit([(range.end..range.end, "\n")], None, cx);
 833                });
 834                edited_buffer = true;
 835                MessageAnchor {
 836                    id: MessageId(post_inc(&mut self.next_message_id.0)),
 837                    start: self.buffer.read(cx).anchor_before(range.end + 1),
 838                }
 839            };
 840
 841            self.message_anchors
 842                .insert(message.index + 1, suffix.clone());
 843            self.messages_metadata.insert(
 844                suffix.id,
 845                MessageMetadata {
 846                    role,
 847                    sent_at: Local::now(),
 848                    status: MessageStatus::Done,
 849                },
 850            );
 851
 852            let new_messages = if range.start == range.end || range.start == message.range.start {
 853                (None, Some(suffix))
 854            } else {
 855                let mut prefix_end = None;
 856                if range.start > message.range.start && range.end < message.range.end - 1 {
 857                    if self.buffer.read(cx).chars_at(range.start).next() == Some('\n') {
 858                        prefix_end = Some(range.start + 1);
 859                    } else if self.buffer.read(cx).reversed_chars_at(range.start).next()
 860                        == Some('\n')
 861                    {
 862                        prefix_end = Some(range.start);
 863                    }
 864                }
 865
 866                let selection = if let Some(prefix_end) = prefix_end {
 867                    cx.emit(AssistantEvent::MessagesEdited);
 868                    MessageAnchor {
 869                        id: MessageId(post_inc(&mut self.next_message_id.0)),
 870                        start: self.buffer.read(cx).anchor_before(prefix_end),
 871                    }
 872                } else {
 873                    self.buffer.update(cx, |buffer, cx| {
 874                        buffer.edit([(range.start..range.start, "\n")], None, cx)
 875                    });
 876                    edited_buffer = true;
 877                    MessageAnchor {
 878                        id: MessageId(post_inc(&mut self.next_message_id.0)),
 879                        start: self.buffer.read(cx).anchor_before(range.end + 1),
 880                    }
 881                };
 882
 883                self.message_anchors
 884                    .insert(message.index + 1, selection.clone());
 885                self.messages_metadata.insert(
 886                    selection.id,
 887                    MessageMetadata {
 888                        role,
 889                        sent_at: Local::now(),
 890                        status: MessageStatus::Done,
 891                    },
 892                );
 893                (Some(selection), Some(suffix))
 894            };
 895
 896            if !edited_buffer {
 897                cx.emit(AssistantEvent::MessagesEdited);
 898            }
 899            new_messages
 900        } else {
 901            (None, None)
 902        }
 903    }
 904
 905    fn summarize(&mut self, cx: &mut ModelContext<Self>) {
 906        if self.message_anchors.len() >= 2 && self.summary.is_none() {
 907            let api_key = self.api_key.borrow().clone();
 908            if let Some(api_key) = api_key {
 909                let messages = self
 910                    .messages(cx)
 911                    .take(2)
 912                    .map(|message| message.to_open_ai_message(self.buffer.read(cx)))
 913                    .chain(Some(RequestMessage {
 914                        role: Role::User,
 915                        content:
 916                            "Summarize the conversation into a short title without punctuation"
 917                                .into(),
 918                    }));
 919                let request = OpenAIRequest {
 920                    model: self.model.clone(),
 921                    messages: messages.collect(),
 922                    stream: true,
 923                };
 924
 925                let stream = stream_completion(api_key, cx.background().clone(), request);
 926                self.pending_summary = cx.spawn(|this, mut cx| {
 927                    async move {
 928                        let mut messages = stream.await?;
 929
 930                        while let Some(message) = messages.next().await {
 931                            let mut message = message?;
 932                            if let Some(choice) = message.choices.pop() {
 933                                let text = choice.delta.content.unwrap_or_default();
 934                                this.update(&mut cx, |this, cx| {
 935                                    this.summary.get_or_insert(String::new()).push_str(&text);
 936                                    cx.emit(AssistantEvent::SummaryChanged);
 937                                });
 938                            }
 939                        }
 940
 941                        anyhow::Ok(())
 942                    }
 943                    .log_err()
 944                });
 945            }
 946        }
 947    }
 948
 949    fn message_for_offset(&self, offset: usize, cx: &AppContext) -> Option<Message> {
 950        self.messages_for_offsets([offset], cx).pop()
 951    }
 952
 953    fn messages_for_offsets(
 954        &self,
 955        offsets: impl IntoIterator<Item = usize>,
 956        cx: &AppContext,
 957    ) -> Vec<Message> {
 958        let mut result = Vec::new();
 959
 960        let buffer_len = self.buffer.read(cx).len();
 961        let mut messages = self.messages(cx).peekable();
 962        let mut offsets = offsets.into_iter().peekable();
 963        while let Some(offset) = offsets.next() {
 964            // Skip messages that start after the offset.
 965            while messages.peek().map_or(false, |message| {
 966                message.range.end < offset || (message.range.end == offset && offset < buffer_len)
 967            }) {
 968                messages.next();
 969            }
 970            let Some(message) = messages.peek() else { continue };
 971
 972            // Skip offsets that are in the same message.
 973            while offsets.peek().map_or(false, |offset| {
 974                message.range.contains(offset) || message.range.end == buffer_len
 975            }) {
 976                offsets.next();
 977            }
 978
 979            result.push(message.clone());
 980        }
 981        result
 982    }
 983
 984    fn messages<'a>(&'a self, cx: &'a AppContext) -> impl 'a + Iterator<Item = Message> {
 985        let buffer = self.buffer.read(cx);
 986        let mut message_anchors = self.message_anchors.iter().enumerate().peekable();
 987        iter::from_fn(move || {
 988            while let Some((ix, message_anchor)) = message_anchors.next() {
 989                let metadata = self.messages_metadata.get(&message_anchor.id)?;
 990                let message_start = message_anchor.start.to_offset(buffer);
 991                let mut message_end = None;
 992                while let Some((_, next_message)) = message_anchors.peek() {
 993                    if next_message.start.is_valid(buffer) {
 994                        message_end = Some(next_message.start);
 995                        break;
 996                    } else {
 997                        message_anchors.next();
 998                    }
 999                }
1000                let message_end = message_end
1001                    .unwrap_or(language::Anchor::MAX)
1002                    .to_offset(buffer);
1003                return Some(Message {
1004                    index: ix,
1005                    range: message_start..message_end,
1006                    id: message_anchor.id,
1007                    anchor: message_anchor.start,
1008                    role: metadata.role,
1009                    sent_at: metadata.sent_at,
1010                    status: metadata.status.clone(),
1011                });
1012            }
1013            None
1014        })
1015    }
1016}
1017
1018struct PendingCompletion {
1019    id: usize,
1020    _tasks: Vec<Task<()>>,
1021}
1022
1023enum AssistantEditorEvent {
1024    TabContentChanged,
1025}
1026
1027#[derive(Copy, Clone, Debug, PartialEq)]
1028struct ScrollPosition {
1029    offset_before_cursor: Vector2F,
1030    cursor: Anchor,
1031}
1032
1033struct AssistantEditor {
1034    assistant: ModelHandle<Assistant>,
1035    editor: ViewHandle<Editor>,
1036    blocks: HashSet<BlockId>,
1037    scroll_position: Option<ScrollPosition>,
1038    _subscriptions: Vec<Subscription>,
1039}
1040
1041impl AssistantEditor {
1042    fn new(
1043        api_key: Rc<RefCell<Option<String>>>,
1044        language_registry: Arc<LanguageRegistry>,
1045        cx: &mut ViewContext<Self>,
1046    ) -> Self {
1047        let assistant = cx.add_model(|cx| Assistant::new(api_key, language_registry, cx));
1048        let editor = cx.add_view(|cx| {
1049            let mut editor = Editor::for_buffer(assistant.read(cx).buffer.clone(), None, cx);
1050            editor.set_soft_wrap_mode(SoftWrap::EditorWidth, cx);
1051            editor.set_show_gutter(false, cx);
1052            editor
1053        });
1054
1055        let _subscriptions = vec![
1056            cx.observe(&assistant, |_, _, cx| cx.notify()),
1057            cx.subscribe(&assistant, Self::handle_assistant_event),
1058            cx.subscribe(&editor, Self::handle_editor_event),
1059        ];
1060
1061        let mut this = Self {
1062            assistant,
1063            editor,
1064            blocks: Default::default(),
1065            scroll_position: None,
1066            _subscriptions,
1067        };
1068        this.update_message_headers(cx);
1069        this
1070    }
1071
1072    fn assist(&mut self, _: &Assist, cx: &mut ViewContext<Self>) {
1073        let cursors = self.cursors(cx);
1074
1075        let user_messages = self.assistant.update(cx, |assistant, cx| {
1076            let selected_messages = assistant
1077                .messages_for_offsets(cursors, cx)
1078                .into_iter()
1079                .map(|message| message.id)
1080                .collect();
1081            assistant.assist(selected_messages, cx)
1082        });
1083        let new_selections = user_messages
1084            .iter()
1085            .map(|message| {
1086                let cursor = message
1087                    .start
1088                    .to_offset(self.assistant.read(cx).buffer.read(cx));
1089                cursor..cursor
1090            })
1091            .collect::<Vec<_>>();
1092        if !new_selections.is_empty() {
1093            self.editor.update(cx, |editor, cx| {
1094                editor.change_selections(
1095                    Some(Autoscroll::Strategy(AutoscrollStrategy::Fit)),
1096                    cx,
1097                    |selections| selections.select_ranges(new_selections),
1098                );
1099            });
1100        }
1101    }
1102
1103    fn cancel_last_assist(&mut self, _: &editor::Cancel, cx: &mut ViewContext<Self>) {
1104        if !self
1105            .assistant
1106            .update(cx, |assistant, _| assistant.cancel_last_assist())
1107        {
1108            cx.propagate_action();
1109        }
1110    }
1111
1112    fn cycle_message_role(&mut self, _: &CycleMessageRole, cx: &mut ViewContext<Self>) {
1113        let cursors = self.cursors(cx);
1114        self.assistant.update(cx, |assistant, cx| {
1115            let messages = assistant
1116                .messages_for_offsets(cursors, cx)
1117                .into_iter()
1118                .map(|message| message.id)
1119                .collect();
1120            assistant.cycle_message_roles(messages, cx)
1121        });
1122    }
1123
1124    fn cursors(&self, cx: &AppContext) -> Vec<usize> {
1125        let selections = self.editor.read(cx).selections.all::<usize>(cx);
1126        selections
1127            .into_iter()
1128            .map(|selection| selection.head())
1129            .collect()
1130    }
1131
1132    fn handle_assistant_event(
1133        &mut self,
1134        _: ModelHandle<Assistant>,
1135        event: &AssistantEvent,
1136        cx: &mut ViewContext<Self>,
1137    ) {
1138        match event {
1139            AssistantEvent::MessagesEdited => self.update_message_headers(cx),
1140            AssistantEvent::SummaryChanged => {
1141                cx.emit(AssistantEditorEvent::TabContentChanged);
1142            }
1143            AssistantEvent::StreamedCompletion => {
1144                self.editor.update(cx, |editor, cx| {
1145                    if let Some(scroll_position) = self.scroll_position {
1146                        let snapshot = editor.snapshot(cx);
1147                        let cursor_point = scroll_position.cursor.to_display_point(&snapshot);
1148                        let scroll_top =
1149                            cursor_point.row() as f32 - scroll_position.offset_before_cursor.y();
1150                        editor.set_scroll_position(
1151                            vec2f(scroll_position.offset_before_cursor.x(), scroll_top),
1152                            cx,
1153                        );
1154                    }
1155                });
1156            }
1157        }
1158    }
1159
1160    fn handle_editor_event(
1161        &mut self,
1162        _: ViewHandle<Editor>,
1163        event: &editor::Event,
1164        cx: &mut ViewContext<Self>,
1165    ) {
1166        match event {
1167            editor::Event::ScrollPositionChanged { autoscroll, .. } => {
1168                let cursor_scroll_position = self.cursor_scroll_position(cx);
1169                if *autoscroll {
1170                    self.scroll_position = cursor_scroll_position;
1171                } else if self.scroll_position != cursor_scroll_position {
1172                    self.scroll_position = None;
1173                }
1174            }
1175            editor::Event::SelectionsChanged { .. } => {
1176                self.scroll_position = self.cursor_scroll_position(cx);
1177            }
1178            _ => {}
1179        }
1180    }
1181
1182    fn cursor_scroll_position(&self, cx: &mut ViewContext<Self>) -> Option<ScrollPosition> {
1183        self.editor.update(cx, |editor, cx| {
1184            let snapshot = editor.snapshot(cx);
1185            let cursor = editor.selections.newest_anchor().head();
1186            let cursor_row = cursor.to_display_point(&snapshot.display_snapshot).row() as f32;
1187            let scroll_position = editor
1188                .scroll_manager
1189                .anchor()
1190                .scroll_position(&snapshot.display_snapshot);
1191
1192            let scroll_bottom = scroll_position.y() + editor.visible_line_count().unwrap_or(0.);
1193            if (scroll_position.y()..scroll_bottom).contains(&cursor_row) {
1194                Some(ScrollPosition {
1195                    cursor,
1196                    offset_before_cursor: vec2f(
1197                        scroll_position.x(),
1198                        cursor_row - scroll_position.y(),
1199                    ),
1200                })
1201            } else {
1202                None
1203            }
1204        })
1205    }
1206
1207    fn update_message_headers(&mut self, cx: &mut ViewContext<Self>) {
1208        self.editor.update(cx, |editor, cx| {
1209            let buffer = editor.buffer().read(cx).snapshot(cx);
1210            let excerpt_id = *buffer.as_singleton().unwrap().0;
1211            let old_blocks = std::mem::take(&mut self.blocks);
1212            let new_blocks = self
1213                .assistant
1214                .read(cx)
1215                .messages(cx)
1216                .map(|message| BlockProperties {
1217                    position: buffer.anchor_in_excerpt(excerpt_id, message.anchor),
1218                    height: 2,
1219                    style: BlockStyle::Sticky,
1220                    render: Arc::new({
1221                        let assistant = self.assistant.clone();
1222                        // let metadata = message.metadata.clone();
1223                        // let message = message.clone();
1224                        move |cx| {
1225                            enum Sender {}
1226                            enum ErrorTooltip {}
1227
1228                            let theme = theme::current(cx);
1229                            let style = &theme.assistant;
1230                            let message_id = message.id;
1231                            let sender = MouseEventHandler::<Sender, _>::new(
1232                                message_id.0,
1233                                cx,
1234                                |state, _| match message.role {
1235                                    Role::User => {
1236                                        let style = style.user_sender.style_for(state, false);
1237                                        Label::new("You", style.text.clone())
1238                                            .contained()
1239                                            .with_style(style.container)
1240                                    }
1241                                    Role::Assistant => {
1242                                        let style = style.assistant_sender.style_for(state, false);
1243                                        Label::new("Assistant", style.text.clone())
1244                                            .contained()
1245                                            .with_style(style.container)
1246                                    }
1247                                    Role::System => {
1248                                        let style = style.system_sender.style_for(state, false);
1249                                        Label::new("System", style.text.clone())
1250                                            .contained()
1251                                            .with_style(style.container)
1252                                    }
1253                                },
1254                            )
1255                            .with_cursor_style(CursorStyle::PointingHand)
1256                            .on_down(MouseButton::Left, {
1257                                let assistant = assistant.clone();
1258                                move |_, _, cx| {
1259                                    assistant.update(cx, |assistant, cx| {
1260                                        assistant.cycle_message_roles(
1261                                            HashSet::from_iter(Some(message_id)),
1262                                            cx,
1263                                        )
1264                                    })
1265                                }
1266                            });
1267
1268                            Flex::row()
1269                                .with_child(sender.aligned())
1270                                .with_child(
1271                                    Label::new(
1272                                        message.sent_at.format("%I:%M%P").to_string(),
1273                                        style.sent_at.text.clone(),
1274                                    )
1275                                    .contained()
1276                                    .with_style(style.sent_at.container)
1277                                    .aligned(),
1278                                )
1279                                .with_children(
1280                                    if let MessageStatus::Error(error) = &message.status {
1281                                        Some(
1282                                            Svg::new("icons/circle_x_mark_12.svg")
1283                                                .with_color(style.error_icon.color)
1284                                                .constrained()
1285                                                .with_width(style.error_icon.width)
1286                                                .contained()
1287                                                .with_style(style.error_icon.container)
1288                                                .with_tooltip::<ErrorTooltip>(
1289                                                    message_id.0,
1290                                                    error.to_string(),
1291                                                    None,
1292                                                    theme.tooltip.clone(),
1293                                                    cx,
1294                                                )
1295                                                .aligned(),
1296                                        )
1297                                    } else {
1298                                        None
1299                                    },
1300                                )
1301                                .aligned()
1302                                .left()
1303                                .contained()
1304                                .with_style(style.header)
1305                                .into_any()
1306                        }
1307                    }),
1308                    disposition: BlockDisposition::Above,
1309                })
1310                .collect::<Vec<_>>();
1311
1312            editor.remove_blocks(old_blocks, None, cx);
1313            let ids = editor.insert_blocks(new_blocks, None, cx);
1314            self.blocks = HashSet::from_iter(ids);
1315        });
1316    }
1317
1318    fn quote_selection(
1319        workspace: &mut Workspace,
1320        _: &QuoteSelection,
1321        cx: &mut ViewContext<Workspace>,
1322    ) {
1323        let Some(panel) = workspace.panel::<AssistantPanel>(cx) else {
1324            return;
1325        };
1326        let Some(editor) = workspace.active_item(cx).and_then(|item| item.downcast::<Editor>()) else {
1327            return;
1328        };
1329
1330        let text = editor.read_with(cx, |editor, cx| {
1331            let range = editor.selections.newest::<usize>(cx).range();
1332            let buffer = editor.buffer().read(cx).snapshot(cx);
1333            let start_language = buffer.language_at(range.start);
1334            let end_language = buffer.language_at(range.end);
1335            let language_name = if start_language == end_language {
1336                start_language.map(|language| language.name())
1337            } else {
1338                None
1339            };
1340            let language_name = language_name.as_deref().unwrap_or("").to_lowercase();
1341
1342            let selected_text = buffer.text_for_range(range).collect::<String>();
1343            if selected_text.is_empty() {
1344                None
1345            } else {
1346                Some(if language_name == "markdown" {
1347                    selected_text
1348                        .lines()
1349                        .map(|line| format!("> {}", line))
1350                        .collect::<Vec<_>>()
1351                        .join("\n")
1352                } else {
1353                    format!("```{language_name}\n{selected_text}\n```")
1354                })
1355            }
1356        });
1357
1358        // Activate the panel
1359        if !panel.read(cx).has_focus(cx) {
1360            workspace.toggle_panel_focus::<AssistantPanel>(cx);
1361        }
1362
1363        if let Some(text) = text {
1364            panel.update(cx, |panel, cx| {
1365                if let Some(assistant) = panel
1366                    .pane
1367                    .read(cx)
1368                    .active_item()
1369                    .and_then(|item| item.downcast::<AssistantEditor>())
1370                    .ok_or_else(|| anyhow!("no active context"))
1371                    .log_err()
1372                {
1373                    assistant.update(cx, |assistant, cx| {
1374                        assistant
1375                            .editor
1376                            .update(cx, |editor, cx| editor.insert(&text, cx))
1377                    });
1378                }
1379            });
1380        }
1381    }
1382
1383    fn copy(&mut self, _: &editor::Copy, cx: &mut ViewContext<Self>) {
1384        let editor = self.editor.read(cx);
1385        let assistant = self.assistant.read(cx);
1386        if editor.selections.count() == 1 {
1387            let selection = editor.selections.newest::<usize>(cx);
1388            let mut copied_text = String::new();
1389            let mut spanned_messages = 0;
1390            for message in assistant.messages(cx) {
1391                if message.range.start >= selection.range().end {
1392                    break;
1393                } else if message.range.end >= selection.range().start {
1394                    let range = cmp::max(message.range.start, selection.range().start)
1395                        ..cmp::min(message.range.end, selection.range().end);
1396                    if !range.is_empty() {
1397                        spanned_messages += 1;
1398                        write!(&mut copied_text, "## {}\n\n", message.role).unwrap();
1399                        for chunk in assistant.buffer.read(cx).text_for_range(range) {
1400                            copied_text.push_str(&chunk);
1401                        }
1402                        copied_text.push('\n');
1403                    }
1404                }
1405            }
1406
1407            if spanned_messages > 1 {
1408                cx.platform()
1409                    .write_to_clipboard(ClipboardItem::new(copied_text));
1410                return;
1411            }
1412        }
1413
1414        cx.propagate_action();
1415    }
1416
1417    fn split(&mut self, _: &Split, cx: &mut ViewContext<Self>) {
1418        self.assistant.update(cx, |assistant, cx| {
1419            let selections = self.editor.read(cx).selections.disjoint_anchors();
1420            for selection in selections.into_iter() {
1421                let buffer = self.editor.read(cx).buffer().read(cx).snapshot(cx);
1422                let range = selection
1423                    .map(|endpoint| endpoint.to_offset(&buffer))
1424                    .range();
1425                assistant.split_message(range, cx);
1426            }
1427        });
1428    }
1429
1430    fn cycle_model(&mut self, cx: &mut ViewContext<Self>) {
1431        self.assistant.update(cx, |assistant, cx| {
1432            let new_model = match assistant.model.as_str() {
1433                "gpt-4-0613" => "gpt-3.5-turbo-0613",
1434                _ => "gpt-4-0613",
1435            };
1436            assistant.set_model(new_model.into(), cx);
1437        });
1438    }
1439
1440    fn title(&self, cx: &AppContext) -> String {
1441        self.assistant
1442            .read(cx)
1443            .summary
1444            .clone()
1445            .unwrap_or_else(|| "New Context".into())
1446    }
1447}
1448
1449impl Entity for AssistantEditor {
1450    type Event = AssistantEditorEvent;
1451}
1452
1453impl View for AssistantEditor {
1454    fn ui_name() -> &'static str {
1455        "AssistantEditor"
1456    }
1457
1458    fn render(&mut self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
1459        enum Model {}
1460        let theme = &theme::current(cx).assistant;
1461        let assistant = &self.assistant.read(cx);
1462        let model = assistant.model.clone();
1463        let remaining_tokens = assistant.remaining_tokens().map(|remaining_tokens| {
1464            let remaining_tokens_style = if remaining_tokens <= 0 {
1465                &theme.no_remaining_tokens
1466            } else {
1467                &theme.remaining_tokens
1468            };
1469            Label::new(
1470                remaining_tokens.to_string(),
1471                remaining_tokens_style.text.clone(),
1472            )
1473            .contained()
1474            .with_style(remaining_tokens_style.container)
1475        });
1476
1477        Stack::new()
1478            .with_child(
1479                ChildView::new(&self.editor, cx)
1480                    .contained()
1481                    .with_style(theme.container),
1482            )
1483            .with_child(
1484                Flex::row()
1485                    .with_child(
1486                        MouseEventHandler::<Model, _>::new(0, cx, |state, _| {
1487                            let style = theme.model.style_for(state, false);
1488                            Label::new(model, style.text.clone())
1489                                .contained()
1490                                .with_style(style.container)
1491                        })
1492                        .with_cursor_style(CursorStyle::PointingHand)
1493                        .on_click(MouseButton::Left, |_, this, cx| this.cycle_model(cx)),
1494                    )
1495                    .with_children(remaining_tokens)
1496                    .contained()
1497                    .with_style(theme.model_info_container)
1498                    .aligned()
1499                    .top()
1500                    .right(),
1501            )
1502            .into_any()
1503    }
1504
1505    fn focus_in(&mut self, _: gpui::AnyViewHandle, cx: &mut ViewContext<Self>) {
1506        if cx.is_self_focused() {
1507            cx.focus(&self.editor);
1508        }
1509    }
1510}
1511
1512impl Item for AssistantEditor {
1513    fn tab_content<V: View>(
1514        &self,
1515        _: Option<usize>,
1516        style: &theme::Tab,
1517        cx: &gpui::AppContext,
1518    ) -> AnyElement<V> {
1519        let title = truncate_and_trailoff(&self.title(cx), editor::MAX_TAB_TITLE_LEN);
1520        Label::new(title, style.label.clone()).into_any()
1521    }
1522
1523    fn tab_tooltip_text(&self, cx: &AppContext) -> Option<Cow<str>> {
1524        Some(self.title(cx).into())
1525    }
1526
1527    fn as_searchable(
1528        &self,
1529        _: &ViewHandle<Self>,
1530    ) -> Option<Box<dyn workspace::searchable::SearchableItemHandle>> {
1531        Some(Box::new(self.editor.clone()))
1532    }
1533}
1534
1535#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, Hash)]
1536struct MessageId(usize);
1537
1538#[derive(Clone, Debug)]
1539struct MessageAnchor {
1540    id: MessageId,
1541    start: language::Anchor,
1542}
1543
1544#[derive(Clone, Debug)]
1545struct MessageMetadata {
1546    role: Role,
1547    sent_at: DateTime<Local>,
1548    status: MessageStatus,
1549}
1550
1551#[derive(Clone, Debug)]
1552enum MessageStatus {
1553    Pending,
1554    Done,
1555    Error(Arc<str>),
1556}
1557
1558#[derive(Clone, Debug)]
1559pub struct Message {
1560    range: Range<usize>,
1561    index: usize,
1562    id: MessageId,
1563    anchor: language::Anchor,
1564    role: Role,
1565    sent_at: DateTime<Local>,
1566    status: MessageStatus,
1567}
1568
1569impl Message {
1570    fn to_open_ai_message(&self, buffer: &Buffer) -> RequestMessage {
1571        let mut content = format!("[Message {}]\n", self.id.0).to_string();
1572        content.extend(buffer.text_for_range(self.range.clone()));
1573        RequestMessage {
1574            role: self.role,
1575            content,
1576        }
1577    }
1578}
1579
1580async fn stream_completion(
1581    api_key: String,
1582    executor: Arc<Background>,
1583    mut request: OpenAIRequest,
1584) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
1585    request.stream = true;
1586
1587    let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
1588
1589    let json_data = serde_json::to_string(&request)?;
1590    let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
1591        .header("Content-Type", "application/json")
1592        .header("Authorization", format!("Bearer {}", api_key))
1593        .body(json_data)?
1594        .send_async()
1595        .await?;
1596
1597    let status = response.status();
1598    if status == StatusCode::OK {
1599        executor
1600            .spawn(async move {
1601                let mut lines = BufReader::new(response.body_mut()).lines();
1602
1603                fn parse_line(
1604                    line: Result<String, io::Error>,
1605                ) -> Result<Option<OpenAIResponseStreamEvent>> {
1606                    if let Some(data) = line?.strip_prefix("data: ") {
1607                        let event = serde_json::from_str(&data)?;
1608                        Ok(Some(event))
1609                    } else {
1610                        Ok(None)
1611                    }
1612                }
1613
1614                while let Some(line) = lines.next().await {
1615                    if let Some(event) = parse_line(line).transpose() {
1616                        let done = event.as_ref().map_or(false, |event| {
1617                            event
1618                                .choices
1619                                .last()
1620                                .map_or(false, |choice| choice.finish_reason.is_some())
1621                        });
1622                        if tx.unbounded_send(event).is_err() {
1623                            break;
1624                        }
1625
1626                        if done {
1627                            break;
1628                        }
1629                    }
1630                }
1631
1632                anyhow::Ok(())
1633            })
1634            .detach();
1635
1636        Ok(rx)
1637    } else {
1638        let mut body = String::new();
1639        response.body_mut().read_to_string(&mut body).await?;
1640
1641        #[derive(Deserialize)]
1642        struct OpenAIResponse {
1643            error: OpenAIError,
1644        }
1645
1646        #[derive(Deserialize)]
1647        struct OpenAIError {
1648            message: String,
1649        }
1650
1651        match serde_json::from_str::<OpenAIResponse>(&body) {
1652            Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
1653                "Failed to connect to OpenAI API: {}",
1654                response.error.message,
1655            )),
1656
1657            _ => Err(anyhow!(
1658                "Failed to connect to OpenAI API: {} {}",
1659                response.status(),
1660                body,
1661            )),
1662        }
1663    }
1664}
1665
1666#[cfg(test)]
1667mod tests {
1668    use super::*;
1669    use gpui::AppContext;
1670
1671    #[gpui::test]
1672    fn test_inserting_and_removing_messages(cx: &mut AppContext) {
1673        let registry = Arc::new(LanguageRegistry::test());
1674        let assistant = cx.add_model(|cx| Assistant::new(Default::default(), registry, cx));
1675        let buffer = assistant.read(cx).buffer.clone();
1676
1677        let message_1 = assistant.read(cx).message_anchors[0].clone();
1678        assert_eq!(
1679            messages(&assistant, cx),
1680            vec![(message_1.id, Role::User, 0..0)]
1681        );
1682
1683        let message_2 = assistant.update(cx, |assistant, cx| {
1684            assistant
1685                .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
1686                .unwrap()
1687        });
1688        assert_eq!(
1689            messages(&assistant, cx),
1690            vec![
1691                (message_1.id, Role::User, 0..1),
1692                (message_2.id, Role::Assistant, 1..1)
1693            ]
1694        );
1695
1696        buffer.update(cx, |buffer, cx| {
1697            buffer.edit([(0..0, "1"), (1..1, "2")], None, cx)
1698        });
1699        assert_eq!(
1700            messages(&assistant, cx),
1701            vec![
1702                (message_1.id, Role::User, 0..2),
1703                (message_2.id, Role::Assistant, 2..3)
1704            ]
1705        );
1706
1707        let message_3 = assistant.update(cx, |assistant, cx| {
1708            assistant
1709                .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
1710                .unwrap()
1711        });
1712        assert_eq!(
1713            messages(&assistant, cx),
1714            vec![
1715                (message_1.id, Role::User, 0..2),
1716                (message_2.id, Role::Assistant, 2..4),
1717                (message_3.id, Role::User, 4..4)
1718            ]
1719        );
1720
1721        let message_4 = assistant.update(cx, |assistant, cx| {
1722            assistant
1723                .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
1724                .unwrap()
1725        });
1726        assert_eq!(
1727            messages(&assistant, cx),
1728            vec![
1729                (message_1.id, Role::User, 0..2),
1730                (message_2.id, Role::Assistant, 2..4),
1731                (message_4.id, Role::User, 4..5),
1732                (message_3.id, Role::User, 5..5),
1733            ]
1734        );
1735
1736        buffer.update(cx, |buffer, cx| {
1737            buffer.edit([(4..4, "C"), (5..5, "D")], None, cx)
1738        });
1739        assert_eq!(
1740            messages(&assistant, cx),
1741            vec![
1742                (message_1.id, Role::User, 0..2),
1743                (message_2.id, Role::Assistant, 2..4),
1744                (message_4.id, Role::User, 4..6),
1745                (message_3.id, Role::User, 6..7),
1746            ]
1747        );
1748
1749        // Deleting across message boundaries merges the messages.
1750        buffer.update(cx, |buffer, cx| buffer.edit([(1..4, "")], None, cx));
1751        assert_eq!(
1752            messages(&assistant, cx),
1753            vec![
1754                (message_1.id, Role::User, 0..3),
1755                (message_3.id, Role::User, 3..4),
1756            ]
1757        );
1758
1759        // Undoing the deletion should also undo the merge.
1760        buffer.update(cx, |buffer, cx| buffer.undo(cx));
1761        assert_eq!(
1762            messages(&assistant, cx),
1763            vec![
1764                (message_1.id, Role::User, 0..2),
1765                (message_2.id, Role::Assistant, 2..4),
1766                (message_4.id, Role::User, 4..6),
1767                (message_3.id, Role::User, 6..7),
1768            ]
1769        );
1770
1771        // Redoing the deletion should also redo the merge.
1772        buffer.update(cx, |buffer, cx| buffer.redo(cx));
1773        assert_eq!(
1774            messages(&assistant, cx),
1775            vec![
1776                (message_1.id, Role::User, 0..3),
1777                (message_3.id, Role::User, 3..4),
1778            ]
1779        );
1780
1781        // Ensure we can still insert after a merged message.
1782        let message_5 = assistant.update(cx, |assistant, cx| {
1783            assistant
1784                .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
1785                .unwrap()
1786        });
1787        assert_eq!(
1788            messages(&assistant, cx),
1789            vec![
1790                (message_1.id, Role::User, 0..3),
1791                (message_5.id, Role::System, 3..4),
1792                (message_3.id, Role::User, 4..5)
1793            ]
1794        );
1795    }
1796
1797    #[gpui::test]
1798    fn test_message_splitting(cx: &mut AppContext) {
1799        let registry = Arc::new(LanguageRegistry::test());
1800        let assistant = cx.add_model(|cx| Assistant::new(Default::default(), registry, cx));
1801        let buffer = assistant.read(cx).buffer.clone();
1802
1803        let message_1 = assistant.read(cx).message_anchors[0].clone();
1804        assert_eq!(
1805            messages(&assistant, cx),
1806            vec![(message_1.id, Role::User, 0..0)]
1807        );
1808
1809        buffer.update(cx, |buffer, cx| {
1810            buffer.edit([(0..0, "aaa\nbbb\nccc\nddd\n")], None, cx)
1811        });
1812
1813        let (_, message_2) =
1814            assistant.update(cx, |assistant, cx| assistant.split_message(3..3, cx));
1815        let message_2 = message_2.unwrap();
1816
1817        // We recycle newlines in the middle of a split message
1818        assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\nddd\n");
1819        assert_eq!(
1820            messages(&assistant, cx),
1821            vec![
1822                (message_1.id, Role::User, 0..4),
1823                (message_2.id, Role::User, 4..16),
1824            ]
1825        );
1826
1827        let (_, message_3) =
1828            assistant.update(cx, |assistant, cx| assistant.split_message(3..3, cx));
1829        let message_3 = message_3.unwrap();
1830
1831        // We don't recycle newlines at the end of a split message
1832        assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
1833        assert_eq!(
1834            messages(&assistant, cx),
1835            vec![
1836                (message_1.id, Role::User, 0..4),
1837                (message_3.id, Role::User, 4..5),
1838                (message_2.id, Role::User, 5..17),
1839            ]
1840        );
1841
1842        let (_, message_4) =
1843            assistant.update(cx, |assistant, cx| assistant.split_message(9..9, cx));
1844        let message_4 = message_4.unwrap();
1845        assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
1846        assert_eq!(
1847            messages(&assistant, cx),
1848            vec![
1849                (message_1.id, Role::User, 0..4),
1850                (message_3.id, Role::User, 4..5),
1851                (message_2.id, Role::User, 5..9),
1852                (message_4.id, Role::User, 9..17),
1853            ]
1854        );
1855
1856        let (_, message_5) =
1857            assistant.update(cx, |assistant, cx| assistant.split_message(9..9, cx));
1858        let message_5 = message_5.unwrap();
1859        assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\nddd\n");
1860        assert_eq!(
1861            messages(&assistant, cx),
1862            vec![
1863                (message_1.id, Role::User, 0..4),
1864                (message_3.id, Role::User, 4..5),
1865                (message_2.id, Role::User, 5..9),
1866                (message_4.id, Role::User, 9..10),
1867                (message_5.id, Role::User, 10..18),
1868            ]
1869        );
1870
1871        let (message_6, message_7) =
1872            assistant.update(cx, |assistant, cx| assistant.split_message(14..16, cx));
1873        let message_6 = message_6.unwrap();
1874        let message_7 = message_7.unwrap();
1875        assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\ndd\nd\n");
1876        assert_eq!(
1877            messages(&assistant, cx),
1878            vec![
1879                (message_1.id, Role::User, 0..4),
1880                (message_3.id, Role::User, 4..5),
1881                (message_2.id, Role::User, 5..9),
1882                (message_4.id, Role::User, 9..10),
1883                (message_5.id, Role::User, 10..14),
1884                (message_6.id, Role::User, 14..17),
1885                (message_7.id, Role::User, 17..19),
1886            ]
1887        );
1888    }
1889
1890    #[gpui::test]
1891    fn test_messages_for_offsets(cx: &mut AppContext) {
1892        let registry = Arc::new(LanguageRegistry::test());
1893        let assistant = cx.add_model(|cx| Assistant::new(Default::default(), registry, cx));
1894        let buffer = assistant.read(cx).buffer.clone();
1895
1896        let message_1 = assistant.read(cx).message_anchors[0].clone();
1897        assert_eq!(
1898            messages(&assistant, cx),
1899            vec![(message_1.id, Role::User, 0..0)]
1900        );
1901
1902        buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
1903        let message_2 = assistant
1904            .update(cx, |assistant, cx| {
1905                assistant.insert_message_after(message_1.id, Role::User, MessageStatus::Done, cx)
1906            })
1907            .unwrap();
1908        buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx));
1909
1910        let message_3 = assistant
1911            .update(cx, |assistant, cx| {
1912                assistant.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
1913            })
1914            .unwrap();
1915        buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx));
1916
1917        assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc");
1918        assert_eq!(
1919            messages(&assistant, cx),
1920            vec![
1921                (message_1.id, Role::User, 0..4),
1922                (message_2.id, Role::User, 4..8),
1923                (message_3.id, Role::User, 8..11)
1924            ]
1925        );
1926
1927        assert_eq!(
1928            message_ids_for_offsets(&assistant, &[0, 4, 9], cx),
1929            [message_1.id, message_2.id, message_3.id]
1930        );
1931        assert_eq!(
1932            message_ids_for_offsets(&assistant, &[0, 1, 11], cx),
1933            [message_1.id, message_3.id]
1934        );
1935
1936        fn message_ids_for_offsets(
1937            assistant: &ModelHandle<Assistant>,
1938            offsets: &[usize],
1939            cx: &AppContext,
1940        ) -> Vec<MessageId> {
1941            assistant
1942                .read(cx)
1943                .messages_for_offsets(offsets.iter().copied(), cx)
1944                .into_iter()
1945                .map(|message| message.id)
1946                .collect()
1947        }
1948    }
1949
1950    fn messages(
1951        assistant: &ModelHandle<Assistant>,
1952        cx: &AppContext,
1953    ) -> Vec<(MessageId, Role, Range<usize>)> {
1954        assistant
1955            .read(cx)
1956            .messages(cx)
1957            .map(|message| (message.id, message.role, message.range))
1958            .collect()
1959    }
1960}