1use crate::{
2 assistant_settings::{AssistantDockPosition, AssistantSettings},
3 OpenAIRequest, OpenAIResponseStreamEvent, RequestMessage, Role,
4};
5use anyhow::{anyhow, Result};
6use chrono::{DateTime, Local};
7use collections::{HashMap, HashSet};
8use editor::{
9 display_map::{BlockDisposition, BlockId, BlockProperties, BlockStyle, ToDisplayPoint},
10 scroll::autoscroll::{Autoscroll, AutoscrollStrategy},
11 Anchor, Editor, ToOffset,
12};
13use fs::Fs;
14use futures::{io::BufReader, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
15use gpui::{
16 actions,
17 elements::*,
18 executor::Background,
19 geometry::vector::{vec2f, Vector2F},
20 platform::{CursorStyle, MouseButton},
21 Action, AppContext, AsyncAppContext, ClipboardItem, Entity, ModelContext, ModelHandle,
22 Subscription, Task, View, ViewContext, ViewHandle, WeakViewHandle, WindowContext,
23};
24use isahc::{http::StatusCode, Request, RequestExt};
25use language::{language_settings::SoftWrap, Buffer, LanguageRegistry, ToOffset as _};
26use serde::Deserialize;
27use settings::SettingsStore;
28use std::{
29 borrow::Cow, cell::RefCell, cmp, fmt::Write, io, iter, ops::Range, rc::Rc, sync::Arc,
30 time::Duration,
31};
32use util::{channel::ReleaseChannel, post_inc, truncate_and_trailoff, ResultExt, TryFutureExt};
33use workspace::{
34 dock::{DockPosition, Panel},
35 item::Item,
36 pane, Pane, Workspace,
37};
38
39const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
40
41actions!(
42 assistant,
43 [
44 NewContext,
45 Assist,
46 Split,
47 CycleMessageRole,
48 QuoteSelection,
49 ToggleFocus,
50 ResetKey
51 ]
52);
53
54pub fn init(cx: &mut AppContext) {
55 if *util::channel::RELEASE_CHANNEL == ReleaseChannel::Stable {
56 cx.update_default_global::<collections::CommandPaletteFilter, _, _>(move |filter, _cx| {
57 filter.filtered_namespaces.insert("assistant");
58 });
59 }
60
61 settings::register::<AssistantSettings>(cx);
62 cx.add_action(
63 |workspace: &mut Workspace, _: &NewContext, cx: &mut ViewContext<Workspace>| {
64 if let Some(this) = workspace.panel::<AssistantPanel>(cx) {
65 this.update(cx, |this, cx| this.add_context(cx))
66 }
67
68 workspace.focus_panel::<AssistantPanel>(cx);
69 },
70 );
71 cx.add_action(AssistantEditor::assist);
72 cx.capture_action(AssistantEditor::cancel_last_assist);
73 cx.add_action(AssistantEditor::quote_selection);
74 cx.capture_action(AssistantEditor::copy);
75 cx.capture_action(AssistantEditor::split);
76 cx.capture_action(AssistantEditor::cycle_message_role);
77 cx.add_action(AssistantPanel::save_api_key);
78 cx.add_action(AssistantPanel::reset_api_key);
79 cx.add_action(
80 |workspace: &mut Workspace, _: &ToggleFocus, cx: &mut ViewContext<Workspace>| {
81 workspace.toggle_panel_focus::<AssistantPanel>(cx);
82 },
83 );
84}
85
86pub enum AssistantPanelEvent {
87 ZoomIn,
88 ZoomOut,
89 Focus,
90 Close,
91 DockPositionChanged,
92}
93
94pub struct AssistantPanel {
95 width: Option<f32>,
96 height: Option<f32>,
97 pane: ViewHandle<Pane>,
98 api_key: Rc<RefCell<Option<String>>>,
99 api_key_editor: Option<ViewHandle<Editor>>,
100 has_read_credentials: bool,
101 languages: Arc<LanguageRegistry>,
102 fs: Arc<dyn Fs>,
103 subscriptions: Vec<Subscription>,
104}
105
106impl AssistantPanel {
107 pub fn load(
108 workspace: WeakViewHandle<Workspace>,
109 cx: AsyncAppContext,
110 ) -> Task<Result<ViewHandle<Self>>> {
111 cx.spawn(|mut cx| async move {
112 // TODO: deserialize state.
113 workspace.update(&mut cx, |workspace, cx| {
114 cx.add_view::<Self, _>(|cx| {
115 let weak_self = cx.weak_handle();
116 let pane = cx.add_view(|cx| {
117 let mut pane = Pane::new(
118 workspace.weak_handle(),
119 workspace.project().clone(),
120 workspace.app_state().background_actions,
121 Default::default(),
122 cx,
123 );
124 pane.set_can_split(false, cx);
125 pane.set_can_navigate(false, cx);
126 pane.on_can_drop(move |_, _| false);
127 pane.set_render_tab_bar_buttons(cx, move |pane, cx| {
128 let weak_self = weak_self.clone();
129 Flex::row()
130 .with_child(Pane::render_tab_bar_button(
131 0,
132 "icons/plus_12.svg",
133 false,
134 Some(("New Context".into(), Some(Box::new(NewContext)))),
135 cx,
136 move |_, cx| {
137 let weak_self = weak_self.clone();
138 cx.window_context().defer(move |cx| {
139 if let Some(this) = weak_self.upgrade(cx) {
140 this.update(cx, |this, cx| this.add_context(cx));
141 }
142 })
143 },
144 None,
145 ))
146 .with_child(Pane::render_tab_bar_button(
147 1,
148 if pane.is_zoomed() {
149 "icons/minimize_8.svg"
150 } else {
151 "icons/maximize_8.svg"
152 },
153 pane.is_zoomed(),
154 Some((
155 "Toggle Zoom".into(),
156 Some(Box::new(workspace::ToggleZoom)),
157 )),
158 cx,
159 move |pane, cx| pane.toggle_zoom(&Default::default(), cx),
160 None,
161 ))
162 .into_any()
163 });
164 let buffer_search_bar = cx.add_view(search::BufferSearchBar::new);
165 pane.toolbar()
166 .update(cx, |toolbar, cx| toolbar.add_item(buffer_search_bar, cx));
167 pane
168 });
169
170 let mut this = Self {
171 pane,
172 api_key: Rc::new(RefCell::new(None)),
173 api_key_editor: None,
174 has_read_credentials: false,
175 languages: workspace.app_state().languages.clone(),
176 fs: workspace.app_state().fs.clone(),
177 width: None,
178 height: None,
179 subscriptions: Default::default(),
180 };
181
182 let mut old_dock_position = this.position(cx);
183 this.subscriptions = vec![
184 cx.observe(&this.pane, |_, _, cx| cx.notify()),
185 cx.subscribe(&this.pane, Self::handle_pane_event),
186 cx.observe_global::<SettingsStore, _>(move |this, cx| {
187 let new_dock_position = this.position(cx);
188 if new_dock_position != old_dock_position {
189 old_dock_position = new_dock_position;
190 cx.emit(AssistantPanelEvent::DockPositionChanged);
191 }
192 }),
193 ];
194
195 this
196 })
197 })
198 })
199 }
200
201 fn handle_pane_event(
202 &mut self,
203 _pane: ViewHandle<Pane>,
204 event: &pane::Event,
205 cx: &mut ViewContext<Self>,
206 ) {
207 match event {
208 pane::Event::ZoomIn => cx.emit(AssistantPanelEvent::ZoomIn),
209 pane::Event::ZoomOut => cx.emit(AssistantPanelEvent::ZoomOut),
210 pane::Event::Focus => cx.emit(AssistantPanelEvent::Focus),
211 pane::Event::Remove => cx.emit(AssistantPanelEvent::Close),
212 _ => {}
213 }
214 }
215
216 fn add_context(&mut self, cx: &mut ViewContext<Self>) {
217 let focus = self.has_focus(cx);
218 let editor = cx
219 .add_view(|cx| AssistantEditor::new(self.api_key.clone(), self.languages.clone(), cx));
220 self.subscriptions
221 .push(cx.subscribe(&editor, Self::handle_assistant_editor_event));
222 self.pane.update(cx, |pane, cx| {
223 pane.add_item(Box::new(editor), true, focus, None, cx)
224 });
225 }
226
227 fn handle_assistant_editor_event(
228 &mut self,
229 _: ViewHandle<AssistantEditor>,
230 event: &AssistantEditorEvent,
231 cx: &mut ViewContext<Self>,
232 ) {
233 match event {
234 AssistantEditorEvent::TabContentChanged => self.pane.update(cx, |_, cx| cx.notify()),
235 }
236 }
237
238 fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
239 if let Some(api_key) = self
240 .api_key_editor
241 .as_ref()
242 .map(|editor| editor.read(cx).text(cx))
243 {
244 if !api_key.is_empty() {
245 cx.platform()
246 .write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
247 .log_err();
248 *self.api_key.borrow_mut() = Some(api_key);
249 self.api_key_editor.take();
250 cx.focus_self();
251 cx.notify();
252 }
253 } else {
254 cx.propagate_action();
255 }
256 }
257
258 fn reset_api_key(&mut self, _: &ResetKey, cx: &mut ViewContext<Self>) {
259 cx.platform().delete_credentials(OPENAI_API_URL).log_err();
260 self.api_key.take();
261 self.api_key_editor = Some(build_api_key_editor(cx));
262 cx.focus_self();
263 cx.notify();
264 }
265}
266
267fn build_api_key_editor(cx: &mut ViewContext<AssistantPanel>) -> ViewHandle<Editor> {
268 cx.add_view(|cx| {
269 let mut editor = Editor::single_line(
270 Some(Arc::new(|theme| theme.assistant.api_key_editor.clone())),
271 cx,
272 );
273 editor.set_placeholder_text("sk-000000000000000000000000000000000000000000000000", cx);
274 editor
275 })
276}
277
278impl Entity for AssistantPanel {
279 type Event = AssistantPanelEvent;
280}
281
282impl View for AssistantPanel {
283 fn ui_name() -> &'static str {
284 "AssistantPanel"
285 }
286
287 fn render(&mut self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
288 let style = &theme::current(cx).assistant;
289 if let Some(api_key_editor) = self.api_key_editor.as_ref() {
290 Flex::column()
291 .with_child(
292 Text::new(
293 "Paste your OpenAI API key and press Enter to use the assistant",
294 style.api_key_prompt.text.clone(),
295 )
296 .aligned(),
297 )
298 .with_child(
299 ChildView::new(api_key_editor, cx)
300 .contained()
301 .with_style(style.api_key_editor.container)
302 .aligned(),
303 )
304 .contained()
305 .with_style(style.api_key_prompt.container)
306 .aligned()
307 .into_any()
308 } else {
309 ChildView::new(&self.pane, cx).into_any()
310 }
311 }
312
313 fn focus_in(&mut self, _: gpui::AnyViewHandle, cx: &mut ViewContext<Self>) {
314 if cx.is_self_focused() {
315 if let Some(api_key_editor) = self.api_key_editor.as_ref() {
316 cx.focus(api_key_editor);
317 } else {
318 cx.focus(&self.pane);
319 }
320 }
321 }
322}
323
324impl Panel for AssistantPanel {
325 fn position(&self, cx: &WindowContext) -> DockPosition {
326 match settings::get::<AssistantSettings>(cx).dock {
327 AssistantDockPosition::Left => DockPosition::Left,
328 AssistantDockPosition::Bottom => DockPosition::Bottom,
329 AssistantDockPosition::Right => DockPosition::Right,
330 }
331 }
332
333 fn position_is_valid(&self, _: DockPosition) -> bool {
334 true
335 }
336
337 fn set_position(&mut self, position: DockPosition, cx: &mut ViewContext<Self>) {
338 settings::update_settings_file::<AssistantSettings>(self.fs.clone(), cx, move |settings| {
339 let dock = match position {
340 DockPosition::Left => AssistantDockPosition::Left,
341 DockPosition::Bottom => AssistantDockPosition::Bottom,
342 DockPosition::Right => AssistantDockPosition::Right,
343 };
344 settings.dock = Some(dock);
345 });
346 }
347
348 fn size(&self, cx: &WindowContext) -> f32 {
349 let settings = settings::get::<AssistantSettings>(cx);
350 match self.position(cx) {
351 DockPosition::Left | DockPosition::Right => {
352 self.width.unwrap_or_else(|| settings.default_width)
353 }
354 DockPosition::Bottom => self.height.unwrap_or_else(|| settings.default_height),
355 }
356 }
357
358 fn set_size(&mut self, size: f32, cx: &mut ViewContext<Self>) {
359 match self.position(cx) {
360 DockPosition::Left | DockPosition::Right => self.width = Some(size),
361 DockPosition::Bottom => self.height = Some(size),
362 }
363 cx.notify();
364 }
365
366 fn should_zoom_in_on_event(event: &AssistantPanelEvent) -> bool {
367 matches!(event, AssistantPanelEvent::ZoomIn)
368 }
369
370 fn should_zoom_out_on_event(event: &AssistantPanelEvent) -> bool {
371 matches!(event, AssistantPanelEvent::ZoomOut)
372 }
373
374 fn is_zoomed(&self, cx: &WindowContext) -> bool {
375 self.pane.read(cx).is_zoomed()
376 }
377
378 fn set_zoomed(&mut self, zoomed: bool, cx: &mut ViewContext<Self>) {
379 self.pane.update(cx, |pane, cx| pane.set_zoomed(zoomed, cx));
380 }
381
382 fn set_active(&mut self, active: bool, cx: &mut ViewContext<Self>) {
383 if active {
384 if self.api_key.borrow().is_none() && !self.has_read_credentials {
385 self.has_read_credentials = true;
386 let api_key = if let Some((_, api_key)) = cx
387 .platform()
388 .read_credentials(OPENAI_API_URL)
389 .log_err()
390 .flatten()
391 {
392 String::from_utf8(api_key).log_err()
393 } else {
394 None
395 };
396 if let Some(api_key) = api_key {
397 *self.api_key.borrow_mut() = Some(api_key);
398 } else if self.api_key_editor.is_none() {
399 self.api_key_editor = Some(build_api_key_editor(cx));
400 cx.notify();
401 }
402 }
403
404 if self.pane.read(cx).items_len() == 0 {
405 self.add_context(cx);
406 }
407 }
408 }
409
410 fn icon_path(&self) -> &'static str {
411 "icons/robot_14.svg"
412 }
413
414 fn icon_tooltip(&self) -> (String, Option<Box<dyn Action>>) {
415 ("Assistant Panel".into(), Some(Box::new(ToggleFocus)))
416 }
417
418 fn should_change_position_on_event(event: &Self::Event) -> bool {
419 matches!(event, AssistantPanelEvent::DockPositionChanged)
420 }
421
422 fn should_activate_on_event(_: &Self::Event) -> bool {
423 false
424 }
425
426 fn should_close_on_event(event: &AssistantPanelEvent) -> bool {
427 matches!(event, AssistantPanelEvent::Close)
428 }
429
430 fn has_focus(&self, cx: &WindowContext) -> bool {
431 self.pane.read(cx).has_focus()
432 || self
433 .api_key_editor
434 .as_ref()
435 .map_or(false, |editor| editor.is_focused(cx))
436 }
437
438 fn is_focus_event(event: &Self::Event) -> bool {
439 matches!(event, AssistantPanelEvent::Focus)
440 }
441}
442
443enum AssistantEvent {
444 MessagesEdited,
445 SummaryChanged,
446 StreamedCompletion,
447}
448
449struct Assistant {
450 buffer: ModelHandle<Buffer>,
451 message_anchors: Vec<MessageAnchor>,
452 messages_metadata: HashMap<MessageId, MessageMetadata>,
453 next_message_id: MessageId,
454 summary: Option<String>,
455 pending_summary: Task<Option<()>>,
456 completion_count: usize,
457 pending_completions: Vec<PendingCompletion>,
458 model: String,
459 token_count: Option<usize>,
460 max_token_count: usize,
461 pending_token_count: Task<Option<()>>,
462 api_key: Rc<RefCell<Option<String>>>,
463 _subscriptions: Vec<Subscription>,
464}
465
466impl Entity for Assistant {
467 type Event = AssistantEvent;
468}
469
470impl Assistant {
471 fn new(
472 api_key: Rc<RefCell<Option<String>>>,
473 language_registry: Arc<LanguageRegistry>,
474 cx: &mut ModelContext<Self>,
475 ) -> Self {
476 let model = "gpt-3.5-turbo-0613";
477 let markdown = language_registry.language_for_name("Markdown");
478 let buffer = cx.add_model(|cx| {
479 let mut buffer = Buffer::new(0, "", cx);
480 buffer.set_language_registry(language_registry);
481 cx.spawn_weak(|buffer, mut cx| async move {
482 let markdown = markdown.await?;
483 let buffer = buffer
484 .upgrade(&cx)
485 .ok_or_else(|| anyhow!("buffer was dropped"))?;
486 buffer.update(&mut cx, |buffer, cx| {
487 buffer.set_language(Some(markdown), cx)
488 });
489 anyhow::Ok(())
490 })
491 .detach_and_log_err(cx);
492 buffer
493 });
494
495 let mut this = Self {
496 message_anchors: Default::default(),
497 messages_metadata: Default::default(),
498 next_message_id: Default::default(),
499 summary: None,
500 pending_summary: Task::ready(None),
501 completion_count: Default::default(),
502 pending_completions: Default::default(),
503 token_count: None,
504 max_token_count: tiktoken_rs::model::get_context_size(model),
505 pending_token_count: Task::ready(None),
506 model: model.into(),
507 _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
508 api_key,
509 buffer,
510 };
511 let message = MessageAnchor {
512 id: MessageId(post_inc(&mut this.next_message_id.0)),
513 start: language::Anchor::MIN,
514 };
515 this.message_anchors.push(message.clone());
516 this.messages_metadata.insert(
517 message.id,
518 MessageMetadata {
519 role: Role::User,
520 sent_at: Local::now(),
521 status: MessageStatus::Done,
522 },
523 );
524
525 this.count_remaining_tokens(cx);
526 this
527 }
528
529 fn handle_buffer_event(
530 &mut self,
531 _: ModelHandle<Buffer>,
532 event: &language::Event,
533 cx: &mut ModelContext<Self>,
534 ) {
535 match event {
536 language::Event::Edited => {
537 self.count_remaining_tokens(cx);
538 cx.emit(AssistantEvent::MessagesEdited);
539 }
540 _ => {}
541 }
542 }
543
544 fn count_remaining_tokens(&mut self, cx: &mut ModelContext<Self>) {
545 let messages = self
546 .messages(cx)
547 .into_iter()
548 .filter_map(|message| {
549 Some(tiktoken_rs::ChatCompletionRequestMessage {
550 role: match message.role {
551 Role::User => "user".into(),
552 Role::Assistant => "assistant".into(),
553 Role::System => "system".into(),
554 },
555 content: self.buffer.read(cx).text_for_range(message.range).collect(),
556 name: None,
557 })
558 })
559 .collect::<Vec<_>>();
560 let model = self.model.clone();
561 self.pending_token_count = cx.spawn_weak(|this, mut cx| {
562 async move {
563 cx.background().timer(Duration::from_millis(200)).await;
564 let token_count = cx
565 .background()
566 .spawn(async move { tiktoken_rs::num_tokens_from_messages(&model, &messages) })
567 .await?;
568
569 this.upgrade(&cx)
570 .ok_or_else(|| anyhow!("assistant was dropped"))?
571 .update(&mut cx, |this, cx| {
572 this.max_token_count = tiktoken_rs::model::get_context_size(&this.model);
573 this.token_count = Some(token_count);
574 cx.notify()
575 });
576 anyhow::Ok(())
577 }
578 .log_err()
579 });
580 }
581
582 fn remaining_tokens(&self) -> Option<isize> {
583 Some(self.max_token_count as isize - self.token_count? as isize)
584 }
585
586 fn set_model(&mut self, model: String, cx: &mut ModelContext<Self>) {
587 self.model = model;
588 self.count_remaining_tokens(cx);
589 cx.notify();
590 }
591
592 fn assist(
593 &mut self,
594 selected_messages: HashSet<MessageId>,
595 cx: &mut ModelContext<Self>,
596 ) -> Vec<MessageAnchor> {
597 let mut user_messages = Vec::new();
598 let mut tasks = Vec::new();
599 for selected_message_id in selected_messages {
600 let selected_message_role =
601 if let Some(metadata) = self.messages_metadata.get(&selected_message_id) {
602 metadata.role
603 } else {
604 continue;
605 };
606
607 if selected_message_role == Role::Assistant {
608 if let Some(user_message) = self.insert_message_after(
609 selected_message_id,
610 Role::User,
611 MessageStatus::Done,
612 cx,
613 ) {
614 user_messages.push(user_message);
615 } else {
616 continue;
617 }
618 } else {
619 let request = OpenAIRequest {
620 model: self.model.clone(),
621 messages: self
622 .messages(cx)
623 .filter(|message| matches!(message.status, MessageStatus::Done))
624 .flat_map(|message| {
625 let mut system_message = None;
626 if message.id == selected_message_id {
627 system_message = Some(RequestMessage {
628 role: Role::System,
629 content: concat!(
630 "Treat the following messages as additional knowledge you have learned about, ",
631 "but act as if they were not part of this conversation. That is, treat them ",
632 "as if the user didn't see them and couldn't possibly inquire about them."
633 ).into()
634 });
635 }
636
637 Some(message.to_open_ai_message(self.buffer.read(cx))).into_iter().chain(system_message)
638 })
639 .chain(Some(RequestMessage {
640 role: Role::System,
641 content: format!(
642 "Direct your reply to message with id {}. Do not include a [Message X] header.",
643 selected_message_id.0
644 ),
645 }))
646 .collect(),
647 stream: true,
648 };
649
650 let Some(api_key) = self.api_key.borrow().clone() else { continue };
651 let stream = stream_completion(api_key, cx.background().clone(), request);
652 let assistant_message = self
653 .insert_message_after(
654 selected_message_id,
655 Role::Assistant,
656 MessageStatus::Pending,
657 cx,
658 )
659 .unwrap();
660
661 tasks.push(cx.spawn_weak({
662 |this, mut cx| async move {
663 let assistant_message_id = assistant_message.id;
664 let stream_completion = async {
665 let mut messages = stream.await?;
666
667 while let Some(message) = messages.next().await {
668 let mut message = message?;
669 if let Some(choice) = message.choices.pop() {
670 this.upgrade(&cx)
671 .ok_or_else(|| anyhow!("assistant was dropped"))?
672 .update(&mut cx, |this, cx| {
673 let text: Arc<str> = choice.delta.content?.into();
674 let message_ix = this.message_anchors.iter().position(
675 |message| message.id == assistant_message_id,
676 )?;
677 this.buffer.update(cx, |buffer, cx| {
678 let offset = this.message_anchors[message_ix + 1..]
679 .iter()
680 .find(|message| message.start.is_valid(buffer))
681 .map_or(buffer.len(), |message| {
682 message
683 .start
684 .to_offset(buffer)
685 .saturating_sub(1)
686 });
687 buffer.edit([(offset..offset, text)], None, cx);
688 });
689 cx.emit(AssistantEvent::StreamedCompletion);
690
691 Some(())
692 });
693 }
694 smol::future::yield_now().await;
695 }
696
697 this.upgrade(&cx)
698 .ok_or_else(|| anyhow!("assistant was dropped"))?
699 .update(&mut cx, |this, cx| {
700 this.pending_completions.retain(|completion| {
701 completion.id != this.completion_count
702 });
703 this.summarize(cx);
704 });
705
706 anyhow::Ok(())
707 };
708
709 let result = stream_completion.await;
710 if let Some(this) = this.upgrade(&cx) {
711 this.update(&mut cx, |this, cx| {
712 if let Some(metadata) =
713 this.messages_metadata.get_mut(&assistant_message.id)
714 {
715 match result {
716 Ok(_) => {
717 metadata.status = MessageStatus::Done;
718 }
719 Err(error) => {
720 metadata.status = MessageStatus::Error(
721 error.to_string().trim().into(),
722 );
723 }
724 }
725 cx.notify();
726 }
727 });
728 }
729 }
730 }));
731 }
732 }
733
734 if !tasks.is_empty() {
735 self.pending_completions.push(PendingCompletion {
736 id: post_inc(&mut self.completion_count),
737 _tasks: tasks,
738 });
739 }
740
741 user_messages
742 }
743
744 fn cancel_last_assist(&mut self) -> bool {
745 self.pending_completions.pop().is_some()
746 }
747
748 fn cycle_message_roles(&mut self, ids: HashSet<MessageId>, cx: &mut ModelContext<Self>) {
749 for id in ids {
750 if let Some(metadata) = self.messages_metadata.get_mut(&id) {
751 metadata.role.cycle();
752 cx.emit(AssistantEvent::MessagesEdited);
753 cx.notify();
754 }
755 }
756 }
757
758 fn insert_message_after(
759 &mut self,
760 message_id: MessageId,
761 role: Role,
762 status: MessageStatus,
763 cx: &mut ModelContext<Self>,
764 ) -> Option<MessageAnchor> {
765 if let Some(prev_message_ix) = self
766 .message_anchors
767 .iter()
768 .position(|message| message.id == message_id)
769 {
770 let start = self.buffer.update(cx, |buffer, cx| {
771 let offset = self.message_anchors[prev_message_ix + 1..]
772 .iter()
773 .find(|message| message.start.is_valid(buffer))
774 .map_or(buffer.len(), |message| message.start.to_offset(buffer) - 1);
775 buffer.edit([(offset..offset, "\n")], None, cx);
776 buffer.anchor_before(offset + 1)
777 });
778 let message = MessageAnchor {
779 id: MessageId(post_inc(&mut self.next_message_id.0)),
780 start,
781 };
782 self.message_anchors
783 .insert(prev_message_ix + 1, message.clone());
784 self.messages_metadata.insert(
785 message.id,
786 MessageMetadata {
787 role,
788 sent_at: Local::now(),
789 status,
790 },
791 );
792 cx.emit(AssistantEvent::MessagesEdited);
793 Some(message)
794 } else {
795 None
796 }
797 }
798
799 fn split_message(
800 &mut self,
801 range: Range<usize>,
802 cx: &mut ModelContext<Self>,
803 ) -> (Option<MessageAnchor>, Option<MessageAnchor>) {
804 let start_message = self.message_for_offset(range.start, cx);
805 let end_message = self.message_for_offset(range.end, cx);
806 if let Some((start_message, end_message)) = start_message.zip(end_message) {
807 // Prevent splitting when range spans multiple messages.
808 if start_message.index != end_message.index {
809 return (None, None);
810 }
811
812 let message = start_message;
813 let role = message.role;
814 let mut edited_buffer = false;
815
816 let mut suffix_start = None;
817 if range.start > message.range.start && range.end < message.range.end - 1 {
818 if self.buffer.read(cx).chars_at(range.end).next() == Some('\n') {
819 suffix_start = Some(range.end + 1);
820 } else if self.buffer.read(cx).reversed_chars_at(range.end).next() == Some('\n') {
821 suffix_start = Some(range.end);
822 }
823 }
824
825 let suffix = if let Some(suffix_start) = suffix_start {
826 MessageAnchor {
827 id: MessageId(post_inc(&mut self.next_message_id.0)),
828 start: self.buffer.read(cx).anchor_before(suffix_start),
829 }
830 } else {
831 self.buffer.update(cx, |buffer, cx| {
832 buffer.edit([(range.end..range.end, "\n")], None, cx);
833 });
834 edited_buffer = true;
835 MessageAnchor {
836 id: MessageId(post_inc(&mut self.next_message_id.0)),
837 start: self.buffer.read(cx).anchor_before(range.end + 1),
838 }
839 };
840
841 self.message_anchors
842 .insert(message.index + 1, suffix.clone());
843 self.messages_metadata.insert(
844 suffix.id,
845 MessageMetadata {
846 role,
847 sent_at: Local::now(),
848 status: MessageStatus::Done,
849 },
850 );
851
852 let new_messages = if range.start == range.end || range.start == message.range.start {
853 (None, Some(suffix))
854 } else {
855 let mut prefix_end = None;
856 if range.start > message.range.start && range.end < message.range.end - 1 {
857 if self.buffer.read(cx).chars_at(range.start).next() == Some('\n') {
858 prefix_end = Some(range.start + 1);
859 } else if self.buffer.read(cx).reversed_chars_at(range.start).next()
860 == Some('\n')
861 {
862 prefix_end = Some(range.start);
863 }
864 }
865
866 let selection = if let Some(prefix_end) = prefix_end {
867 cx.emit(AssistantEvent::MessagesEdited);
868 MessageAnchor {
869 id: MessageId(post_inc(&mut self.next_message_id.0)),
870 start: self.buffer.read(cx).anchor_before(prefix_end),
871 }
872 } else {
873 self.buffer.update(cx, |buffer, cx| {
874 buffer.edit([(range.start..range.start, "\n")], None, cx)
875 });
876 edited_buffer = true;
877 MessageAnchor {
878 id: MessageId(post_inc(&mut self.next_message_id.0)),
879 start: self.buffer.read(cx).anchor_before(range.end + 1),
880 }
881 };
882
883 self.message_anchors
884 .insert(message.index + 1, selection.clone());
885 self.messages_metadata.insert(
886 selection.id,
887 MessageMetadata {
888 role,
889 sent_at: Local::now(),
890 status: MessageStatus::Done,
891 },
892 );
893 (Some(selection), Some(suffix))
894 };
895
896 if !edited_buffer {
897 cx.emit(AssistantEvent::MessagesEdited);
898 }
899 new_messages
900 } else {
901 (None, None)
902 }
903 }
904
905 fn summarize(&mut self, cx: &mut ModelContext<Self>) {
906 if self.message_anchors.len() >= 2 && self.summary.is_none() {
907 let api_key = self.api_key.borrow().clone();
908 if let Some(api_key) = api_key {
909 let messages = self
910 .messages(cx)
911 .take(2)
912 .map(|message| message.to_open_ai_message(self.buffer.read(cx)))
913 .chain(Some(RequestMessage {
914 role: Role::User,
915 content:
916 "Summarize the conversation into a short title without punctuation"
917 .into(),
918 }));
919 let request = OpenAIRequest {
920 model: self.model.clone(),
921 messages: messages.collect(),
922 stream: true,
923 };
924
925 let stream = stream_completion(api_key, cx.background().clone(), request);
926 self.pending_summary = cx.spawn(|this, mut cx| {
927 async move {
928 let mut messages = stream.await?;
929
930 while let Some(message) = messages.next().await {
931 let mut message = message?;
932 if let Some(choice) = message.choices.pop() {
933 let text = choice.delta.content.unwrap_or_default();
934 this.update(&mut cx, |this, cx| {
935 this.summary.get_or_insert(String::new()).push_str(&text);
936 cx.emit(AssistantEvent::SummaryChanged);
937 });
938 }
939 }
940
941 anyhow::Ok(())
942 }
943 .log_err()
944 });
945 }
946 }
947 }
948
949 fn message_for_offset(&self, offset: usize, cx: &AppContext) -> Option<Message> {
950 self.messages_for_offsets([offset], cx).pop()
951 }
952
953 fn messages_for_offsets(
954 &self,
955 offsets: impl IntoIterator<Item = usize>,
956 cx: &AppContext,
957 ) -> Vec<Message> {
958 let mut result = Vec::new();
959
960 let buffer_len = self.buffer.read(cx).len();
961 let mut messages = self.messages(cx).peekable();
962 let mut offsets = offsets.into_iter().peekable();
963 while let Some(offset) = offsets.next() {
964 // Skip messages that start after the offset.
965 while messages.peek().map_or(false, |message| {
966 message.range.end < offset || (message.range.end == offset && offset < buffer_len)
967 }) {
968 messages.next();
969 }
970 let Some(message) = messages.peek() else { continue };
971
972 // Skip offsets that are in the same message.
973 while offsets.peek().map_or(false, |offset| {
974 message.range.contains(offset) || message.range.end == buffer_len
975 }) {
976 offsets.next();
977 }
978
979 result.push(message.clone());
980 }
981 result
982 }
983
984 fn messages<'a>(&'a self, cx: &'a AppContext) -> impl 'a + Iterator<Item = Message> {
985 let buffer = self.buffer.read(cx);
986 let mut message_anchors = self.message_anchors.iter().enumerate().peekable();
987 iter::from_fn(move || {
988 while let Some((ix, message_anchor)) = message_anchors.next() {
989 let metadata = self.messages_metadata.get(&message_anchor.id)?;
990 let message_start = message_anchor.start.to_offset(buffer);
991 let mut message_end = None;
992 while let Some((_, next_message)) = message_anchors.peek() {
993 if next_message.start.is_valid(buffer) {
994 message_end = Some(next_message.start);
995 break;
996 } else {
997 message_anchors.next();
998 }
999 }
1000 let message_end = message_end
1001 .unwrap_or(language::Anchor::MAX)
1002 .to_offset(buffer);
1003 return Some(Message {
1004 index: ix,
1005 range: message_start..message_end,
1006 id: message_anchor.id,
1007 anchor: message_anchor.start,
1008 role: metadata.role,
1009 sent_at: metadata.sent_at,
1010 status: metadata.status.clone(),
1011 });
1012 }
1013 None
1014 })
1015 }
1016}
1017
1018struct PendingCompletion {
1019 id: usize,
1020 _tasks: Vec<Task<()>>,
1021}
1022
1023enum AssistantEditorEvent {
1024 TabContentChanged,
1025}
1026
1027#[derive(Copy, Clone, Debug, PartialEq)]
1028struct ScrollPosition {
1029 offset_before_cursor: Vector2F,
1030 cursor: Anchor,
1031}
1032
1033struct AssistantEditor {
1034 assistant: ModelHandle<Assistant>,
1035 editor: ViewHandle<Editor>,
1036 blocks: HashSet<BlockId>,
1037 scroll_position: Option<ScrollPosition>,
1038 _subscriptions: Vec<Subscription>,
1039}
1040
1041impl AssistantEditor {
1042 fn new(
1043 api_key: Rc<RefCell<Option<String>>>,
1044 language_registry: Arc<LanguageRegistry>,
1045 cx: &mut ViewContext<Self>,
1046 ) -> Self {
1047 let assistant = cx.add_model(|cx| Assistant::new(api_key, language_registry, cx));
1048 let editor = cx.add_view(|cx| {
1049 let mut editor = Editor::for_buffer(assistant.read(cx).buffer.clone(), None, cx);
1050 editor.set_soft_wrap_mode(SoftWrap::EditorWidth, cx);
1051 editor.set_show_gutter(false, cx);
1052 editor
1053 });
1054
1055 let _subscriptions = vec![
1056 cx.observe(&assistant, |_, _, cx| cx.notify()),
1057 cx.subscribe(&assistant, Self::handle_assistant_event),
1058 cx.subscribe(&editor, Self::handle_editor_event),
1059 ];
1060
1061 let mut this = Self {
1062 assistant,
1063 editor,
1064 blocks: Default::default(),
1065 scroll_position: None,
1066 _subscriptions,
1067 };
1068 this.update_message_headers(cx);
1069 this
1070 }
1071
1072 fn assist(&mut self, _: &Assist, cx: &mut ViewContext<Self>) {
1073 let cursors = self.cursors(cx);
1074
1075 let user_messages = self.assistant.update(cx, |assistant, cx| {
1076 let selected_messages = assistant
1077 .messages_for_offsets(cursors, cx)
1078 .into_iter()
1079 .map(|message| message.id)
1080 .collect();
1081 assistant.assist(selected_messages, cx)
1082 });
1083 let new_selections = user_messages
1084 .iter()
1085 .map(|message| {
1086 let cursor = message
1087 .start
1088 .to_offset(self.assistant.read(cx).buffer.read(cx));
1089 cursor..cursor
1090 })
1091 .collect::<Vec<_>>();
1092 if !new_selections.is_empty() {
1093 self.editor.update(cx, |editor, cx| {
1094 editor.change_selections(
1095 Some(Autoscroll::Strategy(AutoscrollStrategy::Fit)),
1096 cx,
1097 |selections| selections.select_ranges(new_selections),
1098 );
1099 });
1100 }
1101 }
1102
1103 fn cancel_last_assist(&mut self, _: &editor::Cancel, cx: &mut ViewContext<Self>) {
1104 if !self
1105 .assistant
1106 .update(cx, |assistant, _| assistant.cancel_last_assist())
1107 {
1108 cx.propagate_action();
1109 }
1110 }
1111
1112 fn cycle_message_role(&mut self, _: &CycleMessageRole, cx: &mut ViewContext<Self>) {
1113 let cursors = self.cursors(cx);
1114 self.assistant.update(cx, |assistant, cx| {
1115 let messages = assistant
1116 .messages_for_offsets(cursors, cx)
1117 .into_iter()
1118 .map(|message| message.id)
1119 .collect();
1120 assistant.cycle_message_roles(messages, cx)
1121 });
1122 }
1123
1124 fn cursors(&self, cx: &AppContext) -> Vec<usize> {
1125 let selections = self.editor.read(cx).selections.all::<usize>(cx);
1126 selections
1127 .into_iter()
1128 .map(|selection| selection.head())
1129 .collect()
1130 }
1131
1132 fn handle_assistant_event(
1133 &mut self,
1134 _: ModelHandle<Assistant>,
1135 event: &AssistantEvent,
1136 cx: &mut ViewContext<Self>,
1137 ) {
1138 match event {
1139 AssistantEvent::MessagesEdited => self.update_message_headers(cx),
1140 AssistantEvent::SummaryChanged => {
1141 cx.emit(AssistantEditorEvent::TabContentChanged);
1142 }
1143 AssistantEvent::StreamedCompletion => {
1144 self.editor.update(cx, |editor, cx| {
1145 if let Some(scroll_position) = self.scroll_position {
1146 let snapshot = editor.snapshot(cx);
1147 let cursor_point = scroll_position.cursor.to_display_point(&snapshot);
1148 let scroll_top =
1149 cursor_point.row() as f32 - scroll_position.offset_before_cursor.y();
1150 editor.set_scroll_position(
1151 vec2f(scroll_position.offset_before_cursor.x(), scroll_top),
1152 cx,
1153 );
1154 }
1155 });
1156 }
1157 }
1158 }
1159
1160 fn handle_editor_event(
1161 &mut self,
1162 _: ViewHandle<Editor>,
1163 event: &editor::Event,
1164 cx: &mut ViewContext<Self>,
1165 ) {
1166 match event {
1167 editor::Event::ScrollPositionChanged { autoscroll, .. } => {
1168 let cursor_scroll_position = self.cursor_scroll_position(cx);
1169 if *autoscroll {
1170 self.scroll_position = cursor_scroll_position;
1171 } else if self.scroll_position != cursor_scroll_position {
1172 self.scroll_position = None;
1173 }
1174 }
1175 editor::Event::SelectionsChanged { .. } => {
1176 self.scroll_position = self.cursor_scroll_position(cx);
1177 }
1178 _ => {}
1179 }
1180 }
1181
1182 fn cursor_scroll_position(&self, cx: &mut ViewContext<Self>) -> Option<ScrollPosition> {
1183 self.editor.update(cx, |editor, cx| {
1184 let snapshot = editor.snapshot(cx);
1185 let cursor = editor.selections.newest_anchor().head();
1186 let cursor_row = cursor.to_display_point(&snapshot.display_snapshot).row() as f32;
1187 let scroll_position = editor
1188 .scroll_manager
1189 .anchor()
1190 .scroll_position(&snapshot.display_snapshot);
1191
1192 let scroll_bottom = scroll_position.y() + editor.visible_line_count().unwrap_or(0.);
1193 if (scroll_position.y()..scroll_bottom).contains(&cursor_row) {
1194 Some(ScrollPosition {
1195 cursor,
1196 offset_before_cursor: vec2f(
1197 scroll_position.x(),
1198 cursor_row - scroll_position.y(),
1199 ),
1200 })
1201 } else {
1202 None
1203 }
1204 })
1205 }
1206
1207 fn update_message_headers(&mut self, cx: &mut ViewContext<Self>) {
1208 self.editor.update(cx, |editor, cx| {
1209 let buffer = editor.buffer().read(cx).snapshot(cx);
1210 let excerpt_id = *buffer.as_singleton().unwrap().0;
1211 let old_blocks = std::mem::take(&mut self.blocks);
1212 let new_blocks = self
1213 .assistant
1214 .read(cx)
1215 .messages(cx)
1216 .map(|message| BlockProperties {
1217 position: buffer.anchor_in_excerpt(excerpt_id, message.anchor),
1218 height: 2,
1219 style: BlockStyle::Sticky,
1220 render: Arc::new({
1221 let assistant = self.assistant.clone();
1222 // let metadata = message.metadata.clone();
1223 // let message = message.clone();
1224 move |cx| {
1225 enum Sender {}
1226 enum ErrorTooltip {}
1227
1228 let theme = theme::current(cx);
1229 let style = &theme.assistant;
1230 let message_id = message.id;
1231 let sender = MouseEventHandler::<Sender, _>::new(
1232 message_id.0,
1233 cx,
1234 |state, _| match message.role {
1235 Role::User => {
1236 let style = style.user_sender.style_for(state, false);
1237 Label::new("You", style.text.clone())
1238 .contained()
1239 .with_style(style.container)
1240 }
1241 Role::Assistant => {
1242 let style = style.assistant_sender.style_for(state, false);
1243 Label::new("Assistant", style.text.clone())
1244 .contained()
1245 .with_style(style.container)
1246 }
1247 Role::System => {
1248 let style = style.system_sender.style_for(state, false);
1249 Label::new("System", style.text.clone())
1250 .contained()
1251 .with_style(style.container)
1252 }
1253 },
1254 )
1255 .with_cursor_style(CursorStyle::PointingHand)
1256 .on_down(MouseButton::Left, {
1257 let assistant = assistant.clone();
1258 move |_, _, cx| {
1259 assistant.update(cx, |assistant, cx| {
1260 assistant.cycle_message_roles(
1261 HashSet::from_iter(Some(message_id)),
1262 cx,
1263 )
1264 })
1265 }
1266 });
1267
1268 Flex::row()
1269 .with_child(sender.aligned())
1270 .with_child(
1271 Label::new(
1272 message.sent_at.format("%I:%M%P").to_string(),
1273 style.sent_at.text.clone(),
1274 )
1275 .contained()
1276 .with_style(style.sent_at.container)
1277 .aligned(),
1278 )
1279 .with_children(
1280 if let MessageStatus::Error(error) = &message.status {
1281 Some(
1282 Svg::new("icons/circle_x_mark_12.svg")
1283 .with_color(style.error_icon.color)
1284 .constrained()
1285 .with_width(style.error_icon.width)
1286 .contained()
1287 .with_style(style.error_icon.container)
1288 .with_tooltip::<ErrorTooltip>(
1289 message_id.0,
1290 error.to_string(),
1291 None,
1292 theme.tooltip.clone(),
1293 cx,
1294 )
1295 .aligned(),
1296 )
1297 } else {
1298 None
1299 },
1300 )
1301 .aligned()
1302 .left()
1303 .contained()
1304 .with_style(style.header)
1305 .into_any()
1306 }
1307 }),
1308 disposition: BlockDisposition::Above,
1309 })
1310 .collect::<Vec<_>>();
1311
1312 editor.remove_blocks(old_blocks, None, cx);
1313 let ids = editor.insert_blocks(new_blocks, None, cx);
1314 self.blocks = HashSet::from_iter(ids);
1315 });
1316 }
1317
1318 fn quote_selection(
1319 workspace: &mut Workspace,
1320 _: &QuoteSelection,
1321 cx: &mut ViewContext<Workspace>,
1322 ) {
1323 let Some(panel) = workspace.panel::<AssistantPanel>(cx) else {
1324 return;
1325 };
1326 let Some(editor) = workspace.active_item(cx).and_then(|item| item.downcast::<Editor>()) else {
1327 return;
1328 };
1329
1330 let text = editor.read_with(cx, |editor, cx| {
1331 let range = editor.selections.newest::<usize>(cx).range();
1332 let buffer = editor.buffer().read(cx).snapshot(cx);
1333 let start_language = buffer.language_at(range.start);
1334 let end_language = buffer.language_at(range.end);
1335 let language_name = if start_language == end_language {
1336 start_language.map(|language| language.name())
1337 } else {
1338 None
1339 };
1340 let language_name = language_name.as_deref().unwrap_or("").to_lowercase();
1341
1342 let selected_text = buffer.text_for_range(range).collect::<String>();
1343 if selected_text.is_empty() {
1344 None
1345 } else {
1346 Some(if language_name == "markdown" {
1347 selected_text
1348 .lines()
1349 .map(|line| format!("> {}", line))
1350 .collect::<Vec<_>>()
1351 .join("\n")
1352 } else {
1353 format!("```{language_name}\n{selected_text}\n```")
1354 })
1355 }
1356 });
1357
1358 // Activate the panel
1359 if !panel.read(cx).has_focus(cx) {
1360 workspace.toggle_panel_focus::<AssistantPanel>(cx);
1361 }
1362
1363 if let Some(text) = text {
1364 panel.update(cx, |panel, cx| {
1365 if let Some(assistant) = panel
1366 .pane
1367 .read(cx)
1368 .active_item()
1369 .and_then(|item| item.downcast::<AssistantEditor>())
1370 .ok_or_else(|| anyhow!("no active context"))
1371 .log_err()
1372 {
1373 assistant.update(cx, |assistant, cx| {
1374 assistant
1375 .editor
1376 .update(cx, |editor, cx| editor.insert(&text, cx))
1377 });
1378 }
1379 });
1380 }
1381 }
1382
1383 fn copy(&mut self, _: &editor::Copy, cx: &mut ViewContext<Self>) {
1384 let editor = self.editor.read(cx);
1385 let assistant = self.assistant.read(cx);
1386 if editor.selections.count() == 1 {
1387 let selection = editor.selections.newest::<usize>(cx);
1388 let mut copied_text = String::new();
1389 let mut spanned_messages = 0;
1390 for message in assistant.messages(cx) {
1391 if message.range.start >= selection.range().end {
1392 break;
1393 } else if message.range.end >= selection.range().start {
1394 let range = cmp::max(message.range.start, selection.range().start)
1395 ..cmp::min(message.range.end, selection.range().end);
1396 if !range.is_empty() {
1397 spanned_messages += 1;
1398 write!(&mut copied_text, "## {}\n\n", message.role).unwrap();
1399 for chunk in assistant.buffer.read(cx).text_for_range(range) {
1400 copied_text.push_str(&chunk);
1401 }
1402 copied_text.push('\n');
1403 }
1404 }
1405 }
1406
1407 if spanned_messages > 1 {
1408 cx.platform()
1409 .write_to_clipboard(ClipboardItem::new(copied_text));
1410 return;
1411 }
1412 }
1413
1414 cx.propagate_action();
1415 }
1416
1417 fn split(&mut self, _: &Split, cx: &mut ViewContext<Self>) {
1418 self.assistant.update(cx, |assistant, cx| {
1419 let selections = self.editor.read(cx).selections.disjoint_anchors();
1420 for selection in selections.into_iter() {
1421 let buffer = self.editor.read(cx).buffer().read(cx).snapshot(cx);
1422 let range = selection
1423 .map(|endpoint| endpoint.to_offset(&buffer))
1424 .range();
1425 assistant.split_message(range, cx);
1426 }
1427 });
1428 }
1429
1430 fn cycle_model(&mut self, cx: &mut ViewContext<Self>) {
1431 self.assistant.update(cx, |assistant, cx| {
1432 let new_model = match assistant.model.as_str() {
1433 "gpt-4-0613" => "gpt-3.5-turbo-0613",
1434 _ => "gpt-4-0613",
1435 };
1436 assistant.set_model(new_model.into(), cx);
1437 });
1438 }
1439
1440 fn title(&self, cx: &AppContext) -> String {
1441 self.assistant
1442 .read(cx)
1443 .summary
1444 .clone()
1445 .unwrap_or_else(|| "New Context".into())
1446 }
1447}
1448
1449impl Entity for AssistantEditor {
1450 type Event = AssistantEditorEvent;
1451}
1452
1453impl View for AssistantEditor {
1454 fn ui_name() -> &'static str {
1455 "AssistantEditor"
1456 }
1457
1458 fn render(&mut self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
1459 enum Model {}
1460 let theme = &theme::current(cx).assistant;
1461 let assistant = &self.assistant.read(cx);
1462 let model = assistant.model.clone();
1463 let remaining_tokens = assistant.remaining_tokens().map(|remaining_tokens| {
1464 let remaining_tokens_style = if remaining_tokens <= 0 {
1465 &theme.no_remaining_tokens
1466 } else {
1467 &theme.remaining_tokens
1468 };
1469 Label::new(
1470 remaining_tokens.to_string(),
1471 remaining_tokens_style.text.clone(),
1472 )
1473 .contained()
1474 .with_style(remaining_tokens_style.container)
1475 });
1476
1477 Stack::new()
1478 .with_child(
1479 ChildView::new(&self.editor, cx)
1480 .contained()
1481 .with_style(theme.container),
1482 )
1483 .with_child(
1484 Flex::row()
1485 .with_child(
1486 MouseEventHandler::<Model, _>::new(0, cx, |state, _| {
1487 let style = theme.model.style_for(state, false);
1488 Label::new(model, style.text.clone())
1489 .contained()
1490 .with_style(style.container)
1491 })
1492 .with_cursor_style(CursorStyle::PointingHand)
1493 .on_click(MouseButton::Left, |_, this, cx| this.cycle_model(cx)),
1494 )
1495 .with_children(remaining_tokens)
1496 .contained()
1497 .with_style(theme.model_info_container)
1498 .aligned()
1499 .top()
1500 .right(),
1501 )
1502 .into_any()
1503 }
1504
1505 fn focus_in(&mut self, _: gpui::AnyViewHandle, cx: &mut ViewContext<Self>) {
1506 if cx.is_self_focused() {
1507 cx.focus(&self.editor);
1508 }
1509 }
1510}
1511
1512impl Item for AssistantEditor {
1513 fn tab_content<V: View>(
1514 &self,
1515 _: Option<usize>,
1516 style: &theme::Tab,
1517 cx: &gpui::AppContext,
1518 ) -> AnyElement<V> {
1519 let title = truncate_and_trailoff(&self.title(cx), editor::MAX_TAB_TITLE_LEN);
1520 Label::new(title, style.label.clone()).into_any()
1521 }
1522
1523 fn tab_tooltip_text(&self, cx: &AppContext) -> Option<Cow<str>> {
1524 Some(self.title(cx).into())
1525 }
1526
1527 fn as_searchable(
1528 &self,
1529 _: &ViewHandle<Self>,
1530 ) -> Option<Box<dyn workspace::searchable::SearchableItemHandle>> {
1531 Some(Box::new(self.editor.clone()))
1532 }
1533}
1534
1535#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, Hash)]
1536struct MessageId(usize);
1537
1538#[derive(Clone, Debug)]
1539struct MessageAnchor {
1540 id: MessageId,
1541 start: language::Anchor,
1542}
1543
1544#[derive(Clone, Debug)]
1545struct MessageMetadata {
1546 role: Role,
1547 sent_at: DateTime<Local>,
1548 status: MessageStatus,
1549}
1550
1551#[derive(Clone, Debug)]
1552enum MessageStatus {
1553 Pending,
1554 Done,
1555 Error(Arc<str>),
1556}
1557
1558#[derive(Clone, Debug)]
1559pub struct Message {
1560 range: Range<usize>,
1561 index: usize,
1562 id: MessageId,
1563 anchor: language::Anchor,
1564 role: Role,
1565 sent_at: DateTime<Local>,
1566 status: MessageStatus,
1567}
1568
1569impl Message {
1570 fn to_open_ai_message(&self, buffer: &Buffer) -> RequestMessage {
1571 let mut content = format!("[Message {}]\n", self.id.0).to_string();
1572 content.extend(buffer.text_for_range(self.range.clone()));
1573 RequestMessage {
1574 role: self.role,
1575 content,
1576 }
1577 }
1578}
1579
1580async fn stream_completion(
1581 api_key: String,
1582 executor: Arc<Background>,
1583 mut request: OpenAIRequest,
1584) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
1585 request.stream = true;
1586
1587 let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
1588
1589 let json_data = serde_json::to_string(&request)?;
1590 let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
1591 .header("Content-Type", "application/json")
1592 .header("Authorization", format!("Bearer {}", api_key))
1593 .body(json_data)?
1594 .send_async()
1595 .await?;
1596
1597 let status = response.status();
1598 if status == StatusCode::OK {
1599 executor
1600 .spawn(async move {
1601 let mut lines = BufReader::new(response.body_mut()).lines();
1602
1603 fn parse_line(
1604 line: Result<String, io::Error>,
1605 ) -> Result<Option<OpenAIResponseStreamEvent>> {
1606 if let Some(data) = line?.strip_prefix("data: ") {
1607 let event = serde_json::from_str(&data)?;
1608 Ok(Some(event))
1609 } else {
1610 Ok(None)
1611 }
1612 }
1613
1614 while let Some(line) = lines.next().await {
1615 if let Some(event) = parse_line(line).transpose() {
1616 let done = event.as_ref().map_or(false, |event| {
1617 event
1618 .choices
1619 .last()
1620 .map_or(false, |choice| choice.finish_reason.is_some())
1621 });
1622 if tx.unbounded_send(event).is_err() {
1623 break;
1624 }
1625
1626 if done {
1627 break;
1628 }
1629 }
1630 }
1631
1632 anyhow::Ok(())
1633 })
1634 .detach();
1635
1636 Ok(rx)
1637 } else {
1638 let mut body = String::new();
1639 response.body_mut().read_to_string(&mut body).await?;
1640
1641 #[derive(Deserialize)]
1642 struct OpenAIResponse {
1643 error: OpenAIError,
1644 }
1645
1646 #[derive(Deserialize)]
1647 struct OpenAIError {
1648 message: String,
1649 }
1650
1651 match serde_json::from_str::<OpenAIResponse>(&body) {
1652 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
1653 "Failed to connect to OpenAI API: {}",
1654 response.error.message,
1655 )),
1656
1657 _ => Err(anyhow!(
1658 "Failed to connect to OpenAI API: {} {}",
1659 response.status(),
1660 body,
1661 )),
1662 }
1663 }
1664}
1665
1666#[cfg(test)]
1667mod tests {
1668 use super::*;
1669 use gpui::AppContext;
1670
1671 #[gpui::test]
1672 fn test_inserting_and_removing_messages(cx: &mut AppContext) {
1673 let registry = Arc::new(LanguageRegistry::test());
1674 let assistant = cx.add_model(|cx| Assistant::new(Default::default(), registry, cx));
1675 let buffer = assistant.read(cx).buffer.clone();
1676
1677 let message_1 = assistant.read(cx).message_anchors[0].clone();
1678 assert_eq!(
1679 messages(&assistant, cx),
1680 vec![(message_1.id, Role::User, 0..0)]
1681 );
1682
1683 let message_2 = assistant.update(cx, |assistant, cx| {
1684 assistant
1685 .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
1686 .unwrap()
1687 });
1688 assert_eq!(
1689 messages(&assistant, cx),
1690 vec![
1691 (message_1.id, Role::User, 0..1),
1692 (message_2.id, Role::Assistant, 1..1)
1693 ]
1694 );
1695
1696 buffer.update(cx, |buffer, cx| {
1697 buffer.edit([(0..0, "1"), (1..1, "2")], None, cx)
1698 });
1699 assert_eq!(
1700 messages(&assistant, cx),
1701 vec![
1702 (message_1.id, Role::User, 0..2),
1703 (message_2.id, Role::Assistant, 2..3)
1704 ]
1705 );
1706
1707 let message_3 = assistant.update(cx, |assistant, cx| {
1708 assistant
1709 .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
1710 .unwrap()
1711 });
1712 assert_eq!(
1713 messages(&assistant, cx),
1714 vec![
1715 (message_1.id, Role::User, 0..2),
1716 (message_2.id, Role::Assistant, 2..4),
1717 (message_3.id, Role::User, 4..4)
1718 ]
1719 );
1720
1721 let message_4 = assistant.update(cx, |assistant, cx| {
1722 assistant
1723 .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
1724 .unwrap()
1725 });
1726 assert_eq!(
1727 messages(&assistant, cx),
1728 vec![
1729 (message_1.id, Role::User, 0..2),
1730 (message_2.id, Role::Assistant, 2..4),
1731 (message_4.id, Role::User, 4..5),
1732 (message_3.id, Role::User, 5..5),
1733 ]
1734 );
1735
1736 buffer.update(cx, |buffer, cx| {
1737 buffer.edit([(4..4, "C"), (5..5, "D")], None, cx)
1738 });
1739 assert_eq!(
1740 messages(&assistant, cx),
1741 vec![
1742 (message_1.id, Role::User, 0..2),
1743 (message_2.id, Role::Assistant, 2..4),
1744 (message_4.id, Role::User, 4..6),
1745 (message_3.id, Role::User, 6..7),
1746 ]
1747 );
1748
1749 // Deleting across message boundaries merges the messages.
1750 buffer.update(cx, |buffer, cx| buffer.edit([(1..4, "")], None, cx));
1751 assert_eq!(
1752 messages(&assistant, cx),
1753 vec![
1754 (message_1.id, Role::User, 0..3),
1755 (message_3.id, Role::User, 3..4),
1756 ]
1757 );
1758
1759 // Undoing the deletion should also undo the merge.
1760 buffer.update(cx, |buffer, cx| buffer.undo(cx));
1761 assert_eq!(
1762 messages(&assistant, cx),
1763 vec![
1764 (message_1.id, Role::User, 0..2),
1765 (message_2.id, Role::Assistant, 2..4),
1766 (message_4.id, Role::User, 4..6),
1767 (message_3.id, Role::User, 6..7),
1768 ]
1769 );
1770
1771 // Redoing the deletion should also redo the merge.
1772 buffer.update(cx, |buffer, cx| buffer.redo(cx));
1773 assert_eq!(
1774 messages(&assistant, cx),
1775 vec![
1776 (message_1.id, Role::User, 0..3),
1777 (message_3.id, Role::User, 3..4),
1778 ]
1779 );
1780
1781 // Ensure we can still insert after a merged message.
1782 let message_5 = assistant.update(cx, |assistant, cx| {
1783 assistant
1784 .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
1785 .unwrap()
1786 });
1787 assert_eq!(
1788 messages(&assistant, cx),
1789 vec![
1790 (message_1.id, Role::User, 0..3),
1791 (message_5.id, Role::System, 3..4),
1792 (message_3.id, Role::User, 4..5)
1793 ]
1794 );
1795 }
1796
1797 #[gpui::test]
1798 fn test_message_splitting(cx: &mut AppContext) {
1799 let registry = Arc::new(LanguageRegistry::test());
1800 let assistant = cx.add_model(|cx| Assistant::new(Default::default(), registry, cx));
1801 let buffer = assistant.read(cx).buffer.clone();
1802
1803 let message_1 = assistant.read(cx).message_anchors[0].clone();
1804 assert_eq!(
1805 messages(&assistant, cx),
1806 vec![(message_1.id, Role::User, 0..0)]
1807 );
1808
1809 buffer.update(cx, |buffer, cx| {
1810 buffer.edit([(0..0, "aaa\nbbb\nccc\nddd\n")], None, cx)
1811 });
1812
1813 let (_, message_2) =
1814 assistant.update(cx, |assistant, cx| assistant.split_message(3..3, cx));
1815 let message_2 = message_2.unwrap();
1816
1817 // We recycle newlines in the middle of a split message
1818 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\nddd\n");
1819 assert_eq!(
1820 messages(&assistant, cx),
1821 vec![
1822 (message_1.id, Role::User, 0..4),
1823 (message_2.id, Role::User, 4..16),
1824 ]
1825 );
1826
1827 let (_, message_3) =
1828 assistant.update(cx, |assistant, cx| assistant.split_message(3..3, cx));
1829 let message_3 = message_3.unwrap();
1830
1831 // We don't recycle newlines at the end of a split message
1832 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
1833 assert_eq!(
1834 messages(&assistant, cx),
1835 vec![
1836 (message_1.id, Role::User, 0..4),
1837 (message_3.id, Role::User, 4..5),
1838 (message_2.id, Role::User, 5..17),
1839 ]
1840 );
1841
1842 let (_, message_4) =
1843 assistant.update(cx, |assistant, cx| assistant.split_message(9..9, cx));
1844 let message_4 = message_4.unwrap();
1845 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
1846 assert_eq!(
1847 messages(&assistant, cx),
1848 vec![
1849 (message_1.id, Role::User, 0..4),
1850 (message_3.id, Role::User, 4..5),
1851 (message_2.id, Role::User, 5..9),
1852 (message_4.id, Role::User, 9..17),
1853 ]
1854 );
1855
1856 let (_, message_5) =
1857 assistant.update(cx, |assistant, cx| assistant.split_message(9..9, cx));
1858 let message_5 = message_5.unwrap();
1859 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\nddd\n");
1860 assert_eq!(
1861 messages(&assistant, cx),
1862 vec![
1863 (message_1.id, Role::User, 0..4),
1864 (message_3.id, Role::User, 4..5),
1865 (message_2.id, Role::User, 5..9),
1866 (message_4.id, Role::User, 9..10),
1867 (message_5.id, Role::User, 10..18),
1868 ]
1869 );
1870
1871 let (message_6, message_7) =
1872 assistant.update(cx, |assistant, cx| assistant.split_message(14..16, cx));
1873 let message_6 = message_6.unwrap();
1874 let message_7 = message_7.unwrap();
1875 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\ndd\nd\n");
1876 assert_eq!(
1877 messages(&assistant, cx),
1878 vec![
1879 (message_1.id, Role::User, 0..4),
1880 (message_3.id, Role::User, 4..5),
1881 (message_2.id, Role::User, 5..9),
1882 (message_4.id, Role::User, 9..10),
1883 (message_5.id, Role::User, 10..14),
1884 (message_6.id, Role::User, 14..17),
1885 (message_7.id, Role::User, 17..19),
1886 ]
1887 );
1888 }
1889
1890 #[gpui::test]
1891 fn test_messages_for_offsets(cx: &mut AppContext) {
1892 let registry = Arc::new(LanguageRegistry::test());
1893 let assistant = cx.add_model(|cx| Assistant::new(Default::default(), registry, cx));
1894 let buffer = assistant.read(cx).buffer.clone();
1895
1896 let message_1 = assistant.read(cx).message_anchors[0].clone();
1897 assert_eq!(
1898 messages(&assistant, cx),
1899 vec![(message_1.id, Role::User, 0..0)]
1900 );
1901
1902 buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
1903 let message_2 = assistant
1904 .update(cx, |assistant, cx| {
1905 assistant.insert_message_after(message_1.id, Role::User, MessageStatus::Done, cx)
1906 })
1907 .unwrap();
1908 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx));
1909
1910 let message_3 = assistant
1911 .update(cx, |assistant, cx| {
1912 assistant.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
1913 })
1914 .unwrap();
1915 buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx));
1916
1917 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc");
1918 assert_eq!(
1919 messages(&assistant, cx),
1920 vec![
1921 (message_1.id, Role::User, 0..4),
1922 (message_2.id, Role::User, 4..8),
1923 (message_3.id, Role::User, 8..11)
1924 ]
1925 );
1926
1927 assert_eq!(
1928 message_ids_for_offsets(&assistant, &[0, 4, 9], cx),
1929 [message_1.id, message_2.id, message_3.id]
1930 );
1931 assert_eq!(
1932 message_ids_for_offsets(&assistant, &[0, 1, 11], cx),
1933 [message_1.id, message_3.id]
1934 );
1935
1936 fn message_ids_for_offsets(
1937 assistant: &ModelHandle<Assistant>,
1938 offsets: &[usize],
1939 cx: &AppContext,
1940 ) -> Vec<MessageId> {
1941 assistant
1942 .read(cx)
1943 .messages_for_offsets(offsets.iter().copied(), cx)
1944 .into_iter()
1945 .map(|message| message.id)
1946 .collect()
1947 }
1948 }
1949
1950 fn messages(
1951 assistant: &ModelHandle<Assistant>,
1952 cx: &AppContext,
1953 ) -> Vec<(MessageId, Role, Range<usize>)> {
1954 assistant
1955 .read(cx)
1956 .messages(cx)
1957 .map(|message| (message.id, message.role, message.range))
1958 .collect()
1959 }
1960}