assistant.rs

   1use crate::{
   2    assistant_settings::{AssistantDockPosition, AssistantSettings},
   3    OpenAIRequest, OpenAIResponseStreamEvent, RequestMessage, Role,
   4};
   5use anyhow::{anyhow, Result};
   6use chrono::{DateTime, Local};
   7use collections::{HashMap, HashSet};
   8use editor::{
   9    display_map::{BlockDisposition, BlockId, BlockProperties, BlockStyle, ToDisplayPoint},
  10    scroll::autoscroll::{Autoscroll, AutoscrollStrategy},
  11    Anchor, Editor,
  12};
  13use fs::Fs;
  14use futures::{io::BufReader, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
  15use gpui::{
  16    actions,
  17    elements::*,
  18    executor::Background,
  19    geometry::vector::{vec2f, Vector2F},
  20    platform::{CursorStyle, MouseButton},
  21    Action, AppContext, AsyncAppContext, ClipboardItem, Entity, ModelContext, ModelHandle,
  22    Subscription, Task, View, ViewContext, ViewHandle, WeakViewHandle, WindowContext,
  23};
  24use isahc::{http::StatusCode, Request, RequestExt};
  25use language::{language_settings::SoftWrap, Buffer, LanguageRegistry, ToOffset as _};
  26use serde::Deserialize;
  27use settings::SettingsStore;
  28use std::{
  29    borrow::Cow, cell::RefCell, cmp, fmt::Write, io, iter, ops::Range, rc::Rc, sync::Arc,
  30    time::Duration,
  31};
  32use util::{channel::ReleaseChannel, post_inc, truncate_and_trailoff, ResultExt, TryFutureExt};
  33use workspace::{
  34    dock::{DockPosition, Panel},
  35    item::Item,
  36    pane, Pane, Workspace,
  37};
  38
  39const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
  40
  41actions!(
  42    assistant,
  43    [
  44        NewContext,
  45        Assist,
  46        Split,
  47        QuoteSelection,
  48        ToggleFocus,
  49        ResetKey
  50    ]
  51);
  52
  53pub fn init(cx: &mut AppContext) {
  54    if *util::channel::RELEASE_CHANNEL == ReleaseChannel::Stable {
  55        cx.update_default_global::<collections::CommandPaletteFilter, _, _>(move |filter, _cx| {
  56            filter.filtered_namespaces.insert("assistant");
  57        });
  58    }
  59
  60    settings::register::<AssistantSettings>(cx);
  61    cx.add_action(
  62        |workspace: &mut Workspace, _: &NewContext, cx: &mut ViewContext<Workspace>| {
  63            if let Some(this) = workspace.panel::<AssistantPanel>(cx) {
  64                this.update(cx, |this, cx| this.add_context(cx))
  65            }
  66
  67            workspace.focus_panel::<AssistantPanel>(cx);
  68        },
  69    );
  70    cx.add_action(AssistantEditor::assist);
  71    cx.capture_action(AssistantEditor::cancel_last_assist);
  72    cx.add_action(AssistantEditor::quote_selection);
  73    cx.capture_action(AssistantEditor::copy);
  74    cx.capture_action(AssistantEditor::split);
  75    cx.add_action(AssistantPanel::save_api_key);
  76    cx.add_action(AssistantPanel::reset_api_key);
  77    cx.add_action(
  78        |workspace: &mut Workspace, _: &ToggleFocus, cx: &mut ViewContext<Workspace>| {
  79            workspace.toggle_panel_focus::<AssistantPanel>(cx);
  80        },
  81    );
  82}
  83
  84pub enum AssistantPanelEvent {
  85    ZoomIn,
  86    ZoomOut,
  87    Focus,
  88    Close,
  89    DockPositionChanged,
  90}
  91
  92pub struct AssistantPanel {
  93    width: Option<f32>,
  94    height: Option<f32>,
  95    pane: ViewHandle<Pane>,
  96    api_key: Rc<RefCell<Option<String>>>,
  97    api_key_editor: Option<ViewHandle<Editor>>,
  98    has_read_credentials: bool,
  99    languages: Arc<LanguageRegistry>,
 100    fs: Arc<dyn Fs>,
 101    subscriptions: Vec<Subscription>,
 102}
 103
 104impl AssistantPanel {
 105    pub fn load(
 106        workspace: WeakViewHandle<Workspace>,
 107        cx: AsyncAppContext,
 108    ) -> Task<Result<ViewHandle<Self>>> {
 109        cx.spawn(|mut cx| async move {
 110            // TODO: deserialize state.
 111            workspace.update(&mut cx, |workspace, cx| {
 112                cx.add_view::<Self, _>(|cx| {
 113                    let weak_self = cx.weak_handle();
 114                    let pane = cx.add_view(|cx| {
 115                        let mut pane = Pane::new(
 116                            workspace.weak_handle(),
 117                            workspace.project().clone(),
 118                            workspace.app_state().background_actions,
 119                            Default::default(),
 120                            cx,
 121                        );
 122                        pane.set_can_split(false, cx);
 123                        pane.set_can_navigate(false, cx);
 124                        pane.on_can_drop(move |_, _| false);
 125                        pane.set_render_tab_bar_buttons(cx, move |pane, cx| {
 126                            let weak_self = weak_self.clone();
 127                            Flex::row()
 128                                .with_child(Pane::render_tab_bar_button(
 129                                    0,
 130                                    "icons/plus_12.svg",
 131                                    false,
 132                                    Some(("New Context".into(), Some(Box::new(NewContext)))),
 133                                    cx,
 134                                    move |_, cx| {
 135                                        let weak_self = weak_self.clone();
 136                                        cx.window_context().defer(move |cx| {
 137                                            if let Some(this) = weak_self.upgrade(cx) {
 138                                                this.update(cx, |this, cx| this.add_context(cx));
 139                                            }
 140                                        })
 141                                    },
 142                                    None,
 143                                ))
 144                                .with_child(Pane::render_tab_bar_button(
 145                                    1,
 146                                    if pane.is_zoomed() {
 147                                        "icons/minimize_8.svg"
 148                                    } else {
 149                                        "icons/maximize_8.svg"
 150                                    },
 151                                    pane.is_zoomed(),
 152                                    Some((
 153                                        "Toggle Zoom".into(),
 154                                        Some(Box::new(workspace::ToggleZoom)),
 155                                    )),
 156                                    cx,
 157                                    move |pane, cx| pane.toggle_zoom(&Default::default(), cx),
 158                                    None,
 159                                ))
 160                                .into_any()
 161                        });
 162                        let buffer_search_bar = cx.add_view(search::BufferSearchBar::new);
 163                        pane.toolbar()
 164                            .update(cx, |toolbar, cx| toolbar.add_item(buffer_search_bar, cx));
 165                        pane
 166                    });
 167
 168                    let mut this = Self {
 169                        pane,
 170                        api_key: Rc::new(RefCell::new(None)),
 171                        api_key_editor: None,
 172                        has_read_credentials: false,
 173                        languages: workspace.app_state().languages.clone(),
 174                        fs: workspace.app_state().fs.clone(),
 175                        width: None,
 176                        height: None,
 177                        subscriptions: Default::default(),
 178                    };
 179
 180                    let mut old_dock_position = this.position(cx);
 181                    this.subscriptions = vec![
 182                        cx.observe(&this.pane, |_, _, cx| cx.notify()),
 183                        cx.subscribe(&this.pane, Self::handle_pane_event),
 184                        cx.observe_global::<SettingsStore, _>(move |this, cx| {
 185                            let new_dock_position = this.position(cx);
 186                            if new_dock_position != old_dock_position {
 187                                old_dock_position = new_dock_position;
 188                                cx.emit(AssistantPanelEvent::DockPositionChanged);
 189                            }
 190                        }),
 191                    ];
 192
 193                    this
 194                })
 195            })
 196        })
 197    }
 198
 199    fn handle_pane_event(
 200        &mut self,
 201        _pane: ViewHandle<Pane>,
 202        event: &pane::Event,
 203        cx: &mut ViewContext<Self>,
 204    ) {
 205        match event {
 206            pane::Event::ZoomIn => cx.emit(AssistantPanelEvent::ZoomIn),
 207            pane::Event::ZoomOut => cx.emit(AssistantPanelEvent::ZoomOut),
 208            pane::Event::Focus => cx.emit(AssistantPanelEvent::Focus),
 209            pane::Event::Remove => cx.emit(AssistantPanelEvent::Close),
 210            _ => {}
 211        }
 212    }
 213
 214    fn add_context(&mut self, cx: &mut ViewContext<Self>) {
 215        let focus = self.has_focus(cx);
 216        let editor = cx
 217            .add_view(|cx| AssistantEditor::new(self.api_key.clone(), self.languages.clone(), cx));
 218        self.subscriptions
 219            .push(cx.subscribe(&editor, Self::handle_assistant_editor_event));
 220        self.pane.update(cx, |pane, cx| {
 221            pane.add_item(Box::new(editor), true, focus, None, cx)
 222        });
 223    }
 224
 225    fn handle_assistant_editor_event(
 226        &mut self,
 227        _: ViewHandle<AssistantEditor>,
 228        event: &AssistantEditorEvent,
 229        cx: &mut ViewContext<Self>,
 230    ) {
 231        match event {
 232            AssistantEditorEvent::TabContentChanged => self.pane.update(cx, |_, cx| cx.notify()),
 233        }
 234    }
 235
 236    fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
 237        if let Some(api_key) = self
 238            .api_key_editor
 239            .as_ref()
 240            .map(|editor| editor.read(cx).text(cx))
 241        {
 242            if !api_key.is_empty() {
 243                cx.platform()
 244                    .write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
 245                    .log_err();
 246                *self.api_key.borrow_mut() = Some(api_key);
 247                self.api_key_editor.take();
 248                cx.focus_self();
 249                cx.notify();
 250            }
 251        } else {
 252            cx.propagate_action();
 253        }
 254    }
 255
 256    fn reset_api_key(&mut self, _: &ResetKey, cx: &mut ViewContext<Self>) {
 257        cx.platform().delete_credentials(OPENAI_API_URL).log_err();
 258        self.api_key.take();
 259        self.api_key_editor = Some(build_api_key_editor(cx));
 260        cx.focus_self();
 261        cx.notify();
 262    }
 263}
 264
 265fn build_api_key_editor(cx: &mut ViewContext<AssistantPanel>) -> ViewHandle<Editor> {
 266    cx.add_view(|cx| {
 267        let mut editor = Editor::single_line(
 268            Some(Arc::new(|theme| theme.assistant.api_key_editor.clone())),
 269            cx,
 270        );
 271        editor.set_placeholder_text("sk-000000000000000000000000000000000000000000000000", cx);
 272        editor
 273    })
 274}
 275
 276impl Entity for AssistantPanel {
 277    type Event = AssistantPanelEvent;
 278}
 279
 280impl View for AssistantPanel {
 281    fn ui_name() -> &'static str {
 282        "AssistantPanel"
 283    }
 284
 285    fn render(&mut self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
 286        let style = &theme::current(cx).assistant;
 287        if let Some(api_key_editor) = self.api_key_editor.as_ref() {
 288            Flex::column()
 289                .with_child(
 290                    Text::new(
 291                        "Paste your OpenAI API key and press Enter to use the assistant",
 292                        style.api_key_prompt.text.clone(),
 293                    )
 294                    .aligned(),
 295                )
 296                .with_child(
 297                    ChildView::new(api_key_editor, cx)
 298                        .contained()
 299                        .with_style(style.api_key_editor.container)
 300                        .aligned(),
 301                )
 302                .contained()
 303                .with_style(style.api_key_prompt.container)
 304                .aligned()
 305                .into_any()
 306        } else {
 307            ChildView::new(&self.pane, cx).into_any()
 308        }
 309    }
 310
 311    fn focus_in(&mut self, _: gpui::AnyViewHandle, cx: &mut ViewContext<Self>) {
 312        if cx.is_self_focused() {
 313            if let Some(api_key_editor) = self.api_key_editor.as_ref() {
 314                cx.focus(api_key_editor);
 315            } else {
 316                cx.focus(&self.pane);
 317            }
 318        }
 319    }
 320}
 321
 322impl Panel for AssistantPanel {
 323    fn position(&self, cx: &WindowContext) -> DockPosition {
 324        match settings::get::<AssistantSettings>(cx).dock {
 325            AssistantDockPosition::Left => DockPosition::Left,
 326            AssistantDockPosition::Bottom => DockPosition::Bottom,
 327            AssistantDockPosition::Right => DockPosition::Right,
 328        }
 329    }
 330
 331    fn position_is_valid(&self, _: DockPosition) -> bool {
 332        true
 333    }
 334
 335    fn set_position(&mut self, position: DockPosition, cx: &mut ViewContext<Self>) {
 336        settings::update_settings_file::<AssistantSettings>(self.fs.clone(), cx, move |settings| {
 337            let dock = match position {
 338                DockPosition::Left => AssistantDockPosition::Left,
 339                DockPosition::Bottom => AssistantDockPosition::Bottom,
 340                DockPosition::Right => AssistantDockPosition::Right,
 341            };
 342            settings.dock = Some(dock);
 343        });
 344    }
 345
 346    fn size(&self, cx: &WindowContext) -> f32 {
 347        let settings = settings::get::<AssistantSettings>(cx);
 348        match self.position(cx) {
 349            DockPosition::Left | DockPosition::Right => {
 350                self.width.unwrap_or_else(|| settings.default_width)
 351            }
 352            DockPosition::Bottom => self.height.unwrap_or_else(|| settings.default_height),
 353        }
 354    }
 355
 356    fn set_size(&mut self, size: f32, cx: &mut ViewContext<Self>) {
 357        match self.position(cx) {
 358            DockPosition::Left | DockPosition::Right => self.width = Some(size),
 359            DockPosition::Bottom => self.height = Some(size),
 360        }
 361        cx.notify();
 362    }
 363
 364    fn should_zoom_in_on_event(event: &AssistantPanelEvent) -> bool {
 365        matches!(event, AssistantPanelEvent::ZoomIn)
 366    }
 367
 368    fn should_zoom_out_on_event(event: &AssistantPanelEvent) -> bool {
 369        matches!(event, AssistantPanelEvent::ZoomOut)
 370    }
 371
 372    fn is_zoomed(&self, cx: &WindowContext) -> bool {
 373        self.pane.read(cx).is_zoomed()
 374    }
 375
 376    fn set_zoomed(&mut self, zoomed: bool, cx: &mut ViewContext<Self>) {
 377        self.pane.update(cx, |pane, cx| pane.set_zoomed(zoomed, cx));
 378    }
 379
 380    fn set_active(&mut self, active: bool, cx: &mut ViewContext<Self>) {
 381        if active {
 382            if self.api_key.borrow().is_none() && !self.has_read_credentials {
 383                self.has_read_credentials = true;
 384                let api_key = if let Some((_, api_key)) = cx
 385                    .platform()
 386                    .read_credentials(OPENAI_API_URL)
 387                    .log_err()
 388                    .flatten()
 389                {
 390                    String::from_utf8(api_key).log_err()
 391                } else {
 392                    None
 393                };
 394                if let Some(api_key) = api_key {
 395                    *self.api_key.borrow_mut() = Some(api_key);
 396                } else if self.api_key_editor.is_none() {
 397                    self.api_key_editor = Some(build_api_key_editor(cx));
 398                    cx.notify();
 399                }
 400            }
 401
 402            if self.pane.read(cx).items_len() == 0 {
 403                self.add_context(cx);
 404            }
 405        }
 406    }
 407
 408    fn icon_path(&self) -> &'static str {
 409        "icons/robot_14.svg"
 410    }
 411
 412    fn icon_tooltip(&self) -> (String, Option<Box<dyn Action>>) {
 413        ("Assistant Panel".into(), Some(Box::new(ToggleFocus)))
 414    }
 415
 416    fn should_change_position_on_event(event: &Self::Event) -> bool {
 417        matches!(event, AssistantPanelEvent::DockPositionChanged)
 418    }
 419
 420    fn should_activate_on_event(_: &Self::Event) -> bool {
 421        false
 422    }
 423
 424    fn should_close_on_event(event: &AssistantPanelEvent) -> bool {
 425        matches!(event, AssistantPanelEvent::Close)
 426    }
 427
 428    fn has_focus(&self, cx: &WindowContext) -> bool {
 429        self.pane.read(cx).has_focus()
 430            || self
 431                .api_key_editor
 432                .as_ref()
 433                .map_or(false, |editor| editor.is_focused(cx))
 434    }
 435
 436    fn is_focus_event(event: &Self::Event) -> bool {
 437        matches!(event, AssistantPanelEvent::Focus)
 438    }
 439}
 440
 441enum AssistantEvent {
 442    MessagesEdited,
 443    SummaryChanged,
 444    StreamedCompletion,
 445}
 446
 447struct Assistant {
 448    buffer: ModelHandle<Buffer>,
 449    messages: Vec<Message>,
 450    messages_metadata: HashMap<MessageId, MessageMetadata>,
 451    next_message_id: MessageId,
 452    summary: Option<String>,
 453    pending_summary: Task<Option<()>>,
 454    completion_count: usize,
 455    pending_completions: Vec<PendingCompletion>,
 456    model: String,
 457    token_count: Option<usize>,
 458    max_token_count: usize,
 459    pending_token_count: Task<Option<()>>,
 460    api_key: Rc<RefCell<Option<String>>>,
 461    _subscriptions: Vec<Subscription>,
 462}
 463
 464impl Entity for Assistant {
 465    type Event = AssistantEvent;
 466}
 467
 468impl Assistant {
 469    fn new(
 470        api_key: Rc<RefCell<Option<String>>>,
 471        language_registry: Arc<LanguageRegistry>,
 472        cx: &mut ModelContext<Self>,
 473    ) -> Self {
 474        let model = "gpt-3.5-turbo";
 475        let markdown = language_registry.language_for_name("Markdown");
 476        let buffer = cx.add_model(|cx| {
 477            let mut buffer = Buffer::new(0, "", cx);
 478            buffer.set_language_registry(language_registry);
 479            cx.spawn_weak(|buffer, mut cx| async move {
 480                let markdown = markdown.await?;
 481                let buffer = buffer
 482                    .upgrade(&cx)
 483                    .ok_or_else(|| anyhow!("buffer was dropped"))?;
 484                buffer.update(&mut cx, |buffer, cx| {
 485                    buffer.set_language(Some(markdown), cx)
 486                });
 487                anyhow::Ok(())
 488            })
 489            .detach_and_log_err(cx);
 490            buffer
 491        });
 492
 493        let mut this = Self {
 494            messages: Default::default(),
 495            messages_metadata: Default::default(),
 496            next_message_id: Default::default(),
 497            summary: None,
 498            pending_summary: Task::ready(None),
 499            completion_count: Default::default(),
 500            pending_completions: Default::default(),
 501            token_count: None,
 502            max_token_count: tiktoken_rs::model::get_context_size(model),
 503            pending_token_count: Task::ready(None),
 504            model: model.into(),
 505            _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
 506            api_key,
 507            buffer,
 508        };
 509        let message = Message {
 510            id: MessageId(post_inc(&mut this.next_message_id.0)),
 511            start: language::Anchor::MIN,
 512        };
 513        this.messages.push(message.clone());
 514        this.messages_metadata.insert(
 515            message.id,
 516            MessageMetadata {
 517                role: Role::User,
 518                sent_at: Local::now(),
 519                error: None,
 520            },
 521        );
 522
 523        this.count_remaining_tokens(cx);
 524        this
 525    }
 526
 527    fn handle_buffer_event(
 528        &mut self,
 529        _: ModelHandle<Buffer>,
 530        event: &language::Event,
 531        cx: &mut ModelContext<Self>,
 532    ) {
 533        match event {
 534            language::Event::Edited => {
 535                self.count_remaining_tokens(cx);
 536                cx.emit(AssistantEvent::MessagesEdited);
 537            }
 538            _ => {}
 539        }
 540    }
 541
 542    fn count_remaining_tokens(&mut self, cx: &mut ModelContext<Self>) {
 543        let messages = self
 544            .open_ai_request_messages(cx)
 545            .into_iter()
 546            .filter_map(|message| {
 547                Some(tiktoken_rs::ChatCompletionRequestMessage {
 548                    role: match message.role {
 549                        Role::User => "user".into(),
 550                        Role::Assistant => "assistant".into(),
 551                        Role::System => "system".into(),
 552                    },
 553                    content: message.content,
 554                    name: None,
 555                })
 556            })
 557            .collect::<Vec<_>>();
 558        let model = self.model.clone();
 559        self.pending_token_count = cx.spawn_weak(|this, mut cx| {
 560            async move {
 561                cx.background().timer(Duration::from_millis(200)).await;
 562                let token_count = cx
 563                    .background()
 564                    .spawn(async move { tiktoken_rs::num_tokens_from_messages(&model, &messages) })
 565                    .await?;
 566
 567                this.upgrade(&cx)
 568                    .ok_or_else(|| anyhow!("assistant was dropped"))?
 569                    .update(&mut cx, |this, cx| {
 570                        this.max_token_count = tiktoken_rs::model::get_context_size(&this.model);
 571                        this.token_count = Some(token_count);
 572                        cx.notify()
 573                    });
 574                anyhow::Ok(())
 575            }
 576            .log_err()
 577        });
 578    }
 579
 580    fn remaining_tokens(&self) -> Option<isize> {
 581        Some(self.max_token_count as isize - self.token_count? as isize)
 582    }
 583
 584    fn set_model(&mut self, model: String, cx: &mut ModelContext<Self>) {
 585        self.model = model;
 586        self.count_remaining_tokens(cx);
 587        cx.notify();
 588    }
 589
 590    fn assist(&mut self, cx: &mut ModelContext<Self>) -> Option<(Message, Message)> {
 591        let request = OpenAIRequest {
 592            model: self.model.clone(),
 593            messages: self.open_ai_request_messages(cx),
 594            stream: true,
 595        };
 596
 597        let api_key = self.api_key.borrow().clone()?;
 598        let stream = stream_completion(api_key, cx.background().clone(), request);
 599        let assistant_message =
 600            self.insert_message_after(self.messages.last()?.id, Role::Assistant, cx)?;
 601        let user_message = self.insert_message_after(assistant_message.id, Role::User, cx)?;
 602        let task = cx.spawn_weak({
 603            |this, mut cx| async move {
 604                let assistant_message_id = assistant_message.id;
 605                let stream_completion = async {
 606                    let mut messages = stream.await?;
 607
 608                    while let Some(message) = messages.next().await {
 609                        let mut message = message?;
 610                        if let Some(choice) = message.choices.pop() {
 611                            this.upgrade(&cx)
 612                                .ok_or_else(|| anyhow!("assistant was dropped"))?
 613                                .update(&mut cx, |this, cx| {
 614                                    let text: Arc<str> = choice.delta.content?.into();
 615                                    let message_ix = this
 616                                        .messages
 617                                        .iter()
 618                                        .position(|message| message.id == assistant_message_id)?;
 619                                    this.buffer.update(cx, |buffer, cx| {
 620                                        let offset = if message_ix + 1 == this.messages.len() {
 621                                            buffer.len()
 622                                        } else {
 623                                            this.messages[message_ix + 1]
 624                                                .start
 625                                                .to_offset(buffer)
 626                                                .saturating_sub(1)
 627                                        };
 628                                        buffer.edit([(offset..offset, text)], None, cx);
 629                                    });
 630                                    cx.emit(AssistantEvent::StreamedCompletion);
 631
 632                                    Some(())
 633                                });
 634                        }
 635                    }
 636
 637                    this.upgrade(&cx)
 638                        .ok_or_else(|| anyhow!("assistant was dropped"))?
 639                        .update(&mut cx, |this, cx| {
 640                            this.pending_completions
 641                                .retain(|completion| completion.id != this.completion_count);
 642                            this.summarize(cx);
 643                        });
 644
 645                    anyhow::Ok(())
 646                };
 647
 648                let result = stream_completion.await;
 649                if let Some(this) = this.upgrade(&cx) {
 650                    this.update(&mut cx, |this, cx| {
 651                        if let Err(error) = result {
 652                            if let Some(metadata) =
 653                                this.messages_metadata.get_mut(&assistant_message.id)
 654                            {
 655                                metadata.error = Some(error.to_string().trim().into());
 656                                cx.notify();
 657                            }
 658                        }
 659                    });
 660                }
 661            }
 662        });
 663
 664        self.pending_completions.push(PendingCompletion {
 665            id: post_inc(&mut self.completion_count),
 666            _task: task,
 667        });
 668        Some((assistant_message, user_message))
 669    }
 670
 671    fn cancel_last_assist(&mut self) -> bool {
 672        self.pending_completions.pop().is_some()
 673    }
 674
 675    fn cycle_message_role(&mut self, id: MessageId, cx: &mut ModelContext<Self>) {
 676        if let Some(metadata) = self.messages_metadata.get_mut(&id) {
 677            metadata.role.cycle();
 678            cx.emit(AssistantEvent::MessagesEdited);
 679            cx.notify();
 680        }
 681    }
 682
 683    fn insert_message_after(
 684        &mut self,
 685        message_id: MessageId,
 686        role: Role,
 687        cx: &mut ModelContext<Self>,
 688    ) -> Option<Message> {
 689        if let Some(prev_message_ix) = self
 690            .messages
 691            .iter()
 692            .position(|message| message.id == message_id)
 693        {
 694            let start = self.buffer.update(cx, |buffer, cx| {
 695                let offset = self.messages[prev_message_ix + 1..]
 696                    .iter()
 697                    .find(|message| message.start.is_valid(buffer))
 698                    .map_or(buffer.len(), |message| message.start.to_offset(buffer) - 1);
 699                buffer.edit([(offset..offset, "\n")], None, cx);
 700                buffer.anchor_before(offset + 1)
 701            });
 702            let message = Message {
 703                id: MessageId(post_inc(&mut self.next_message_id.0)),
 704                start,
 705            };
 706            self.messages.insert(prev_message_ix + 1, message.clone());
 707            self.messages_metadata.insert(
 708                message.id,
 709                MessageMetadata {
 710                    role,
 711                    sent_at: Local::now(),
 712                    error: None,
 713                },
 714            );
 715            cx.emit(AssistantEvent::MessagesEdited);
 716            Some(message)
 717        } else {
 718            None
 719        }
 720    }
 721
 722    fn split_message(
 723        &mut self,
 724        range: Range<usize>,
 725        cx: &mut ModelContext<Self>,
 726    ) -> (Option<Message>, Option<Message>) {
 727        let start_message = self.message_for_offset(range.start, cx);
 728        let end_message = self.message_for_offset(range.end, cx);
 729        if let Some((start_message, end_message)) = start_message.zip(end_message) {
 730            let (start_message_ix, _, metadata, message_range) = start_message;
 731            let (end_message_ix, _, _, _) = end_message;
 732
 733            // Prevent splitting when range spans multiple messages.
 734            if start_message_ix != end_message_ix {
 735                return (None, None);
 736            }
 737
 738            let role = metadata.role;
 739            let mut edited_buffer = false;
 740
 741            let mut suffix_start = None;
 742            if range.start > message_range.start && range.end < message_range.end - 1 {
 743                if self.buffer.read(cx).chars_at(range.end).next() == Some('\n') {
 744                    suffix_start = Some(range.end + 1);
 745                } else if self.buffer.read(cx).reversed_chars_at(range.end).next() == Some('\n') {
 746                    suffix_start = Some(range.end);
 747                }
 748            }
 749
 750            let suffix = if let Some(suffix_start) = suffix_start {
 751                Message {
 752                    id: MessageId(post_inc(&mut self.next_message_id.0)),
 753                    start: self.buffer.read(cx).anchor_before(suffix_start),
 754                }
 755            } else {
 756                self.buffer.update(cx, |buffer, cx| {
 757                    buffer.edit([(range.end..range.end, "\n")], None, cx);
 758                });
 759                edited_buffer = true;
 760                Message {
 761                    id: MessageId(post_inc(&mut self.next_message_id.0)),
 762                    start: self.buffer.read(cx).anchor_before(range.end + 1),
 763                }
 764            };
 765
 766            self.messages.insert(start_message_ix + 1, suffix.clone());
 767            self.messages_metadata.insert(
 768                suffix.id,
 769                MessageMetadata {
 770                    role,
 771                    sent_at: Local::now(),
 772                    error: None,
 773                },
 774            );
 775
 776            let new_messages = if range.start == range.end || range.start == message_range.start {
 777                (None, Some(suffix))
 778            } else {
 779                let mut prefix_end = None;
 780                if range.start > message_range.start && range.end < message_range.end - 1 {
 781                    if self.buffer.read(cx).chars_at(range.start).next() == Some('\n') {
 782                        prefix_end = Some(range.start + 1);
 783                    } else if self.buffer.read(cx).reversed_chars_at(range.start).next()
 784                        == Some('\n')
 785                    {
 786                        prefix_end = Some(range.start);
 787                    }
 788                }
 789
 790                let selection = if let Some(prefix_end) = prefix_end {
 791                    cx.emit(AssistantEvent::MessagesEdited);
 792                    Message {
 793                        id: MessageId(post_inc(&mut self.next_message_id.0)),
 794                        start: self.buffer.read(cx).anchor_before(prefix_end),
 795                    }
 796                } else {
 797                    self.buffer.update(cx, |buffer, cx| {
 798                        buffer.edit([(range.start..range.start, "\n")], None, cx)
 799                    });
 800                    edited_buffer = true;
 801                    Message {
 802                        id: MessageId(post_inc(&mut self.next_message_id.0)),
 803                        start: self.buffer.read(cx).anchor_before(range.end + 1),
 804                    }
 805                };
 806
 807                self.messages
 808                    .insert(start_message_ix + 1, selection.clone());
 809                self.messages_metadata.insert(
 810                    selection.id,
 811                    MessageMetadata {
 812                        role,
 813                        sent_at: Local::now(),
 814                        error: None,
 815                    },
 816                );
 817                (Some(selection), Some(suffix))
 818            };
 819
 820            if !edited_buffer {
 821                cx.emit(AssistantEvent::MessagesEdited);
 822            }
 823            new_messages
 824        } else {
 825            (None, None)
 826        }
 827    }
 828
 829    fn summarize(&mut self, cx: &mut ModelContext<Self>) {
 830        if self.messages.len() >= 2 && self.summary.is_none() {
 831            let api_key = self.api_key.borrow().clone();
 832            if let Some(api_key) = api_key {
 833                let mut messages = self.open_ai_request_messages(cx);
 834                messages.truncate(2);
 835                messages.push(RequestMessage {
 836                    role: Role::User,
 837                    content: "Summarize the conversation into a short title without punctuation"
 838                        .into(),
 839                });
 840                let request = OpenAIRequest {
 841                    model: self.model.clone(),
 842                    messages,
 843                    stream: true,
 844                };
 845
 846                let stream = stream_completion(api_key, cx.background().clone(), request);
 847                self.pending_summary = cx.spawn(|this, mut cx| {
 848                    async move {
 849                        let mut messages = stream.await?;
 850
 851                        while let Some(message) = messages.next().await {
 852                            let mut message = message?;
 853                            if let Some(choice) = message.choices.pop() {
 854                                let text = choice.delta.content.unwrap_or_default();
 855                                this.update(&mut cx, |this, cx| {
 856                                    this.summary.get_or_insert(String::new()).push_str(&text);
 857                                    cx.emit(AssistantEvent::SummaryChanged);
 858                                });
 859                            }
 860                        }
 861
 862                        anyhow::Ok(())
 863                    }
 864                    .log_err()
 865                });
 866            }
 867        }
 868    }
 869
 870    fn open_ai_request_messages(&self, cx: &AppContext) -> Vec<RequestMessage> {
 871        let buffer = self.buffer.read(cx);
 872        self.messages(cx)
 873            .map(|(_ix, _message, metadata, range)| RequestMessage {
 874                role: metadata.role,
 875                content: buffer.text_for_range(range).collect(),
 876            })
 877            .collect()
 878    }
 879
 880    fn message_for_offset<'a>(
 881        &'a self,
 882        offset: usize,
 883        cx: &'a AppContext,
 884    ) -> Option<(usize, &Message, &MessageMetadata, Range<usize>)> {
 885        let mut messages = self.messages(cx).peekable();
 886        while let Some((ix, message, metadata, range)) = messages.next() {
 887            if range.contains(&offset) || messages.peek().is_none() {
 888                return Some((ix, message, metadata, range));
 889            }
 890        }
 891        None
 892    }
 893
 894    fn messages<'a>(
 895        &'a self,
 896        cx: &'a AppContext,
 897    ) -> impl 'a + Iterator<Item = (usize, &Message, &MessageMetadata, Range<usize>)> {
 898        let buffer = self.buffer.read(cx);
 899        let mut messages = self.messages.iter().enumerate().peekable();
 900        iter::from_fn(move || {
 901            while let Some((ix, message)) = messages.next() {
 902                let metadata = self.messages_metadata.get(&message.id)?;
 903                let message_start = message.start.to_offset(buffer);
 904                let mut message_end = None;
 905                while let Some((_, next_message)) = messages.peek() {
 906                    if next_message.start.is_valid(buffer) {
 907                        message_end = Some(next_message.start);
 908                        break;
 909                    } else {
 910                        messages.next();
 911                    }
 912                }
 913                let message_end = message_end
 914                    .unwrap_or(language::Anchor::MAX)
 915                    .to_offset(buffer);
 916                return Some((ix, message, metadata, message_start..message_end));
 917            }
 918            None
 919        })
 920    }
 921}
 922
 923struct PendingCompletion {
 924    id: usize,
 925    _task: Task<()>,
 926}
 927
 928enum AssistantEditorEvent {
 929    TabContentChanged,
 930}
 931
 932#[derive(Copy, Clone, Debug, PartialEq)]
 933struct ScrollPosition {
 934    offset_before_cursor: Vector2F,
 935    cursor: Anchor,
 936}
 937
 938struct AssistantEditor {
 939    assistant: ModelHandle<Assistant>,
 940    editor: ViewHandle<Editor>,
 941    blocks: HashSet<BlockId>,
 942    scroll_position: Option<ScrollPosition>,
 943    _subscriptions: Vec<Subscription>,
 944}
 945
 946impl AssistantEditor {
 947    fn new(
 948        api_key: Rc<RefCell<Option<String>>>,
 949        language_registry: Arc<LanguageRegistry>,
 950        cx: &mut ViewContext<Self>,
 951    ) -> Self {
 952        let assistant = cx.add_model(|cx| Assistant::new(api_key, language_registry, cx));
 953        let editor = cx.add_view(|cx| {
 954            let mut editor = Editor::for_buffer(assistant.read(cx).buffer.clone(), None, cx);
 955            editor.set_soft_wrap_mode(SoftWrap::EditorWidth, cx);
 956            editor.set_show_gutter(false, cx);
 957            editor
 958        });
 959
 960        let _subscriptions = vec![
 961            cx.observe(&assistant, |_, _, cx| cx.notify()),
 962            cx.subscribe(&assistant, Self::handle_assistant_event),
 963            cx.subscribe(&editor, Self::handle_editor_event),
 964        ];
 965
 966        let mut this = Self {
 967            assistant,
 968            editor,
 969            blocks: Default::default(),
 970            scroll_position: None,
 971            _subscriptions,
 972        };
 973        this.update_message_headers(cx);
 974        this
 975    }
 976
 977    fn assist(&mut self, _: &Assist, cx: &mut ViewContext<Self>) {
 978        let user_message = self.assistant.update(cx, |assistant, cx| {
 979            let (_, user_message) = assistant.assist(cx)?;
 980            Some(user_message)
 981        });
 982
 983        if let Some(user_message) = user_message {
 984            let cursor = user_message
 985                .start
 986                .to_offset(&self.assistant.read(cx).buffer.read(cx));
 987            self.editor.update(cx, |editor, cx| {
 988                editor.change_selections(
 989                    Some(Autoscroll::Strategy(AutoscrollStrategy::Fit)),
 990                    cx,
 991                    |selections| selections.select_ranges([cursor..cursor]),
 992                );
 993            });
 994        }
 995    }
 996
 997    fn cancel_last_assist(&mut self, _: &editor::Cancel, cx: &mut ViewContext<Self>) {
 998        if !self
 999            .assistant
1000            .update(cx, |assistant, _| assistant.cancel_last_assist())
1001        {
1002            cx.propagate_action();
1003        }
1004    }
1005
1006    fn handle_assistant_event(
1007        &mut self,
1008        _: ModelHandle<Assistant>,
1009        event: &AssistantEvent,
1010        cx: &mut ViewContext<Self>,
1011    ) {
1012        match event {
1013            AssistantEvent::MessagesEdited => self.update_message_headers(cx),
1014            AssistantEvent::SummaryChanged => {
1015                cx.emit(AssistantEditorEvent::TabContentChanged);
1016            }
1017            AssistantEvent::StreamedCompletion => {
1018                self.editor.update(cx, |editor, cx| {
1019                    if let Some(scroll_position) = self.scroll_position {
1020                        let snapshot = editor.snapshot(cx);
1021                        let cursor_point = scroll_position.cursor.to_display_point(&snapshot);
1022                        let scroll_top =
1023                            cursor_point.row() as f32 - scroll_position.offset_before_cursor.y();
1024                        editor.set_scroll_position(
1025                            vec2f(scroll_position.offset_before_cursor.x(), scroll_top),
1026                            cx,
1027                        );
1028                    }
1029                });
1030            }
1031        }
1032    }
1033
1034    fn handle_editor_event(
1035        &mut self,
1036        _: ViewHandle<Editor>,
1037        event: &editor::Event,
1038        cx: &mut ViewContext<Self>,
1039    ) {
1040        match event {
1041            editor::Event::ScrollPositionChanged { autoscroll, .. } => {
1042                let cursor_scroll_position = self.cursor_scroll_position(cx);
1043                if *autoscroll {
1044                    self.scroll_position = cursor_scroll_position;
1045                } else if self.scroll_position != cursor_scroll_position {
1046                    self.scroll_position = None;
1047                }
1048            }
1049            editor::Event::SelectionsChanged { .. } => {
1050                self.scroll_position = self.cursor_scroll_position(cx);
1051            }
1052            _ => {}
1053        }
1054    }
1055
1056    fn cursor_scroll_position(&self, cx: &mut ViewContext<Self>) -> Option<ScrollPosition> {
1057        self.editor.update(cx, |editor, cx| {
1058            let snapshot = editor.snapshot(cx);
1059            let cursor = editor.selections.newest_anchor().head();
1060            let cursor_row = cursor.to_display_point(&snapshot.display_snapshot).row() as f32;
1061            let scroll_position = editor
1062                .scroll_manager
1063                .anchor()
1064                .scroll_position(&snapshot.display_snapshot);
1065
1066            let scroll_bottom = scroll_position.y() + editor.visible_line_count().unwrap_or(0.);
1067            if (scroll_position.y()..scroll_bottom).contains(&cursor_row) {
1068                Some(ScrollPosition {
1069                    cursor,
1070                    offset_before_cursor: vec2f(
1071                        scroll_position.x(),
1072                        cursor_row - scroll_position.y(),
1073                    ),
1074                })
1075            } else {
1076                None
1077            }
1078        })
1079    }
1080
1081    fn update_message_headers(&mut self, cx: &mut ViewContext<Self>) {
1082        self.editor.update(cx, |editor, cx| {
1083            let buffer = editor.buffer().read(cx).snapshot(cx);
1084            let excerpt_id = *buffer.as_singleton().unwrap().0;
1085            let old_blocks = std::mem::take(&mut self.blocks);
1086            let new_blocks = self
1087                .assistant
1088                .read(cx)
1089                .messages(cx)
1090                .map(|(_, message, metadata, _)| BlockProperties {
1091                    position: buffer.anchor_in_excerpt(excerpt_id, message.start),
1092                    height: 2,
1093                    style: BlockStyle::Sticky,
1094                    render: Arc::new({
1095                        let assistant = self.assistant.clone();
1096                        let metadata = metadata.clone();
1097                        let message = message.clone();
1098                        move |cx| {
1099                            enum Sender {}
1100                            enum ErrorTooltip {}
1101
1102                            let theme = theme::current(cx);
1103                            let style = &theme.assistant;
1104                            let message_id = message.id;
1105                            let sender = MouseEventHandler::<Sender, _>::new(
1106                                message_id.0,
1107                                cx,
1108                                |state, _| match metadata.role {
1109                                    Role::User => {
1110                                        let style = style.user_sender.style_for(state, false);
1111                                        Label::new("You", style.text.clone())
1112                                            .contained()
1113                                            .with_style(style.container)
1114                                    }
1115                                    Role::Assistant => {
1116                                        let style = style.assistant_sender.style_for(state, false);
1117                                        Label::new("Assistant", style.text.clone())
1118                                            .contained()
1119                                            .with_style(style.container)
1120                                    }
1121                                    Role::System => {
1122                                        let style = style.system_sender.style_for(state, false);
1123                                        Label::new("System", style.text.clone())
1124                                            .contained()
1125                                            .with_style(style.container)
1126                                    }
1127                                },
1128                            )
1129                            .with_cursor_style(CursorStyle::PointingHand)
1130                            .on_down(MouseButton::Left, {
1131                                let assistant = assistant.clone();
1132                                move |_, _, cx| {
1133                                    assistant.update(cx, |assistant, cx| {
1134                                        assistant.cycle_message_role(message_id, cx)
1135                                    })
1136                                }
1137                            });
1138
1139                            Flex::row()
1140                                .with_child(sender.aligned())
1141                                .with_child(
1142                                    Label::new(
1143                                        metadata.sent_at.format("%I:%M%P").to_string(),
1144                                        style.sent_at.text.clone(),
1145                                    )
1146                                    .contained()
1147                                    .with_style(style.sent_at.container)
1148                                    .aligned(),
1149                                )
1150                                .with_children(metadata.error.clone().map(|error| {
1151                                    Svg::new("icons/circle_x_mark_12.svg")
1152                                        .with_color(style.error_icon.color)
1153                                        .constrained()
1154                                        .with_width(style.error_icon.width)
1155                                        .contained()
1156                                        .with_style(style.error_icon.container)
1157                                        .with_tooltip::<ErrorTooltip>(
1158                                            message_id.0,
1159                                            error,
1160                                            None,
1161                                            theme.tooltip.clone(),
1162                                            cx,
1163                                        )
1164                                        .aligned()
1165                                }))
1166                                .aligned()
1167                                .left()
1168                                .contained()
1169                                .with_style(style.header)
1170                                .into_any()
1171                        }
1172                    }),
1173                    disposition: BlockDisposition::Above,
1174                })
1175                .collect::<Vec<_>>();
1176
1177            editor.remove_blocks(old_blocks, None, cx);
1178            let ids = editor.insert_blocks(new_blocks, None, cx);
1179            self.blocks = HashSet::from_iter(ids);
1180        });
1181    }
1182
1183    fn quote_selection(
1184        workspace: &mut Workspace,
1185        _: &QuoteSelection,
1186        cx: &mut ViewContext<Workspace>,
1187    ) {
1188        let Some(panel) = workspace.panel::<AssistantPanel>(cx) else {
1189            return;
1190        };
1191        let Some(editor) = workspace.active_item(cx).and_then(|item| item.downcast::<Editor>()) else {
1192            return;
1193        };
1194
1195        let text = editor.read_with(cx, |editor, cx| {
1196            let range = editor.selections.newest::<usize>(cx).range();
1197            let buffer = editor.buffer().read(cx).snapshot(cx);
1198            let start_language = buffer.language_at(range.start);
1199            let end_language = buffer.language_at(range.end);
1200            let language_name = if start_language == end_language {
1201                start_language.map(|language| language.name())
1202            } else {
1203                None
1204            };
1205            let language_name = language_name.as_deref().unwrap_or("").to_lowercase();
1206
1207            let selected_text = buffer.text_for_range(range).collect::<String>();
1208            if selected_text.is_empty() {
1209                None
1210            } else {
1211                Some(if language_name == "markdown" {
1212                    selected_text
1213                        .lines()
1214                        .map(|line| format!("> {}", line))
1215                        .collect::<Vec<_>>()
1216                        .join("\n")
1217                } else {
1218                    format!("```{language_name}\n{selected_text}\n```")
1219                })
1220            }
1221        });
1222
1223        // Activate the panel
1224        if !panel.read(cx).has_focus(cx) {
1225            workspace.toggle_panel_focus::<AssistantPanel>(cx);
1226        }
1227
1228        if let Some(text) = text {
1229            panel.update(cx, |panel, cx| {
1230                if let Some(assistant) = panel
1231                    .pane
1232                    .read(cx)
1233                    .active_item()
1234                    .and_then(|item| item.downcast::<AssistantEditor>())
1235                    .ok_or_else(|| anyhow!("no active context"))
1236                    .log_err()
1237                {
1238                    assistant.update(cx, |assistant, cx| {
1239                        assistant
1240                            .editor
1241                            .update(cx, |editor, cx| editor.insert(&text, cx))
1242                    });
1243                }
1244            });
1245        }
1246    }
1247
1248    fn copy(&mut self, _: &editor::Copy, cx: &mut ViewContext<Self>) {
1249        let editor = self.editor.read(cx);
1250        let assistant = self.assistant.read(cx);
1251        if editor.selections.count() == 1 {
1252            let selection = editor.selections.newest::<usize>(cx);
1253            let mut copied_text = String::new();
1254            let mut spanned_messages = 0;
1255            for (_ix, _message, metadata, message_range) in assistant.messages(cx) {
1256                if message_range.start >= selection.range().end {
1257                    break;
1258                } else if message_range.end >= selection.range().start {
1259                    let range = cmp::max(message_range.start, selection.range().start)
1260                        ..cmp::min(message_range.end, selection.range().end);
1261                    if !range.is_empty() {
1262                        spanned_messages += 1;
1263                        write!(&mut copied_text, "## {}\n\n", metadata.role).unwrap();
1264                        for chunk in assistant.buffer.read(cx).text_for_range(range) {
1265                            copied_text.push_str(&chunk);
1266                        }
1267                        copied_text.push('\n');
1268                    }
1269                }
1270            }
1271
1272            if spanned_messages > 1 {
1273                cx.platform()
1274                    .write_to_clipboard(ClipboardItem::new(copied_text));
1275                return;
1276            }
1277        }
1278
1279        cx.propagate_action();
1280    }
1281
1282    fn split(&mut self, _: &Split, cx: &mut ViewContext<Self>) {
1283        self.assistant.update(cx, |assistant, cx| {
1284            let range = self.editor.read(cx).selections.newest::<usize>(cx).range();
1285            assistant.split_message(range, cx);
1286        });
1287    }
1288
1289    fn cycle_model(&mut self, cx: &mut ViewContext<Self>) {
1290        self.assistant.update(cx, |assistant, cx| {
1291            let new_model = match assistant.model.as_str() {
1292                "gpt-4" => "gpt-3.5-turbo",
1293                _ => "gpt-4",
1294            };
1295            assistant.set_model(new_model.into(), cx);
1296        });
1297    }
1298
1299    fn title(&self, cx: &AppContext) -> String {
1300        self.assistant
1301            .read(cx)
1302            .summary
1303            .clone()
1304            .unwrap_or_else(|| "New Context".into())
1305    }
1306}
1307
1308impl Entity for AssistantEditor {
1309    type Event = AssistantEditorEvent;
1310}
1311
1312impl View for AssistantEditor {
1313    fn ui_name() -> &'static str {
1314        "AssistantEditor"
1315    }
1316
1317    fn render(&mut self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
1318        enum Model {}
1319        let theme = &theme::current(cx).assistant;
1320        let assistant = &self.assistant.read(cx);
1321        let model = assistant.model.clone();
1322        let remaining_tokens = assistant.remaining_tokens().map(|remaining_tokens| {
1323            let remaining_tokens_style = if remaining_tokens <= 0 {
1324                &theme.no_remaining_tokens
1325            } else {
1326                &theme.remaining_tokens
1327            };
1328            Label::new(
1329                remaining_tokens.to_string(),
1330                remaining_tokens_style.text.clone(),
1331            )
1332            .contained()
1333            .with_style(remaining_tokens_style.container)
1334        });
1335
1336        Stack::new()
1337            .with_child(
1338                ChildView::new(&self.editor, cx)
1339                    .contained()
1340                    .with_style(theme.container),
1341            )
1342            .with_child(
1343                Flex::row()
1344                    .with_child(
1345                        MouseEventHandler::<Model, _>::new(0, cx, |state, _| {
1346                            let style = theme.model.style_for(state, false);
1347                            Label::new(model, style.text.clone())
1348                                .contained()
1349                                .with_style(style.container)
1350                        })
1351                        .with_cursor_style(CursorStyle::PointingHand)
1352                        .on_click(MouseButton::Left, |_, this, cx| this.cycle_model(cx)),
1353                    )
1354                    .with_children(remaining_tokens)
1355                    .contained()
1356                    .with_style(theme.model_info_container)
1357                    .aligned()
1358                    .top()
1359                    .right(),
1360            )
1361            .into_any()
1362    }
1363
1364    fn focus_in(&mut self, _: gpui::AnyViewHandle, cx: &mut ViewContext<Self>) {
1365        if cx.is_self_focused() {
1366            cx.focus(&self.editor);
1367        }
1368    }
1369}
1370
1371impl Item for AssistantEditor {
1372    fn tab_content<V: View>(
1373        &self,
1374        _: Option<usize>,
1375        style: &theme::Tab,
1376        cx: &gpui::AppContext,
1377    ) -> AnyElement<V> {
1378        let title = truncate_and_trailoff(&self.title(cx), editor::MAX_TAB_TITLE_LEN);
1379        Label::new(title, style.label.clone()).into_any()
1380    }
1381
1382    fn tab_tooltip_text(&self, cx: &AppContext) -> Option<Cow<str>> {
1383        Some(self.title(cx).into())
1384    }
1385
1386    fn as_searchable(
1387        &self,
1388        _: &ViewHandle<Self>,
1389    ) -> Option<Box<dyn workspace::searchable::SearchableItemHandle>> {
1390        Some(Box::new(self.editor.clone()))
1391    }
1392}
1393
1394#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, Hash)]
1395struct MessageId(usize);
1396
1397#[derive(Clone, Debug)]
1398struct Message {
1399    id: MessageId,
1400    start: language::Anchor,
1401}
1402
1403#[derive(Clone, Debug)]
1404struct MessageMetadata {
1405    role: Role,
1406    sent_at: DateTime<Local>,
1407    error: Option<String>,
1408}
1409
1410async fn stream_completion(
1411    api_key: String,
1412    executor: Arc<Background>,
1413    mut request: OpenAIRequest,
1414) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
1415    request.stream = true;
1416
1417    let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
1418
1419    let json_data = serde_json::to_string(&request)?;
1420    let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
1421        .header("Content-Type", "application/json")
1422        .header("Authorization", format!("Bearer {}", api_key))
1423        .body(json_data)?
1424        .send_async()
1425        .await?;
1426
1427    let status = response.status();
1428    if status == StatusCode::OK {
1429        executor
1430            .spawn(async move {
1431                let mut lines = BufReader::new(response.body_mut()).lines();
1432
1433                fn parse_line(
1434                    line: Result<String, io::Error>,
1435                ) -> Result<Option<OpenAIResponseStreamEvent>> {
1436                    if let Some(data) = line?.strip_prefix("data: ") {
1437                        let event = serde_json::from_str(&data)?;
1438                        Ok(Some(event))
1439                    } else {
1440                        Ok(None)
1441                    }
1442                }
1443
1444                while let Some(line) = lines.next().await {
1445                    if let Some(event) = parse_line(line).transpose() {
1446                        let done = event.as_ref().map_or(false, |event| {
1447                            event
1448                                .choices
1449                                .last()
1450                                .map_or(false, |choice| choice.finish_reason.is_some())
1451                        });
1452                        if tx.unbounded_send(event).is_err() {
1453                            break;
1454                        }
1455
1456                        if done {
1457                            break;
1458                        }
1459                    }
1460                }
1461
1462                anyhow::Ok(())
1463            })
1464            .detach();
1465
1466        Ok(rx)
1467    } else {
1468        let mut body = String::new();
1469        response.body_mut().read_to_string(&mut body).await?;
1470
1471        #[derive(Deserialize)]
1472        struct OpenAIResponse {
1473            error: OpenAIError,
1474        }
1475
1476        #[derive(Deserialize)]
1477        struct OpenAIError {
1478            message: String,
1479        }
1480
1481        match serde_json::from_str::<OpenAIResponse>(&body) {
1482            Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
1483                "Failed to connect to OpenAI API: {}",
1484                response.error.message,
1485            )),
1486
1487            _ => Err(anyhow!(
1488                "Failed to connect to OpenAI API: {} {}",
1489                response.status(),
1490                body,
1491            )),
1492        }
1493    }
1494}
1495
1496#[cfg(test)]
1497mod tests {
1498    use super::*;
1499    use gpui::AppContext;
1500
1501    #[gpui::test]
1502    fn test_inserting_and_removing_messages(cx: &mut AppContext) {
1503        let registry = Arc::new(LanguageRegistry::test());
1504        let assistant = cx.add_model(|cx| Assistant::new(Default::default(), registry, cx));
1505        let buffer = assistant.read(cx).buffer.clone();
1506
1507        let message_1 = assistant.read(cx).messages[0].clone();
1508        assert_eq!(
1509            messages(&assistant, cx),
1510            vec![(message_1.id, Role::User, 0..0)]
1511        );
1512
1513        let message_2 = assistant.update(cx, |assistant, cx| {
1514            assistant
1515                .insert_message_after(message_1.id, Role::Assistant, cx)
1516                .unwrap()
1517        });
1518        assert_eq!(
1519            messages(&assistant, cx),
1520            vec![
1521                (message_1.id, Role::User, 0..1),
1522                (message_2.id, Role::Assistant, 1..1)
1523            ]
1524        );
1525
1526        buffer.update(cx, |buffer, cx| {
1527            buffer.edit([(0..0, "1"), (1..1, "2")], None, cx)
1528        });
1529        assert_eq!(
1530            messages(&assistant, cx),
1531            vec![
1532                (message_1.id, Role::User, 0..2),
1533                (message_2.id, Role::Assistant, 2..3)
1534            ]
1535        );
1536
1537        let message_3 = assistant.update(cx, |assistant, cx| {
1538            assistant
1539                .insert_message_after(message_2.id, Role::User, cx)
1540                .unwrap()
1541        });
1542        assert_eq!(
1543            messages(&assistant, cx),
1544            vec![
1545                (message_1.id, Role::User, 0..2),
1546                (message_2.id, Role::Assistant, 2..4),
1547                (message_3.id, Role::User, 4..4)
1548            ]
1549        );
1550
1551        let message_4 = assistant.update(cx, |assistant, cx| {
1552            assistant
1553                .insert_message_after(message_2.id, Role::User, cx)
1554                .unwrap()
1555        });
1556        assert_eq!(
1557            messages(&assistant, cx),
1558            vec![
1559                (message_1.id, Role::User, 0..2),
1560                (message_2.id, Role::Assistant, 2..4),
1561                (message_4.id, Role::User, 4..5),
1562                (message_3.id, Role::User, 5..5),
1563            ]
1564        );
1565
1566        buffer.update(cx, |buffer, cx| {
1567            buffer.edit([(4..4, "C"), (5..5, "D")], None, cx)
1568        });
1569        assert_eq!(
1570            messages(&assistant, cx),
1571            vec![
1572                (message_1.id, Role::User, 0..2),
1573                (message_2.id, Role::Assistant, 2..4),
1574                (message_4.id, Role::User, 4..6),
1575                (message_3.id, Role::User, 6..7),
1576            ]
1577        );
1578
1579        // Deleting across message boundaries merges the messages.
1580        buffer.update(cx, |buffer, cx| buffer.edit([(1..4, "")], None, cx));
1581        assert_eq!(
1582            messages(&assistant, cx),
1583            vec![
1584                (message_1.id, Role::User, 0..3),
1585                (message_3.id, Role::User, 3..4),
1586            ]
1587        );
1588
1589        // Undoing the deletion should also undo the merge.
1590        buffer.update(cx, |buffer, cx| buffer.undo(cx));
1591        assert_eq!(
1592            messages(&assistant, cx),
1593            vec![
1594                (message_1.id, Role::User, 0..2),
1595                (message_2.id, Role::Assistant, 2..4),
1596                (message_4.id, Role::User, 4..6),
1597                (message_3.id, Role::User, 6..7),
1598            ]
1599        );
1600
1601        // Redoing the deletion should also redo the merge.
1602        buffer.update(cx, |buffer, cx| buffer.redo(cx));
1603        assert_eq!(
1604            messages(&assistant, cx),
1605            vec![
1606                (message_1.id, Role::User, 0..3),
1607                (message_3.id, Role::User, 3..4),
1608            ]
1609        );
1610
1611        // Ensure we can still insert after a merged message.
1612        let message_5 = assistant.update(cx, |assistant, cx| {
1613            assistant
1614                .insert_message_after(message_1.id, Role::System, cx)
1615                .unwrap()
1616        });
1617        assert_eq!(
1618            messages(&assistant, cx),
1619            vec![
1620                (message_1.id, Role::User, 0..3),
1621                (message_5.id, Role::System, 3..4),
1622                (message_3.id, Role::User, 4..5)
1623            ]
1624        );
1625    }
1626
1627    #[gpui::test]
1628    fn test_message_splitting(cx: &mut AppContext) {
1629        let registry = Arc::new(LanguageRegistry::test());
1630        let assistant = cx.add_model(|cx| Assistant::new(Default::default(), registry, cx));
1631        let buffer = assistant.read(cx).buffer.clone();
1632
1633        let message_1 = assistant.read(cx).messages[0].clone();
1634        assert_eq!(
1635            messages(&assistant, cx),
1636            vec![(message_1.id, Role::User, 0..0)]
1637        );
1638
1639        buffer.update(cx, |buffer, cx| {
1640            buffer.edit([(0..0, "aaa\nbbb\nccc\nddd\n")], None, cx)
1641        });
1642
1643        let (_, message_2) =
1644            assistant.update(cx, |assistant, cx| assistant.split_message(3..3, cx));
1645        let message_2 = message_2.unwrap();
1646
1647        // We recycle newlines in the middle of a split message
1648        assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\nddd\n");
1649        assert_eq!(
1650            messages(&assistant, cx),
1651            vec![
1652                (message_1.id, Role::User, 0..4),
1653                (message_2.id, Role::User, 4..16),
1654            ]
1655        );
1656
1657        let (_, message_3) =
1658            assistant.update(cx, |assistant, cx| assistant.split_message(3..3, cx));
1659        let message_3 = message_3.unwrap();
1660
1661        // We don't recycle newlines at the end of a split message
1662        assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
1663        assert_eq!(
1664            messages(&assistant, cx),
1665            vec![
1666                (message_1.id, Role::User, 0..4),
1667                (message_3.id, Role::User, 4..5),
1668                (message_2.id, Role::User, 5..17),
1669            ]
1670        );
1671
1672        let (_, message_4) =
1673            assistant.update(cx, |assistant, cx| assistant.split_message(9..9, cx));
1674        let message_4 = message_4.unwrap();
1675        assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
1676        assert_eq!(
1677            messages(&assistant, cx),
1678            vec![
1679                (message_1.id, Role::User, 0..4),
1680                (message_3.id, Role::User, 4..5),
1681                (message_2.id, Role::User, 5..9),
1682                (message_4.id, Role::User, 9..17),
1683            ]
1684        );
1685
1686        let (_, message_5) =
1687            assistant.update(cx, |assistant, cx| assistant.split_message(9..9, cx));
1688        let message_5 = message_5.unwrap();
1689        assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\nddd\n");
1690        assert_eq!(
1691            messages(&assistant, cx),
1692            vec![
1693                (message_1.id, Role::User, 0..4),
1694                (message_3.id, Role::User, 4..5),
1695                (message_2.id, Role::User, 5..9),
1696                (message_4.id, Role::User, 9..10),
1697                (message_5.id, Role::User, 10..18),
1698            ]
1699        );
1700
1701        let (message_6, message_7) =
1702            assistant.update(cx, |assistant, cx| assistant.split_message(14..16, cx));
1703        let message_6 = message_6.unwrap();
1704        let message_7 = message_7.unwrap();
1705        assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\ndd\nd\n");
1706        assert_eq!(
1707            messages(&assistant, cx),
1708            vec![
1709                (message_1.id, Role::User, 0..4),
1710                (message_3.id, Role::User, 4..5),
1711                (message_2.id, Role::User, 5..9),
1712                (message_4.id, Role::User, 9..10),
1713                (message_5.id, Role::User, 10..14),
1714                (message_6.id, Role::User, 14..17),
1715                (message_7.id, Role::User, 17..19),
1716            ]
1717        );
1718    }
1719
1720    fn messages(
1721        assistant: &ModelHandle<Assistant>,
1722        cx: &AppContext,
1723    ) -> Vec<(MessageId, Role, Range<usize>)> {
1724        assistant
1725            .read(cx)
1726            .messages(cx)
1727            .map(|(_, message, metadata, range)| (message.id, metadata.role, range))
1728            .collect()
1729    }
1730}