1use crate::{
2 assistant_settings::{AssistantDockPosition, AssistantSettings},
3 MessageId, MessageMetadata, MessageStatus, OpenAIRequest, OpenAIResponseStreamEvent,
4 RequestMessage, Role, SavedConversation, SavedConversationMetadata, SavedMessage,
5};
6use anyhow::{anyhow, Result};
7use chrono::{DateTime, Local};
8use collections::{HashMap, HashSet};
9use editor::{
10 display_map::{BlockDisposition, BlockId, BlockProperties, BlockStyle, ToDisplayPoint},
11 scroll::autoscroll::{Autoscroll, AutoscrollStrategy},
12 Anchor, Editor, ToOffset,
13};
14use fs::Fs;
15use futures::{io::BufReader, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
16use gpui::{
17 actions,
18 elements::*,
19 executor::Background,
20 geometry::vector::{vec2f, Vector2F},
21 platform::{CursorStyle, MouseButton},
22 Action, AppContext, AsyncAppContext, ClipboardItem, Entity, ModelContext, ModelHandle,
23 Subscription, Task, View, ViewContext, ViewHandle, WeakViewHandle, WindowContext,
24};
25use isahc::{http::StatusCode, Request, RequestExt};
26use language::{language_settings::SoftWrap, Buffer, LanguageRegistry, ToOffset as _};
27use serde::Deserialize;
28use settings::SettingsStore;
29use std::{
30 borrow::Cow,
31 cell::RefCell,
32 cmp, env,
33 fmt::Write,
34 io, iter,
35 ops::Range,
36 path::{Path, PathBuf},
37 rc::Rc,
38 sync::Arc,
39 time::Duration,
40};
41use theme::ui::IconStyle;
42use util::{
43 channel::ReleaseChannel, paths::CONVERSATIONS_DIR, post_inc, truncate_and_trailoff, ResultExt,
44 TryFutureExt,
45};
46use workspace::{
47 dock::{DockPosition, Panel},
48 item::Item,
49 Save, Workspace,
50};
51
52const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
53
54actions!(
55 assistant,
56 [
57 NewContext,
58 Assist,
59 Split,
60 CycleMessageRole,
61 QuoteSelection,
62 ToggleFocus,
63 ResetKey,
64 ]
65);
66
67pub fn init(cx: &mut AppContext) {
68 if *util::channel::RELEASE_CHANNEL == ReleaseChannel::Stable {
69 cx.update_default_global::<collections::CommandPaletteFilter, _, _>(move |filter, _cx| {
70 filter.filtered_namespaces.insert("assistant");
71 });
72 }
73
74 settings::register::<AssistantSettings>(cx);
75 cx.add_action(
76 |workspace: &mut Workspace, _: &NewContext, cx: &mut ViewContext<Workspace>| {
77 if let Some(this) = workspace.panel::<AssistantPanel>(cx) {
78 this.update(cx, |this, cx| {
79 this.new_conversation(cx);
80 })
81 }
82
83 workspace.focus_panel::<AssistantPanel>(cx);
84 },
85 );
86 cx.add_action(ConversationEditor::assist);
87 cx.capture_action(ConversationEditor::cancel_last_assist);
88 cx.capture_action(ConversationEditor::save);
89 cx.add_action(ConversationEditor::quote_selection);
90 cx.capture_action(ConversationEditor::copy);
91 cx.capture_action(ConversationEditor::split);
92 cx.capture_action(ConversationEditor::cycle_message_role);
93 cx.add_action(AssistantPanel::save_api_key);
94 cx.add_action(AssistantPanel::reset_api_key);
95 cx.add_action(AssistantPanel::toggle_zoom);
96 cx.add_action(
97 |workspace: &mut Workspace, _: &ToggleFocus, cx: &mut ViewContext<Workspace>| {
98 workspace.toggle_panel_focus::<AssistantPanel>(cx);
99 },
100 );
101}
102
103#[derive(Debug)]
104pub enum AssistantPanelEvent {
105 ZoomIn,
106 ZoomOut,
107 Focus,
108 Close,
109 DockPositionChanged,
110}
111
112pub struct AssistantPanel {
113 width: Option<f32>,
114 height: Option<f32>,
115 active_conversation_index: Option<usize>,
116 conversation_editors: Vec<ViewHandle<ConversationEditor>>,
117 saved_conversations: Vec<SavedConversationMetadata>,
118 saved_conversations_list_state: UniformListState,
119 zoomed: bool,
120 has_focus: bool,
121 api_key: Rc<RefCell<Option<String>>>,
122 api_key_editor: Option<ViewHandle<Editor>>,
123 has_read_credentials: bool,
124 languages: Arc<LanguageRegistry>,
125 fs: Arc<dyn Fs>,
126 subscriptions: Vec<Subscription>,
127 _watch_saved_conversations: Task<Result<()>>,
128}
129
130impl AssistantPanel {
131 pub fn load(
132 workspace: WeakViewHandle<Workspace>,
133 cx: AsyncAppContext,
134 ) -> Task<Result<ViewHandle<Self>>> {
135 cx.spawn(|mut cx| async move {
136 let fs = workspace.read_with(&cx, |workspace, _| workspace.app_state().fs.clone())?;
137 let saved_conversations = SavedConversationMetadata::list(fs.clone())
138 .await
139 .log_err()
140 .unwrap_or_default();
141
142 // TODO: deserialize state.
143 workspace.update(&mut cx, |workspace, cx| {
144 cx.add_view::<Self, _>(|cx| {
145 const CONVERSATION_WATCH_DURATION: Duration = Duration::from_millis(100);
146 let _watch_saved_conversations = cx.spawn(move |this, mut cx| async move {
147 let mut events = fs
148 .watch(&CONVERSATIONS_DIR, CONVERSATION_WATCH_DURATION)
149 .await;
150 while events.next().await.is_some() {
151 let saved_conversations = SavedConversationMetadata::list(fs.clone())
152 .await
153 .log_err()
154 .unwrap_or_default();
155 this.update(&mut cx, |this, _| {
156 this.saved_conversations = saved_conversations
157 })
158 .ok();
159 }
160
161 anyhow::Ok(())
162 });
163
164 let mut this = Self {
165 active_conversation_index: Default::default(),
166 conversation_editors: Default::default(),
167 saved_conversations,
168 saved_conversations_list_state: Default::default(),
169 zoomed: false,
170 has_focus: false,
171 api_key: Rc::new(RefCell::new(None)),
172 api_key_editor: None,
173 has_read_credentials: false,
174 languages: workspace.app_state().languages.clone(),
175 fs: workspace.app_state().fs.clone(),
176 width: None,
177 height: None,
178 subscriptions: Default::default(),
179 _watch_saved_conversations,
180 };
181
182 let mut old_dock_position = this.position(cx);
183 this.subscriptions =
184 vec![cx.observe_global::<SettingsStore, _>(move |this, cx| {
185 let new_dock_position = this.position(cx);
186 if new_dock_position != old_dock_position {
187 old_dock_position = new_dock_position;
188 cx.emit(AssistantPanelEvent::DockPositionChanged);
189 }
190 })];
191
192 this
193 })
194 })
195 })
196 }
197
198 fn new_conversation(&mut self, cx: &mut ViewContext<Self>) -> ViewHandle<ConversationEditor> {
199 let editor = cx.add_view(|cx| {
200 ConversationEditor::new(
201 self.api_key.clone(),
202 self.languages.clone(),
203 self.fs.clone(),
204 cx,
205 )
206 });
207 self.add_conversation(editor.clone(), cx);
208 editor
209 }
210
211 fn add_conversation(
212 &mut self,
213 editor: ViewHandle<ConversationEditor>,
214 cx: &mut ViewContext<Self>,
215 ) {
216 self.subscriptions
217 .push(cx.subscribe(&editor, Self::handle_conversation_editor_event));
218
219 self.active_conversation_index = Some(self.conversation_editors.len());
220 self.conversation_editors.push(editor.clone());
221 if self.has_focus(cx) {
222 cx.focus(&editor);
223 }
224 cx.notify();
225 }
226
227 fn handle_conversation_editor_event(
228 &mut self,
229 _: ViewHandle<ConversationEditor>,
230 event: &ConversationEditorEvent,
231 cx: &mut ViewContext<Self>,
232 ) {
233 match event {
234 ConversationEditorEvent::TabContentChanged => cx.notify(),
235 }
236 }
237
238 fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
239 if let Some(api_key) = self
240 .api_key_editor
241 .as_ref()
242 .map(|editor| editor.read(cx).text(cx))
243 {
244 if !api_key.is_empty() {
245 cx.platform()
246 .write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
247 .log_err();
248 *self.api_key.borrow_mut() = Some(api_key);
249 self.api_key_editor.take();
250 cx.focus_self();
251 cx.notify();
252 }
253 } else {
254 cx.propagate_action();
255 }
256 }
257
258 fn reset_api_key(&mut self, _: &ResetKey, cx: &mut ViewContext<Self>) {
259 cx.platform().delete_credentials(OPENAI_API_URL).log_err();
260 self.api_key.take();
261 self.api_key_editor = Some(build_api_key_editor(cx));
262 cx.focus_self();
263 cx.notify();
264 }
265
266 fn toggle_zoom(&mut self, _: &workspace::ToggleZoom, cx: &mut ViewContext<Self>) {
267 if self.zoomed {
268 cx.emit(AssistantPanelEvent::ZoomOut)
269 } else {
270 cx.emit(AssistantPanelEvent::ZoomIn)
271 }
272 }
273
274 fn active_conversation_editor(&self) -> Option<&ViewHandle<ConversationEditor>> {
275 self.conversation_editors
276 .get(self.active_conversation_index?)
277 }
278
279 fn render_hamburger_button(style: &IconStyle) -> impl Element<Self> {
280 enum ListConversations {}
281 Svg::for_style(style.icon.clone())
282 .contained()
283 .with_style(style.container)
284 .mouse::<ListConversations>(0)
285 .with_cursor_style(CursorStyle::PointingHand)
286 .on_click(MouseButton::Left, |_, this: &mut Self, cx| {
287 this.active_conversation_index = None;
288 cx.notify();
289 })
290 }
291
292 fn render_plus_button(style: &IconStyle) -> impl Element<Self> {
293 enum AddConversation {}
294 Svg::for_style(style.icon.clone())
295 .contained()
296 .with_style(style.container)
297 .mouse::<AddConversation>(0)
298 .with_cursor_style(CursorStyle::PointingHand)
299 .on_click(MouseButton::Left, |_, this: &mut Self, cx| {
300 this.new_conversation(cx);
301 })
302 }
303
304 fn render_saved_conversation(
305 &mut self,
306 index: usize,
307 cx: &mut ViewContext<Self>,
308 ) -> impl Element<Self> {
309 let conversation = &self.saved_conversations[index];
310 let path = conversation.path.clone();
311 MouseEventHandler::<SavedConversationMetadata, _>::new(index, cx, move |state, cx| {
312 let style = &theme::current(cx).assistant.saved_conversation;
313 Flex::row()
314 .with_child(
315 Label::new(
316 conversation.mtime.format("%F %I:%M%p").to_string(),
317 style.saved_at.text.clone(),
318 )
319 .aligned()
320 .contained()
321 .with_style(style.saved_at.container),
322 )
323 .with_child(
324 Label::new(conversation.title.clone(), style.title.text.clone())
325 .aligned()
326 .contained()
327 .with_style(style.title.container),
328 )
329 .contained()
330 .with_style(*style.container.style_for(state, false))
331 })
332 .with_cursor_style(CursorStyle::PointingHand)
333 .on_click(MouseButton::Left, move |_, this, cx| {
334 this.open_conversation(path.clone(), cx)
335 .detach_and_log_err(cx)
336 })
337 }
338
339 fn open_conversation(&mut self, path: PathBuf, cx: &mut ViewContext<Self>) -> Task<Result<()>> {
340 if let Some(ix) = self.conversation_editor_index_for_path(&path, cx) {
341 self.active_conversation_index = Some(ix);
342 cx.notify();
343 return Task::ready(Ok(()));
344 }
345
346 let fs = self.fs.clone();
347 let conversation = Conversation::load(
348 path.clone(),
349 self.api_key.clone(),
350 self.languages.clone(),
351 self.fs.clone(),
352 cx,
353 );
354 cx.spawn(|this, mut cx| async move {
355 let conversation = conversation.await?;
356 this.update(&mut cx, |this, cx| {
357 // If, by the time we've loaded the conversation, the user has already opened
358 // the same conversation, we don't want to open it again.
359 if let Some(ix) = this.conversation_editor_index_for_path(&path, cx) {
360 this.active_conversation_index = Some(ix);
361 } else {
362 let editor = cx
363 .add_view(|cx| ConversationEditor::from_conversation(conversation, fs, cx));
364 this.add_conversation(editor, cx);
365 }
366 })?;
367 Ok(())
368 })
369 }
370
371 fn conversation_editor_index_for_path(&self, path: &Path, cx: &AppContext) -> Option<usize> {
372 self.conversation_editors
373 .iter()
374 .position(|editor| editor.read(cx).conversation.read(cx).path.as_deref() == Some(path))
375 }
376}
377
378fn build_api_key_editor(cx: &mut ViewContext<AssistantPanel>) -> ViewHandle<Editor> {
379 cx.add_view(|cx| {
380 let mut editor = Editor::single_line(
381 Some(Arc::new(|theme| theme.assistant.api_key_editor.clone())),
382 cx,
383 );
384 editor.set_placeholder_text("sk-000000000000000000000000000000000000000000000000", cx);
385 editor
386 })
387}
388
389impl Entity for AssistantPanel {
390 type Event = AssistantPanelEvent;
391}
392
393impl View for AssistantPanel {
394 fn ui_name() -> &'static str {
395 "AssistantPanel"
396 }
397
398 fn render(&mut self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
399 let theme = &theme::current(cx);
400 let style = &theme.assistant;
401 if let Some(api_key_editor) = self.api_key_editor.as_ref() {
402 Flex::column()
403 .with_child(
404 Text::new(
405 "Paste your OpenAI API key and press Enter to use the assistant",
406 style.api_key_prompt.text.clone(),
407 )
408 .aligned(),
409 )
410 .with_child(
411 ChildView::new(api_key_editor, cx)
412 .contained()
413 .with_style(style.api_key_editor.container)
414 .aligned(),
415 )
416 .contained()
417 .with_style(style.api_key_prompt.container)
418 .aligned()
419 .into_any()
420 } else {
421 let title = self.active_conversation_editor().map(|editor| {
422 Label::new(editor.read(cx).title(cx), style.title.text.clone())
423 .contained()
424 .with_style(style.title.container)
425 .aligned()
426 });
427
428 Flex::column()
429 .with_child(
430 Flex::row()
431 .with_child(
432 Self::render_hamburger_button(&style.hamburger_button).aligned(),
433 )
434 .with_children(title)
435 .with_child(
436 Self::render_plus_button(&style.plus_button)
437 .aligned()
438 .flex_float(),
439 )
440 .contained()
441 .with_style(theme.workspace.tab_bar.container)
442 .expanded()
443 .constrained()
444 .with_height(theme.workspace.tab_bar.height),
445 )
446 .with_child(if let Some(editor) = self.active_conversation_editor() {
447 ChildView::new(editor, cx).flex(1., true).into_any()
448 } else {
449 UniformList::new(
450 self.saved_conversations_list_state.clone(),
451 self.saved_conversations.len(),
452 cx,
453 |this, range, items, cx| {
454 for ix in range {
455 items.push(this.render_saved_conversation(ix, cx).into_any());
456 }
457 },
458 )
459 .flex(1., true)
460 .into_any()
461 })
462 .into_any()
463 }
464 }
465
466 fn focus_in(&mut self, _: gpui::AnyViewHandle, cx: &mut ViewContext<Self>) {
467 self.has_focus = true;
468 if cx.is_self_focused() {
469 if let Some(editor) = self.active_conversation_editor() {
470 cx.focus(editor);
471 } else if let Some(api_key_editor) = self.api_key_editor.as_ref() {
472 cx.focus(api_key_editor);
473 }
474 }
475 }
476
477 fn focus_out(&mut self, _: gpui::AnyViewHandle, _: &mut ViewContext<Self>) {
478 self.has_focus = false;
479 }
480}
481
482impl Panel for AssistantPanel {
483 fn position(&self, cx: &WindowContext) -> DockPosition {
484 match settings::get::<AssistantSettings>(cx).dock {
485 AssistantDockPosition::Left => DockPosition::Left,
486 AssistantDockPosition::Bottom => DockPosition::Bottom,
487 AssistantDockPosition::Right => DockPosition::Right,
488 }
489 }
490
491 fn position_is_valid(&self, _: DockPosition) -> bool {
492 true
493 }
494
495 fn set_position(&mut self, position: DockPosition, cx: &mut ViewContext<Self>) {
496 settings::update_settings_file::<AssistantSettings>(self.fs.clone(), cx, move |settings| {
497 let dock = match position {
498 DockPosition::Left => AssistantDockPosition::Left,
499 DockPosition::Bottom => AssistantDockPosition::Bottom,
500 DockPosition::Right => AssistantDockPosition::Right,
501 };
502 settings.dock = Some(dock);
503 });
504 }
505
506 fn size(&self, cx: &WindowContext) -> f32 {
507 let settings = settings::get::<AssistantSettings>(cx);
508 match self.position(cx) {
509 DockPosition::Left | DockPosition::Right => {
510 self.width.unwrap_or_else(|| settings.default_width)
511 }
512 DockPosition::Bottom => self.height.unwrap_or_else(|| settings.default_height),
513 }
514 }
515
516 fn set_size(&mut self, size: f32, cx: &mut ViewContext<Self>) {
517 match self.position(cx) {
518 DockPosition::Left | DockPosition::Right => self.width = Some(size),
519 DockPosition::Bottom => self.height = Some(size),
520 }
521 cx.notify();
522 }
523
524 fn should_zoom_in_on_event(event: &AssistantPanelEvent) -> bool {
525 matches!(event, AssistantPanelEvent::ZoomIn)
526 }
527
528 fn should_zoom_out_on_event(event: &AssistantPanelEvent) -> bool {
529 matches!(event, AssistantPanelEvent::ZoomOut)
530 }
531
532 fn is_zoomed(&self, _: &WindowContext) -> bool {
533 self.zoomed
534 }
535
536 fn set_zoomed(&mut self, zoomed: bool, cx: &mut ViewContext<Self>) {
537 self.zoomed = zoomed;
538 cx.notify();
539 }
540
541 fn set_active(&mut self, active: bool, cx: &mut ViewContext<Self>) {
542 if active {
543 if self.api_key.borrow().is_none() && !self.has_read_credentials {
544 self.has_read_credentials = true;
545 let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
546 Some(api_key)
547 } else if let Some((_, api_key)) = cx
548 .platform()
549 .read_credentials(OPENAI_API_URL)
550 .log_err()
551 .flatten()
552 {
553 String::from_utf8(api_key).log_err()
554 } else {
555 None
556 };
557 if let Some(api_key) = api_key {
558 *self.api_key.borrow_mut() = Some(api_key);
559 } else if self.api_key_editor.is_none() {
560 self.api_key_editor = Some(build_api_key_editor(cx));
561 cx.notify();
562 }
563 }
564
565 if self.conversation_editors.is_empty() {
566 self.new_conversation(cx);
567 }
568 }
569 }
570
571 fn icon_path(&self) -> &'static str {
572 "icons/robot_14.svg"
573 }
574
575 fn icon_tooltip(&self) -> (String, Option<Box<dyn Action>>) {
576 ("Assistant Panel".into(), Some(Box::new(ToggleFocus)))
577 }
578
579 fn should_change_position_on_event(event: &Self::Event) -> bool {
580 matches!(event, AssistantPanelEvent::DockPositionChanged)
581 }
582
583 fn should_activate_on_event(_: &Self::Event) -> bool {
584 false
585 }
586
587 fn should_close_on_event(event: &AssistantPanelEvent) -> bool {
588 matches!(event, AssistantPanelEvent::Close)
589 }
590
591 fn has_focus(&self, _: &WindowContext) -> bool {
592 self.has_focus
593 }
594
595 fn is_focus_event(event: &Self::Event) -> bool {
596 matches!(event, AssistantPanelEvent::Focus)
597 }
598}
599
600enum ConversationEvent {
601 MessagesEdited,
602 SummaryChanged,
603 StreamedCompletion,
604}
605
606#[derive(Default)]
607struct Summary {
608 text: String,
609 done: bool,
610}
611
612struct Conversation {
613 buffer: ModelHandle<Buffer>,
614 message_anchors: Vec<MessageAnchor>,
615 messages_metadata: HashMap<MessageId, MessageMetadata>,
616 next_message_id: MessageId,
617 summary: Option<Summary>,
618 pending_summary: Task<Option<()>>,
619 completion_count: usize,
620 pending_completions: Vec<PendingCompletion>,
621 model: String,
622 token_count: Option<usize>,
623 max_token_count: usize,
624 pending_token_count: Task<Option<()>>,
625 api_key: Rc<RefCell<Option<String>>>,
626 pending_save: Task<Result<()>>,
627 path: Option<PathBuf>,
628 _subscriptions: Vec<Subscription>,
629}
630
631impl Entity for Conversation {
632 type Event = ConversationEvent;
633}
634
635impl Conversation {
636 fn new(
637 api_key: Rc<RefCell<Option<String>>>,
638 language_registry: Arc<LanguageRegistry>,
639 cx: &mut ModelContext<Self>,
640 ) -> Self {
641 let model = "gpt-3.5-turbo-0613";
642 let markdown = language_registry.language_for_name("Markdown");
643 let buffer = cx.add_model(|cx| {
644 let mut buffer = Buffer::new(0, "", cx);
645 buffer.set_language_registry(language_registry);
646 cx.spawn_weak(|buffer, mut cx| async move {
647 let markdown = markdown.await?;
648 let buffer = buffer
649 .upgrade(&cx)
650 .ok_or_else(|| anyhow!("buffer was dropped"))?;
651 buffer.update(&mut cx, |buffer, cx| {
652 buffer.set_language(Some(markdown), cx)
653 });
654 anyhow::Ok(())
655 })
656 .detach_and_log_err(cx);
657 buffer
658 });
659
660 let mut this = Self {
661 message_anchors: Default::default(),
662 messages_metadata: Default::default(),
663 next_message_id: Default::default(),
664 summary: None,
665 pending_summary: Task::ready(None),
666 completion_count: Default::default(),
667 pending_completions: Default::default(),
668 token_count: None,
669 max_token_count: tiktoken_rs::model::get_context_size(model),
670 pending_token_count: Task::ready(None),
671 model: model.into(),
672 _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
673 pending_save: Task::ready(Ok(())),
674 path: None,
675 api_key,
676 buffer,
677 };
678 let message = MessageAnchor {
679 id: MessageId(post_inc(&mut this.next_message_id.0)),
680 start: language::Anchor::MIN,
681 };
682 this.message_anchors.push(message.clone());
683 this.messages_metadata.insert(
684 message.id,
685 MessageMetadata {
686 role: Role::User,
687 sent_at: Local::now(),
688 status: MessageStatus::Done,
689 },
690 );
691
692 this.count_remaining_tokens(cx);
693 this
694 }
695
696 fn load(
697 path: PathBuf,
698 api_key: Rc<RefCell<Option<String>>>,
699 language_registry: Arc<LanguageRegistry>,
700 fs: Arc<dyn Fs>,
701 cx: &mut AppContext,
702 ) -> Task<Result<ModelHandle<Self>>> {
703 cx.spawn(|mut cx| async move {
704 let saved_conversation = fs.load(&path).await?;
705 let saved_conversation: SavedConversation = serde_json::from_str(&saved_conversation)?;
706
707 let model = saved_conversation.model;
708 let markdown = language_registry.language_for_name("Markdown");
709 let mut message_anchors = Vec::new();
710 let mut next_message_id = MessageId(0);
711 let buffer = cx.add_model(|cx| {
712 let mut buffer = Buffer::new(0, saved_conversation.text, cx);
713 for message in saved_conversation.messages {
714 message_anchors.push(MessageAnchor {
715 id: message.id,
716 start: buffer.anchor_before(message.start),
717 });
718 next_message_id = cmp::max(next_message_id, MessageId(message.id.0 + 1));
719 }
720 buffer.set_language_registry(language_registry);
721 cx.spawn_weak(|buffer, mut cx| async move {
722 let markdown = markdown.await?;
723 let buffer = buffer
724 .upgrade(&cx)
725 .ok_or_else(|| anyhow!("buffer was dropped"))?;
726 buffer.update(&mut cx, |buffer, cx| {
727 buffer.set_language(Some(markdown), cx)
728 });
729 anyhow::Ok(())
730 })
731 .detach_and_log_err(cx);
732 buffer
733 });
734 let conversation = cx.add_model(|cx| {
735 let mut this = Self {
736 message_anchors,
737 messages_metadata: saved_conversation.message_metadata,
738 next_message_id,
739 summary: Some(Summary {
740 text: saved_conversation.summary,
741 done: true,
742 }),
743 pending_summary: Task::ready(None),
744 completion_count: Default::default(),
745 pending_completions: Default::default(),
746 token_count: None,
747 max_token_count: tiktoken_rs::model::get_context_size(&model),
748 pending_token_count: Task::ready(None),
749 model,
750 _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
751 pending_save: Task::ready(Ok(())),
752 path: Some(path),
753 api_key,
754 buffer,
755 };
756
757 this.count_remaining_tokens(cx);
758 this
759 });
760 Ok(conversation)
761 })
762 }
763
764 fn handle_buffer_event(
765 &mut self,
766 _: ModelHandle<Buffer>,
767 event: &language::Event,
768 cx: &mut ModelContext<Self>,
769 ) {
770 match event {
771 language::Event::Edited => {
772 self.count_remaining_tokens(cx);
773 cx.emit(ConversationEvent::MessagesEdited);
774 }
775 _ => {}
776 }
777 }
778
779 fn count_remaining_tokens(&mut self, cx: &mut ModelContext<Self>) {
780 let messages = self
781 .messages(cx)
782 .into_iter()
783 .filter_map(|message| {
784 Some(tiktoken_rs::ChatCompletionRequestMessage {
785 role: match message.role {
786 Role::User => "user".into(),
787 Role::Assistant => "assistant".into(),
788 Role::System => "system".into(),
789 },
790 content: self.buffer.read(cx).text_for_range(message.range).collect(),
791 name: None,
792 })
793 })
794 .collect::<Vec<_>>();
795 let model = self.model.clone();
796 self.pending_token_count = cx.spawn_weak(|this, mut cx| {
797 async move {
798 cx.background().timer(Duration::from_millis(200)).await;
799 let token_count = cx
800 .background()
801 .spawn(async move { tiktoken_rs::num_tokens_from_messages(&model, &messages) })
802 .await?;
803
804 this.upgrade(&cx)
805 .ok_or_else(|| anyhow!("conversation was dropped"))?
806 .update(&mut cx, |this, cx| {
807 this.max_token_count = tiktoken_rs::model::get_context_size(&this.model);
808 this.token_count = Some(token_count);
809 cx.notify()
810 });
811 anyhow::Ok(())
812 }
813 .log_err()
814 });
815 }
816
817 fn remaining_tokens(&self) -> Option<isize> {
818 Some(self.max_token_count as isize - self.token_count? as isize)
819 }
820
821 fn set_model(&mut self, model: String, cx: &mut ModelContext<Self>) {
822 self.model = model;
823 self.count_remaining_tokens(cx);
824 cx.notify();
825 }
826
827 fn assist(
828 &mut self,
829 selected_messages: HashSet<MessageId>,
830 cx: &mut ModelContext<Self>,
831 ) -> Vec<MessageAnchor> {
832 let mut user_messages = Vec::new();
833 let mut tasks = Vec::new();
834 for selected_message_id in selected_messages {
835 let selected_message_role =
836 if let Some(metadata) = self.messages_metadata.get(&selected_message_id) {
837 metadata.role
838 } else {
839 continue;
840 };
841
842 if selected_message_role == Role::Assistant {
843 if let Some(user_message) = self.insert_message_after(
844 selected_message_id,
845 Role::User,
846 MessageStatus::Done,
847 cx,
848 ) {
849 user_messages.push(user_message);
850 } else {
851 continue;
852 }
853 } else {
854 let request = OpenAIRequest {
855 model: self.model.clone(),
856 messages: self
857 .messages(cx)
858 .filter(|message| matches!(message.status, MessageStatus::Done))
859 .flat_map(|message| {
860 let mut system_message = None;
861 if message.id == selected_message_id {
862 system_message = Some(RequestMessage {
863 role: Role::System,
864 content: concat!(
865 "Treat the following messages as additional knowledge you have learned about, ",
866 "but act as if they were not part of this conversation. That is, treat them ",
867 "as if the user didn't see them and couldn't possibly inquire about them."
868 ).into()
869 });
870 }
871
872 Some(message.to_open_ai_message(self.buffer.read(cx))).into_iter().chain(system_message)
873 })
874 .chain(Some(RequestMessage {
875 role: Role::System,
876 content: format!(
877 "Direct your reply to message with id {}. Do not include a [Message X] header.",
878 selected_message_id.0
879 ),
880 }))
881 .collect(),
882 stream: true,
883 };
884
885 let Some(api_key) = self.api_key.borrow().clone() else { continue };
886 let stream = stream_completion(api_key, cx.background().clone(), request);
887 let assistant_message = self
888 .insert_message_after(
889 selected_message_id,
890 Role::Assistant,
891 MessageStatus::Pending,
892 cx,
893 )
894 .unwrap();
895
896 tasks.push(cx.spawn_weak({
897 |this, mut cx| async move {
898 let assistant_message_id = assistant_message.id;
899 let stream_completion = async {
900 let mut messages = stream.await?;
901
902 while let Some(message) = messages.next().await {
903 let mut message = message?;
904 if let Some(choice) = message.choices.pop() {
905 this.upgrade(&cx)
906 .ok_or_else(|| anyhow!("conversation was dropped"))?
907 .update(&mut cx, |this, cx| {
908 let text: Arc<str> = choice.delta.content?.into();
909 let message_ix = this.message_anchors.iter().position(
910 |message| message.id == assistant_message_id,
911 )?;
912 this.buffer.update(cx, |buffer, cx| {
913 let offset = this.message_anchors[message_ix + 1..]
914 .iter()
915 .find(|message| message.start.is_valid(buffer))
916 .map_or(buffer.len(), |message| {
917 message
918 .start
919 .to_offset(buffer)
920 .saturating_sub(1)
921 });
922 buffer.edit([(offset..offset, text)], None, cx);
923 });
924 cx.emit(ConversationEvent::StreamedCompletion);
925
926 Some(())
927 });
928 }
929 smol::future::yield_now().await;
930 }
931
932 this.upgrade(&cx)
933 .ok_or_else(|| anyhow!("conversation was dropped"))?
934 .update(&mut cx, |this, cx| {
935 this.pending_completions.retain(|completion| {
936 completion.id != this.completion_count
937 });
938 this.summarize(cx);
939 });
940
941 anyhow::Ok(())
942 };
943
944 let result = stream_completion.await;
945 if let Some(this) = this.upgrade(&cx) {
946 this.update(&mut cx, |this, cx| {
947 if let Some(metadata) =
948 this.messages_metadata.get_mut(&assistant_message.id)
949 {
950 match result {
951 Ok(_) => {
952 metadata.status = MessageStatus::Done;
953 }
954 Err(error) => {
955 metadata.status = MessageStatus::Error(
956 error.to_string().trim().into(),
957 );
958 }
959 }
960 cx.notify();
961 }
962 });
963 }
964 }
965 }));
966 }
967 }
968
969 if !tasks.is_empty() {
970 self.pending_completions.push(PendingCompletion {
971 id: post_inc(&mut self.completion_count),
972 _tasks: tasks,
973 });
974 }
975
976 user_messages
977 }
978
979 fn cancel_last_assist(&mut self) -> bool {
980 self.pending_completions.pop().is_some()
981 }
982
983 fn cycle_message_roles(&mut self, ids: HashSet<MessageId>, cx: &mut ModelContext<Self>) {
984 for id in ids {
985 if let Some(metadata) = self.messages_metadata.get_mut(&id) {
986 metadata.role.cycle();
987 cx.emit(ConversationEvent::MessagesEdited);
988 cx.notify();
989 }
990 }
991 }
992
993 fn insert_message_after(
994 &mut self,
995 message_id: MessageId,
996 role: Role,
997 status: MessageStatus,
998 cx: &mut ModelContext<Self>,
999 ) -> Option<MessageAnchor> {
1000 if let Some(prev_message_ix) = self
1001 .message_anchors
1002 .iter()
1003 .position(|message| message.id == message_id)
1004 {
1005 let start = self.buffer.update(cx, |buffer, cx| {
1006 let offset = self.message_anchors[prev_message_ix + 1..]
1007 .iter()
1008 .find(|message| message.start.is_valid(buffer))
1009 .map_or(buffer.len(), |message| message.start.to_offset(buffer) - 1);
1010 buffer.edit([(offset..offset, "\n")], None, cx);
1011 buffer.anchor_before(offset + 1)
1012 });
1013 let message = MessageAnchor {
1014 id: MessageId(post_inc(&mut self.next_message_id.0)),
1015 start,
1016 };
1017 self.message_anchors
1018 .insert(prev_message_ix + 1, message.clone());
1019 self.messages_metadata.insert(
1020 message.id,
1021 MessageMetadata {
1022 role,
1023 sent_at: Local::now(),
1024 status,
1025 },
1026 );
1027 cx.emit(ConversationEvent::MessagesEdited);
1028 Some(message)
1029 } else {
1030 None
1031 }
1032 }
1033
1034 fn split_message(
1035 &mut self,
1036 range: Range<usize>,
1037 cx: &mut ModelContext<Self>,
1038 ) -> (Option<MessageAnchor>, Option<MessageAnchor>) {
1039 let start_message = self.message_for_offset(range.start, cx);
1040 let end_message = self.message_for_offset(range.end, cx);
1041 if let Some((start_message, end_message)) = start_message.zip(end_message) {
1042 // Prevent splitting when range spans multiple messages.
1043 if start_message.index != end_message.index {
1044 return (None, None);
1045 }
1046
1047 let message = start_message;
1048 let role = message.role;
1049 let mut edited_buffer = false;
1050
1051 let mut suffix_start = None;
1052 if range.start > message.range.start && range.end < message.range.end - 1 {
1053 if self.buffer.read(cx).chars_at(range.end).next() == Some('\n') {
1054 suffix_start = Some(range.end + 1);
1055 } else if self.buffer.read(cx).reversed_chars_at(range.end).next() == Some('\n') {
1056 suffix_start = Some(range.end);
1057 }
1058 }
1059
1060 let suffix = if let Some(suffix_start) = suffix_start {
1061 MessageAnchor {
1062 id: MessageId(post_inc(&mut self.next_message_id.0)),
1063 start: self.buffer.read(cx).anchor_before(suffix_start),
1064 }
1065 } else {
1066 self.buffer.update(cx, |buffer, cx| {
1067 buffer.edit([(range.end..range.end, "\n")], None, cx);
1068 });
1069 edited_buffer = true;
1070 MessageAnchor {
1071 id: MessageId(post_inc(&mut self.next_message_id.0)),
1072 start: self.buffer.read(cx).anchor_before(range.end + 1),
1073 }
1074 };
1075
1076 self.message_anchors
1077 .insert(message.index + 1, suffix.clone());
1078 self.messages_metadata.insert(
1079 suffix.id,
1080 MessageMetadata {
1081 role,
1082 sent_at: Local::now(),
1083 status: MessageStatus::Done,
1084 },
1085 );
1086
1087 let new_messages = if range.start == range.end || range.start == message.range.start {
1088 (None, Some(suffix))
1089 } else {
1090 let mut prefix_end = None;
1091 if range.start > message.range.start && range.end < message.range.end - 1 {
1092 if self.buffer.read(cx).chars_at(range.start).next() == Some('\n') {
1093 prefix_end = Some(range.start + 1);
1094 } else if self.buffer.read(cx).reversed_chars_at(range.start).next()
1095 == Some('\n')
1096 {
1097 prefix_end = Some(range.start);
1098 }
1099 }
1100
1101 let selection = if let Some(prefix_end) = prefix_end {
1102 cx.emit(ConversationEvent::MessagesEdited);
1103 MessageAnchor {
1104 id: MessageId(post_inc(&mut self.next_message_id.0)),
1105 start: self.buffer.read(cx).anchor_before(prefix_end),
1106 }
1107 } else {
1108 self.buffer.update(cx, |buffer, cx| {
1109 buffer.edit([(range.start..range.start, "\n")], None, cx)
1110 });
1111 edited_buffer = true;
1112 MessageAnchor {
1113 id: MessageId(post_inc(&mut self.next_message_id.0)),
1114 start: self.buffer.read(cx).anchor_before(range.end + 1),
1115 }
1116 };
1117
1118 self.message_anchors
1119 .insert(message.index + 1, selection.clone());
1120 self.messages_metadata.insert(
1121 selection.id,
1122 MessageMetadata {
1123 role,
1124 sent_at: Local::now(),
1125 status: MessageStatus::Done,
1126 },
1127 );
1128 (Some(selection), Some(suffix))
1129 };
1130
1131 if !edited_buffer {
1132 cx.emit(ConversationEvent::MessagesEdited);
1133 }
1134 new_messages
1135 } else {
1136 (None, None)
1137 }
1138 }
1139
1140 fn summarize(&mut self, cx: &mut ModelContext<Self>) {
1141 if self.message_anchors.len() >= 2 && self.summary.is_none() {
1142 let api_key = self.api_key.borrow().clone();
1143 if let Some(api_key) = api_key {
1144 let messages = self
1145 .messages(cx)
1146 .take(2)
1147 .map(|message| message.to_open_ai_message(self.buffer.read(cx)))
1148 .chain(Some(RequestMessage {
1149 role: Role::User,
1150 content:
1151 "Summarize the conversation into a short title without punctuation"
1152 .into(),
1153 }));
1154 let request = OpenAIRequest {
1155 model: self.model.clone(),
1156 messages: messages.collect(),
1157 stream: true,
1158 };
1159
1160 let stream = stream_completion(api_key, cx.background().clone(), request);
1161 self.pending_summary = cx.spawn(|this, mut cx| {
1162 async move {
1163 let mut messages = stream.await?;
1164
1165 while let Some(message) = messages.next().await {
1166 let mut message = message?;
1167 if let Some(choice) = message.choices.pop() {
1168 let text = choice.delta.content.unwrap_or_default();
1169 this.update(&mut cx, |this, cx| {
1170 this.summary
1171 .get_or_insert(Default::default())
1172 .text
1173 .push_str(&text);
1174 cx.emit(ConversationEvent::SummaryChanged);
1175 });
1176 }
1177 }
1178
1179 this.update(&mut cx, |this, cx| {
1180 if let Some(summary) = this.summary.as_mut() {
1181 summary.done = true;
1182 cx.emit(ConversationEvent::SummaryChanged);
1183 }
1184 });
1185
1186 anyhow::Ok(())
1187 }
1188 .log_err()
1189 });
1190 }
1191 }
1192 }
1193
1194 fn message_for_offset(&self, offset: usize, cx: &AppContext) -> Option<Message> {
1195 self.messages_for_offsets([offset], cx).pop()
1196 }
1197
1198 fn messages_for_offsets(
1199 &self,
1200 offsets: impl IntoIterator<Item = usize>,
1201 cx: &AppContext,
1202 ) -> Vec<Message> {
1203 let mut result = Vec::new();
1204
1205 let buffer_len = self.buffer.read(cx).len();
1206 let mut messages = self.messages(cx).peekable();
1207 let mut offsets = offsets.into_iter().peekable();
1208 while let Some(offset) = offsets.next() {
1209 // Skip messages that start after the offset.
1210 while messages.peek().map_or(false, |message| {
1211 message.range.end < offset || (message.range.end == offset && offset < buffer_len)
1212 }) {
1213 messages.next();
1214 }
1215 let Some(message) = messages.peek() else { continue };
1216
1217 // Skip offsets that are in the same message.
1218 while offsets.peek().map_or(false, |offset| {
1219 message.range.contains(offset) || message.range.end == buffer_len
1220 }) {
1221 offsets.next();
1222 }
1223
1224 result.push(message.clone());
1225 }
1226 result
1227 }
1228
1229 fn messages<'a>(&'a self, cx: &'a AppContext) -> impl 'a + Iterator<Item = Message> {
1230 let buffer = self.buffer.read(cx);
1231 let mut message_anchors = self.message_anchors.iter().enumerate().peekable();
1232 iter::from_fn(move || {
1233 while let Some((ix, message_anchor)) = message_anchors.next() {
1234 let metadata = self.messages_metadata.get(&message_anchor.id)?;
1235 let message_start = message_anchor.start.to_offset(buffer);
1236 let mut message_end = None;
1237 while let Some((_, next_message)) = message_anchors.peek() {
1238 if next_message.start.is_valid(buffer) {
1239 message_end = Some(next_message.start);
1240 break;
1241 } else {
1242 message_anchors.next();
1243 }
1244 }
1245 let message_end = message_end
1246 .unwrap_or(language::Anchor::MAX)
1247 .to_offset(buffer);
1248 return Some(Message {
1249 index: ix,
1250 range: message_start..message_end,
1251 id: message_anchor.id,
1252 anchor: message_anchor.start,
1253 role: metadata.role,
1254 sent_at: metadata.sent_at,
1255 status: metadata.status.clone(),
1256 });
1257 }
1258 None
1259 })
1260 }
1261
1262 fn save(
1263 &mut self,
1264 debounce: Option<Duration>,
1265 fs: Arc<dyn Fs>,
1266 cx: &mut ModelContext<Conversation>,
1267 ) {
1268 self.pending_save = cx.spawn(|this, mut cx| async move {
1269 if let Some(debounce) = debounce {
1270 cx.background().timer(debounce).await;
1271 }
1272
1273 let (old_path, summary) = this.read_with(&cx, |this, _| {
1274 let path = this.path.clone();
1275 let summary = if let Some(summary) = this.summary.as_ref() {
1276 if summary.done {
1277 Some(summary.text.clone())
1278 } else {
1279 None
1280 }
1281 } else {
1282 None
1283 };
1284 (path, summary)
1285 });
1286
1287 if let Some(summary) = summary {
1288 let conversation = this.read_with(&cx, |this, cx| SavedConversation {
1289 zed: "conversation".into(),
1290 version: SavedConversation::VERSION.into(),
1291 text: this.buffer.read(cx).text(),
1292 message_metadata: this.messages_metadata.clone(),
1293 messages: this
1294 .message_anchors
1295 .iter()
1296 .map(|message| SavedMessage {
1297 id: message.id,
1298 start: message.start.to_offset(this.buffer.read(cx)),
1299 })
1300 .collect(),
1301 summary: summary.clone(),
1302 model: this.model.clone(),
1303 });
1304
1305 let path = if let Some(old_path) = old_path {
1306 old_path
1307 } else {
1308 let mut discriminant = 1;
1309 let mut new_path;
1310 loop {
1311 new_path = CONVERSATIONS_DIR.join(&format!(
1312 "{} - {}.zed.json",
1313 summary.trim(),
1314 discriminant
1315 ));
1316 if fs.is_file(&new_path).await {
1317 discriminant += 1;
1318 } else {
1319 break;
1320 }
1321 }
1322 new_path
1323 };
1324
1325 fs.create_dir(CONVERSATIONS_DIR.as_ref()).await?;
1326 fs.atomic_write(path.clone(), serde_json::to_string(&conversation).unwrap())
1327 .await?;
1328 this.update(&mut cx, |this, _| this.path = Some(path));
1329 }
1330
1331 Ok(())
1332 });
1333 }
1334}
1335
1336struct PendingCompletion {
1337 id: usize,
1338 _tasks: Vec<Task<()>>,
1339}
1340
1341enum ConversationEditorEvent {
1342 TabContentChanged,
1343}
1344
1345#[derive(Copy, Clone, Debug, PartialEq)]
1346struct ScrollPosition {
1347 offset_before_cursor: Vector2F,
1348 cursor: Anchor,
1349}
1350
1351struct ConversationEditor {
1352 conversation: ModelHandle<Conversation>,
1353 fs: Arc<dyn Fs>,
1354 editor: ViewHandle<Editor>,
1355 blocks: HashSet<BlockId>,
1356 scroll_position: Option<ScrollPosition>,
1357 _subscriptions: Vec<Subscription>,
1358}
1359
1360impl ConversationEditor {
1361 fn new(
1362 api_key: Rc<RefCell<Option<String>>>,
1363 language_registry: Arc<LanguageRegistry>,
1364 fs: Arc<dyn Fs>,
1365 cx: &mut ViewContext<Self>,
1366 ) -> Self {
1367 let conversation = cx.add_model(|cx| Conversation::new(api_key, language_registry, cx));
1368 Self::from_conversation(conversation, fs, cx)
1369 }
1370
1371 fn from_conversation(
1372 conversation: ModelHandle<Conversation>,
1373 fs: Arc<dyn Fs>,
1374 cx: &mut ViewContext<Self>,
1375 ) -> Self {
1376 let editor = cx.add_view(|cx| {
1377 let mut editor = Editor::for_buffer(conversation.read(cx).buffer.clone(), None, cx);
1378 editor.set_soft_wrap_mode(SoftWrap::EditorWidth, cx);
1379 editor.set_show_gutter(false, cx);
1380 editor
1381 });
1382
1383 let _subscriptions = vec![
1384 cx.observe(&conversation, |_, _, cx| cx.notify()),
1385 cx.subscribe(&conversation, Self::handle_conversation_event),
1386 cx.subscribe(&editor, Self::handle_editor_event),
1387 ];
1388
1389 let mut this = Self {
1390 conversation,
1391 editor,
1392 blocks: Default::default(),
1393 scroll_position: None,
1394 fs,
1395 _subscriptions,
1396 };
1397 this.update_message_headers(cx);
1398 this
1399 }
1400
1401 fn assist(&mut self, _: &Assist, cx: &mut ViewContext<Self>) {
1402 let cursors = self.cursors(cx);
1403
1404 let user_messages = self.conversation.update(cx, |conversation, cx| {
1405 let selected_messages = conversation
1406 .messages_for_offsets(cursors, cx)
1407 .into_iter()
1408 .map(|message| message.id)
1409 .collect();
1410 conversation.assist(selected_messages, cx)
1411 });
1412 let new_selections = user_messages
1413 .iter()
1414 .map(|message| {
1415 let cursor = message
1416 .start
1417 .to_offset(self.conversation.read(cx).buffer.read(cx));
1418 cursor..cursor
1419 })
1420 .collect::<Vec<_>>();
1421 if !new_selections.is_empty() {
1422 self.editor.update(cx, |editor, cx| {
1423 editor.change_selections(
1424 Some(Autoscroll::Strategy(AutoscrollStrategy::Fit)),
1425 cx,
1426 |selections| selections.select_ranges(new_selections),
1427 );
1428 });
1429 }
1430 }
1431
1432 fn cancel_last_assist(&mut self, _: &editor::Cancel, cx: &mut ViewContext<Self>) {
1433 if !self
1434 .conversation
1435 .update(cx, |conversation, _| conversation.cancel_last_assist())
1436 {
1437 cx.propagate_action();
1438 }
1439 }
1440
1441 fn cycle_message_role(&mut self, _: &CycleMessageRole, cx: &mut ViewContext<Self>) {
1442 let cursors = self.cursors(cx);
1443 self.conversation.update(cx, |conversation, cx| {
1444 let messages = conversation
1445 .messages_for_offsets(cursors, cx)
1446 .into_iter()
1447 .map(|message| message.id)
1448 .collect();
1449 conversation.cycle_message_roles(messages, cx)
1450 });
1451 }
1452
1453 fn cursors(&self, cx: &AppContext) -> Vec<usize> {
1454 let selections = self.editor.read(cx).selections.all::<usize>(cx);
1455 selections
1456 .into_iter()
1457 .map(|selection| selection.head())
1458 .collect()
1459 }
1460
1461 fn handle_conversation_event(
1462 &mut self,
1463 _: ModelHandle<Conversation>,
1464 event: &ConversationEvent,
1465 cx: &mut ViewContext<Self>,
1466 ) {
1467 match event {
1468 ConversationEvent::MessagesEdited => {
1469 self.update_message_headers(cx);
1470 self.conversation.update(cx, |conversation, cx| {
1471 conversation.save(Some(Duration::from_millis(500)), self.fs.clone(), cx);
1472 });
1473 }
1474 ConversationEvent::SummaryChanged => {
1475 cx.emit(ConversationEditorEvent::TabContentChanged);
1476 self.conversation.update(cx, |conversation, cx| {
1477 conversation.save(None, self.fs.clone(), cx);
1478 });
1479 }
1480 ConversationEvent::StreamedCompletion => {
1481 self.editor.update(cx, |editor, cx| {
1482 if let Some(scroll_position) = self.scroll_position {
1483 let snapshot = editor.snapshot(cx);
1484 let cursor_point = scroll_position.cursor.to_display_point(&snapshot);
1485 let scroll_top =
1486 cursor_point.row() as f32 - scroll_position.offset_before_cursor.y();
1487 editor.set_scroll_position(
1488 vec2f(scroll_position.offset_before_cursor.x(), scroll_top),
1489 cx,
1490 );
1491 }
1492 });
1493 }
1494 }
1495 }
1496
1497 fn handle_editor_event(
1498 &mut self,
1499 _: ViewHandle<Editor>,
1500 event: &editor::Event,
1501 cx: &mut ViewContext<Self>,
1502 ) {
1503 match event {
1504 editor::Event::ScrollPositionChanged { autoscroll, .. } => {
1505 let cursor_scroll_position = self.cursor_scroll_position(cx);
1506 if *autoscroll {
1507 self.scroll_position = cursor_scroll_position;
1508 } else if self.scroll_position != cursor_scroll_position {
1509 self.scroll_position = None;
1510 }
1511 }
1512 editor::Event::SelectionsChanged { .. } => {
1513 self.scroll_position = self.cursor_scroll_position(cx);
1514 }
1515 _ => {}
1516 }
1517 }
1518
1519 fn cursor_scroll_position(&self, cx: &mut ViewContext<Self>) -> Option<ScrollPosition> {
1520 self.editor.update(cx, |editor, cx| {
1521 let snapshot = editor.snapshot(cx);
1522 let cursor = editor.selections.newest_anchor().head();
1523 let cursor_row = cursor.to_display_point(&snapshot.display_snapshot).row() as f32;
1524 let scroll_position = editor
1525 .scroll_manager
1526 .anchor()
1527 .scroll_position(&snapshot.display_snapshot);
1528
1529 let scroll_bottom = scroll_position.y() + editor.visible_line_count().unwrap_or(0.);
1530 if (scroll_position.y()..scroll_bottom).contains(&cursor_row) {
1531 Some(ScrollPosition {
1532 cursor,
1533 offset_before_cursor: vec2f(
1534 scroll_position.x(),
1535 cursor_row - scroll_position.y(),
1536 ),
1537 })
1538 } else {
1539 None
1540 }
1541 })
1542 }
1543
1544 fn update_message_headers(&mut self, cx: &mut ViewContext<Self>) {
1545 self.editor.update(cx, |editor, cx| {
1546 let buffer = editor.buffer().read(cx).snapshot(cx);
1547 let excerpt_id = *buffer.as_singleton().unwrap().0;
1548 let old_blocks = std::mem::take(&mut self.blocks);
1549 let new_blocks = self
1550 .conversation
1551 .read(cx)
1552 .messages(cx)
1553 .map(|message| BlockProperties {
1554 position: buffer.anchor_in_excerpt(excerpt_id, message.anchor),
1555 height: 2,
1556 style: BlockStyle::Sticky,
1557 render: Arc::new({
1558 let conversation = self.conversation.clone();
1559 // let metadata = message.metadata.clone();
1560 // let message = message.clone();
1561 move |cx| {
1562 enum Sender {}
1563 enum ErrorTooltip {}
1564
1565 let theme = theme::current(cx);
1566 let style = &theme.assistant;
1567 let message_id = message.id;
1568 let sender = MouseEventHandler::<Sender, _>::new(
1569 message_id.0,
1570 cx,
1571 |state, _| match message.role {
1572 Role::User => {
1573 let style = style.user_sender.style_for(state, false);
1574 Label::new("You", style.text.clone())
1575 .contained()
1576 .with_style(style.container)
1577 }
1578 Role::Assistant => {
1579 let style = style.assistant_sender.style_for(state, false);
1580 Label::new("Assistant", style.text.clone())
1581 .contained()
1582 .with_style(style.container)
1583 }
1584 Role::System => {
1585 let style = style.system_sender.style_for(state, false);
1586 Label::new("System", style.text.clone())
1587 .contained()
1588 .with_style(style.container)
1589 }
1590 },
1591 )
1592 .with_cursor_style(CursorStyle::PointingHand)
1593 .on_down(MouseButton::Left, {
1594 let conversation = conversation.clone();
1595 move |_, _, cx| {
1596 conversation.update(cx, |conversation, cx| {
1597 conversation.cycle_message_roles(
1598 HashSet::from_iter(Some(message_id)),
1599 cx,
1600 )
1601 })
1602 }
1603 });
1604
1605 Flex::row()
1606 .with_child(sender.aligned())
1607 .with_child(
1608 Label::new(
1609 message.sent_at.format("%I:%M%P").to_string(),
1610 style.sent_at.text.clone(),
1611 )
1612 .contained()
1613 .with_style(style.sent_at.container)
1614 .aligned(),
1615 )
1616 .with_children(
1617 if let MessageStatus::Error(error) = &message.status {
1618 Some(
1619 Svg::new("icons/circle_x_mark_12.svg")
1620 .with_color(style.error_icon.color)
1621 .constrained()
1622 .with_width(style.error_icon.width)
1623 .contained()
1624 .with_style(style.error_icon.container)
1625 .with_tooltip::<ErrorTooltip>(
1626 message_id.0,
1627 error.to_string(),
1628 None,
1629 theme.tooltip.clone(),
1630 cx,
1631 )
1632 .aligned(),
1633 )
1634 } else {
1635 None
1636 },
1637 )
1638 .aligned()
1639 .left()
1640 .contained()
1641 .with_style(style.message_header)
1642 .into_any()
1643 }
1644 }),
1645 disposition: BlockDisposition::Above,
1646 })
1647 .collect::<Vec<_>>();
1648
1649 editor.remove_blocks(old_blocks, None, cx);
1650 let ids = editor.insert_blocks(new_blocks, None, cx);
1651 self.blocks = HashSet::from_iter(ids);
1652 });
1653 }
1654
1655 fn quote_selection(
1656 workspace: &mut Workspace,
1657 _: &QuoteSelection,
1658 cx: &mut ViewContext<Workspace>,
1659 ) {
1660 let Some(panel) = workspace.panel::<AssistantPanel>(cx) else {
1661 return;
1662 };
1663 let Some(editor) = workspace.active_item(cx).and_then(|item| item.downcast::<Editor>()) else {
1664 return;
1665 };
1666
1667 let text = editor.read_with(cx, |editor, cx| {
1668 let range = editor.selections.newest::<usize>(cx).range();
1669 let buffer = editor.buffer().read(cx).snapshot(cx);
1670 let start_language = buffer.language_at(range.start);
1671 let end_language = buffer.language_at(range.end);
1672 let language_name = if start_language == end_language {
1673 start_language.map(|language| language.name())
1674 } else {
1675 None
1676 };
1677 let language_name = language_name.as_deref().unwrap_or("").to_lowercase();
1678
1679 let selected_text = buffer.text_for_range(range).collect::<String>();
1680 if selected_text.is_empty() {
1681 None
1682 } else {
1683 Some(if language_name == "markdown" {
1684 selected_text
1685 .lines()
1686 .map(|line| format!("> {}", line))
1687 .collect::<Vec<_>>()
1688 .join("\n")
1689 } else {
1690 format!("```{language_name}\n{selected_text}\n```")
1691 })
1692 }
1693 });
1694
1695 // Activate the panel
1696 if !panel.read(cx).has_focus(cx) {
1697 workspace.toggle_panel_focus::<AssistantPanel>(cx);
1698 }
1699
1700 if let Some(text) = text {
1701 panel.update(cx, |panel, cx| {
1702 let conversation = panel
1703 .active_conversation_editor()
1704 .cloned()
1705 .unwrap_or_else(|| panel.new_conversation(cx));
1706 conversation.update(cx, |conversation, cx| {
1707 conversation
1708 .editor
1709 .update(cx, |editor, cx| editor.insert(&text, cx))
1710 });
1711 });
1712 }
1713 }
1714
1715 fn copy(&mut self, _: &editor::Copy, cx: &mut ViewContext<Self>) {
1716 let editor = self.editor.read(cx);
1717 let conversation = self.conversation.read(cx);
1718 if editor.selections.count() == 1 {
1719 let selection = editor.selections.newest::<usize>(cx);
1720 let mut copied_text = String::new();
1721 let mut spanned_messages = 0;
1722 for message in conversation.messages(cx) {
1723 if message.range.start >= selection.range().end {
1724 break;
1725 } else if message.range.end >= selection.range().start {
1726 let range = cmp::max(message.range.start, selection.range().start)
1727 ..cmp::min(message.range.end, selection.range().end);
1728 if !range.is_empty() {
1729 spanned_messages += 1;
1730 write!(&mut copied_text, "## {}\n\n", message.role).unwrap();
1731 for chunk in conversation.buffer.read(cx).text_for_range(range) {
1732 copied_text.push_str(&chunk);
1733 }
1734 copied_text.push('\n');
1735 }
1736 }
1737 }
1738
1739 if spanned_messages > 1 {
1740 cx.platform()
1741 .write_to_clipboard(ClipboardItem::new(copied_text));
1742 return;
1743 }
1744 }
1745
1746 cx.propagate_action();
1747 }
1748
1749 fn split(&mut self, _: &Split, cx: &mut ViewContext<Self>) {
1750 self.conversation.update(cx, |conversation, cx| {
1751 let selections = self.editor.read(cx).selections.disjoint_anchors();
1752 for selection in selections.into_iter() {
1753 let buffer = self.editor.read(cx).buffer().read(cx).snapshot(cx);
1754 let range = selection
1755 .map(|endpoint| endpoint.to_offset(&buffer))
1756 .range();
1757 conversation.split_message(range, cx);
1758 }
1759 });
1760 }
1761
1762 fn save(&mut self, _: &Save, cx: &mut ViewContext<Self>) {
1763 self.conversation.update(cx, |conversation, cx| {
1764 conversation.save(None, self.fs.clone(), cx)
1765 });
1766 }
1767
1768 fn cycle_model(&mut self, cx: &mut ViewContext<Self>) {
1769 self.conversation.update(cx, |conversation, cx| {
1770 let new_model = match conversation.model.as_str() {
1771 "gpt-4-0613" => "gpt-3.5-turbo-0613",
1772 _ => "gpt-4-0613",
1773 };
1774 conversation.set_model(new_model.into(), cx);
1775 });
1776 }
1777
1778 fn title(&self, cx: &AppContext) -> String {
1779 self.conversation
1780 .read(cx)
1781 .summary
1782 .as_ref()
1783 .map(|summary| summary.text.clone())
1784 .unwrap_or_else(|| "New Context".into())
1785 }
1786}
1787
1788impl Entity for ConversationEditor {
1789 type Event = ConversationEditorEvent;
1790}
1791
1792impl View for ConversationEditor {
1793 fn ui_name() -> &'static str {
1794 "ConversationEditor"
1795 }
1796
1797 fn render(&mut self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
1798 enum Model {}
1799 let theme = &theme::current(cx).assistant;
1800 let conversation = self.conversation.read(cx);
1801 let model = conversation.model.clone();
1802 let remaining_tokens = conversation.remaining_tokens().map(|remaining_tokens| {
1803 let remaining_tokens_style = if remaining_tokens <= 0 {
1804 &theme.no_remaining_tokens
1805 } else {
1806 &theme.remaining_tokens
1807 };
1808 Label::new(
1809 remaining_tokens.to_string(),
1810 remaining_tokens_style.text.clone(),
1811 )
1812 .contained()
1813 .with_style(remaining_tokens_style.container)
1814 });
1815
1816 Stack::new()
1817 .with_child(
1818 ChildView::new(&self.editor, cx)
1819 .contained()
1820 .with_style(theme.container),
1821 )
1822 .with_child(
1823 Flex::row()
1824 .with_child(
1825 MouseEventHandler::<Model, _>::new(0, cx, |state, _| {
1826 let style = theme.model.style_for(state, false);
1827 Label::new(model, style.text.clone())
1828 .contained()
1829 .with_style(style.container)
1830 })
1831 .with_cursor_style(CursorStyle::PointingHand)
1832 .on_click(MouseButton::Left, |_, this, cx| this.cycle_model(cx)),
1833 )
1834 .with_children(remaining_tokens)
1835 .contained()
1836 .with_style(theme.model_info_container)
1837 .aligned()
1838 .top()
1839 .right(),
1840 )
1841 .into_any()
1842 }
1843
1844 fn focus_in(&mut self, _: gpui::AnyViewHandle, cx: &mut ViewContext<Self>) {
1845 if cx.is_self_focused() {
1846 cx.focus(&self.editor);
1847 }
1848 }
1849}
1850
1851impl Item for ConversationEditor {
1852 fn tab_content<V: View>(
1853 &self,
1854 _: Option<usize>,
1855 style: &theme::Tab,
1856 cx: &gpui::AppContext,
1857 ) -> AnyElement<V> {
1858 let title = truncate_and_trailoff(&self.title(cx), editor::MAX_TAB_TITLE_LEN);
1859 Label::new(title, style.label.clone()).into_any()
1860 }
1861
1862 fn tab_tooltip_text(&self, cx: &AppContext) -> Option<Cow<str>> {
1863 Some(self.title(cx).into())
1864 }
1865
1866 fn as_searchable(
1867 &self,
1868 _: &ViewHandle<Self>,
1869 ) -> Option<Box<dyn workspace::searchable::SearchableItemHandle>> {
1870 Some(Box::new(self.editor.clone()))
1871 }
1872}
1873
1874#[derive(Clone, Debug)]
1875struct MessageAnchor {
1876 id: MessageId,
1877 start: language::Anchor,
1878}
1879
1880#[derive(Clone, Debug)]
1881pub struct Message {
1882 range: Range<usize>,
1883 index: usize,
1884 id: MessageId,
1885 anchor: language::Anchor,
1886 role: Role,
1887 sent_at: DateTime<Local>,
1888 status: MessageStatus,
1889}
1890
1891impl Message {
1892 fn to_open_ai_message(&self, buffer: &Buffer) -> RequestMessage {
1893 let mut content = format!("[Message {}]\n", self.id.0).to_string();
1894 content.extend(buffer.text_for_range(self.range.clone()));
1895 RequestMessage {
1896 role: self.role,
1897 content: content.trim_end().into(),
1898 }
1899 }
1900}
1901
1902async fn stream_completion(
1903 api_key: String,
1904 executor: Arc<Background>,
1905 mut request: OpenAIRequest,
1906) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
1907 request.stream = true;
1908
1909 let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
1910
1911 let json_data = serde_json::to_string(&request)?;
1912 let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
1913 .header("Content-Type", "application/json")
1914 .header("Authorization", format!("Bearer {}", api_key))
1915 .body(json_data)?
1916 .send_async()
1917 .await?;
1918
1919 let status = response.status();
1920 if status == StatusCode::OK {
1921 executor
1922 .spawn(async move {
1923 let mut lines = BufReader::new(response.body_mut()).lines();
1924
1925 fn parse_line(
1926 line: Result<String, io::Error>,
1927 ) -> Result<Option<OpenAIResponseStreamEvent>> {
1928 if let Some(data) = line?.strip_prefix("data: ") {
1929 let event = serde_json::from_str(&data)?;
1930 Ok(Some(event))
1931 } else {
1932 Ok(None)
1933 }
1934 }
1935
1936 while let Some(line) = lines.next().await {
1937 if let Some(event) = parse_line(line).transpose() {
1938 let done = event.as_ref().map_or(false, |event| {
1939 event
1940 .choices
1941 .last()
1942 .map_or(false, |choice| choice.finish_reason.is_some())
1943 });
1944 if tx.unbounded_send(event).is_err() {
1945 break;
1946 }
1947
1948 if done {
1949 break;
1950 }
1951 }
1952 }
1953
1954 anyhow::Ok(())
1955 })
1956 .detach();
1957
1958 Ok(rx)
1959 } else {
1960 let mut body = String::new();
1961 response.body_mut().read_to_string(&mut body).await?;
1962
1963 #[derive(Deserialize)]
1964 struct OpenAIResponse {
1965 error: OpenAIError,
1966 }
1967
1968 #[derive(Deserialize)]
1969 struct OpenAIError {
1970 message: String,
1971 }
1972
1973 match serde_json::from_str::<OpenAIResponse>(&body) {
1974 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
1975 "Failed to connect to OpenAI API: {}",
1976 response.error.message,
1977 )),
1978
1979 _ => Err(anyhow!(
1980 "Failed to connect to OpenAI API: {} {}",
1981 response.status(),
1982 body,
1983 )),
1984 }
1985 }
1986}
1987
1988#[cfg(test)]
1989mod tests {
1990 use crate::MessageId;
1991
1992 use super::*;
1993 use fs::FakeFs;
1994 use gpui::{AppContext, TestAppContext};
1995 use project::Project;
1996
1997 fn init_test(cx: &mut TestAppContext) {
1998 cx.foreground().forbid_parking();
1999 cx.update(|cx| {
2000 cx.set_global(SettingsStore::test(cx));
2001 theme::init((), cx);
2002 language::init(cx);
2003 editor::init_settings(cx);
2004 crate::init(cx);
2005 workspace::init_settings(cx);
2006 Project::init_settings(cx);
2007 });
2008 }
2009
2010 #[gpui::test]
2011 async fn test_panel(cx: &mut TestAppContext) {
2012 init_test(cx);
2013
2014 let fs = FakeFs::new(cx.background());
2015 let project = Project::test(fs, [], cx).await;
2016 let (window_id, workspace) = cx.add_window(|cx| Workspace::test_new(project.clone(), cx));
2017 let weak_workspace = workspace.downgrade();
2018
2019 let panel = cx
2020 .spawn(|cx| async move { AssistantPanel::load(weak_workspace, cx).await })
2021 .await
2022 .unwrap();
2023
2024 workspace.update(cx, |workspace, cx| {
2025 workspace.add_panel(panel.clone(), cx);
2026 workspace.toggle_dock(DockPosition::Right, cx);
2027 assert!(workspace.right_dock().read(cx).is_open());
2028 cx.focus(&panel);
2029 });
2030
2031 cx.dispatch_action(window_id, workspace::ToggleZoom);
2032
2033 workspace.read_with(cx, |workspace, cx| {
2034 assert_eq!(workspace.zoomed_view(cx).unwrap(), panel);
2035 })
2036 }
2037
2038 #[gpui::test]
2039 fn test_inserting_and_removing_messages(cx: &mut AppContext) {
2040 let registry = Arc::new(LanguageRegistry::test());
2041 let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
2042 let buffer = conversation.read(cx).buffer.clone();
2043
2044 let message_1 = conversation.read(cx).message_anchors[0].clone();
2045 assert_eq!(
2046 messages(&conversation, cx),
2047 vec![(message_1.id, Role::User, 0..0)]
2048 );
2049
2050 let message_2 = conversation.update(cx, |conversation, cx| {
2051 conversation
2052 .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
2053 .unwrap()
2054 });
2055 assert_eq!(
2056 messages(&conversation, cx),
2057 vec![
2058 (message_1.id, Role::User, 0..1),
2059 (message_2.id, Role::Assistant, 1..1)
2060 ]
2061 );
2062
2063 buffer.update(cx, |buffer, cx| {
2064 buffer.edit([(0..0, "1"), (1..1, "2")], None, cx)
2065 });
2066 assert_eq!(
2067 messages(&conversation, cx),
2068 vec![
2069 (message_1.id, Role::User, 0..2),
2070 (message_2.id, Role::Assistant, 2..3)
2071 ]
2072 );
2073
2074 let message_3 = conversation.update(cx, |conversation, cx| {
2075 conversation
2076 .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
2077 .unwrap()
2078 });
2079 assert_eq!(
2080 messages(&conversation, cx),
2081 vec![
2082 (message_1.id, Role::User, 0..2),
2083 (message_2.id, Role::Assistant, 2..4),
2084 (message_3.id, Role::User, 4..4)
2085 ]
2086 );
2087
2088 let message_4 = conversation.update(cx, |conversation, cx| {
2089 conversation
2090 .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
2091 .unwrap()
2092 });
2093 assert_eq!(
2094 messages(&conversation, cx),
2095 vec![
2096 (message_1.id, Role::User, 0..2),
2097 (message_2.id, Role::Assistant, 2..4),
2098 (message_4.id, Role::User, 4..5),
2099 (message_3.id, Role::User, 5..5),
2100 ]
2101 );
2102
2103 buffer.update(cx, |buffer, cx| {
2104 buffer.edit([(4..4, "C"), (5..5, "D")], None, cx)
2105 });
2106 assert_eq!(
2107 messages(&conversation, cx),
2108 vec![
2109 (message_1.id, Role::User, 0..2),
2110 (message_2.id, Role::Assistant, 2..4),
2111 (message_4.id, Role::User, 4..6),
2112 (message_3.id, Role::User, 6..7),
2113 ]
2114 );
2115
2116 // Deleting across message boundaries merges the messages.
2117 buffer.update(cx, |buffer, cx| buffer.edit([(1..4, "")], None, cx));
2118 assert_eq!(
2119 messages(&conversation, cx),
2120 vec![
2121 (message_1.id, Role::User, 0..3),
2122 (message_3.id, Role::User, 3..4),
2123 ]
2124 );
2125
2126 // Undoing the deletion should also undo the merge.
2127 buffer.update(cx, |buffer, cx| buffer.undo(cx));
2128 assert_eq!(
2129 messages(&conversation, cx),
2130 vec![
2131 (message_1.id, Role::User, 0..2),
2132 (message_2.id, Role::Assistant, 2..4),
2133 (message_4.id, Role::User, 4..6),
2134 (message_3.id, Role::User, 6..7),
2135 ]
2136 );
2137
2138 // Redoing the deletion should also redo the merge.
2139 buffer.update(cx, |buffer, cx| buffer.redo(cx));
2140 assert_eq!(
2141 messages(&conversation, cx),
2142 vec![
2143 (message_1.id, Role::User, 0..3),
2144 (message_3.id, Role::User, 3..4),
2145 ]
2146 );
2147
2148 // Ensure we can still insert after a merged message.
2149 let message_5 = conversation.update(cx, |conversation, cx| {
2150 conversation
2151 .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
2152 .unwrap()
2153 });
2154 assert_eq!(
2155 messages(&conversation, cx),
2156 vec![
2157 (message_1.id, Role::User, 0..3),
2158 (message_5.id, Role::System, 3..4),
2159 (message_3.id, Role::User, 4..5)
2160 ]
2161 );
2162 }
2163
2164 #[gpui::test]
2165 fn test_message_splitting(cx: &mut AppContext) {
2166 let registry = Arc::new(LanguageRegistry::test());
2167 let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
2168 let buffer = conversation.read(cx).buffer.clone();
2169
2170 let message_1 = conversation.read(cx).message_anchors[0].clone();
2171 assert_eq!(
2172 messages(&conversation, cx),
2173 vec![(message_1.id, Role::User, 0..0)]
2174 );
2175
2176 buffer.update(cx, |buffer, cx| {
2177 buffer.edit([(0..0, "aaa\nbbb\nccc\nddd\n")], None, cx)
2178 });
2179
2180 let (_, message_2) =
2181 conversation.update(cx, |conversation, cx| conversation.split_message(3..3, cx));
2182 let message_2 = message_2.unwrap();
2183
2184 // We recycle newlines in the middle of a split message
2185 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\nddd\n");
2186 assert_eq!(
2187 messages(&conversation, cx),
2188 vec![
2189 (message_1.id, Role::User, 0..4),
2190 (message_2.id, Role::User, 4..16),
2191 ]
2192 );
2193
2194 let (_, message_3) =
2195 conversation.update(cx, |conversation, cx| conversation.split_message(3..3, cx));
2196 let message_3 = message_3.unwrap();
2197
2198 // We don't recycle newlines at the end of a split message
2199 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
2200 assert_eq!(
2201 messages(&conversation, cx),
2202 vec![
2203 (message_1.id, Role::User, 0..4),
2204 (message_3.id, Role::User, 4..5),
2205 (message_2.id, Role::User, 5..17),
2206 ]
2207 );
2208
2209 let (_, message_4) =
2210 conversation.update(cx, |conversation, cx| conversation.split_message(9..9, cx));
2211 let message_4 = message_4.unwrap();
2212 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
2213 assert_eq!(
2214 messages(&conversation, cx),
2215 vec![
2216 (message_1.id, Role::User, 0..4),
2217 (message_3.id, Role::User, 4..5),
2218 (message_2.id, Role::User, 5..9),
2219 (message_4.id, Role::User, 9..17),
2220 ]
2221 );
2222
2223 let (_, message_5) =
2224 conversation.update(cx, |conversation, cx| conversation.split_message(9..9, cx));
2225 let message_5 = message_5.unwrap();
2226 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\nddd\n");
2227 assert_eq!(
2228 messages(&conversation, cx),
2229 vec![
2230 (message_1.id, Role::User, 0..4),
2231 (message_3.id, Role::User, 4..5),
2232 (message_2.id, Role::User, 5..9),
2233 (message_4.id, Role::User, 9..10),
2234 (message_5.id, Role::User, 10..18),
2235 ]
2236 );
2237
2238 let (message_6, message_7) = conversation.update(cx, |conversation, cx| {
2239 conversation.split_message(14..16, cx)
2240 });
2241 let message_6 = message_6.unwrap();
2242 let message_7 = message_7.unwrap();
2243 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\ndd\nd\n");
2244 assert_eq!(
2245 messages(&conversation, cx),
2246 vec![
2247 (message_1.id, Role::User, 0..4),
2248 (message_3.id, Role::User, 4..5),
2249 (message_2.id, Role::User, 5..9),
2250 (message_4.id, Role::User, 9..10),
2251 (message_5.id, Role::User, 10..14),
2252 (message_6.id, Role::User, 14..17),
2253 (message_7.id, Role::User, 17..19),
2254 ]
2255 );
2256 }
2257
2258 #[gpui::test]
2259 fn test_messages_for_offsets(cx: &mut AppContext) {
2260 let registry = Arc::new(LanguageRegistry::test());
2261 let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
2262 let buffer = conversation.read(cx).buffer.clone();
2263
2264 let message_1 = conversation.read(cx).message_anchors[0].clone();
2265 assert_eq!(
2266 messages(&conversation, cx),
2267 vec![(message_1.id, Role::User, 0..0)]
2268 );
2269
2270 buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
2271 let message_2 = conversation
2272 .update(cx, |conversation, cx| {
2273 conversation.insert_message_after(message_1.id, Role::User, MessageStatus::Done, cx)
2274 })
2275 .unwrap();
2276 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx));
2277
2278 let message_3 = conversation
2279 .update(cx, |conversation, cx| {
2280 conversation.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
2281 })
2282 .unwrap();
2283 buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx));
2284
2285 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc");
2286 assert_eq!(
2287 messages(&conversation, cx),
2288 vec![
2289 (message_1.id, Role::User, 0..4),
2290 (message_2.id, Role::User, 4..8),
2291 (message_3.id, Role::User, 8..11)
2292 ]
2293 );
2294
2295 assert_eq!(
2296 message_ids_for_offsets(&conversation, &[0, 4, 9], cx),
2297 [message_1.id, message_2.id, message_3.id]
2298 );
2299 assert_eq!(
2300 message_ids_for_offsets(&conversation, &[0, 1, 11], cx),
2301 [message_1.id, message_3.id]
2302 );
2303
2304 fn message_ids_for_offsets(
2305 conversation: &ModelHandle<Conversation>,
2306 offsets: &[usize],
2307 cx: &AppContext,
2308 ) -> Vec<MessageId> {
2309 conversation
2310 .read(cx)
2311 .messages_for_offsets(offsets.iter().copied(), cx)
2312 .into_iter()
2313 .map(|message| message.id)
2314 .collect()
2315 }
2316 }
2317
2318 fn messages(
2319 conversation: &ModelHandle<Conversation>,
2320 cx: &AppContext,
2321 ) -> Vec<(MessageId, Role, Range<usize>)> {
2322 conversation
2323 .read(cx)
2324 .messages(cx)
2325 .map(|message| (message.id, message.role, message.range))
2326 .collect()
2327 }
2328}