1use crate::{
2 assistant_settings::{AssistantDockPosition, AssistantSettings},
3 stream_completion, MessageId, MessageMetadata, MessageStatus, OpenAIRequest, RequestMessage,
4 Role, SavedConversation, SavedConversationMetadata, SavedMessage, OPENAI_API_URL,
5};
6use anyhow::{anyhow, Result};
7use chrono::{DateTime, Local};
8use collections::{HashMap, HashSet};
9use editor::{
10 display_map::{BlockDisposition, BlockId, BlockProperties, BlockStyle, ToDisplayPoint},
11 scroll::autoscroll::{Autoscroll, AutoscrollStrategy},
12 Anchor, Editor, ToOffset,
13};
14use fs::Fs;
15use futures::StreamExt;
16use gpui::{
17 actions,
18 elements::*,
19 geometry::vector::{vec2f, Vector2F},
20 platform::{CursorStyle, MouseButton},
21 Action, AppContext, AsyncAppContext, ClipboardItem, Entity, ModelContext, ModelHandle,
22 Subscription, Task, View, ViewContext, ViewHandle, WeakViewHandle, WindowContext,
23};
24use language::{language_settings::SoftWrap, Buffer, LanguageRegistry, ToOffset as _};
25use search::BufferSearchBar;
26use settings::SettingsStore;
27use std::{
28 cell::RefCell,
29 cmp, env,
30 fmt::Write,
31 iter,
32 ops::Range,
33 path::{Path, PathBuf},
34 rc::Rc,
35 sync::Arc,
36 time::Duration,
37};
38use theme::AssistantStyle;
39use util::{paths::CONVERSATIONS_DIR, post_inc, ResultExt, TryFutureExt};
40use workspace::{
41 dock::{DockPosition, Panel},
42 searchable::Direction,
43 Save, ToggleZoom, Toolbar, Workspace,
44};
45
46actions!(
47 assistant,
48 [
49 NewConversation,
50 Assist,
51 Split,
52 CycleMessageRole,
53 QuoteSelection,
54 ToggleFocus,
55 ResetKey,
56 ]
57);
58
59pub fn init(cx: &mut AppContext) {
60 settings::register::<AssistantSettings>(cx);
61 cx.add_action(
62 |this: &mut AssistantPanel,
63 _: &workspace::NewFile,
64 cx: &mut ViewContext<AssistantPanel>| {
65 this.new_conversation(cx);
66 },
67 );
68 cx.add_action(ConversationEditor::assist);
69 cx.capture_action(ConversationEditor::cancel_last_assist);
70 cx.capture_action(ConversationEditor::save);
71 cx.add_action(ConversationEditor::quote_selection);
72 cx.capture_action(ConversationEditor::copy);
73 cx.add_action(ConversationEditor::split);
74 cx.capture_action(ConversationEditor::cycle_message_role);
75 cx.add_action(AssistantPanel::save_api_key);
76 cx.add_action(AssistantPanel::reset_api_key);
77 cx.add_action(AssistantPanel::toggle_zoom);
78 cx.add_action(AssistantPanel::deploy);
79 cx.add_action(AssistantPanel::select_next_match);
80 cx.add_action(AssistantPanel::select_prev_match);
81 cx.add_action(AssistantPanel::handle_editor_cancel);
82 cx.add_action(
83 |workspace: &mut Workspace, _: &ToggleFocus, cx: &mut ViewContext<Workspace>| {
84 workspace.toggle_panel_focus::<AssistantPanel>(cx);
85 },
86 );
87}
88
89#[derive(Debug)]
90pub enum AssistantPanelEvent {
91 ZoomIn,
92 ZoomOut,
93 Focus,
94 Close,
95 DockPositionChanged,
96}
97
98pub struct AssistantPanel {
99 workspace: WeakViewHandle<Workspace>,
100 width: Option<f32>,
101 height: Option<f32>,
102 active_editor_index: Option<usize>,
103 prev_active_editor_index: Option<usize>,
104 editors: Vec<ViewHandle<ConversationEditor>>,
105 saved_conversations: Vec<SavedConversationMetadata>,
106 saved_conversations_list_state: UniformListState,
107 zoomed: bool,
108 has_focus: bool,
109 toolbar: ViewHandle<Toolbar>,
110 api_key: Rc<RefCell<Option<String>>>,
111 api_key_editor: Option<ViewHandle<Editor>>,
112 has_read_credentials: bool,
113 languages: Arc<LanguageRegistry>,
114 fs: Arc<dyn Fs>,
115 subscriptions: Vec<Subscription>,
116 _watch_saved_conversations: Task<Result<()>>,
117}
118
119impl AssistantPanel {
120 pub fn load(
121 workspace: WeakViewHandle<Workspace>,
122 cx: AsyncAppContext,
123 ) -> Task<Result<ViewHandle<Self>>> {
124 cx.spawn(|mut cx| async move {
125 let fs = workspace.read_with(&cx, |workspace, _| workspace.app_state().fs.clone())?;
126 let saved_conversations = SavedConversationMetadata::list(fs.clone())
127 .await
128 .log_err()
129 .unwrap_or_default();
130
131 // TODO: deserialize state.
132 let workspace_handle = workspace.clone();
133 workspace.update(&mut cx, |workspace, cx| {
134 cx.add_view::<Self, _>(|cx| {
135 const CONVERSATION_WATCH_DURATION: Duration = Duration::from_millis(100);
136 let _watch_saved_conversations = cx.spawn(move |this, mut cx| async move {
137 let mut events = fs
138 .watch(&CONVERSATIONS_DIR, CONVERSATION_WATCH_DURATION)
139 .await;
140 while events.next().await.is_some() {
141 let saved_conversations = SavedConversationMetadata::list(fs.clone())
142 .await
143 .log_err()
144 .unwrap_or_default();
145 this.update(&mut cx, |this, cx| {
146 this.saved_conversations = saved_conversations;
147 cx.notify();
148 })
149 .ok();
150 }
151
152 anyhow::Ok(())
153 });
154
155 let toolbar = cx.add_view(|cx| {
156 let mut toolbar = Toolbar::new();
157 toolbar.set_can_navigate(false, cx);
158 toolbar.add_item(cx.add_view(|cx| BufferSearchBar::new(cx)), cx);
159 toolbar
160 });
161 let mut this = Self {
162 workspace: workspace_handle,
163 active_editor_index: Default::default(),
164 prev_active_editor_index: Default::default(),
165 editors: Default::default(),
166 saved_conversations,
167 saved_conversations_list_state: Default::default(),
168 zoomed: false,
169 has_focus: false,
170 toolbar,
171 api_key: Rc::new(RefCell::new(None)),
172 api_key_editor: None,
173 has_read_credentials: false,
174 languages: workspace.app_state().languages.clone(),
175 fs: workspace.app_state().fs.clone(),
176 width: None,
177 height: None,
178 subscriptions: Default::default(),
179 _watch_saved_conversations,
180 };
181
182 let mut old_dock_position = this.position(cx);
183 this.subscriptions =
184 vec![cx.observe_global::<SettingsStore, _>(move |this, cx| {
185 let new_dock_position = this.position(cx);
186 if new_dock_position != old_dock_position {
187 old_dock_position = new_dock_position;
188 cx.emit(AssistantPanelEvent::DockPositionChanged);
189 }
190 cx.notify();
191 })];
192
193 this
194 })
195 })
196 })
197 }
198
199 fn new_conversation(&mut self, cx: &mut ViewContext<Self>) -> ViewHandle<ConversationEditor> {
200 let editor = cx.add_view(|cx| {
201 ConversationEditor::new(
202 self.api_key.clone(),
203 self.languages.clone(),
204 self.fs.clone(),
205 cx,
206 )
207 });
208 self.add_conversation(editor.clone(), cx);
209 editor
210 }
211
212 fn add_conversation(
213 &mut self,
214 editor: ViewHandle<ConversationEditor>,
215 cx: &mut ViewContext<Self>,
216 ) {
217 self.subscriptions
218 .push(cx.subscribe(&editor, Self::handle_conversation_editor_event));
219
220 let conversation = editor.read(cx).conversation.clone();
221 self.subscriptions
222 .push(cx.observe(&conversation, |_, _, cx| cx.notify()));
223
224 let index = self.editors.len();
225 self.editors.push(editor);
226 self.set_active_editor_index(Some(index), cx);
227 }
228
229 fn set_active_editor_index(&mut self, index: Option<usize>, cx: &mut ViewContext<Self>) {
230 self.prev_active_editor_index = self.active_editor_index;
231 self.active_editor_index = index;
232 if let Some(editor) = self.active_editor() {
233 let editor = editor.read(cx).editor.clone();
234 self.toolbar.update(cx, |toolbar, cx| {
235 toolbar.set_active_item(Some(&editor), cx);
236 });
237 if self.has_focus(cx) {
238 cx.focus(&editor);
239 }
240 } else {
241 self.toolbar.update(cx, |toolbar, cx| {
242 toolbar.set_active_item(None, cx);
243 });
244 }
245
246 cx.notify();
247 }
248
249 fn handle_conversation_editor_event(
250 &mut self,
251 _: ViewHandle<ConversationEditor>,
252 event: &ConversationEditorEvent,
253 cx: &mut ViewContext<Self>,
254 ) {
255 match event {
256 ConversationEditorEvent::TabContentChanged => cx.notify(),
257 }
258 }
259
260 fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
261 if let Some(api_key) = self
262 .api_key_editor
263 .as_ref()
264 .map(|editor| editor.read(cx).text(cx))
265 {
266 if !api_key.is_empty() {
267 cx.platform()
268 .write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
269 .log_err();
270 *self.api_key.borrow_mut() = Some(api_key);
271 self.api_key_editor.take();
272 cx.focus_self();
273 cx.notify();
274 }
275 } else {
276 cx.propagate_action();
277 }
278 }
279
280 fn reset_api_key(&mut self, _: &ResetKey, cx: &mut ViewContext<Self>) {
281 cx.platform().delete_credentials(OPENAI_API_URL).log_err();
282 self.api_key.take();
283 self.api_key_editor = Some(build_api_key_editor(cx));
284 cx.focus_self();
285 cx.notify();
286 }
287
288 fn toggle_zoom(&mut self, _: &workspace::ToggleZoom, cx: &mut ViewContext<Self>) {
289 if self.zoomed {
290 cx.emit(AssistantPanelEvent::ZoomOut)
291 } else {
292 cx.emit(AssistantPanelEvent::ZoomIn)
293 }
294 }
295
296 fn deploy(&mut self, action: &search::buffer_search::Deploy, cx: &mut ViewContext<Self>) {
297 let mut propagate_action = true;
298 if let Some(search_bar) = self.toolbar.read(cx).item_of_type::<BufferSearchBar>() {
299 search_bar.update(cx, |search_bar, cx| {
300 if search_bar.show(cx) {
301 search_bar.search_suggested(cx);
302 if action.focus {
303 search_bar.select_query(cx);
304 cx.focus_self();
305 }
306 propagate_action = false
307 }
308 });
309 }
310 if propagate_action {
311 cx.propagate_action();
312 }
313 }
314
315 fn handle_editor_cancel(&mut self, _: &editor::Cancel, cx: &mut ViewContext<Self>) {
316 if let Some(search_bar) = self.toolbar.read(cx).item_of_type::<BufferSearchBar>() {
317 if !search_bar.read(cx).is_dismissed() {
318 search_bar.update(cx, |search_bar, cx| {
319 search_bar.dismiss(&Default::default(), cx)
320 });
321 return;
322 }
323 }
324 cx.propagate_action();
325 }
326
327 fn select_next_match(&mut self, _: &search::SelectNextMatch, cx: &mut ViewContext<Self>) {
328 if let Some(search_bar) = self.toolbar.read(cx).item_of_type::<BufferSearchBar>() {
329 search_bar.update(cx, |bar, cx| bar.select_match(Direction::Next, 1, cx));
330 }
331 }
332
333 fn select_prev_match(&mut self, _: &search::SelectPrevMatch, cx: &mut ViewContext<Self>) {
334 if let Some(search_bar) = self.toolbar.read(cx).item_of_type::<BufferSearchBar>() {
335 search_bar.update(cx, |bar, cx| bar.select_match(Direction::Prev, 1, cx));
336 }
337 }
338
339 fn active_editor(&self) -> Option<&ViewHandle<ConversationEditor>> {
340 self.editors.get(self.active_editor_index?)
341 }
342
343 fn render_hamburger_button(cx: &mut ViewContext<Self>) -> impl Element<Self> {
344 enum History {}
345 let theme = theme::current(cx);
346 let tooltip_style = theme::current(cx).tooltip.clone();
347 MouseEventHandler::new::<History, _>(0, cx, |state, _| {
348 let style = theme.assistant.hamburger_button.style_for(state);
349 Svg::for_style(style.icon.clone())
350 .contained()
351 .with_style(style.container)
352 })
353 .with_cursor_style(CursorStyle::PointingHand)
354 .on_click(MouseButton::Left, |_, this: &mut Self, cx| {
355 if this.active_editor().is_some() {
356 this.set_active_editor_index(None, cx);
357 } else {
358 this.set_active_editor_index(this.prev_active_editor_index, cx);
359 }
360 })
361 .with_tooltip::<History>(1, "History", None, tooltip_style, cx)
362 }
363
364 fn render_editor_tools(&self, cx: &mut ViewContext<Self>) -> Vec<AnyElement<Self>> {
365 if self.active_editor().is_some() {
366 vec![
367 Self::render_split_button(cx).into_any(),
368 Self::render_quote_button(cx).into_any(),
369 Self::render_assist_button(cx).into_any(),
370 ]
371 } else {
372 Default::default()
373 }
374 }
375
376 fn render_split_button(cx: &mut ViewContext<Self>) -> impl Element<Self> {
377 let theme = theme::current(cx);
378 let tooltip_style = theme::current(cx).tooltip.clone();
379 MouseEventHandler::new::<Split, _>(0, cx, |state, _| {
380 let style = theme.assistant.split_button.style_for(state);
381 Svg::for_style(style.icon.clone())
382 .contained()
383 .with_style(style.container)
384 })
385 .with_cursor_style(CursorStyle::PointingHand)
386 .on_click(MouseButton::Left, |_, this: &mut Self, cx| {
387 if let Some(active_editor) = this.active_editor() {
388 active_editor.update(cx, |editor, cx| editor.split(&Default::default(), cx));
389 }
390 })
391 .with_tooltip::<Split>(
392 1,
393 "Split Message",
394 Some(Box::new(Split)),
395 tooltip_style,
396 cx,
397 )
398 }
399
400 fn render_assist_button(cx: &mut ViewContext<Self>) -> impl Element<Self> {
401 let theme = theme::current(cx);
402 let tooltip_style = theme::current(cx).tooltip.clone();
403 MouseEventHandler::new::<Assist, _>(0, cx, |state, _| {
404 let style = theme.assistant.assist_button.style_for(state);
405 Svg::for_style(style.icon.clone())
406 .contained()
407 .with_style(style.container)
408 })
409 .with_cursor_style(CursorStyle::PointingHand)
410 .on_click(MouseButton::Left, |_, this: &mut Self, cx| {
411 if let Some(active_editor) = this.active_editor() {
412 active_editor.update(cx, |editor, cx| editor.assist(&Default::default(), cx));
413 }
414 })
415 .with_tooltip::<Assist>(1, "Assist", Some(Box::new(Assist)), tooltip_style, cx)
416 }
417
418 fn render_quote_button(cx: &mut ViewContext<Self>) -> impl Element<Self> {
419 let theme = theme::current(cx);
420 let tooltip_style = theme::current(cx).tooltip.clone();
421 MouseEventHandler::new::<QuoteSelection, _>(0, cx, |state, _| {
422 let style = theme.assistant.quote_button.style_for(state);
423 Svg::for_style(style.icon.clone())
424 .contained()
425 .with_style(style.container)
426 })
427 .with_cursor_style(CursorStyle::PointingHand)
428 .on_click(MouseButton::Left, |_, this: &mut Self, cx| {
429 if let Some(workspace) = this.workspace.upgrade(cx) {
430 cx.window_context().defer(move |cx| {
431 workspace.update(cx, |workspace, cx| {
432 ConversationEditor::quote_selection(workspace, &Default::default(), cx)
433 });
434 });
435 }
436 })
437 .with_tooltip::<QuoteSelection>(
438 1,
439 "Quote Selection",
440 Some(Box::new(QuoteSelection)),
441 tooltip_style,
442 cx,
443 )
444 }
445
446 fn render_plus_button(cx: &mut ViewContext<Self>) -> impl Element<Self> {
447 let theme = theme::current(cx);
448 let tooltip_style = theme::current(cx).tooltip.clone();
449 MouseEventHandler::new::<NewConversation, _>(0, cx, |state, _| {
450 let style = theme.assistant.plus_button.style_for(state);
451 Svg::for_style(style.icon.clone())
452 .contained()
453 .with_style(style.container)
454 })
455 .with_cursor_style(CursorStyle::PointingHand)
456 .on_click(MouseButton::Left, |_, this: &mut Self, cx| {
457 this.new_conversation(cx);
458 })
459 .with_tooltip::<NewConversation>(
460 1,
461 "New Conversation",
462 Some(Box::new(NewConversation)),
463 tooltip_style,
464 cx,
465 )
466 }
467
468 fn render_zoom_button(&self, cx: &mut ViewContext<Self>) -> impl Element<Self> {
469 enum ToggleZoomButton {}
470
471 let theme = theme::current(cx);
472 let tooltip_style = theme::current(cx).tooltip.clone();
473 let style = if self.zoomed {
474 &theme.assistant.zoom_out_button
475 } else {
476 &theme.assistant.zoom_in_button
477 };
478
479 MouseEventHandler::new::<ToggleZoomButton, _>(0, cx, |state, _| {
480 let style = style.style_for(state);
481 Svg::for_style(style.icon.clone())
482 .contained()
483 .with_style(style.container)
484 })
485 .with_cursor_style(CursorStyle::PointingHand)
486 .on_click(MouseButton::Left, |_, this, cx| {
487 this.toggle_zoom(&ToggleZoom, cx);
488 })
489 .with_tooltip::<ToggleZoom>(
490 0,
491 if self.zoomed { "Zoom Out" } else { "Zoom In" },
492 Some(Box::new(ToggleZoom)),
493 tooltip_style,
494 cx,
495 )
496 }
497
498 fn render_saved_conversation(
499 &mut self,
500 index: usize,
501 cx: &mut ViewContext<Self>,
502 ) -> impl Element<Self> {
503 let conversation = &self.saved_conversations[index];
504 let path = conversation.path.clone();
505 MouseEventHandler::new::<SavedConversationMetadata, _>(index, cx, move |state, cx| {
506 let style = &theme::current(cx).assistant.saved_conversation;
507 Flex::row()
508 .with_child(
509 Label::new(
510 conversation.mtime.format("%F %I:%M%p").to_string(),
511 style.saved_at.text.clone(),
512 )
513 .aligned()
514 .contained()
515 .with_style(style.saved_at.container),
516 )
517 .with_child(
518 Label::new(conversation.title.clone(), style.title.text.clone())
519 .aligned()
520 .contained()
521 .with_style(style.title.container),
522 )
523 .contained()
524 .with_style(*style.container.style_for(state))
525 })
526 .with_cursor_style(CursorStyle::PointingHand)
527 .on_click(MouseButton::Left, move |_, this, cx| {
528 this.open_conversation(path.clone(), cx)
529 .detach_and_log_err(cx)
530 })
531 }
532
533 fn open_conversation(&mut self, path: PathBuf, cx: &mut ViewContext<Self>) -> Task<Result<()>> {
534 if let Some(ix) = self.editor_index_for_path(&path, cx) {
535 self.set_active_editor_index(Some(ix), cx);
536 return Task::ready(Ok(()));
537 }
538
539 let fs = self.fs.clone();
540 let api_key = self.api_key.clone();
541 let languages = self.languages.clone();
542 cx.spawn(|this, mut cx| async move {
543 let saved_conversation = fs.load(&path).await?;
544 let saved_conversation = serde_json::from_str(&saved_conversation)?;
545 let conversation = cx.add_model(|cx| {
546 Conversation::deserialize(saved_conversation, path.clone(), api_key, languages, cx)
547 });
548 this.update(&mut cx, |this, cx| {
549 // If, by the time we've loaded the conversation, the user has already opened
550 // the same conversation, we don't want to open it again.
551 if let Some(ix) = this.editor_index_for_path(&path, cx) {
552 this.set_active_editor_index(Some(ix), cx);
553 } else {
554 let editor = cx
555 .add_view(|cx| ConversationEditor::for_conversation(conversation, fs, cx));
556 this.add_conversation(editor, cx);
557 }
558 })?;
559 Ok(())
560 })
561 }
562
563 fn editor_index_for_path(&self, path: &Path, cx: &AppContext) -> Option<usize> {
564 self.editors
565 .iter()
566 .position(|editor| editor.read(cx).conversation.read(cx).path.as_deref() == Some(path))
567 }
568}
569
570fn build_api_key_editor(cx: &mut ViewContext<AssistantPanel>) -> ViewHandle<Editor> {
571 cx.add_view(|cx| {
572 let mut editor = Editor::single_line(
573 Some(Arc::new(|theme| theme.assistant.api_key_editor.clone())),
574 cx,
575 );
576 editor.set_placeholder_text("sk-000000000000000000000000000000000000000000000000", cx);
577 editor
578 })
579}
580
581impl Entity for AssistantPanel {
582 type Event = AssistantPanelEvent;
583}
584
585impl View for AssistantPanel {
586 fn ui_name() -> &'static str {
587 "AssistantPanel"
588 }
589
590 fn render(&mut self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
591 let theme = &theme::current(cx);
592 let style = &theme.assistant;
593 if let Some(api_key_editor) = self.api_key_editor.as_ref() {
594 Flex::column()
595 .with_child(
596 Text::new(
597 "Paste your OpenAI API key and press Enter to use the assistant",
598 style.api_key_prompt.text.clone(),
599 )
600 .aligned(),
601 )
602 .with_child(
603 ChildView::new(api_key_editor, cx)
604 .contained()
605 .with_style(style.api_key_editor.container)
606 .aligned(),
607 )
608 .contained()
609 .with_style(style.api_key_prompt.container)
610 .aligned()
611 .into_any()
612 } else {
613 let title = self.active_editor().map(|editor| {
614 Label::new(editor.read(cx).title(cx), style.title.text.clone())
615 .contained()
616 .with_style(style.title.container)
617 .aligned()
618 .left()
619 .flex(1., false)
620 });
621 let mut header = Flex::row()
622 .with_child(Self::render_hamburger_button(cx).aligned())
623 .with_children(title);
624 if self.has_focus {
625 header.add_children(
626 self.render_editor_tools(cx)
627 .into_iter()
628 .map(|tool| tool.aligned().flex_float()),
629 );
630 header.add_child(Self::render_plus_button(cx).aligned().flex_float());
631 header.add_child(self.render_zoom_button(cx).aligned());
632 }
633
634 Flex::column()
635 .with_child(
636 header
637 .contained()
638 .with_style(theme.workspace.tab_bar.container)
639 .expanded()
640 .constrained()
641 .with_height(theme.workspace.tab_bar.height),
642 )
643 .with_children(if self.toolbar.read(cx).hidden() {
644 None
645 } else {
646 Some(ChildView::new(&self.toolbar, cx).expanded())
647 })
648 .with_child(if let Some(editor) = self.active_editor() {
649 ChildView::new(editor, cx).flex(1., true).into_any()
650 } else {
651 UniformList::new(
652 self.saved_conversations_list_state.clone(),
653 self.saved_conversations.len(),
654 cx,
655 |this, range, items, cx| {
656 for ix in range {
657 items.push(this.render_saved_conversation(ix, cx).into_any());
658 }
659 },
660 )
661 .flex(1., true)
662 .into_any()
663 })
664 .into_any()
665 }
666 }
667
668 fn focus_in(&mut self, _: gpui::AnyViewHandle, cx: &mut ViewContext<Self>) {
669 self.has_focus = true;
670 self.toolbar
671 .update(cx, |toolbar, cx| toolbar.focus_changed(true, cx));
672 cx.notify();
673 if cx.is_self_focused() {
674 if let Some(editor) = self.active_editor() {
675 cx.focus(editor);
676 } else if let Some(api_key_editor) = self.api_key_editor.as_ref() {
677 cx.focus(api_key_editor);
678 }
679 }
680 }
681
682 fn focus_out(&mut self, _: gpui::AnyViewHandle, cx: &mut ViewContext<Self>) {
683 self.has_focus = false;
684 self.toolbar
685 .update(cx, |toolbar, cx| toolbar.focus_changed(false, cx));
686 cx.notify();
687 }
688}
689
690impl Panel for AssistantPanel {
691 fn position(&self, cx: &WindowContext) -> DockPosition {
692 match settings::get::<AssistantSettings>(cx).dock {
693 AssistantDockPosition::Left => DockPosition::Left,
694 AssistantDockPosition::Bottom => DockPosition::Bottom,
695 AssistantDockPosition::Right => DockPosition::Right,
696 }
697 }
698
699 fn position_is_valid(&self, _: DockPosition) -> bool {
700 true
701 }
702
703 fn set_position(&mut self, position: DockPosition, cx: &mut ViewContext<Self>) {
704 settings::update_settings_file::<AssistantSettings>(self.fs.clone(), cx, move |settings| {
705 let dock = match position {
706 DockPosition::Left => AssistantDockPosition::Left,
707 DockPosition::Bottom => AssistantDockPosition::Bottom,
708 DockPosition::Right => AssistantDockPosition::Right,
709 };
710 settings.dock = Some(dock);
711 });
712 }
713
714 fn size(&self, cx: &WindowContext) -> f32 {
715 let settings = settings::get::<AssistantSettings>(cx);
716 match self.position(cx) {
717 DockPosition::Left | DockPosition::Right => {
718 self.width.unwrap_or_else(|| settings.default_width)
719 }
720 DockPosition::Bottom => self.height.unwrap_or_else(|| settings.default_height),
721 }
722 }
723
724 fn set_size(&mut self, size: Option<f32>, cx: &mut ViewContext<Self>) {
725 match self.position(cx) {
726 DockPosition::Left | DockPosition::Right => self.width = size,
727 DockPosition::Bottom => self.height = size,
728 }
729 cx.notify();
730 }
731
732 fn should_zoom_in_on_event(event: &AssistantPanelEvent) -> bool {
733 matches!(event, AssistantPanelEvent::ZoomIn)
734 }
735
736 fn should_zoom_out_on_event(event: &AssistantPanelEvent) -> bool {
737 matches!(event, AssistantPanelEvent::ZoomOut)
738 }
739
740 fn is_zoomed(&self, _: &WindowContext) -> bool {
741 self.zoomed
742 }
743
744 fn set_zoomed(&mut self, zoomed: bool, cx: &mut ViewContext<Self>) {
745 self.zoomed = zoomed;
746 cx.notify();
747 }
748
749 fn set_active(&mut self, active: bool, cx: &mut ViewContext<Self>) {
750 if active {
751 if self.api_key.borrow().is_none() && !self.has_read_credentials {
752 self.has_read_credentials = true;
753 let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
754 Some(api_key)
755 } else if let Some((_, api_key)) = cx
756 .platform()
757 .read_credentials(OPENAI_API_URL)
758 .log_err()
759 .flatten()
760 {
761 String::from_utf8(api_key).log_err()
762 } else {
763 None
764 };
765 if let Some(api_key) = api_key {
766 *self.api_key.borrow_mut() = Some(api_key);
767 } else if self.api_key_editor.is_none() {
768 self.api_key_editor = Some(build_api_key_editor(cx));
769 cx.notify();
770 }
771 }
772
773 if self.editors.is_empty() {
774 self.new_conversation(cx);
775 }
776 }
777 }
778
779 fn icon_path(&self, cx: &WindowContext) -> Option<&'static str> {
780 settings::get::<AssistantSettings>(cx)
781 .button
782 .then(|| "icons/ai.svg")
783 }
784
785 fn icon_tooltip(&self) -> (String, Option<Box<dyn Action>>) {
786 ("Assistant Panel".into(), Some(Box::new(ToggleFocus)))
787 }
788
789 fn should_change_position_on_event(event: &Self::Event) -> bool {
790 matches!(event, AssistantPanelEvent::DockPositionChanged)
791 }
792
793 fn should_activate_on_event(_: &Self::Event) -> bool {
794 false
795 }
796
797 fn should_close_on_event(event: &AssistantPanelEvent) -> bool {
798 matches!(event, AssistantPanelEvent::Close)
799 }
800
801 fn has_focus(&self, _: &WindowContext) -> bool {
802 self.has_focus
803 }
804
805 fn is_focus_event(event: &Self::Event) -> bool {
806 matches!(event, AssistantPanelEvent::Focus)
807 }
808}
809
810enum ConversationEvent {
811 MessagesEdited,
812 SummaryChanged,
813 StreamedCompletion,
814}
815
816#[derive(Default)]
817struct Summary {
818 text: String,
819 done: bool,
820}
821
822struct Conversation {
823 buffer: ModelHandle<Buffer>,
824 message_anchors: Vec<MessageAnchor>,
825 messages_metadata: HashMap<MessageId, MessageMetadata>,
826 next_message_id: MessageId,
827 summary: Option<Summary>,
828 pending_summary: Task<Option<()>>,
829 completion_count: usize,
830 pending_completions: Vec<PendingCompletion>,
831 model: String,
832 token_count: Option<usize>,
833 max_token_count: usize,
834 pending_token_count: Task<Option<()>>,
835 api_key: Rc<RefCell<Option<String>>>,
836 pending_save: Task<Result<()>>,
837 path: Option<PathBuf>,
838 _subscriptions: Vec<Subscription>,
839}
840
841impl Entity for Conversation {
842 type Event = ConversationEvent;
843}
844
845impl Conversation {
846 fn new(
847 api_key: Rc<RefCell<Option<String>>>,
848 language_registry: Arc<LanguageRegistry>,
849 cx: &mut ModelContext<Self>,
850 ) -> Self {
851 let model = "gpt-3.5-turbo-0613";
852 let markdown = language_registry.language_for_name("Markdown");
853 let buffer = cx.add_model(|cx| {
854 let mut buffer = Buffer::new(0, "", cx);
855 buffer.set_language_registry(language_registry);
856 cx.spawn_weak(|buffer, mut cx| async move {
857 let markdown = markdown.await?;
858 let buffer = buffer
859 .upgrade(&cx)
860 .ok_or_else(|| anyhow!("buffer was dropped"))?;
861 buffer.update(&mut cx, |buffer, cx| {
862 buffer.set_language(Some(markdown), cx)
863 });
864 anyhow::Ok(())
865 })
866 .detach_and_log_err(cx);
867 buffer
868 });
869
870 let mut this = Self {
871 message_anchors: Default::default(),
872 messages_metadata: Default::default(),
873 next_message_id: Default::default(),
874 summary: None,
875 pending_summary: Task::ready(None),
876 completion_count: Default::default(),
877 pending_completions: Default::default(),
878 token_count: None,
879 max_token_count: tiktoken_rs::model::get_context_size(model),
880 pending_token_count: Task::ready(None),
881 model: model.into(),
882 _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
883 pending_save: Task::ready(Ok(())),
884 path: None,
885 api_key,
886 buffer,
887 };
888 let message = MessageAnchor {
889 id: MessageId(post_inc(&mut this.next_message_id.0)),
890 start: language::Anchor::MIN,
891 };
892 this.message_anchors.push(message.clone());
893 this.messages_metadata.insert(
894 message.id,
895 MessageMetadata {
896 role: Role::User,
897 sent_at: Local::now(),
898 status: MessageStatus::Done,
899 },
900 );
901
902 this.count_remaining_tokens(cx);
903 this
904 }
905
906 fn serialize(&self, cx: &AppContext) -> SavedConversation {
907 SavedConversation {
908 zed: "conversation".into(),
909 version: SavedConversation::VERSION.into(),
910 text: self.buffer.read(cx).text(),
911 message_metadata: self.messages_metadata.clone(),
912 messages: self
913 .messages(cx)
914 .map(|message| SavedMessage {
915 id: message.id,
916 start: message.offset_range.start,
917 })
918 .collect(),
919 summary: self
920 .summary
921 .as_ref()
922 .map(|summary| summary.text.clone())
923 .unwrap_or_default(),
924 model: self.model.clone(),
925 }
926 }
927
928 fn deserialize(
929 saved_conversation: SavedConversation,
930 path: PathBuf,
931 api_key: Rc<RefCell<Option<String>>>,
932 language_registry: Arc<LanguageRegistry>,
933 cx: &mut ModelContext<Self>,
934 ) -> Self {
935 let model = saved_conversation.model;
936 let markdown = language_registry.language_for_name("Markdown");
937 let mut message_anchors = Vec::new();
938 let mut next_message_id = MessageId(0);
939 let buffer = cx.add_model(|cx| {
940 let mut buffer = Buffer::new(0, saved_conversation.text, cx);
941 for message in saved_conversation.messages {
942 message_anchors.push(MessageAnchor {
943 id: message.id,
944 start: buffer.anchor_before(message.start),
945 });
946 next_message_id = cmp::max(next_message_id, MessageId(message.id.0 + 1));
947 }
948 buffer.set_language_registry(language_registry);
949 cx.spawn_weak(|buffer, mut cx| async move {
950 let markdown = markdown.await?;
951 let buffer = buffer
952 .upgrade(&cx)
953 .ok_or_else(|| anyhow!("buffer was dropped"))?;
954 buffer.update(&mut cx, |buffer, cx| {
955 buffer.set_language(Some(markdown), cx)
956 });
957 anyhow::Ok(())
958 })
959 .detach_and_log_err(cx);
960 buffer
961 });
962
963 let mut this = Self {
964 message_anchors,
965 messages_metadata: saved_conversation.message_metadata,
966 next_message_id,
967 summary: Some(Summary {
968 text: saved_conversation.summary,
969 done: true,
970 }),
971 pending_summary: Task::ready(None),
972 completion_count: Default::default(),
973 pending_completions: Default::default(),
974 token_count: None,
975 max_token_count: tiktoken_rs::model::get_context_size(&model),
976 pending_token_count: Task::ready(None),
977 model,
978 _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
979 pending_save: Task::ready(Ok(())),
980 path: Some(path),
981 api_key,
982 buffer,
983 };
984 this.count_remaining_tokens(cx);
985 this
986 }
987
988 fn handle_buffer_event(
989 &mut self,
990 _: ModelHandle<Buffer>,
991 event: &language::Event,
992 cx: &mut ModelContext<Self>,
993 ) {
994 match event {
995 language::Event::Edited => {
996 self.count_remaining_tokens(cx);
997 cx.emit(ConversationEvent::MessagesEdited);
998 }
999 _ => {}
1000 }
1001 }
1002
1003 fn count_remaining_tokens(&mut self, cx: &mut ModelContext<Self>) {
1004 let messages = self
1005 .messages(cx)
1006 .into_iter()
1007 .filter_map(|message| {
1008 Some(tiktoken_rs::ChatCompletionRequestMessage {
1009 role: match message.role {
1010 Role::User => "user".into(),
1011 Role::Assistant => "assistant".into(),
1012 Role::System => "system".into(),
1013 },
1014 content: self
1015 .buffer
1016 .read(cx)
1017 .text_for_range(message.offset_range)
1018 .collect(),
1019 name: None,
1020 })
1021 })
1022 .collect::<Vec<_>>();
1023 let model = self.model.clone();
1024 self.pending_token_count = cx.spawn_weak(|this, mut cx| {
1025 async move {
1026 cx.background().timer(Duration::from_millis(200)).await;
1027 let token_count = cx
1028 .background()
1029 .spawn(async move { tiktoken_rs::num_tokens_from_messages(&model, &messages) })
1030 .await?;
1031
1032 this.upgrade(&cx)
1033 .ok_or_else(|| anyhow!("conversation was dropped"))?
1034 .update(&mut cx, |this, cx| {
1035 this.max_token_count = tiktoken_rs::model::get_context_size(&this.model);
1036 this.token_count = Some(token_count);
1037 cx.notify()
1038 });
1039 anyhow::Ok(())
1040 }
1041 .log_err()
1042 });
1043 }
1044
1045 fn remaining_tokens(&self) -> Option<isize> {
1046 Some(self.max_token_count as isize - self.token_count? as isize)
1047 }
1048
1049 fn set_model(&mut self, model: String, cx: &mut ModelContext<Self>) {
1050 self.model = model;
1051 self.count_remaining_tokens(cx);
1052 cx.notify();
1053 }
1054
1055 fn assist(
1056 &mut self,
1057 selected_messages: HashSet<MessageId>,
1058 cx: &mut ModelContext<Self>,
1059 ) -> Vec<MessageAnchor> {
1060 let mut user_messages = Vec::new();
1061 let mut tasks = Vec::new();
1062
1063 let last_message_id = self.message_anchors.iter().rev().find_map(|message| {
1064 message
1065 .start
1066 .is_valid(self.buffer.read(cx))
1067 .then_some(message.id)
1068 });
1069
1070 for selected_message_id in selected_messages {
1071 let selected_message_role =
1072 if let Some(metadata) = self.messages_metadata.get(&selected_message_id) {
1073 metadata.role
1074 } else {
1075 continue;
1076 };
1077
1078 if selected_message_role == Role::Assistant {
1079 if let Some(user_message) = self.insert_message_after(
1080 selected_message_id,
1081 Role::User,
1082 MessageStatus::Done,
1083 cx,
1084 ) {
1085 user_messages.push(user_message);
1086 } else {
1087 continue;
1088 }
1089 } else {
1090 let request = OpenAIRequest {
1091 model: self.model.clone(),
1092 messages: self
1093 .messages(cx)
1094 .filter(|message| matches!(message.status, MessageStatus::Done))
1095 .flat_map(|message| {
1096 let mut system_message = None;
1097 if message.id == selected_message_id {
1098 system_message = Some(RequestMessage {
1099 role: Role::System,
1100 content: concat!(
1101 "Treat the following messages as additional knowledge you have learned about, ",
1102 "but act as if they were not part of this conversation. That is, treat them ",
1103 "as if the user didn't see them and couldn't possibly inquire about them."
1104 ).into()
1105 });
1106 }
1107
1108 Some(message.to_open_ai_message(self.buffer.read(cx))).into_iter().chain(system_message)
1109 })
1110 .chain(Some(RequestMessage {
1111 role: Role::System,
1112 content: format!(
1113 "Direct your reply to message with id {}. Do not include a [Message X] header.",
1114 selected_message_id.0
1115 ),
1116 }))
1117 .collect(),
1118 stream: true,
1119 };
1120
1121 let Some(api_key) = self.api_key.borrow().clone() else { continue };
1122 let stream = stream_completion(api_key, cx.background().clone(), request);
1123 let assistant_message = self
1124 .insert_message_after(
1125 selected_message_id,
1126 Role::Assistant,
1127 MessageStatus::Pending,
1128 cx,
1129 )
1130 .unwrap();
1131
1132 // Queue up the user's next reply
1133 if Some(selected_message_id) == last_message_id {
1134 let user_message = self
1135 .insert_message_after(
1136 assistant_message.id,
1137 Role::User,
1138 MessageStatus::Done,
1139 cx,
1140 )
1141 .unwrap();
1142 user_messages.push(user_message);
1143 }
1144
1145 tasks.push(cx.spawn_weak({
1146 |this, mut cx| async move {
1147 let assistant_message_id = assistant_message.id;
1148 let stream_completion = async {
1149 let mut messages = stream.await?;
1150
1151 while let Some(message) = messages.next().await {
1152 let mut message = message?;
1153 if let Some(choice) = message.choices.pop() {
1154 this.upgrade(&cx)
1155 .ok_or_else(|| anyhow!("conversation was dropped"))?
1156 .update(&mut cx, |this, cx| {
1157 let text: Arc<str> = choice.delta.content?.into();
1158 let message_ix = this.message_anchors.iter().position(
1159 |message| message.id == assistant_message_id,
1160 )?;
1161 this.buffer.update(cx, |buffer, cx| {
1162 let offset = this.message_anchors[message_ix + 1..]
1163 .iter()
1164 .find(|message| message.start.is_valid(buffer))
1165 .map_or(buffer.len(), |message| {
1166 message
1167 .start
1168 .to_offset(buffer)
1169 .saturating_sub(1)
1170 });
1171 buffer.edit([(offset..offset, text)], None, cx);
1172 });
1173 cx.emit(ConversationEvent::StreamedCompletion);
1174
1175 Some(())
1176 });
1177 }
1178 smol::future::yield_now().await;
1179 }
1180
1181 this.upgrade(&cx)
1182 .ok_or_else(|| anyhow!("conversation was dropped"))?
1183 .update(&mut cx, |this, cx| {
1184 this.pending_completions.retain(|completion| {
1185 completion.id != this.completion_count
1186 });
1187 this.summarize(cx);
1188 });
1189
1190 anyhow::Ok(())
1191 };
1192
1193 let result = stream_completion.await;
1194 if let Some(this) = this.upgrade(&cx) {
1195 this.update(&mut cx, |this, cx| {
1196 if let Some(metadata) =
1197 this.messages_metadata.get_mut(&assistant_message.id)
1198 {
1199 match result {
1200 Ok(_) => {
1201 metadata.status = MessageStatus::Done;
1202 }
1203 Err(error) => {
1204 metadata.status = MessageStatus::Error(
1205 error.to_string().trim().into(),
1206 );
1207 }
1208 }
1209 cx.notify();
1210 }
1211 });
1212 }
1213 }
1214 }));
1215 }
1216 }
1217
1218 if !tasks.is_empty() {
1219 self.pending_completions.push(PendingCompletion {
1220 id: post_inc(&mut self.completion_count),
1221 _tasks: tasks,
1222 });
1223 }
1224
1225 user_messages
1226 }
1227
1228 fn cancel_last_assist(&mut self) -> bool {
1229 self.pending_completions.pop().is_some()
1230 }
1231
1232 fn cycle_message_roles(&mut self, ids: HashSet<MessageId>, cx: &mut ModelContext<Self>) {
1233 for id in ids {
1234 if let Some(metadata) = self.messages_metadata.get_mut(&id) {
1235 metadata.role.cycle();
1236 cx.emit(ConversationEvent::MessagesEdited);
1237 cx.notify();
1238 }
1239 }
1240 }
1241
1242 fn insert_message_after(
1243 &mut self,
1244 message_id: MessageId,
1245 role: Role,
1246 status: MessageStatus,
1247 cx: &mut ModelContext<Self>,
1248 ) -> Option<MessageAnchor> {
1249 if let Some(prev_message_ix) = self
1250 .message_anchors
1251 .iter()
1252 .position(|message| message.id == message_id)
1253 {
1254 // Find the next valid message after the one we were given.
1255 let mut next_message_ix = prev_message_ix + 1;
1256 while let Some(next_message) = self.message_anchors.get(next_message_ix) {
1257 if next_message.start.is_valid(self.buffer.read(cx)) {
1258 break;
1259 }
1260 next_message_ix += 1;
1261 }
1262
1263 let start = self.buffer.update(cx, |buffer, cx| {
1264 let offset = self
1265 .message_anchors
1266 .get(next_message_ix)
1267 .map_or(buffer.len(), |message| message.start.to_offset(buffer) - 1);
1268 buffer.edit([(offset..offset, "\n")], None, cx);
1269 buffer.anchor_before(offset + 1)
1270 });
1271 let message = MessageAnchor {
1272 id: MessageId(post_inc(&mut self.next_message_id.0)),
1273 start,
1274 };
1275 self.message_anchors
1276 .insert(next_message_ix, message.clone());
1277 self.messages_metadata.insert(
1278 message.id,
1279 MessageMetadata {
1280 role,
1281 sent_at: Local::now(),
1282 status,
1283 },
1284 );
1285 cx.emit(ConversationEvent::MessagesEdited);
1286 Some(message)
1287 } else {
1288 None
1289 }
1290 }
1291
1292 fn split_message(
1293 &mut self,
1294 range: Range<usize>,
1295 cx: &mut ModelContext<Self>,
1296 ) -> (Option<MessageAnchor>, Option<MessageAnchor>) {
1297 let start_message = self.message_for_offset(range.start, cx);
1298 let end_message = self.message_for_offset(range.end, cx);
1299 if let Some((start_message, end_message)) = start_message.zip(end_message) {
1300 // Prevent splitting when range spans multiple messages.
1301 if start_message.id != end_message.id {
1302 return (None, None);
1303 }
1304
1305 let message = start_message;
1306 let role = message.role;
1307 let mut edited_buffer = false;
1308
1309 let mut suffix_start = None;
1310 if range.start > message.offset_range.start && range.end < message.offset_range.end - 1
1311 {
1312 if self.buffer.read(cx).chars_at(range.end).next() == Some('\n') {
1313 suffix_start = Some(range.end + 1);
1314 } else if self.buffer.read(cx).reversed_chars_at(range.end).next() == Some('\n') {
1315 suffix_start = Some(range.end);
1316 }
1317 }
1318
1319 let suffix = if let Some(suffix_start) = suffix_start {
1320 MessageAnchor {
1321 id: MessageId(post_inc(&mut self.next_message_id.0)),
1322 start: self.buffer.read(cx).anchor_before(suffix_start),
1323 }
1324 } else {
1325 self.buffer.update(cx, |buffer, cx| {
1326 buffer.edit([(range.end..range.end, "\n")], None, cx);
1327 });
1328 edited_buffer = true;
1329 MessageAnchor {
1330 id: MessageId(post_inc(&mut self.next_message_id.0)),
1331 start: self.buffer.read(cx).anchor_before(range.end + 1),
1332 }
1333 };
1334
1335 self.message_anchors
1336 .insert(message.index_range.end + 1, suffix.clone());
1337 self.messages_metadata.insert(
1338 suffix.id,
1339 MessageMetadata {
1340 role,
1341 sent_at: Local::now(),
1342 status: MessageStatus::Done,
1343 },
1344 );
1345
1346 let new_messages =
1347 if range.start == range.end || range.start == message.offset_range.start {
1348 (None, Some(suffix))
1349 } else {
1350 let mut prefix_end = None;
1351 if range.start > message.offset_range.start
1352 && range.end < message.offset_range.end - 1
1353 {
1354 if self.buffer.read(cx).chars_at(range.start).next() == Some('\n') {
1355 prefix_end = Some(range.start + 1);
1356 } else if self.buffer.read(cx).reversed_chars_at(range.start).next()
1357 == Some('\n')
1358 {
1359 prefix_end = Some(range.start);
1360 }
1361 }
1362
1363 let selection = if let Some(prefix_end) = prefix_end {
1364 cx.emit(ConversationEvent::MessagesEdited);
1365 MessageAnchor {
1366 id: MessageId(post_inc(&mut self.next_message_id.0)),
1367 start: self.buffer.read(cx).anchor_before(prefix_end),
1368 }
1369 } else {
1370 self.buffer.update(cx, |buffer, cx| {
1371 buffer.edit([(range.start..range.start, "\n")], None, cx)
1372 });
1373 edited_buffer = true;
1374 MessageAnchor {
1375 id: MessageId(post_inc(&mut self.next_message_id.0)),
1376 start: self.buffer.read(cx).anchor_before(range.end + 1),
1377 }
1378 };
1379
1380 self.message_anchors
1381 .insert(message.index_range.end + 1, selection.clone());
1382 self.messages_metadata.insert(
1383 selection.id,
1384 MessageMetadata {
1385 role,
1386 sent_at: Local::now(),
1387 status: MessageStatus::Done,
1388 },
1389 );
1390 (Some(selection), Some(suffix))
1391 };
1392
1393 if !edited_buffer {
1394 cx.emit(ConversationEvent::MessagesEdited);
1395 }
1396 new_messages
1397 } else {
1398 (None, None)
1399 }
1400 }
1401
1402 fn summarize(&mut self, cx: &mut ModelContext<Self>) {
1403 if self.message_anchors.len() >= 2 && self.summary.is_none() {
1404 let api_key = self.api_key.borrow().clone();
1405 if let Some(api_key) = api_key {
1406 let messages = self
1407 .messages(cx)
1408 .take(2)
1409 .map(|message| message.to_open_ai_message(self.buffer.read(cx)))
1410 .chain(Some(RequestMessage {
1411 role: Role::User,
1412 content:
1413 "Summarize the conversation into a short title without punctuation"
1414 .into(),
1415 }));
1416 let request = OpenAIRequest {
1417 model: self.model.clone(),
1418 messages: messages.collect(),
1419 stream: true,
1420 };
1421
1422 let stream = stream_completion(api_key, cx.background().clone(), request);
1423 self.pending_summary = cx.spawn(|this, mut cx| {
1424 async move {
1425 let mut messages = stream.await?;
1426
1427 while let Some(message) = messages.next().await {
1428 let mut message = message?;
1429 if let Some(choice) = message.choices.pop() {
1430 let text = choice.delta.content.unwrap_or_default();
1431 this.update(&mut cx, |this, cx| {
1432 this.summary
1433 .get_or_insert(Default::default())
1434 .text
1435 .push_str(&text);
1436 cx.emit(ConversationEvent::SummaryChanged);
1437 });
1438 }
1439 }
1440
1441 this.update(&mut cx, |this, cx| {
1442 if let Some(summary) = this.summary.as_mut() {
1443 summary.done = true;
1444 cx.emit(ConversationEvent::SummaryChanged);
1445 }
1446 });
1447
1448 anyhow::Ok(())
1449 }
1450 .log_err()
1451 });
1452 }
1453 }
1454 }
1455
1456 fn message_for_offset(&self, offset: usize, cx: &AppContext) -> Option<Message> {
1457 self.messages_for_offsets([offset], cx).pop()
1458 }
1459
1460 fn messages_for_offsets(
1461 &self,
1462 offsets: impl IntoIterator<Item = usize>,
1463 cx: &AppContext,
1464 ) -> Vec<Message> {
1465 let mut result = Vec::new();
1466
1467 let mut messages = self.messages(cx).peekable();
1468 let mut offsets = offsets.into_iter().peekable();
1469 let mut current_message = messages.next();
1470 while let Some(offset) = offsets.next() {
1471 // Locate the message that contains the offset.
1472 while current_message.as_ref().map_or(false, |message| {
1473 !message.offset_range.contains(&offset) && messages.peek().is_some()
1474 }) {
1475 current_message = messages.next();
1476 }
1477 let Some(message) = current_message.as_ref() else { break };
1478
1479 // Skip offsets that are in the same message.
1480 while offsets.peek().map_or(false, |offset| {
1481 message.offset_range.contains(offset) || messages.peek().is_none()
1482 }) {
1483 offsets.next();
1484 }
1485
1486 result.push(message.clone());
1487 }
1488 result
1489 }
1490
1491 fn messages<'a>(&'a self, cx: &'a AppContext) -> impl 'a + Iterator<Item = Message> {
1492 let buffer = self.buffer.read(cx);
1493 let mut message_anchors = self.message_anchors.iter().enumerate().peekable();
1494 iter::from_fn(move || {
1495 while let Some((start_ix, message_anchor)) = message_anchors.next() {
1496 let metadata = self.messages_metadata.get(&message_anchor.id)?;
1497 let message_start = message_anchor.start.to_offset(buffer);
1498 let mut message_end = None;
1499 let mut end_ix = start_ix;
1500 while let Some((_, next_message)) = message_anchors.peek() {
1501 if next_message.start.is_valid(buffer) {
1502 message_end = Some(next_message.start);
1503 break;
1504 } else {
1505 end_ix += 1;
1506 message_anchors.next();
1507 }
1508 }
1509 let message_end = message_end
1510 .unwrap_or(language::Anchor::MAX)
1511 .to_offset(buffer);
1512 return Some(Message {
1513 index_range: start_ix..end_ix,
1514 offset_range: message_start..message_end,
1515 id: message_anchor.id,
1516 anchor: message_anchor.start,
1517 role: metadata.role,
1518 sent_at: metadata.sent_at,
1519 status: metadata.status.clone(),
1520 });
1521 }
1522 None
1523 })
1524 }
1525
1526 fn save(
1527 &mut self,
1528 debounce: Option<Duration>,
1529 fs: Arc<dyn Fs>,
1530 cx: &mut ModelContext<Conversation>,
1531 ) {
1532 self.pending_save = cx.spawn(|this, mut cx| async move {
1533 if let Some(debounce) = debounce {
1534 cx.background().timer(debounce).await;
1535 }
1536
1537 let (old_path, summary) = this.read_with(&cx, |this, _| {
1538 let path = this.path.clone();
1539 let summary = if let Some(summary) = this.summary.as_ref() {
1540 if summary.done {
1541 Some(summary.text.clone())
1542 } else {
1543 None
1544 }
1545 } else {
1546 None
1547 };
1548 (path, summary)
1549 });
1550
1551 if let Some(summary) = summary {
1552 let conversation = this.read_with(&cx, |this, cx| this.serialize(cx));
1553 let path = if let Some(old_path) = old_path {
1554 old_path
1555 } else {
1556 let mut discriminant = 1;
1557 let mut new_path;
1558 loop {
1559 new_path = CONVERSATIONS_DIR.join(&format!(
1560 "{} - {}.zed.json",
1561 summary.trim(),
1562 discriminant
1563 ));
1564 if fs.is_file(&new_path).await {
1565 discriminant += 1;
1566 } else {
1567 break;
1568 }
1569 }
1570 new_path
1571 };
1572
1573 fs.create_dir(CONVERSATIONS_DIR.as_ref()).await?;
1574 fs.atomic_write(path.clone(), serde_json::to_string(&conversation).unwrap())
1575 .await?;
1576 this.update(&mut cx, |this, _| this.path = Some(path));
1577 }
1578
1579 Ok(())
1580 });
1581 }
1582}
1583
1584struct PendingCompletion {
1585 id: usize,
1586 _tasks: Vec<Task<()>>,
1587}
1588
1589enum ConversationEditorEvent {
1590 TabContentChanged,
1591}
1592
1593#[derive(Copy, Clone, Debug, PartialEq)]
1594struct ScrollPosition {
1595 offset_before_cursor: Vector2F,
1596 cursor: Anchor,
1597}
1598
1599struct ConversationEditor {
1600 conversation: ModelHandle<Conversation>,
1601 fs: Arc<dyn Fs>,
1602 editor: ViewHandle<Editor>,
1603 blocks: HashSet<BlockId>,
1604 scroll_position: Option<ScrollPosition>,
1605 _subscriptions: Vec<Subscription>,
1606}
1607
1608impl ConversationEditor {
1609 fn new(
1610 api_key: Rc<RefCell<Option<String>>>,
1611 language_registry: Arc<LanguageRegistry>,
1612 fs: Arc<dyn Fs>,
1613 cx: &mut ViewContext<Self>,
1614 ) -> Self {
1615 let conversation = cx.add_model(|cx| Conversation::new(api_key, language_registry, cx));
1616 Self::for_conversation(conversation, fs, cx)
1617 }
1618
1619 fn for_conversation(
1620 conversation: ModelHandle<Conversation>,
1621 fs: Arc<dyn Fs>,
1622 cx: &mut ViewContext<Self>,
1623 ) -> Self {
1624 let editor = cx.add_view(|cx| {
1625 let mut editor = Editor::for_buffer(conversation.read(cx).buffer.clone(), None, cx);
1626 editor.set_soft_wrap_mode(SoftWrap::EditorWidth, cx);
1627 editor.set_show_gutter(false, cx);
1628 editor.set_show_wrap_guides(false, cx);
1629 editor
1630 });
1631
1632 let _subscriptions = vec![
1633 cx.observe(&conversation, |_, _, cx| cx.notify()),
1634 cx.subscribe(&conversation, Self::handle_conversation_event),
1635 cx.subscribe(&editor, Self::handle_editor_event),
1636 ];
1637
1638 let mut this = Self {
1639 conversation,
1640 editor,
1641 blocks: Default::default(),
1642 scroll_position: None,
1643 fs,
1644 _subscriptions,
1645 };
1646 this.update_message_headers(cx);
1647 this
1648 }
1649
1650 fn assist(&mut self, _: &Assist, cx: &mut ViewContext<Self>) {
1651 let cursors = self.cursors(cx);
1652
1653 let user_messages = self.conversation.update(cx, |conversation, cx| {
1654 let selected_messages = conversation
1655 .messages_for_offsets(cursors, cx)
1656 .into_iter()
1657 .map(|message| message.id)
1658 .collect();
1659 conversation.assist(selected_messages, cx)
1660 });
1661 let new_selections = user_messages
1662 .iter()
1663 .map(|message| {
1664 let cursor = message
1665 .start
1666 .to_offset(self.conversation.read(cx).buffer.read(cx));
1667 cursor..cursor
1668 })
1669 .collect::<Vec<_>>();
1670 if !new_selections.is_empty() {
1671 self.editor.update(cx, |editor, cx| {
1672 editor.change_selections(
1673 Some(Autoscroll::Strategy(AutoscrollStrategy::Fit)),
1674 cx,
1675 |selections| selections.select_ranges(new_selections),
1676 );
1677 });
1678 // Avoid scrolling to the new cursor position so the assistant's output is stable.
1679 cx.defer(|this, _| this.scroll_position = None);
1680 }
1681 }
1682
1683 fn cancel_last_assist(&mut self, _: &editor::Cancel, cx: &mut ViewContext<Self>) {
1684 if !self
1685 .conversation
1686 .update(cx, |conversation, _| conversation.cancel_last_assist())
1687 {
1688 cx.propagate_action();
1689 }
1690 }
1691
1692 fn cycle_message_role(&mut self, _: &CycleMessageRole, cx: &mut ViewContext<Self>) {
1693 let cursors = self.cursors(cx);
1694 self.conversation.update(cx, |conversation, cx| {
1695 let messages = conversation
1696 .messages_for_offsets(cursors, cx)
1697 .into_iter()
1698 .map(|message| message.id)
1699 .collect();
1700 conversation.cycle_message_roles(messages, cx)
1701 });
1702 }
1703
1704 fn cursors(&self, cx: &AppContext) -> Vec<usize> {
1705 let selections = self.editor.read(cx).selections.all::<usize>(cx);
1706 selections
1707 .into_iter()
1708 .map(|selection| selection.head())
1709 .collect()
1710 }
1711
1712 fn handle_conversation_event(
1713 &mut self,
1714 _: ModelHandle<Conversation>,
1715 event: &ConversationEvent,
1716 cx: &mut ViewContext<Self>,
1717 ) {
1718 match event {
1719 ConversationEvent::MessagesEdited => {
1720 self.update_message_headers(cx);
1721 self.conversation.update(cx, |conversation, cx| {
1722 conversation.save(Some(Duration::from_millis(500)), self.fs.clone(), cx);
1723 });
1724 }
1725 ConversationEvent::SummaryChanged => {
1726 cx.emit(ConversationEditorEvent::TabContentChanged);
1727 self.conversation.update(cx, |conversation, cx| {
1728 conversation.save(None, self.fs.clone(), cx);
1729 });
1730 }
1731 ConversationEvent::StreamedCompletion => {
1732 self.editor.update(cx, |editor, cx| {
1733 if let Some(scroll_position) = self.scroll_position {
1734 let snapshot = editor.snapshot(cx);
1735 let cursor_point = scroll_position.cursor.to_display_point(&snapshot);
1736 let scroll_top =
1737 cursor_point.row() as f32 - scroll_position.offset_before_cursor.y();
1738 editor.set_scroll_position(
1739 vec2f(scroll_position.offset_before_cursor.x(), scroll_top),
1740 cx,
1741 );
1742 }
1743 });
1744 }
1745 }
1746 }
1747
1748 fn handle_editor_event(
1749 &mut self,
1750 _: ViewHandle<Editor>,
1751 event: &editor::Event,
1752 cx: &mut ViewContext<Self>,
1753 ) {
1754 match event {
1755 editor::Event::ScrollPositionChanged { autoscroll, .. } => {
1756 let cursor_scroll_position = self.cursor_scroll_position(cx);
1757 if *autoscroll {
1758 self.scroll_position = cursor_scroll_position;
1759 } else if self.scroll_position != cursor_scroll_position {
1760 self.scroll_position = None;
1761 }
1762 }
1763 editor::Event::SelectionsChanged { .. } => {
1764 self.scroll_position = self.cursor_scroll_position(cx);
1765 }
1766 _ => {}
1767 }
1768 }
1769
1770 fn cursor_scroll_position(&self, cx: &mut ViewContext<Self>) -> Option<ScrollPosition> {
1771 self.editor.update(cx, |editor, cx| {
1772 let snapshot = editor.snapshot(cx);
1773 let cursor = editor.selections.newest_anchor().head();
1774 let cursor_row = cursor.to_display_point(&snapshot.display_snapshot).row() as f32;
1775 let scroll_position = editor
1776 .scroll_manager
1777 .anchor()
1778 .scroll_position(&snapshot.display_snapshot);
1779
1780 let scroll_bottom = scroll_position.y() + editor.visible_line_count().unwrap_or(0.);
1781 if (scroll_position.y()..scroll_bottom).contains(&cursor_row) {
1782 Some(ScrollPosition {
1783 cursor,
1784 offset_before_cursor: vec2f(
1785 scroll_position.x(),
1786 cursor_row - scroll_position.y(),
1787 ),
1788 })
1789 } else {
1790 None
1791 }
1792 })
1793 }
1794
1795 fn update_message_headers(&mut self, cx: &mut ViewContext<Self>) {
1796 self.editor.update(cx, |editor, cx| {
1797 let buffer = editor.buffer().read(cx).snapshot(cx);
1798 let excerpt_id = *buffer.as_singleton().unwrap().0;
1799 let old_blocks = std::mem::take(&mut self.blocks);
1800 let new_blocks = self
1801 .conversation
1802 .read(cx)
1803 .messages(cx)
1804 .map(|message| BlockProperties {
1805 position: buffer.anchor_in_excerpt(excerpt_id, message.anchor),
1806 height: 2,
1807 style: BlockStyle::Sticky,
1808 render: Arc::new({
1809 let conversation = self.conversation.clone();
1810 // let metadata = message.metadata.clone();
1811 // let message = message.clone();
1812 move |cx| {
1813 enum Sender {}
1814 enum ErrorTooltip {}
1815
1816 let theme = theme::current(cx);
1817 let style = &theme.assistant;
1818 let message_id = message.id;
1819 let sender = MouseEventHandler::new::<Sender, _>(
1820 message_id.0,
1821 cx,
1822 |state, _| match message.role {
1823 Role::User => {
1824 let style = style.user_sender.style_for(state);
1825 Label::new("You", style.text.clone())
1826 .contained()
1827 .with_style(style.container)
1828 }
1829 Role::Assistant => {
1830 let style = style.assistant_sender.style_for(state);
1831 Label::new("Assistant", style.text.clone())
1832 .contained()
1833 .with_style(style.container)
1834 }
1835 Role::System => {
1836 let style = style.system_sender.style_for(state);
1837 Label::new("System", style.text.clone())
1838 .contained()
1839 .with_style(style.container)
1840 }
1841 },
1842 )
1843 .with_cursor_style(CursorStyle::PointingHand)
1844 .on_down(MouseButton::Left, {
1845 let conversation = conversation.clone();
1846 move |_, _, cx| {
1847 conversation.update(cx, |conversation, cx| {
1848 conversation.cycle_message_roles(
1849 HashSet::from_iter(Some(message_id)),
1850 cx,
1851 )
1852 })
1853 }
1854 });
1855
1856 Flex::row()
1857 .with_child(sender.aligned())
1858 .with_child(
1859 Label::new(
1860 message.sent_at.format("%I:%M%P").to_string(),
1861 style.sent_at.text.clone(),
1862 )
1863 .contained()
1864 .with_style(style.sent_at.container)
1865 .aligned(),
1866 )
1867 .with_children(
1868 if let MessageStatus::Error(error) = &message.status {
1869 Some(
1870 Svg::new("icons/circle_x_mark_12.svg")
1871 .with_color(style.error_icon.color)
1872 .constrained()
1873 .with_width(style.error_icon.width)
1874 .contained()
1875 .with_style(style.error_icon.container)
1876 .with_tooltip::<ErrorTooltip>(
1877 message_id.0,
1878 error.to_string(),
1879 None,
1880 theme.tooltip.clone(),
1881 cx,
1882 )
1883 .aligned(),
1884 )
1885 } else {
1886 None
1887 },
1888 )
1889 .aligned()
1890 .left()
1891 .contained()
1892 .with_style(style.message_header)
1893 .into_any()
1894 }
1895 }),
1896 disposition: BlockDisposition::Above,
1897 })
1898 .collect::<Vec<_>>();
1899
1900 editor.remove_blocks(old_blocks, None, cx);
1901 let ids = editor.insert_blocks(new_blocks, None, cx);
1902 self.blocks = HashSet::from_iter(ids);
1903 });
1904 }
1905
1906 fn quote_selection(
1907 workspace: &mut Workspace,
1908 _: &QuoteSelection,
1909 cx: &mut ViewContext<Workspace>,
1910 ) {
1911 let Some(panel) = workspace.panel::<AssistantPanel>(cx) else {
1912 return;
1913 };
1914 let Some(editor) = workspace.active_item(cx).and_then(|item| item.act_as::<Editor>(cx)) else {
1915 return;
1916 };
1917
1918 let text = editor.read_with(cx, |editor, cx| {
1919 let range = editor.selections.newest::<usize>(cx).range();
1920 let buffer = editor.buffer().read(cx).snapshot(cx);
1921 let start_language = buffer.language_at(range.start);
1922 let end_language = buffer.language_at(range.end);
1923 let language_name = if start_language == end_language {
1924 start_language.map(|language| language.name())
1925 } else {
1926 None
1927 };
1928 let language_name = language_name.as_deref().unwrap_or("").to_lowercase();
1929
1930 let selected_text = buffer.text_for_range(range).collect::<String>();
1931 if selected_text.is_empty() {
1932 None
1933 } else {
1934 Some(if language_name == "markdown" {
1935 selected_text
1936 .lines()
1937 .map(|line| format!("> {}", line))
1938 .collect::<Vec<_>>()
1939 .join("\n")
1940 } else {
1941 format!("```{language_name}\n{selected_text}\n```")
1942 })
1943 }
1944 });
1945
1946 // Activate the panel
1947 if !panel.read(cx).has_focus(cx) {
1948 workspace.toggle_panel_focus::<AssistantPanel>(cx);
1949 }
1950
1951 if let Some(text) = text {
1952 panel.update(cx, |panel, cx| {
1953 let conversation = panel
1954 .active_editor()
1955 .cloned()
1956 .unwrap_or_else(|| panel.new_conversation(cx));
1957 conversation.update(cx, |conversation, cx| {
1958 conversation
1959 .editor
1960 .update(cx, |editor, cx| editor.insert(&text, cx))
1961 });
1962 });
1963 }
1964 }
1965
1966 fn copy(&mut self, _: &editor::Copy, cx: &mut ViewContext<Self>) {
1967 let editor = self.editor.read(cx);
1968 let conversation = self.conversation.read(cx);
1969 if editor.selections.count() == 1 {
1970 let selection = editor.selections.newest::<usize>(cx);
1971 let mut copied_text = String::new();
1972 let mut spanned_messages = 0;
1973 for message in conversation.messages(cx) {
1974 if message.offset_range.start >= selection.range().end {
1975 break;
1976 } else if message.offset_range.end >= selection.range().start {
1977 let range = cmp::max(message.offset_range.start, selection.range().start)
1978 ..cmp::min(message.offset_range.end, selection.range().end);
1979 if !range.is_empty() {
1980 spanned_messages += 1;
1981 write!(&mut copied_text, "## {}\n\n", message.role).unwrap();
1982 for chunk in conversation.buffer.read(cx).text_for_range(range) {
1983 copied_text.push_str(&chunk);
1984 }
1985 copied_text.push('\n');
1986 }
1987 }
1988 }
1989
1990 if spanned_messages > 1 {
1991 cx.platform()
1992 .write_to_clipboard(ClipboardItem::new(copied_text));
1993 return;
1994 }
1995 }
1996
1997 cx.propagate_action();
1998 }
1999
2000 fn split(&mut self, _: &Split, cx: &mut ViewContext<Self>) {
2001 self.conversation.update(cx, |conversation, cx| {
2002 let selections = self.editor.read(cx).selections.disjoint_anchors();
2003 for selection in selections.into_iter() {
2004 let buffer = self.editor.read(cx).buffer().read(cx).snapshot(cx);
2005 let range = selection
2006 .map(|endpoint| endpoint.to_offset(&buffer))
2007 .range();
2008 conversation.split_message(range, cx);
2009 }
2010 });
2011 }
2012
2013 fn save(&mut self, _: &Save, cx: &mut ViewContext<Self>) {
2014 self.conversation.update(cx, |conversation, cx| {
2015 conversation.save(None, self.fs.clone(), cx)
2016 });
2017 }
2018
2019 fn cycle_model(&mut self, cx: &mut ViewContext<Self>) {
2020 self.conversation.update(cx, |conversation, cx| {
2021 let new_model = match conversation.model.as_str() {
2022 "gpt-4-0613" => "gpt-3.5-turbo-0613",
2023 _ => "gpt-4-0613",
2024 };
2025 conversation.set_model(new_model.into(), cx);
2026 });
2027 }
2028
2029 fn title(&self, cx: &AppContext) -> String {
2030 self.conversation
2031 .read(cx)
2032 .summary
2033 .as_ref()
2034 .map(|summary| summary.text.clone())
2035 .unwrap_or_else(|| "New Conversation".into())
2036 }
2037
2038 fn render_current_model(
2039 &self,
2040 style: &AssistantStyle,
2041 cx: &mut ViewContext<Self>,
2042 ) -> impl Element<Self> {
2043 enum Model {}
2044
2045 MouseEventHandler::new::<Model, _>(0, cx, |state, cx| {
2046 let style = style.model.style_for(state);
2047 Label::new(self.conversation.read(cx).model.clone(), style.text.clone())
2048 .contained()
2049 .with_style(style.container)
2050 })
2051 .with_cursor_style(CursorStyle::PointingHand)
2052 .on_click(MouseButton::Left, |_, this, cx| this.cycle_model(cx))
2053 }
2054
2055 fn render_remaining_tokens(
2056 &self,
2057 style: &AssistantStyle,
2058 cx: &mut ViewContext<Self>,
2059 ) -> Option<impl Element<Self>> {
2060 let remaining_tokens = self.conversation.read(cx).remaining_tokens()?;
2061 let remaining_tokens_style = if remaining_tokens <= 0 {
2062 &style.no_remaining_tokens
2063 } else if remaining_tokens <= 500 {
2064 &style.low_remaining_tokens
2065 } else {
2066 &style.remaining_tokens
2067 };
2068 Some(
2069 Label::new(
2070 remaining_tokens.to_string(),
2071 remaining_tokens_style.text.clone(),
2072 )
2073 .contained()
2074 .with_style(remaining_tokens_style.container),
2075 )
2076 }
2077}
2078
2079impl Entity for ConversationEditor {
2080 type Event = ConversationEditorEvent;
2081}
2082
2083impl View for ConversationEditor {
2084 fn ui_name() -> &'static str {
2085 "ConversationEditor"
2086 }
2087
2088 fn render(&mut self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
2089 let theme = &theme::current(cx).assistant;
2090 Stack::new()
2091 .with_child(
2092 ChildView::new(&self.editor, cx)
2093 .contained()
2094 .with_style(theme.container),
2095 )
2096 .with_child(
2097 Flex::row()
2098 .with_child(self.render_current_model(theme, cx))
2099 .with_children(self.render_remaining_tokens(theme, cx))
2100 .aligned()
2101 .top()
2102 .right(),
2103 )
2104 .into_any()
2105 }
2106
2107 fn focus_in(&mut self, _: gpui::AnyViewHandle, cx: &mut ViewContext<Self>) {
2108 if cx.is_self_focused() {
2109 cx.focus(&self.editor);
2110 }
2111 }
2112}
2113
2114#[derive(Clone, Debug)]
2115struct MessageAnchor {
2116 id: MessageId,
2117 start: language::Anchor,
2118}
2119
2120#[derive(Clone, Debug)]
2121pub struct Message {
2122 offset_range: Range<usize>,
2123 index_range: Range<usize>,
2124 id: MessageId,
2125 anchor: language::Anchor,
2126 role: Role,
2127 sent_at: DateTime<Local>,
2128 status: MessageStatus,
2129}
2130
2131impl Message {
2132 fn to_open_ai_message(&self, buffer: &Buffer) -> RequestMessage {
2133 let mut content = format!("[Message {}]\n", self.id.0).to_string();
2134 content.extend(buffer.text_for_range(self.offset_range.clone()));
2135 RequestMessage {
2136 role: self.role,
2137 content: content.trim_end().into(),
2138 }
2139 }
2140}
2141
2142#[cfg(test)]
2143mod tests {
2144 use super::*;
2145 use crate::MessageId;
2146 use gpui::AppContext;
2147
2148 #[gpui::test]
2149 fn test_inserting_and_removing_messages(cx: &mut AppContext) {
2150 let registry = Arc::new(LanguageRegistry::test());
2151 let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
2152 let buffer = conversation.read(cx).buffer.clone();
2153
2154 let message_1 = conversation.read(cx).message_anchors[0].clone();
2155 assert_eq!(
2156 messages(&conversation, cx),
2157 vec![(message_1.id, Role::User, 0..0)]
2158 );
2159
2160 let message_2 = conversation.update(cx, |conversation, cx| {
2161 conversation
2162 .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
2163 .unwrap()
2164 });
2165 assert_eq!(
2166 messages(&conversation, cx),
2167 vec![
2168 (message_1.id, Role::User, 0..1),
2169 (message_2.id, Role::Assistant, 1..1)
2170 ]
2171 );
2172
2173 buffer.update(cx, |buffer, cx| {
2174 buffer.edit([(0..0, "1"), (1..1, "2")], None, cx)
2175 });
2176 assert_eq!(
2177 messages(&conversation, cx),
2178 vec![
2179 (message_1.id, Role::User, 0..2),
2180 (message_2.id, Role::Assistant, 2..3)
2181 ]
2182 );
2183
2184 let message_3 = conversation.update(cx, |conversation, cx| {
2185 conversation
2186 .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
2187 .unwrap()
2188 });
2189 assert_eq!(
2190 messages(&conversation, cx),
2191 vec![
2192 (message_1.id, Role::User, 0..2),
2193 (message_2.id, Role::Assistant, 2..4),
2194 (message_3.id, Role::User, 4..4)
2195 ]
2196 );
2197
2198 let message_4 = conversation.update(cx, |conversation, cx| {
2199 conversation
2200 .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
2201 .unwrap()
2202 });
2203 assert_eq!(
2204 messages(&conversation, cx),
2205 vec![
2206 (message_1.id, Role::User, 0..2),
2207 (message_2.id, Role::Assistant, 2..4),
2208 (message_4.id, Role::User, 4..5),
2209 (message_3.id, Role::User, 5..5),
2210 ]
2211 );
2212
2213 buffer.update(cx, |buffer, cx| {
2214 buffer.edit([(4..4, "C"), (5..5, "D")], None, cx)
2215 });
2216 assert_eq!(
2217 messages(&conversation, cx),
2218 vec![
2219 (message_1.id, Role::User, 0..2),
2220 (message_2.id, Role::Assistant, 2..4),
2221 (message_4.id, Role::User, 4..6),
2222 (message_3.id, Role::User, 6..7),
2223 ]
2224 );
2225
2226 // Deleting across message boundaries merges the messages.
2227 buffer.update(cx, |buffer, cx| buffer.edit([(1..4, "")], None, cx));
2228 assert_eq!(
2229 messages(&conversation, cx),
2230 vec![
2231 (message_1.id, Role::User, 0..3),
2232 (message_3.id, Role::User, 3..4),
2233 ]
2234 );
2235
2236 // Undoing the deletion should also undo the merge.
2237 buffer.update(cx, |buffer, cx| buffer.undo(cx));
2238 assert_eq!(
2239 messages(&conversation, cx),
2240 vec![
2241 (message_1.id, Role::User, 0..2),
2242 (message_2.id, Role::Assistant, 2..4),
2243 (message_4.id, Role::User, 4..6),
2244 (message_3.id, Role::User, 6..7),
2245 ]
2246 );
2247
2248 // Redoing the deletion should also redo the merge.
2249 buffer.update(cx, |buffer, cx| buffer.redo(cx));
2250 assert_eq!(
2251 messages(&conversation, cx),
2252 vec![
2253 (message_1.id, Role::User, 0..3),
2254 (message_3.id, Role::User, 3..4),
2255 ]
2256 );
2257
2258 // Ensure we can still insert after a merged message.
2259 let message_5 = conversation.update(cx, |conversation, cx| {
2260 conversation
2261 .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
2262 .unwrap()
2263 });
2264 assert_eq!(
2265 messages(&conversation, cx),
2266 vec![
2267 (message_1.id, Role::User, 0..3),
2268 (message_5.id, Role::System, 3..4),
2269 (message_3.id, Role::User, 4..5)
2270 ]
2271 );
2272 }
2273
2274 #[gpui::test]
2275 fn test_message_splitting(cx: &mut AppContext) {
2276 let registry = Arc::new(LanguageRegistry::test());
2277 let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
2278 let buffer = conversation.read(cx).buffer.clone();
2279
2280 let message_1 = conversation.read(cx).message_anchors[0].clone();
2281 assert_eq!(
2282 messages(&conversation, cx),
2283 vec![(message_1.id, Role::User, 0..0)]
2284 );
2285
2286 buffer.update(cx, |buffer, cx| {
2287 buffer.edit([(0..0, "aaa\nbbb\nccc\nddd\n")], None, cx)
2288 });
2289
2290 let (_, message_2) =
2291 conversation.update(cx, |conversation, cx| conversation.split_message(3..3, cx));
2292 let message_2 = message_2.unwrap();
2293
2294 // We recycle newlines in the middle of a split message
2295 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\nddd\n");
2296 assert_eq!(
2297 messages(&conversation, cx),
2298 vec![
2299 (message_1.id, Role::User, 0..4),
2300 (message_2.id, Role::User, 4..16),
2301 ]
2302 );
2303
2304 let (_, message_3) =
2305 conversation.update(cx, |conversation, cx| conversation.split_message(3..3, cx));
2306 let message_3 = message_3.unwrap();
2307
2308 // We don't recycle newlines at the end of a split message
2309 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
2310 assert_eq!(
2311 messages(&conversation, cx),
2312 vec![
2313 (message_1.id, Role::User, 0..4),
2314 (message_3.id, Role::User, 4..5),
2315 (message_2.id, Role::User, 5..17),
2316 ]
2317 );
2318
2319 let (_, message_4) =
2320 conversation.update(cx, |conversation, cx| conversation.split_message(9..9, cx));
2321 let message_4 = message_4.unwrap();
2322 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
2323 assert_eq!(
2324 messages(&conversation, cx),
2325 vec![
2326 (message_1.id, Role::User, 0..4),
2327 (message_3.id, Role::User, 4..5),
2328 (message_2.id, Role::User, 5..9),
2329 (message_4.id, Role::User, 9..17),
2330 ]
2331 );
2332
2333 let (_, message_5) =
2334 conversation.update(cx, |conversation, cx| conversation.split_message(9..9, cx));
2335 let message_5 = message_5.unwrap();
2336 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\nddd\n");
2337 assert_eq!(
2338 messages(&conversation, cx),
2339 vec![
2340 (message_1.id, Role::User, 0..4),
2341 (message_3.id, Role::User, 4..5),
2342 (message_2.id, Role::User, 5..9),
2343 (message_4.id, Role::User, 9..10),
2344 (message_5.id, Role::User, 10..18),
2345 ]
2346 );
2347
2348 let (message_6, message_7) = conversation.update(cx, |conversation, cx| {
2349 conversation.split_message(14..16, cx)
2350 });
2351 let message_6 = message_6.unwrap();
2352 let message_7 = message_7.unwrap();
2353 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\ndd\nd\n");
2354 assert_eq!(
2355 messages(&conversation, cx),
2356 vec![
2357 (message_1.id, Role::User, 0..4),
2358 (message_3.id, Role::User, 4..5),
2359 (message_2.id, Role::User, 5..9),
2360 (message_4.id, Role::User, 9..10),
2361 (message_5.id, Role::User, 10..14),
2362 (message_6.id, Role::User, 14..17),
2363 (message_7.id, Role::User, 17..19),
2364 ]
2365 );
2366 }
2367
2368 #[gpui::test]
2369 fn test_messages_for_offsets(cx: &mut AppContext) {
2370 let registry = Arc::new(LanguageRegistry::test());
2371 let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
2372 let buffer = conversation.read(cx).buffer.clone();
2373
2374 let message_1 = conversation.read(cx).message_anchors[0].clone();
2375 assert_eq!(
2376 messages(&conversation, cx),
2377 vec![(message_1.id, Role::User, 0..0)]
2378 );
2379
2380 buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
2381 let message_2 = conversation
2382 .update(cx, |conversation, cx| {
2383 conversation.insert_message_after(message_1.id, Role::User, MessageStatus::Done, cx)
2384 })
2385 .unwrap();
2386 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx));
2387
2388 let message_3 = conversation
2389 .update(cx, |conversation, cx| {
2390 conversation.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
2391 })
2392 .unwrap();
2393 buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx));
2394
2395 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc");
2396 assert_eq!(
2397 messages(&conversation, cx),
2398 vec![
2399 (message_1.id, Role::User, 0..4),
2400 (message_2.id, Role::User, 4..8),
2401 (message_3.id, Role::User, 8..11)
2402 ]
2403 );
2404
2405 assert_eq!(
2406 message_ids_for_offsets(&conversation, &[0, 4, 9], cx),
2407 [message_1.id, message_2.id, message_3.id]
2408 );
2409 assert_eq!(
2410 message_ids_for_offsets(&conversation, &[0, 1, 11], cx),
2411 [message_1.id, message_3.id]
2412 );
2413
2414 let message_4 = conversation
2415 .update(cx, |conversation, cx| {
2416 conversation.insert_message_after(message_3.id, Role::User, MessageStatus::Done, cx)
2417 })
2418 .unwrap();
2419 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\n");
2420 assert_eq!(
2421 messages(&conversation, cx),
2422 vec![
2423 (message_1.id, Role::User, 0..4),
2424 (message_2.id, Role::User, 4..8),
2425 (message_3.id, Role::User, 8..12),
2426 (message_4.id, Role::User, 12..12)
2427 ]
2428 );
2429 assert_eq!(
2430 message_ids_for_offsets(&conversation, &[0, 4, 8, 12], cx),
2431 [message_1.id, message_2.id, message_3.id, message_4.id]
2432 );
2433
2434 fn message_ids_for_offsets(
2435 conversation: &ModelHandle<Conversation>,
2436 offsets: &[usize],
2437 cx: &AppContext,
2438 ) -> Vec<MessageId> {
2439 conversation
2440 .read(cx)
2441 .messages_for_offsets(offsets.iter().copied(), cx)
2442 .into_iter()
2443 .map(|message| message.id)
2444 .collect()
2445 }
2446 }
2447
2448 #[gpui::test]
2449 fn test_serialization(cx: &mut AppContext) {
2450 let registry = Arc::new(LanguageRegistry::test());
2451 let conversation =
2452 cx.add_model(|cx| Conversation::new(Default::default(), registry.clone(), cx));
2453 let buffer = conversation.read(cx).buffer.clone();
2454 let message_0 = conversation.read(cx).message_anchors[0].id;
2455 let message_1 = conversation.update(cx, |conversation, cx| {
2456 conversation
2457 .insert_message_after(message_0, Role::Assistant, MessageStatus::Done, cx)
2458 .unwrap()
2459 });
2460 let message_2 = conversation.update(cx, |conversation, cx| {
2461 conversation
2462 .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
2463 .unwrap()
2464 });
2465 buffer.update(cx, |buffer, cx| {
2466 buffer.edit([(0..0, "a"), (1..1, "b\nc")], None, cx);
2467 buffer.finalize_last_transaction();
2468 });
2469 let _message_3 = conversation.update(cx, |conversation, cx| {
2470 conversation
2471 .insert_message_after(message_2.id, Role::System, MessageStatus::Done, cx)
2472 .unwrap()
2473 });
2474 buffer.update(cx, |buffer, cx| buffer.undo(cx));
2475 assert_eq!(buffer.read(cx).text(), "a\nb\nc\n");
2476 assert_eq!(
2477 messages(&conversation, cx),
2478 [
2479 (message_0, Role::User, 0..2),
2480 (message_1.id, Role::Assistant, 2..6),
2481 (message_2.id, Role::System, 6..6),
2482 ]
2483 );
2484
2485 let deserialized_conversation = cx.add_model(|cx| {
2486 Conversation::deserialize(
2487 conversation.read(cx).serialize(cx),
2488 Default::default(),
2489 Default::default(),
2490 registry.clone(),
2491 cx,
2492 )
2493 });
2494 let deserialized_buffer = deserialized_conversation.read(cx).buffer.clone();
2495 assert_eq!(deserialized_buffer.read(cx).text(), "a\nb\nc\n");
2496 assert_eq!(
2497 messages(&deserialized_conversation, cx),
2498 [
2499 (message_0, Role::User, 0..2),
2500 (message_1.id, Role::Assistant, 2..6),
2501 (message_2.id, Role::System, 6..6),
2502 ]
2503 );
2504 }
2505
2506 fn messages(
2507 conversation: &ModelHandle<Conversation>,
2508 cx: &AppContext,
2509 ) -> Vec<(MessageId, Role, Range<usize>)> {
2510 conversation
2511 .read(cx)
2512 .messages(cx)
2513 .map(|message| (message.id, message.role, message.offset_range))
2514 .collect()
2515 }
2516}