assistant.rs

   1use crate::{
   2    assistant_settings::{AssistantDockPosition, AssistantSettings},
   3    OpenAIRequest, OpenAIResponseStreamEvent, RequestMessage, Role,
   4};
   5use anyhow::{anyhow, Result};
   6use chrono::{DateTime, Local};
   7use collections::{HashMap, HashSet};
   8use editor::{
   9    display_map::{BlockDisposition, BlockId, BlockProperties, BlockStyle, ToDisplayPoint},
  10    scroll::autoscroll::{Autoscroll, AutoscrollStrategy},
  11    Anchor, Editor,
  12};
  13use fs::Fs;
  14use futures::{io::BufReader, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
  15use gpui::{
  16    actions,
  17    elements::*,
  18    executor::Background,
  19    geometry::vector::{vec2f, Vector2F},
  20    platform::{CursorStyle, MouseButton},
  21    Action, AppContext, AsyncAppContext, ClipboardItem, Entity, ModelContext, ModelHandle,
  22    Subscription, Task, View, ViewContext, ViewHandle, WeakViewHandle, WindowContext,
  23};
  24use isahc::{http::StatusCode, Request, RequestExt};
  25use language::{language_settings::SoftWrap, Buffer, LanguageRegistry, ToOffset as _};
  26use serde::Deserialize;
  27use settings::SettingsStore;
  28use std::{
  29    borrow::Cow, cell::RefCell, cmp, fmt::Write, io, iter, ops::Range, rc::Rc, sync::Arc,
  30    time::Duration,
  31};
  32use util::{channel::ReleaseChannel, post_inc, truncate_and_trailoff, ResultExt, TryFutureExt};
  33use workspace::{
  34    dock::{DockPosition, Panel},
  35    item::Item,
  36    pane, Pane, Workspace,
  37};
  38
  39const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
  40
  41actions!(
  42    assistant,
  43    [
  44        NewContext,
  45        Assist,
  46        Split,
  47        CycleMessageRole,
  48        QuoteSelection,
  49        ToggleFocus,
  50        ResetKey
  51    ]
  52);
  53
  54pub fn init(cx: &mut AppContext) {
  55    if *util::channel::RELEASE_CHANNEL == ReleaseChannel::Stable {
  56        cx.update_default_global::<collections::CommandPaletteFilter, _, _>(move |filter, _cx| {
  57            filter.filtered_namespaces.insert("assistant");
  58        });
  59    }
  60
  61    settings::register::<AssistantSettings>(cx);
  62    cx.add_action(
  63        |workspace: &mut Workspace, _: &NewContext, cx: &mut ViewContext<Workspace>| {
  64            if let Some(this) = workspace.panel::<AssistantPanel>(cx) {
  65                this.update(cx, |this, cx| this.add_context(cx))
  66            }
  67
  68            workspace.focus_panel::<AssistantPanel>(cx);
  69        },
  70    );
  71    cx.add_action(AssistantEditor::assist);
  72    cx.capture_action(AssistantEditor::cancel_last_assist);
  73    cx.add_action(AssistantEditor::quote_selection);
  74    cx.capture_action(AssistantEditor::copy);
  75    cx.capture_action(AssistantEditor::split);
  76    cx.capture_action(AssistantEditor::cycle_message_role);
  77    cx.add_action(AssistantPanel::save_api_key);
  78    cx.add_action(AssistantPanel::reset_api_key);
  79    cx.add_action(
  80        |workspace: &mut Workspace, _: &ToggleFocus, cx: &mut ViewContext<Workspace>| {
  81            workspace.toggle_panel_focus::<AssistantPanel>(cx);
  82        },
  83    );
  84}
  85
  86pub enum AssistantPanelEvent {
  87    ZoomIn,
  88    ZoomOut,
  89    Focus,
  90    Close,
  91    DockPositionChanged,
  92}
  93
  94pub struct AssistantPanel {
  95    width: Option<f32>,
  96    height: Option<f32>,
  97    pane: ViewHandle<Pane>,
  98    api_key: Rc<RefCell<Option<String>>>,
  99    api_key_editor: Option<ViewHandle<Editor>>,
 100    has_read_credentials: bool,
 101    languages: Arc<LanguageRegistry>,
 102    fs: Arc<dyn Fs>,
 103    subscriptions: Vec<Subscription>,
 104}
 105
 106impl AssistantPanel {
 107    pub fn load(
 108        workspace: WeakViewHandle<Workspace>,
 109        cx: AsyncAppContext,
 110    ) -> Task<Result<ViewHandle<Self>>> {
 111        cx.spawn(|mut cx| async move {
 112            // TODO: deserialize state.
 113            workspace.update(&mut cx, |workspace, cx| {
 114                cx.add_view::<Self, _>(|cx| {
 115                    let weak_self = cx.weak_handle();
 116                    let pane = cx.add_view(|cx| {
 117                        let mut pane = Pane::new(
 118                            workspace.weak_handle(),
 119                            workspace.project().clone(),
 120                            workspace.app_state().background_actions,
 121                            Default::default(),
 122                            cx,
 123                        );
 124                        pane.set_can_split(false, cx);
 125                        pane.set_can_navigate(false, cx);
 126                        pane.on_can_drop(move |_, _| false);
 127                        pane.set_render_tab_bar_buttons(cx, move |pane, cx| {
 128                            let weak_self = weak_self.clone();
 129                            Flex::row()
 130                                .with_child(Pane::render_tab_bar_button(
 131                                    0,
 132                                    "icons/plus_12.svg",
 133                                    false,
 134                                    Some(("New Context".into(), Some(Box::new(NewContext)))),
 135                                    cx,
 136                                    move |_, cx| {
 137                                        let weak_self = weak_self.clone();
 138                                        cx.window_context().defer(move |cx| {
 139                                            if let Some(this) = weak_self.upgrade(cx) {
 140                                                this.update(cx, |this, cx| this.add_context(cx));
 141                                            }
 142                                        })
 143                                    },
 144                                    None,
 145                                ))
 146                                .with_child(Pane::render_tab_bar_button(
 147                                    1,
 148                                    if pane.is_zoomed() {
 149                                        "icons/minimize_8.svg"
 150                                    } else {
 151                                        "icons/maximize_8.svg"
 152                                    },
 153                                    pane.is_zoomed(),
 154                                    Some((
 155                                        "Toggle Zoom".into(),
 156                                        Some(Box::new(workspace::ToggleZoom)),
 157                                    )),
 158                                    cx,
 159                                    move |pane, cx| pane.toggle_zoom(&Default::default(), cx),
 160                                    None,
 161                                ))
 162                                .into_any()
 163                        });
 164                        let buffer_search_bar = cx.add_view(search::BufferSearchBar::new);
 165                        pane.toolbar()
 166                            .update(cx, |toolbar, cx| toolbar.add_item(buffer_search_bar, cx));
 167                        pane
 168                    });
 169
 170                    let mut this = Self {
 171                        pane,
 172                        api_key: Rc::new(RefCell::new(None)),
 173                        api_key_editor: None,
 174                        has_read_credentials: false,
 175                        languages: workspace.app_state().languages.clone(),
 176                        fs: workspace.app_state().fs.clone(),
 177                        width: None,
 178                        height: None,
 179                        subscriptions: Default::default(),
 180                    };
 181
 182                    let mut old_dock_position = this.position(cx);
 183                    this.subscriptions = vec![
 184                        cx.observe(&this.pane, |_, _, cx| cx.notify()),
 185                        cx.subscribe(&this.pane, Self::handle_pane_event),
 186                        cx.observe_global::<SettingsStore, _>(move |this, cx| {
 187                            let new_dock_position = this.position(cx);
 188                            if new_dock_position != old_dock_position {
 189                                old_dock_position = new_dock_position;
 190                                cx.emit(AssistantPanelEvent::DockPositionChanged);
 191                            }
 192                        }),
 193                    ];
 194
 195                    this
 196                })
 197            })
 198        })
 199    }
 200
 201    fn handle_pane_event(
 202        &mut self,
 203        _pane: ViewHandle<Pane>,
 204        event: &pane::Event,
 205        cx: &mut ViewContext<Self>,
 206    ) {
 207        match event {
 208            pane::Event::ZoomIn => cx.emit(AssistantPanelEvent::ZoomIn),
 209            pane::Event::ZoomOut => cx.emit(AssistantPanelEvent::ZoomOut),
 210            pane::Event::Focus => cx.emit(AssistantPanelEvent::Focus),
 211            pane::Event::Remove => cx.emit(AssistantPanelEvent::Close),
 212            _ => {}
 213        }
 214    }
 215
 216    fn add_context(&mut self, cx: &mut ViewContext<Self>) {
 217        let focus = self.has_focus(cx);
 218        let editor = cx
 219            .add_view(|cx| AssistantEditor::new(self.api_key.clone(), self.languages.clone(), cx));
 220        self.subscriptions
 221            .push(cx.subscribe(&editor, Self::handle_assistant_editor_event));
 222        self.pane.update(cx, |pane, cx| {
 223            pane.add_item(Box::new(editor), true, focus, None, cx)
 224        });
 225    }
 226
 227    fn handle_assistant_editor_event(
 228        &mut self,
 229        _: ViewHandle<AssistantEditor>,
 230        event: &AssistantEditorEvent,
 231        cx: &mut ViewContext<Self>,
 232    ) {
 233        match event {
 234            AssistantEditorEvent::TabContentChanged => self.pane.update(cx, |_, cx| cx.notify()),
 235        }
 236    }
 237
 238    fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
 239        if let Some(api_key) = self
 240            .api_key_editor
 241            .as_ref()
 242            .map(|editor| editor.read(cx).text(cx))
 243        {
 244            if !api_key.is_empty() {
 245                cx.platform()
 246                    .write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
 247                    .log_err();
 248                *self.api_key.borrow_mut() = Some(api_key);
 249                self.api_key_editor.take();
 250                cx.focus_self();
 251                cx.notify();
 252            }
 253        } else {
 254            cx.propagate_action();
 255        }
 256    }
 257
 258    fn reset_api_key(&mut self, _: &ResetKey, cx: &mut ViewContext<Self>) {
 259        cx.platform().delete_credentials(OPENAI_API_URL).log_err();
 260        self.api_key.take();
 261        self.api_key_editor = Some(build_api_key_editor(cx));
 262        cx.focus_self();
 263        cx.notify();
 264    }
 265}
 266
 267fn build_api_key_editor(cx: &mut ViewContext<AssistantPanel>) -> ViewHandle<Editor> {
 268    cx.add_view(|cx| {
 269        let mut editor = Editor::single_line(
 270            Some(Arc::new(|theme| theme.assistant.api_key_editor.clone())),
 271            cx,
 272        );
 273        editor.set_placeholder_text("sk-000000000000000000000000000000000000000000000000", cx);
 274        editor
 275    })
 276}
 277
 278impl Entity for AssistantPanel {
 279    type Event = AssistantPanelEvent;
 280}
 281
 282impl View for AssistantPanel {
 283    fn ui_name() -> &'static str {
 284        "AssistantPanel"
 285    }
 286
 287    fn render(&mut self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
 288        let style = &theme::current(cx).assistant;
 289        if let Some(api_key_editor) = self.api_key_editor.as_ref() {
 290            Flex::column()
 291                .with_child(
 292                    Text::new(
 293                        "Paste your OpenAI API key and press Enter to use the assistant",
 294                        style.api_key_prompt.text.clone(),
 295                    )
 296                    .aligned(),
 297                )
 298                .with_child(
 299                    ChildView::new(api_key_editor, cx)
 300                        .contained()
 301                        .with_style(style.api_key_editor.container)
 302                        .aligned(),
 303                )
 304                .contained()
 305                .with_style(style.api_key_prompt.container)
 306                .aligned()
 307                .into_any()
 308        } else {
 309            ChildView::new(&self.pane, cx).into_any()
 310        }
 311    }
 312
 313    fn focus_in(&mut self, _: gpui::AnyViewHandle, cx: &mut ViewContext<Self>) {
 314        if cx.is_self_focused() {
 315            if let Some(api_key_editor) = self.api_key_editor.as_ref() {
 316                cx.focus(api_key_editor);
 317            } else {
 318                cx.focus(&self.pane);
 319            }
 320        }
 321    }
 322}
 323
 324impl Panel for AssistantPanel {
 325    fn position(&self, cx: &WindowContext) -> DockPosition {
 326        match settings::get::<AssistantSettings>(cx).dock {
 327            AssistantDockPosition::Left => DockPosition::Left,
 328            AssistantDockPosition::Bottom => DockPosition::Bottom,
 329            AssistantDockPosition::Right => DockPosition::Right,
 330        }
 331    }
 332
 333    fn position_is_valid(&self, _: DockPosition) -> bool {
 334        true
 335    }
 336
 337    fn set_position(&mut self, position: DockPosition, cx: &mut ViewContext<Self>) {
 338        settings::update_settings_file::<AssistantSettings>(self.fs.clone(), cx, move |settings| {
 339            let dock = match position {
 340                DockPosition::Left => AssistantDockPosition::Left,
 341                DockPosition::Bottom => AssistantDockPosition::Bottom,
 342                DockPosition::Right => AssistantDockPosition::Right,
 343            };
 344            settings.dock = Some(dock);
 345        });
 346    }
 347
 348    fn size(&self, cx: &WindowContext) -> f32 {
 349        let settings = settings::get::<AssistantSettings>(cx);
 350        match self.position(cx) {
 351            DockPosition::Left | DockPosition::Right => {
 352                self.width.unwrap_or_else(|| settings.default_width)
 353            }
 354            DockPosition::Bottom => self.height.unwrap_or_else(|| settings.default_height),
 355        }
 356    }
 357
 358    fn set_size(&mut self, size: f32, cx: &mut ViewContext<Self>) {
 359        match self.position(cx) {
 360            DockPosition::Left | DockPosition::Right => self.width = Some(size),
 361            DockPosition::Bottom => self.height = Some(size),
 362        }
 363        cx.notify();
 364    }
 365
 366    fn should_zoom_in_on_event(event: &AssistantPanelEvent) -> bool {
 367        matches!(event, AssistantPanelEvent::ZoomIn)
 368    }
 369
 370    fn should_zoom_out_on_event(event: &AssistantPanelEvent) -> bool {
 371        matches!(event, AssistantPanelEvent::ZoomOut)
 372    }
 373
 374    fn is_zoomed(&self, cx: &WindowContext) -> bool {
 375        self.pane.read(cx).is_zoomed()
 376    }
 377
 378    fn set_zoomed(&mut self, zoomed: bool, cx: &mut ViewContext<Self>) {
 379        self.pane.update(cx, |pane, cx| pane.set_zoomed(zoomed, cx));
 380    }
 381
 382    fn set_active(&mut self, active: bool, cx: &mut ViewContext<Self>) {
 383        if active {
 384            if self.api_key.borrow().is_none() && !self.has_read_credentials {
 385                self.has_read_credentials = true;
 386                let api_key = if let Some((_, api_key)) = cx
 387                    .platform()
 388                    .read_credentials(OPENAI_API_URL)
 389                    .log_err()
 390                    .flatten()
 391                {
 392                    String::from_utf8(api_key).log_err()
 393                } else {
 394                    None
 395                };
 396                if let Some(api_key) = api_key {
 397                    *self.api_key.borrow_mut() = Some(api_key);
 398                } else if self.api_key_editor.is_none() {
 399                    self.api_key_editor = Some(build_api_key_editor(cx));
 400                    cx.notify();
 401                }
 402            }
 403
 404            if self.pane.read(cx).items_len() == 0 {
 405                self.add_context(cx);
 406            }
 407        }
 408    }
 409
 410    fn icon_path(&self) -> &'static str {
 411        "icons/robot_14.svg"
 412    }
 413
 414    fn icon_tooltip(&self) -> (String, Option<Box<dyn Action>>) {
 415        ("Assistant Panel".into(), Some(Box::new(ToggleFocus)))
 416    }
 417
 418    fn should_change_position_on_event(event: &Self::Event) -> bool {
 419        matches!(event, AssistantPanelEvent::DockPositionChanged)
 420    }
 421
 422    fn should_activate_on_event(_: &Self::Event) -> bool {
 423        false
 424    }
 425
 426    fn should_close_on_event(event: &AssistantPanelEvent) -> bool {
 427        matches!(event, AssistantPanelEvent::Close)
 428    }
 429
 430    fn has_focus(&self, cx: &WindowContext) -> bool {
 431        self.pane.read(cx).has_focus()
 432            || self
 433                .api_key_editor
 434                .as_ref()
 435                .map_or(false, |editor| editor.is_focused(cx))
 436    }
 437
 438    fn is_focus_event(event: &Self::Event) -> bool {
 439        matches!(event, AssistantPanelEvent::Focus)
 440    }
 441}
 442
 443enum AssistantEvent {
 444    MessagesEdited,
 445    SummaryChanged,
 446    StreamedCompletion,
 447}
 448
 449struct Assistant {
 450    buffer: ModelHandle<Buffer>,
 451    message_anchors: Vec<MessageAnchor>,
 452    messages_metadata: HashMap<MessageId, MessageMetadata>,
 453    next_message_id: MessageId,
 454    summary: Option<String>,
 455    pending_summary: Task<Option<()>>,
 456    completion_count: usize,
 457    pending_completions: Vec<PendingCompletion>,
 458    model: String,
 459    token_count: Option<usize>,
 460    max_token_count: usize,
 461    pending_token_count: Task<Option<()>>,
 462    api_key: Rc<RefCell<Option<String>>>,
 463    _subscriptions: Vec<Subscription>,
 464}
 465
 466impl Entity for Assistant {
 467    type Event = AssistantEvent;
 468}
 469
 470impl Assistant {
 471    fn new(
 472        api_key: Rc<RefCell<Option<String>>>,
 473        language_registry: Arc<LanguageRegistry>,
 474        cx: &mut ModelContext<Self>,
 475    ) -> Self {
 476        let model = "gpt-3.5-turbo";
 477        let markdown = language_registry.language_for_name("Markdown");
 478        let buffer = cx.add_model(|cx| {
 479            let mut buffer = Buffer::new(0, "", cx);
 480            buffer.set_language_registry(language_registry);
 481            cx.spawn_weak(|buffer, mut cx| async move {
 482                let markdown = markdown.await?;
 483                let buffer = buffer
 484                    .upgrade(&cx)
 485                    .ok_or_else(|| anyhow!("buffer was dropped"))?;
 486                buffer.update(&mut cx, |buffer, cx| {
 487                    buffer.set_language(Some(markdown), cx)
 488                });
 489                anyhow::Ok(())
 490            })
 491            .detach_and_log_err(cx);
 492            buffer
 493        });
 494
 495        let mut this = Self {
 496            message_anchors: Default::default(),
 497            messages_metadata: Default::default(),
 498            next_message_id: Default::default(),
 499            summary: None,
 500            pending_summary: Task::ready(None),
 501            completion_count: Default::default(),
 502            pending_completions: Default::default(),
 503            token_count: None,
 504            max_token_count: tiktoken_rs::model::get_context_size(model),
 505            pending_token_count: Task::ready(None),
 506            model: model.into(),
 507            _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
 508            api_key,
 509            buffer,
 510        };
 511        let message = MessageAnchor {
 512            id: MessageId(post_inc(&mut this.next_message_id.0)),
 513            start: language::Anchor::MIN,
 514        };
 515        this.message_anchors.push(message.clone());
 516        this.messages_metadata.insert(
 517            message.id,
 518            MessageMetadata {
 519                role: Role::User,
 520                sent_at: Local::now(),
 521                error: None,
 522            },
 523        );
 524
 525        this.count_remaining_tokens(cx);
 526        this
 527    }
 528
 529    fn handle_buffer_event(
 530        &mut self,
 531        _: ModelHandle<Buffer>,
 532        event: &language::Event,
 533        cx: &mut ModelContext<Self>,
 534    ) {
 535        match event {
 536            language::Event::Edited => {
 537                self.count_remaining_tokens(cx);
 538                cx.emit(AssistantEvent::MessagesEdited);
 539            }
 540            _ => {}
 541        }
 542    }
 543
 544    fn count_remaining_tokens(&mut self, cx: &mut ModelContext<Self>) {
 545        let messages = self
 546            .open_ai_request_messages(cx)
 547            .into_iter()
 548            .filter_map(|message| {
 549                Some(tiktoken_rs::ChatCompletionRequestMessage {
 550                    role: match message.role {
 551                        Role::User => "user".into(),
 552                        Role::Assistant => "assistant".into(),
 553                        Role::System => "system".into(),
 554                    },
 555                    content: message.content,
 556                    name: None,
 557                })
 558            })
 559            .collect::<Vec<_>>();
 560        let model = self.model.clone();
 561        self.pending_token_count = cx.spawn_weak(|this, mut cx| {
 562            async move {
 563                cx.background().timer(Duration::from_millis(200)).await;
 564                let token_count = cx
 565                    .background()
 566                    .spawn(async move { tiktoken_rs::num_tokens_from_messages(&model, &messages) })
 567                    .await?;
 568
 569                this.upgrade(&cx)
 570                    .ok_or_else(|| anyhow!("assistant was dropped"))?
 571                    .update(&mut cx, |this, cx| {
 572                        this.max_token_count = tiktoken_rs::model::get_context_size(&this.model);
 573                        this.token_count = Some(token_count);
 574                        cx.notify()
 575                    });
 576                anyhow::Ok(())
 577            }
 578            .log_err()
 579        });
 580    }
 581
 582    fn remaining_tokens(&self) -> Option<isize> {
 583        Some(self.max_token_count as isize - self.token_count? as isize)
 584    }
 585
 586    fn set_model(&mut self, model: String, cx: &mut ModelContext<Self>) {
 587        self.model = model;
 588        self.count_remaining_tokens(cx);
 589        cx.notify();
 590    }
 591
 592    fn assist(&mut self, cx: &mut ModelContext<Self>) -> Option<(MessageAnchor, MessageAnchor)> {
 593        let request = OpenAIRequest {
 594            model: self.model.clone(),
 595            messages: self.open_ai_request_messages(cx),
 596            stream: true,
 597        };
 598
 599        let api_key = self.api_key.borrow().clone()?;
 600        let stream = stream_completion(api_key, cx.background().clone(), request);
 601        let assistant_message =
 602            self.insert_message_after(self.message_anchors.last()?.id, Role::Assistant, cx)?;
 603        let user_message = self.insert_message_after(assistant_message.id, Role::User, cx)?;
 604        let task = cx.spawn_weak({
 605            |this, mut cx| async move {
 606                let assistant_message_id = assistant_message.id;
 607                let stream_completion = async {
 608                    let mut messages = stream.await?;
 609
 610                    while let Some(message) = messages.next().await {
 611                        let mut message = message?;
 612                        if let Some(choice) = message.choices.pop() {
 613                            this.upgrade(&cx)
 614                                .ok_or_else(|| anyhow!("assistant was dropped"))?
 615                                .update(&mut cx, |this, cx| {
 616                                    let text: Arc<str> = choice.delta.content?.into();
 617                                    let message_ix = this
 618                                        .message_anchors
 619                                        .iter()
 620                                        .position(|message| message.id == assistant_message_id)?;
 621                                    this.buffer.update(cx, |buffer, cx| {
 622                                        let offset = if message_ix + 1 == this.message_anchors.len()
 623                                        {
 624                                            buffer.len()
 625                                        } else {
 626                                            this.message_anchors[message_ix + 1]
 627                                                .start
 628                                                .to_offset(buffer)
 629                                                .saturating_sub(1)
 630                                        };
 631                                        buffer.edit([(offset..offset, text)], None, cx);
 632                                    });
 633                                    cx.emit(AssistantEvent::StreamedCompletion);
 634
 635                                    Some(())
 636                                });
 637                        }
 638                    }
 639
 640                    this.upgrade(&cx)
 641                        .ok_or_else(|| anyhow!("assistant was dropped"))?
 642                        .update(&mut cx, |this, cx| {
 643                            this.pending_completions
 644                                .retain(|completion| completion.id != this.completion_count);
 645                            this.summarize(cx);
 646                        });
 647
 648                    anyhow::Ok(())
 649                };
 650
 651                let result = stream_completion.await;
 652                if let Some(this) = this.upgrade(&cx) {
 653                    this.update(&mut cx, |this, cx| {
 654                        if let Err(error) = result {
 655                            if let Some(metadata) =
 656                                this.messages_metadata.get_mut(&assistant_message.id)
 657                            {
 658                                metadata.error = Some(error.to_string().trim().into());
 659                                cx.notify();
 660                            }
 661                        }
 662                    });
 663                }
 664            }
 665        });
 666
 667        self.pending_completions.push(PendingCompletion {
 668            id: post_inc(&mut self.completion_count),
 669            _task: task,
 670        });
 671        Some((assistant_message, user_message))
 672    }
 673
 674    fn cancel_last_assist(&mut self) -> bool {
 675        self.pending_completions.pop().is_some()
 676    }
 677
 678    fn cycle_message_role(&mut self, id: MessageId, cx: &mut ModelContext<Self>) {
 679        if let Some(metadata) = self.messages_metadata.get_mut(&id) {
 680            metadata.role.cycle();
 681            cx.emit(AssistantEvent::MessagesEdited);
 682            cx.notify();
 683        }
 684    }
 685
 686    fn insert_message_after(
 687        &mut self,
 688        message_id: MessageId,
 689        role: Role,
 690        cx: &mut ModelContext<Self>,
 691    ) -> Option<MessageAnchor> {
 692        if let Some(prev_message_ix) = self
 693            .message_anchors
 694            .iter()
 695            .position(|message| message.id == message_id)
 696        {
 697            let start = self.buffer.update(cx, |buffer, cx| {
 698                let offset = self.message_anchors[prev_message_ix + 1..]
 699                    .iter()
 700                    .find(|message| message.start.is_valid(buffer))
 701                    .map_or(buffer.len(), |message| message.start.to_offset(buffer) - 1);
 702                buffer.edit([(offset..offset, "\n")], None, cx);
 703                buffer.anchor_before(offset + 1)
 704            });
 705            let message = MessageAnchor {
 706                id: MessageId(post_inc(&mut self.next_message_id.0)),
 707                start,
 708            };
 709            self.message_anchors
 710                .insert(prev_message_ix + 1, message.clone());
 711            self.messages_metadata.insert(
 712                message.id,
 713                MessageMetadata {
 714                    role,
 715                    sent_at: Local::now(),
 716                    error: None,
 717                },
 718            );
 719            cx.emit(AssistantEvent::MessagesEdited);
 720            Some(message)
 721        } else {
 722            None
 723        }
 724    }
 725
 726    fn split_message(
 727        &mut self,
 728        range: Range<usize>,
 729        cx: &mut ModelContext<Self>,
 730    ) -> (Option<MessageAnchor>, Option<MessageAnchor>) {
 731        let start_message = self.message_for_offset(range.start, cx);
 732        let end_message = self.message_for_offset(range.end, cx);
 733        if let Some((start_message, end_message)) = start_message.zip(end_message) {
 734            // Prevent splitting when range spans multiple messages.
 735            if start_message.index != end_message.index {
 736                return (None, None);
 737            }
 738
 739            let message = start_message;
 740            let role = message.role;
 741            let mut edited_buffer = false;
 742
 743            let mut suffix_start = None;
 744            if range.start > message.range.start && range.end < message.range.end - 1 {
 745                if self.buffer.read(cx).chars_at(range.end).next() == Some('\n') {
 746                    suffix_start = Some(range.end + 1);
 747                } else if self.buffer.read(cx).reversed_chars_at(range.end).next() == Some('\n') {
 748                    suffix_start = Some(range.end);
 749                }
 750            }
 751
 752            let suffix = if let Some(suffix_start) = suffix_start {
 753                MessageAnchor {
 754                    id: MessageId(post_inc(&mut self.next_message_id.0)),
 755                    start: self.buffer.read(cx).anchor_before(suffix_start),
 756                }
 757            } else {
 758                self.buffer.update(cx, |buffer, cx| {
 759                    buffer.edit([(range.end..range.end, "\n")], None, cx);
 760                });
 761                edited_buffer = true;
 762                MessageAnchor {
 763                    id: MessageId(post_inc(&mut self.next_message_id.0)),
 764                    start: self.buffer.read(cx).anchor_before(range.end + 1),
 765                }
 766            };
 767
 768            self.message_anchors
 769                .insert(message.index + 1, suffix.clone());
 770            self.messages_metadata.insert(
 771                suffix.id,
 772                MessageMetadata {
 773                    role,
 774                    sent_at: Local::now(),
 775                    error: None,
 776                },
 777            );
 778
 779            let new_messages = if range.start == range.end || range.start == message.range.start {
 780                (None, Some(suffix))
 781            } else {
 782                let mut prefix_end = None;
 783                if range.start > message.range.start && range.end < message.range.end - 1 {
 784                    if self.buffer.read(cx).chars_at(range.start).next() == Some('\n') {
 785                        prefix_end = Some(range.start + 1);
 786                    } else if self.buffer.read(cx).reversed_chars_at(range.start).next()
 787                        == Some('\n')
 788                    {
 789                        prefix_end = Some(range.start);
 790                    }
 791                }
 792
 793                let selection = if let Some(prefix_end) = prefix_end {
 794                    cx.emit(AssistantEvent::MessagesEdited);
 795                    MessageAnchor {
 796                        id: MessageId(post_inc(&mut self.next_message_id.0)),
 797                        start: self.buffer.read(cx).anchor_before(prefix_end),
 798                    }
 799                } else {
 800                    self.buffer.update(cx, |buffer, cx| {
 801                        buffer.edit([(range.start..range.start, "\n")], None, cx)
 802                    });
 803                    edited_buffer = true;
 804                    MessageAnchor {
 805                        id: MessageId(post_inc(&mut self.next_message_id.0)),
 806                        start: self.buffer.read(cx).anchor_before(range.end + 1),
 807                    }
 808                };
 809
 810                self.message_anchors
 811                    .insert(message.index + 1, selection.clone());
 812                self.messages_metadata.insert(
 813                    selection.id,
 814                    MessageMetadata {
 815                        role,
 816                        sent_at: Local::now(),
 817                        error: None,
 818                    },
 819                );
 820                (Some(selection), Some(suffix))
 821            };
 822
 823            if !edited_buffer {
 824                cx.emit(AssistantEvent::MessagesEdited);
 825            }
 826            new_messages
 827        } else {
 828            (None, None)
 829        }
 830    }
 831
 832    fn summarize(&mut self, cx: &mut ModelContext<Self>) {
 833        if self.message_anchors.len() >= 2 && self.summary.is_none() {
 834            let api_key = self.api_key.borrow().clone();
 835            if let Some(api_key) = api_key {
 836                let mut messages = self.open_ai_request_messages(cx);
 837                messages.truncate(2);
 838                messages.push(RequestMessage {
 839                    role: Role::User,
 840                    content: "Summarize the conversation into a short title without punctuation"
 841                        .into(),
 842                });
 843                let request = OpenAIRequest {
 844                    model: self.model.clone(),
 845                    messages,
 846                    stream: true,
 847                };
 848
 849                let stream = stream_completion(api_key, cx.background().clone(), request);
 850                self.pending_summary = cx.spawn(|this, mut cx| {
 851                    async move {
 852                        let mut messages = stream.await?;
 853
 854                        while let Some(message) = messages.next().await {
 855                            let mut message = message?;
 856                            if let Some(choice) = message.choices.pop() {
 857                                let text = choice.delta.content.unwrap_or_default();
 858                                this.update(&mut cx, |this, cx| {
 859                                    this.summary.get_or_insert(String::new()).push_str(&text);
 860                                    cx.emit(AssistantEvent::SummaryChanged);
 861                                });
 862                            }
 863                        }
 864
 865                        anyhow::Ok(())
 866                    }
 867                    .log_err()
 868                });
 869            }
 870        }
 871    }
 872
 873    fn open_ai_request_messages(&self, cx: &AppContext) -> Vec<RequestMessage> {
 874        let buffer = self.buffer.read(cx);
 875        self.messages(cx)
 876            .map(|message| RequestMessage {
 877                role: message.role,
 878                content: buffer.text_for_range(message.range).collect(),
 879            })
 880            .collect()
 881    }
 882
 883    fn message_for_offset<'a>(&'a self, offset: usize, cx: &'a AppContext) -> Option<Message> {
 884        let mut messages = self.messages(cx).peekable();
 885        while let Some(message) = messages.next() {
 886            if message.range.contains(&offset) || messages.peek().is_none() {
 887                return Some(message);
 888            }
 889        }
 890        None
 891    }
 892
 893    fn messages<'a>(&'a self, cx: &'a AppContext) -> impl 'a + Iterator<Item = Message> {
 894        let buffer = self.buffer.read(cx);
 895        let mut message_anchors = self.message_anchors.iter().enumerate().peekable();
 896        iter::from_fn(move || {
 897            while let Some((ix, message_anchor)) = message_anchors.next() {
 898                let metadata = self.messages_metadata.get(&message_anchor.id)?;
 899                let message_start = message_anchor.start.to_offset(buffer);
 900                let mut message_end = None;
 901                while let Some((_, next_message)) = message_anchors.peek() {
 902                    if next_message.start.is_valid(buffer) {
 903                        message_end = Some(next_message.start);
 904                        break;
 905                    } else {
 906                        message_anchors.next();
 907                    }
 908                }
 909                let message_end = message_end
 910                    .unwrap_or(language::Anchor::MAX)
 911                    .to_offset(buffer);
 912                return Some(Message {
 913                    index: ix,
 914                    range: message_start..message_end,
 915                    id: message_anchor.id,
 916                    anchor: message_anchor.start,
 917                    role: metadata.role,
 918                    sent_at: metadata.sent_at,
 919                    error: metadata.error.clone(),
 920                });
 921            }
 922            None
 923        })
 924    }
 925}
 926
 927struct PendingCompletion {
 928    id: usize,
 929    _task: Task<()>,
 930}
 931
 932enum AssistantEditorEvent {
 933    TabContentChanged,
 934}
 935
 936#[derive(Copy, Clone, Debug, PartialEq)]
 937struct ScrollPosition {
 938    offset_before_cursor: Vector2F,
 939    cursor: Anchor,
 940}
 941
 942struct AssistantEditor {
 943    assistant: ModelHandle<Assistant>,
 944    editor: ViewHandle<Editor>,
 945    blocks: HashSet<BlockId>,
 946    scroll_position: Option<ScrollPosition>,
 947    _subscriptions: Vec<Subscription>,
 948}
 949
 950impl AssistantEditor {
 951    fn new(
 952        api_key: Rc<RefCell<Option<String>>>,
 953        language_registry: Arc<LanguageRegistry>,
 954        cx: &mut ViewContext<Self>,
 955    ) -> Self {
 956        let assistant = cx.add_model(|cx| Assistant::new(api_key, language_registry, cx));
 957        let editor = cx.add_view(|cx| {
 958            let mut editor = Editor::for_buffer(assistant.read(cx).buffer.clone(), None, cx);
 959            editor.set_soft_wrap_mode(SoftWrap::EditorWidth, cx);
 960            editor.set_show_gutter(false, cx);
 961            editor
 962        });
 963
 964        let _subscriptions = vec![
 965            cx.observe(&assistant, |_, _, cx| cx.notify()),
 966            cx.subscribe(&assistant, Self::handle_assistant_event),
 967            cx.subscribe(&editor, Self::handle_editor_event),
 968        ];
 969
 970        let mut this = Self {
 971            assistant,
 972            editor,
 973            blocks: Default::default(),
 974            scroll_position: None,
 975            _subscriptions,
 976        };
 977        this.update_message_headers(cx);
 978        this
 979    }
 980
 981    fn assist(&mut self, _: &Assist, cx: &mut ViewContext<Self>) {
 982        let user_message = self.assistant.update(cx, |assistant, cx| {
 983            let (_, user_message) = assistant.assist(cx)?;
 984            Some(user_message)
 985        });
 986
 987        if let Some(user_message) = user_message {
 988            let cursor = user_message
 989                .start
 990                .to_offset(&self.assistant.read(cx).buffer.read(cx));
 991            self.editor.update(cx, |editor, cx| {
 992                editor.change_selections(
 993                    Some(Autoscroll::Strategy(AutoscrollStrategy::Fit)),
 994                    cx,
 995                    |selections| selections.select_ranges([cursor..cursor]),
 996                );
 997            });
 998        }
 999    }
1000
1001    fn cancel_last_assist(&mut self, _: &editor::Cancel, cx: &mut ViewContext<Self>) {
1002        if !self
1003            .assistant
1004            .update(cx, |assistant, _| assistant.cancel_last_assist())
1005        {
1006            cx.propagate_action();
1007        }
1008    }
1009
1010    fn cycle_message_role(&mut self, _: &CycleMessageRole, cx: &mut ViewContext<Self>) {
1011        let cursor_offset = self.editor.read(cx).selections.newest(cx).head();
1012        self.assistant.update(cx, |assistant, cx| {
1013            if let Some(message) = assistant.message_for_offset(cursor_offset, cx) {
1014                assistant.cycle_message_role(message.id, cx);
1015            }
1016        });
1017    }
1018
1019    fn handle_assistant_event(
1020        &mut self,
1021        _: ModelHandle<Assistant>,
1022        event: &AssistantEvent,
1023        cx: &mut ViewContext<Self>,
1024    ) {
1025        match event {
1026            AssistantEvent::MessagesEdited => self.update_message_headers(cx),
1027            AssistantEvent::SummaryChanged => {
1028                cx.emit(AssistantEditorEvent::TabContentChanged);
1029            }
1030            AssistantEvent::StreamedCompletion => {
1031                self.editor.update(cx, |editor, cx| {
1032                    if let Some(scroll_position) = self.scroll_position {
1033                        let snapshot = editor.snapshot(cx);
1034                        let cursor_point = scroll_position.cursor.to_display_point(&snapshot);
1035                        let scroll_top =
1036                            cursor_point.row() as f32 - scroll_position.offset_before_cursor.y();
1037                        editor.set_scroll_position(
1038                            vec2f(scroll_position.offset_before_cursor.x(), scroll_top),
1039                            cx,
1040                        );
1041                    }
1042                });
1043            }
1044        }
1045    }
1046
1047    fn handle_editor_event(
1048        &mut self,
1049        _: ViewHandle<Editor>,
1050        event: &editor::Event,
1051        cx: &mut ViewContext<Self>,
1052    ) {
1053        match event {
1054            editor::Event::ScrollPositionChanged { autoscroll, .. } => {
1055                let cursor_scroll_position = self.cursor_scroll_position(cx);
1056                if *autoscroll {
1057                    self.scroll_position = cursor_scroll_position;
1058                } else if self.scroll_position != cursor_scroll_position {
1059                    self.scroll_position = None;
1060                }
1061            }
1062            editor::Event::SelectionsChanged { .. } => {
1063                self.scroll_position = self.cursor_scroll_position(cx);
1064            }
1065            _ => {}
1066        }
1067    }
1068
1069    fn cursor_scroll_position(&self, cx: &mut ViewContext<Self>) -> Option<ScrollPosition> {
1070        self.editor.update(cx, |editor, cx| {
1071            let snapshot = editor.snapshot(cx);
1072            let cursor = editor.selections.newest_anchor().head();
1073            let cursor_row = cursor.to_display_point(&snapshot.display_snapshot).row() as f32;
1074            let scroll_position = editor
1075                .scroll_manager
1076                .anchor()
1077                .scroll_position(&snapshot.display_snapshot);
1078
1079            let scroll_bottom = scroll_position.y() + editor.visible_line_count().unwrap_or(0.);
1080            if (scroll_position.y()..scroll_bottom).contains(&cursor_row) {
1081                Some(ScrollPosition {
1082                    cursor,
1083                    offset_before_cursor: vec2f(
1084                        scroll_position.x(),
1085                        cursor_row - scroll_position.y(),
1086                    ),
1087                })
1088            } else {
1089                None
1090            }
1091        })
1092    }
1093
1094    fn update_message_headers(&mut self, cx: &mut ViewContext<Self>) {
1095        self.editor.update(cx, |editor, cx| {
1096            let buffer = editor.buffer().read(cx).snapshot(cx);
1097            let excerpt_id = *buffer.as_singleton().unwrap().0;
1098            let old_blocks = std::mem::take(&mut self.blocks);
1099            let new_blocks = self
1100                .assistant
1101                .read(cx)
1102                .messages(cx)
1103                .map(|message| BlockProperties {
1104                    position: buffer.anchor_in_excerpt(excerpt_id, message.anchor),
1105                    height: 2,
1106                    style: BlockStyle::Sticky,
1107                    render: Arc::new({
1108                        let assistant = self.assistant.clone();
1109                        // let metadata = message.metadata.clone();
1110                        // let message = message.clone();
1111                        move |cx| {
1112                            enum Sender {}
1113                            enum ErrorTooltip {}
1114
1115                            let theme = theme::current(cx);
1116                            let style = &theme.assistant;
1117                            let message_id = message.id;
1118                            let sender = MouseEventHandler::<Sender, _>::new(
1119                                message_id.0,
1120                                cx,
1121                                |state, _| match message.role {
1122                                    Role::User => {
1123                                        let style = style.user_sender.style_for(state, false);
1124                                        Label::new("You", style.text.clone())
1125                                            .contained()
1126                                            .with_style(style.container)
1127                                    }
1128                                    Role::Assistant => {
1129                                        let style = style.assistant_sender.style_for(state, false);
1130                                        Label::new("Assistant", style.text.clone())
1131                                            .contained()
1132                                            .with_style(style.container)
1133                                    }
1134                                    Role::System => {
1135                                        let style = style.system_sender.style_for(state, false);
1136                                        Label::new("System", style.text.clone())
1137                                            .contained()
1138                                            .with_style(style.container)
1139                                    }
1140                                },
1141                            )
1142                            .with_cursor_style(CursorStyle::PointingHand)
1143                            .on_down(MouseButton::Left, {
1144                                let assistant = assistant.clone();
1145                                move |_, _, cx| {
1146                                    assistant.update(cx, |assistant, cx| {
1147                                        assistant.cycle_message_role(message_id, cx)
1148                                    })
1149                                }
1150                            });
1151
1152                            Flex::row()
1153                                .with_child(sender.aligned())
1154                                .with_child(
1155                                    Label::new(
1156                                        message.sent_at.format("%I:%M%P").to_string(),
1157                                        style.sent_at.text.clone(),
1158                                    )
1159                                    .contained()
1160                                    .with_style(style.sent_at.container)
1161                                    .aligned(),
1162                                )
1163                                .with_children(message.error.as_ref().map(|error| {
1164                                    Svg::new("icons/circle_x_mark_12.svg")
1165                                        .with_color(style.error_icon.color)
1166                                        .constrained()
1167                                        .with_width(style.error_icon.width)
1168                                        .contained()
1169                                        .with_style(style.error_icon.container)
1170                                        .with_tooltip::<ErrorTooltip>(
1171                                            message_id.0,
1172                                            error.to_string(),
1173                                            None,
1174                                            theme.tooltip.clone(),
1175                                            cx,
1176                                        )
1177                                        .aligned()
1178                                }))
1179                                .aligned()
1180                                .left()
1181                                .contained()
1182                                .with_style(style.header)
1183                                .into_any()
1184                        }
1185                    }),
1186                    disposition: BlockDisposition::Above,
1187                })
1188                .collect::<Vec<_>>();
1189
1190            editor.remove_blocks(old_blocks, None, cx);
1191            let ids = editor.insert_blocks(new_blocks, None, cx);
1192            self.blocks = HashSet::from_iter(ids);
1193        });
1194    }
1195
1196    fn quote_selection(
1197        workspace: &mut Workspace,
1198        _: &QuoteSelection,
1199        cx: &mut ViewContext<Workspace>,
1200    ) {
1201        let Some(panel) = workspace.panel::<AssistantPanel>(cx) else {
1202            return;
1203        };
1204        let Some(editor) = workspace.active_item(cx).and_then(|item| item.downcast::<Editor>()) else {
1205            return;
1206        };
1207
1208        let text = editor.read_with(cx, |editor, cx| {
1209            let range = editor.selections.newest::<usize>(cx).range();
1210            let buffer = editor.buffer().read(cx).snapshot(cx);
1211            let start_language = buffer.language_at(range.start);
1212            let end_language = buffer.language_at(range.end);
1213            let language_name = if start_language == end_language {
1214                start_language.map(|language| language.name())
1215            } else {
1216                None
1217            };
1218            let language_name = language_name.as_deref().unwrap_or("").to_lowercase();
1219
1220            let selected_text = buffer.text_for_range(range).collect::<String>();
1221            if selected_text.is_empty() {
1222                None
1223            } else {
1224                Some(if language_name == "markdown" {
1225                    selected_text
1226                        .lines()
1227                        .map(|line| format!("> {}", line))
1228                        .collect::<Vec<_>>()
1229                        .join("\n")
1230                } else {
1231                    format!("```{language_name}\n{selected_text}\n```")
1232                })
1233            }
1234        });
1235
1236        // Activate the panel
1237        if !panel.read(cx).has_focus(cx) {
1238            workspace.toggle_panel_focus::<AssistantPanel>(cx);
1239        }
1240
1241        if let Some(text) = text {
1242            panel.update(cx, |panel, cx| {
1243                if let Some(assistant) = panel
1244                    .pane
1245                    .read(cx)
1246                    .active_item()
1247                    .and_then(|item| item.downcast::<AssistantEditor>())
1248                    .ok_or_else(|| anyhow!("no active context"))
1249                    .log_err()
1250                {
1251                    assistant.update(cx, |assistant, cx| {
1252                        assistant
1253                            .editor
1254                            .update(cx, |editor, cx| editor.insert(&text, cx))
1255                    });
1256                }
1257            });
1258        }
1259    }
1260
1261    fn copy(&mut self, _: &editor::Copy, cx: &mut ViewContext<Self>) {
1262        let editor = self.editor.read(cx);
1263        let assistant = self.assistant.read(cx);
1264        if editor.selections.count() == 1 {
1265            let selection = editor.selections.newest::<usize>(cx);
1266            let mut copied_text = String::new();
1267            let mut spanned_messages = 0;
1268            for message in assistant.messages(cx) {
1269                if message.range.start >= selection.range().end {
1270                    break;
1271                } else if message.range.end >= selection.range().start {
1272                    let range = cmp::max(message.range.start, selection.range().start)
1273                        ..cmp::min(message.range.end, selection.range().end);
1274                    if !range.is_empty() {
1275                        spanned_messages += 1;
1276                        write!(&mut copied_text, "## {}\n\n", message.role).unwrap();
1277                        for chunk in assistant.buffer.read(cx).text_for_range(range) {
1278                            copied_text.push_str(&chunk);
1279                        }
1280                        copied_text.push('\n');
1281                    }
1282                }
1283            }
1284
1285            if spanned_messages > 1 {
1286                cx.platform()
1287                    .write_to_clipboard(ClipboardItem::new(copied_text));
1288                return;
1289            }
1290        }
1291
1292        cx.propagate_action();
1293    }
1294
1295    fn split(&mut self, _: &Split, cx: &mut ViewContext<Self>) {
1296        self.assistant.update(cx, |assistant, cx| {
1297            let range = self.editor.read(cx).selections.newest::<usize>(cx).range();
1298            assistant.split_message(range, cx);
1299        });
1300    }
1301
1302    fn cycle_model(&mut self, cx: &mut ViewContext<Self>) {
1303        self.assistant.update(cx, |assistant, cx| {
1304            let new_model = match assistant.model.as_str() {
1305                "gpt-4" => "gpt-3.5-turbo",
1306                _ => "gpt-4",
1307            };
1308            assistant.set_model(new_model.into(), cx);
1309        });
1310    }
1311
1312    fn title(&self, cx: &AppContext) -> String {
1313        self.assistant
1314            .read(cx)
1315            .summary
1316            .clone()
1317            .unwrap_or_else(|| "New Context".into())
1318    }
1319}
1320
1321impl Entity for AssistantEditor {
1322    type Event = AssistantEditorEvent;
1323}
1324
1325impl View for AssistantEditor {
1326    fn ui_name() -> &'static str {
1327        "AssistantEditor"
1328    }
1329
1330    fn render(&mut self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
1331        enum Model {}
1332        let theme = &theme::current(cx).assistant;
1333        let assistant = &self.assistant.read(cx);
1334        let model = assistant.model.clone();
1335        let remaining_tokens = assistant.remaining_tokens().map(|remaining_tokens| {
1336            let remaining_tokens_style = if remaining_tokens <= 0 {
1337                &theme.no_remaining_tokens
1338            } else {
1339                &theme.remaining_tokens
1340            };
1341            Label::new(
1342                remaining_tokens.to_string(),
1343                remaining_tokens_style.text.clone(),
1344            )
1345            .contained()
1346            .with_style(remaining_tokens_style.container)
1347        });
1348
1349        Stack::new()
1350            .with_child(
1351                ChildView::new(&self.editor, cx)
1352                    .contained()
1353                    .with_style(theme.container),
1354            )
1355            .with_child(
1356                Flex::row()
1357                    .with_child(
1358                        MouseEventHandler::<Model, _>::new(0, cx, |state, _| {
1359                            let style = theme.model.style_for(state, false);
1360                            Label::new(model, style.text.clone())
1361                                .contained()
1362                                .with_style(style.container)
1363                        })
1364                        .with_cursor_style(CursorStyle::PointingHand)
1365                        .on_click(MouseButton::Left, |_, this, cx| this.cycle_model(cx)),
1366                    )
1367                    .with_children(remaining_tokens)
1368                    .contained()
1369                    .with_style(theme.model_info_container)
1370                    .aligned()
1371                    .top()
1372                    .right(),
1373            )
1374            .into_any()
1375    }
1376
1377    fn focus_in(&mut self, _: gpui::AnyViewHandle, cx: &mut ViewContext<Self>) {
1378        if cx.is_self_focused() {
1379            cx.focus(&self.editor);
1380        }
1381    }
1382}
1383
1384impl Item for AssistantEditor {
1385    fn tab_content<V: View>(
1386        &self,
1387        _: Option<usize>,
1388        style: &theme::Tab,
1389        cx: &gpui::AppContext,
1390    ) -> AnyElement<V> {
1391        let title = truncate_and_trailoff(&self.title(cx), editor::MAX_TAB_TITLE_LEN);
1392        Label::new(title, style.label.clone()).into_any()
1393    }
1394
1395    fn tab_tooltip_text(&self, cx: &AppContext) -> Option<Cow<str>> {
1396        Some(self.title(cx).into())
1397    }
1398
1399    fn as_searchable(
1400        &self,
1401        _: &ViewHandle<Self>,
1402    ) -> Option<Box<dyn workspace::searchable::SearchableItemHandle>> {
1403        Some(Box::new(self.editor.clone()))
1404    }
1405}
1406
1407#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, Hash)]
1408struct MessageId(usize);
1409
1410#[derive(Clone, Debug)]
1411struct MessageAnchor {
1412    id: MessageId,
1413    start: language::Anchor,
1414}
1415
1416#[derive(Clone, Debug)]
1417struct MessageMetadata {
1418    role: Role,
1419    sent_at: DateTime<Local>,
1420    error: Option<Arc<str>>,
1421}
1422
1423#[derive(Clone, Debug)]
1424pub struct Message {
1425    range: Range<usize>,
1426    index: usize,
1427    id: MessageId,
1428    anchor: language::Anchor,
1429    role: Role,
1430    sent_at: DateTime<Local>,
1431    error: Option<Arc<str>>,
1432}
1433
1434async fn stream_completion(
1435    api_key: String,
1436    executor: Arc<Background>,
1437    mut request: OpenAIRequest,
1438) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
1439    request.stream = true;
1440
1441    let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
1442
1443    let json_data = serde_json::to_string(&request)?;
1444    let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
1445        .header("Content-Type", "application/json")
1446        .header("Authorization", format!("Bearer {}", api_key))
1447        .body(json_data)?
1448        .send_async()
1449        .await?;
1450
1451    let status = response.status();
1452    if status == StatusCode::OK {
1453        executor
1454            .spawn(async move {
1455                let mut lines = BufReader::new(response.body_mut()).lines();
1456
1457                fn parse_line(
1458                    line: Result<String, io::Error>,
1459                ) -> Result<Option<OpenAIResponseStreamEvent>> {
1460                    if let Some(data) = line?.strip_prefix("data: ") {
1461                        let event = serde_json::from_str(&data)?;
1462                        Ok(Some(event))
1463                    } else {
1464                        Ok(None)
1465                    }
1466                }
1467
1468                while let Some(line) = lines.next().await {
1469                    if let Some(event) = parse_line(line).transpose() {
1470                        let done = event.as_ref().map_or(false, |event| {
1471                            event
1472                                .choices
1473                                .last()
1474                                .map_or(false, |choice| choice.finish_reason.is_some())
1475                        });
1476                        if tx.unbounded_send(event).is_err() {
1477                            break;
1478                        }
1479
1480                        if done {
1481                            break;
1482                        }
1483                    }
1484                }
1485
1486                anyhow::Ok(())
1487            })
1488            .detach();
1489
1490        Ok(rx)
1491    } else {
1492        let mut body = String::new();
1493        response.body_mut().read_to_string(&mut body).await?;
1494
1495        #[derive(Deserialize)]
1496        struct OpenAIResponse {
1497            error: OpenAIError,
1498        }
1499
1500        #[derive(Deserialize)]
1501        struct OpenAIError {
1502            message: String,
1503        }
1504
1505        match serde_json::from_str::<OpenAIResponse>(&body) {
1506            Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
1507                "Failed to connect to OpenAI API: {}",
1508                response.error.message,
1509            )),
1510
1511            _ => Err(anyhow!(
1512                "Failed to connect to OpenAI API: {} {}",
1513                response.status(),
1514                body,
1515            )),
1516        }
1517    }
1518}
1519
1520#[cfg(test)]
1521mod tests {
1522    use super::*;
1523    use gpui::AppContext;
1524
1525    #[gpui::test]
1526    fn test_inserting_and_removing_messages(cx: &mut AppContext) {
1527        let registry = Arc::new(LanguageRegistry::test());
1528        let assistant = cx.add_model(|cx| Assistant::new(Default::default(), registry, cx));
1529        let buffer = assistant.read(cx).buffer.clone();
1530
1531        let message_1 = assistant.read(cx).message_anchors[0].clone();
1532        assert_eq!(
1533            messages(&assistant, cx),
1534            vec![(message_1.id, Role::User, 0..0)]
1535        );
1536
1537        let message_2 = assistant.update(cx, |assistant, cx| {
1538            assistant
1539                .insert_message_after(message_1.id, Role::Assistant, cx)
1540                .unwrap()
1541        });
1542        assert_eq!(
1543            messages(&assistant, cx),
1544            vec![
1545                (message_1.id, Role::User, 0..1),
1546                (message_2.id, Role::Assistant, 1..1)
1547            ]
1548        );
1549
1550        buffer.update(cx, |buffer, cx| {
1551            buffer.edit([(0..0, "1"), (1..1, "2")], None, cx)
1552        });
1553        assert_eq!(
1554            messages(&assistant, cx),
1555            vec![
1556                (message_1.id, Role::User, 0..2),
1557                (message_2.id, Role::Assistant, 2..3)
1558            ]
1559        );
1560
1561        let message_3 = assistant.update(cx, |assistant, cx| {
1562            assistant
1563                .insert_message_after(message_2.id, Role::User, cx)
1564                .unwrap()
1565        });
1566        assert_eq!(
1567            messages(&assistant, cx),
1568            vec![
1569                (message_1.id, Role::User, 0..2),
1570                (message_2.id, Role::Assistant, 2..4),
1571                (message_3.id, Role::User, 4..4)
1572            ]
1573        );
1574
1575        let message_4 = assistant.update(cx, |assistant, cx| {
1576            assistant
1577                .insert_message_after(message_2.id, Role::User, cx)
1578                .unwrap()
1579        });
1580        assert_eq!(
1581            messages(&assistant, cx),
1582            vec![
1583                (message_1.id, Role::User, 0..2),
1584                (message_2.id, Role::Assistant, 2..4),
1585                (message_4.id, Role::User, 4..5),
1586                (message_3.id, Role::User, 5..5),
1587            ]
1588        );
1589
1590        buffer.update(cx, |buffer, cx| {
1591            buffer.edit([(4..4, "C"), (5..5, "D")], None, cx)
1592        });
1593        assert_eq!(
1594            messages(&assistant, cx),
1595            vec![
1596                (message_1.id, Role::User, 0..2),
1597                (message_2.id, Role::Assistant, 2..4),
1598                (message_4.id, Role::User, 4..6),
1599                (message_3.id, Role::User, 6..7),
1600            ]
1601        );
1602
1603        // Deleting across message boundaries merges the messages.
1604        buffer.update(cx, |buffer, cx| buffer.edit([(1..4, "")], None, cx));
1605        assert_eq!(
1606            messages(&assistant, cx),
1607            vec![
1608                (message_1.id, Role::User, 0..3),
1609                (message_3.id, Role::User, 3..4),
1610            ]
1611        );
1612
1613        // Undoing the deletion should also undo the merge.
1614        buffer.update(cx, |buffer, cx| buffer.undo(cx));
1615        assert_eq!(
1616            messages(&assistant, cx),
1617            vec![
1618                (message_1.id, Role::User, 0..2),
1619                (message_2.id, Role::Assistant, 2..4),
1620                (message_4.id, Role::User, 4..6),
1621                (message_3.id, Role::User, 6..7),
1622            ]
1623        );
1624
1625        // Redoing the deletion should also redo the merge.
1626        buffer.update(cx, |buffer, cx| buffer.redo(cx));
1627        assert_eq!(
1628            messages(&assistant, cx),
1629            vec![
1630                (message_1.id, Role::User, 0..3),
1631                (message_3.id, Role::User, 3..4),
1632            ]
1633        );
1634
1635        // Ensure we can still insert after a merged message.
1636        let message_5 = assistant.update(cx, |assistant, cx| {
1637            assistant
1638                .insert_message_after(message_1.id, Role::System, cx)
1639                .unwrap()
1640        });
1641        assert_eq!(
1642            messages(&assistant, cx),
1643            vec![
1644                (message_1.id, Role::User, 0..3),
1645                (message_5.id, Role::System, 3..4),
1646                (message_3.id, Role::User, 4..5)
1647            ]
1648        );
1649    }
1650
1651    #[gpui::test]
1652    fn test_message_splitting(cx: &mut AppContext) {
1653        let registry = Arc::new(LanguageRegistry::test());
1654        let assistant = cx.add_model(|cx| Assistant::new(Default::default(), registry, cx));
1655        let buffer = assistant.read(cx).buffer.clone();
1656
1657        let message_1 = assistant.read(cx).message_anchors[0].clone();
1658        assert_eq!(
1659            messages(&assistant, cx),
1660            vec![(message_1.id, Role::User, 0..0)]
1661        );
1662
1663        buffer.update(cx, |buffer, cx| {
1664            buffer.edit([(0..0, "aaa\nbbb\nccc\nddd\n")], None, cx)
1665        });
1666
1667        let (_, message_2) =
1668            assistant.update(cx, |assistant, cx| assistant.split_message(3..3, cx));
1669        let message_2 = message_2.unwrap();
1670
1671        // We recycle newlines in the middle of a split message
1672        assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\nddd\n");
1673        assert_eq!(
1674            messages(&assistant, cx),
1675            vec![
1676                (message_1.id, Role::User, 0..4),
1677                (message_2.id, Role::User, 4..16),
1678            ]
1679        );
1680
1681        let (_, message_3) =
1682            assistant.update(cx, |assistant, cx| assistant.split_message(3..3, cx));
1683        let message_3 = message_3.unwrap();
1684
1685        // We don't recycle newlines at the end of a split message
1686        assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
1687        assert_eq!(
1688            messages(&assistant, cx),
1689            vec![
1690                (message_1.id, Role::User, 0..4),
1691                (message_3.id, Role::User, 4..5),
1692                (message_2.id, Role::User, 5..17),
1693            ]
1694        );
1695
1696        let (_, message_4) =
1697            assistant.update(cx, |assistant, cx| assistant.split_message(9..9, cx));
1698        let message_4 = message_4.unwrap();
1699        assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
1700        assert_eq!(
1701            messages(&assistant, cx),
1702            vec![
1703                (message_1.id, Role::User, 0..4),
1704                (message_3.id, Role::User, 4..5),
1705                (message_2.id, Role::User, 5..9),
1706                (message_4.id, Role::User, 9..17),
1707            ]
1708        );
1709
1710        let (_, message_5) =
1711            assistant.update(cx, |assistant, cx| assistant.split_message(9..9, cx));
1712        let message_5 = message_5.unwrap();
1713        assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\nddd\n");
1714        assert_eq!(
1715            messages(&assistant, cx),
1716            vec![
1717                (message_1.id, Role::User, 0..4),
1718                (message_3.id, Role::User, 4..5),
1719                (message_2.id, Role::User, 5..9),
1720                (message_4.id, Role::User, 9..10),
1721                (message_5.id, Role::User, 10..18),
1722            ]
1723        );
1724
1725        let (message_6, message_7) =
1726            assistant.update(cx, |assistant, cx| assistant.split_message(14..16, cx));
1727        let message_6 = message_6.unwrap();
1728        let message_7 = message_7.unwrap();
1729        assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\ndd\nd\n");
1730        assert_eq!(
1731            messages(&assistant, cx),
1732            vec![
1733                (message_1.id, Role::User, 0..4),
1734                (message_3.id, Role::User, 4..5),
1735                (message_2.id, Role::User, 5..9),
1736                (message_4.id, Role::User, 9..10),
1737                (message_5.id, Role::User, 10..14),
1738                (message_6.id, Role::User, 14..17),
1739                (message_7.id, Role::User, 17..19),
1740            ]
1741        );
1742    }
1743
1744    fn messages(
1745        assistant: &ModelHandle<Assistant>,
1746        cx: &AppContext,
1747    ) -> Vec<(MessageId, Role, Range<usize>)> {
1748        assistant
1749            .read(cx)
1750            .messages(cx)
1751            .map(|message| (message.id, message.role, message.range))
1752            .collect()
1753    }
1754}