1use crate::{
2 assistant_settings::{AssistantDockPosition, AssistantSettings},
3 MessageId, MessageMetadata, MessageStatus, OpenAIRequest, OpenAIResponseStreamEvent,
4 RequestMessage, Role, SavedConversation, SavedConversationMetadata, SavedMessage,
5};
6use anyhow::{anyhow, Result};
7use chrono::{DateTime, Local};
8use collections::{HashMap, HashSet};
9use editor::{
10 display_map::{BlockDisposition, BlockId, BlockProperties, BlockStyle, ToDisplayPoint},
11 scroll::autoscroll::{Autoscroll, AutoscrollStrategy},
12 Anchor, Editor, ToOffset,
13};
14use fs::Fs;
15use futures::{io::BufReader, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
16use gpui::{
17 actions,
18 elements::*,
19 executor::Background,
20 geometry::vector::{vec2f, Vector2F},
21 platform::{CursorStyle, MouseButton},
22 Action, AppContext, AsyncAppContext, ClipboardItem, Entity, ModelContext, ModelHandle,
23 Subscription, Task, View, ViewContext, ViewHandle, WeakViewHandle, WindowContext,
24};
25use isahc::{http::StatusCode, Request, RequestExt};
26use language::{language_settings::SoftWrap, Buffer, LanguageRegistry, ToOffset as _};
27use search::BufferSearchBar;
28use serde::Deserialize;
29use settings::SettingsStore;
30use std::{
31 cell::RefCell,
32 cmp, env,
33 fmt::Write,
34 io, iter,
35 ops::Range,
36 path::{Path, PathBuf},
37 rc::Rc,
38 sync::Arc,
39 time::Duration,
40};
41use theme::AssistantStyle;
42use util::{channel::ReleaseChannel, paths::CONVERSATIONS_DIR, post_inc, ResultExt, TryFutureExt};
43use workspace::{
44 dock::{DockPosition, Panel},
45 searchable::Direction,
46 Save, ToggleZoom, Toolbar, Workspace,
47};
48
49const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
50
51actions!(
52 assistant,
53 [
54 NewContext,
55 Assist,
56 Split,
57 CycleMessageRole,
58 QuoteSelection,
59 ToggleFocus,
60 ResetKey,
61 ]
62);
63
64pub fn init(cx: &mut AppContext) {
65 if *util::channel::RELEASE_CHANNEL == ReleaseChannel::Stable {
66 cx.update_default_global::<collections::CommandPaletteFilter, _, _>(move |filter, _cx| {
67 filter.filtered_namespaces.insert("assistant");
68 });
69 }
70
71 settings::register::<AssistantSettings>(cx);
72 cx.add_action(
73 |workspace: &mut Workspace, _: &NewContext, cx: &mut ViewContext<Workspace>| {
74 if let Some(this) = workspace.panel::<AssistantPanel>(cx) {
75 this.update(cx, |this, cx| {
76 this.new_conversation(cx);
77 })
78 }
79
80 workspace.focus_panel::<AssistantPanel>(cx);
81 },
82 );
83 cx.add_action(ConversationEditor::assist);
84 cx.capture_action(ConversationEditor::cancel_last_assist);
85 cx.capture_action(ConversationEditor::save);
86 cx.add_action(ConversationEditor::quote_selection);
87 cx.capture_action(ConversationEditor::copy);
88 cx.add_action(ConversationEditor::split);
89 cx.capture_action(ConversationEditor::cycle_message_role);
90 cx.add_action(AssistantPanel::save_api_key);
91 cx.add_action(AssistantPanel::reset_api_key);
92 cx.add_action(AssistantPanel::toggle_zoom);
93 cx.add_action(AssistantPanel::deploy);
94 cx.add_action(AssistantPanel::select_next_match);
95 cx.add_action(AssistantPanel::select_prev_match);
96 cx.add_action(AssistantPanel::handle_editor_cancel);
97 cx.add_action(
98 |workspace: &mut Workspace, _: &ToggleFocus, cx: &mut ViewContext<Workspace>| {
99 workspace.toggle_panel_focus::<AssistantPanel>(cx);
100 },
101 );
102}
103
104#[derive(Debug)]
105pub enum AssistantPanelEvent {
106 ZoomIn,
107 ZoomOut,
108 Focus,
109 Close,
110 DockPositionChanged,
111}
112
113pub struct AssistantPanel {
114 workspace: WeakViewHandle<Workspace>,
115 width: Option<f32>,
116 height: Option<f32>,
117 active_editor_index: Option<usize>,
118 prev_active_editor_index: Option<usize>,
119 editors: Vec<ViewHandle<ConversationEditor>>,
120 saved_conversations: Vec<SavedConversationMetadata>,
121 saved_conversations_list_state: UniformListState,
122 zoomed: bool,
123 has_focus: bool,
124 toolbar: ViewHandle<Toolbar>,
125 api_key: Rc<RefCell<Option<String>>>,
126 api_key_editor: Option<ViewHandle<Editor>>,
127 has_read_credentials: bool,
128 languages: Arc<LanguageRegistry>,
129 fs: Arc<dyn Fs>,
130 subscriptions: Vec<Subscription>,
131 _watch_saved_conversations: Task<Result<()>>,
132}
133
134impl AssistantPanel {
135 pub fn load(
136 workspace: WeakViewHandle<Workspace>,
137 cx: AsyncAppContext,
138 ) -> Task<Result<ViewHandle<Self>>> {
139 cx.spawn(|mut cx| async move {
140 let fs = workspace.read_with(&cx, |workspace, _| workspace.app_state().fs.clone())?;
141 let saved_conversations = SavedConversationMetadata::list(fs.clone())
142 .await
143 .log_err()
144 .unwrap_or_default();
145
146 // TODO: deserialize state.
147 let workspace_handle = workspace.clone();
148 workspace.update(&mut cx, |workspace, cx| {
149 cx.add_view::<Self, _>(|cx| {
150 const CONVERSATION_WATCH_DURATION: Duration = Duration::from_millis(100);
151 let _watch_saved_conversations = cx.spawn(move |this, mut cx| async move {
152 let mut events = fs
153 .watch(&CONVERSATIONS_DIR, CONVERSATION_WATCH_DURATION)
154 .await;
155 while events.next().await.is_some() {
156 let saved_conversations = SavedConversationMetadata::list(fs.clone())
157 .await
158 .log_err()
159 .unwrap_or_default();
160 this.update(&mut cx, |this, _| {
161 this.saved_conversations = saved_conversations
162 })
163 .ok();
164 }
165
166 anyhow::Ok(())
167 });
168
169 let toolbar = cx.add_view(|cx| {
170 let mut toolbar = Toolbar::new(None);
171 toolbar.set_can_navigate(false, cx);
172 toolbar.add_item(cx.add_view(|cx| BufferSearchBar::new(cx)), cx);
173 toolbar
174 });
175 let mut this = Self {
176 workspace: workspace_handle,
177 active_editor_index: Default::default(),
178 prev_active_editor_index: Default::default(),
179 editors: Default::default(),
180 saved_conversations,
181 saved_conversations_list_state: Default::default(),
182 zoomed: false,
183 has_focus: false,
184 toolbar,
185 api_key: Rc::new(RefCell::new(None)),
186 api_key_editor: None,
187 has_read_credentials: false,
188 languages: workspace.app_state().languages.clone(),
189 fs: workspace.app_state().fs.clone(),
190 width: None,
191 height: None,
192 subscriptions: Default::default(),
193 _watch_saved_conversations,
194 };
195
196 let mut old_dock_position = this.position(cx);
197 this.subscriptions =
198 vec![cx.observe_global::<SettingsStore, _>(move |this, cx| {
199 let new_dock_position = this.position(cx);
200 if new_dock_position != old_dock_position {
201 old_dock_position = new_dock_position;
202 cx.emit(AssistantPanelEvent::DockPositionChanged);
203 }
204 })];
205
206 this
207 })
208 })
209 })
210 }
211
212 fn new_conversation(&mut self, cx: &mut ViewContext<Self>) -> ViewHandle<ConversationEditor> {
213 let editor = cx.add_view(|cx| {
214 ConversationEditor::new(
215 self.api_key.clone(),
216 self.languages.clone(),
217 self.fs.clone(),
218 cx,
219 )
220 });
221 self.add_conversation(editor.clone(), cx);
222 editor
223 }
224
225 fn add_conversation(
226 &mut self,
227 editor: ViewHandle<ConversationEditor>,
228 cx: &mut ViewContext<Self>,
229 ) {
230 self.subscriptions
231 .push(cx.subscribe(&editor, Self::handle_conversation_editor_event));
232
233 let conversation = editor.read(cx).conversation.clone();
234 self.subscriptions
235 .push(cx.observe(&conversation, |_, _, cx| cx.notify()));
236
237 let index = self.editors.len();
238 self.editors.push(editor);
239 self.set_active_editor_index(Some(index), cx);
240 }
241
242 fn set_active_editor_index(&mut self, index: Option<usize>, cx: &mut ViewContext<Self>) {
243 self.prev_active_editor_index = self.active_editor_index;
244 self.active_editor_index = index;
245 if let Some(editor) = self.active_editor() {
246 let editor = editor.read(cx).editor.clone();
247 self.toolbar.update(cx, |toolbar, cx| {
248 toolbar.set_active_item(Some(&editor), cx);
249 });
250 if self.has_focus(cx) {
251 cx.focus(&editor);
252 }
253 } else {
254 self.toolbar.update(cx, |toolbar, cx| {
255 toolbar.set_active_item(None, cx);
256 });
257 }
258
259 cx.notify();
260 }
261
262 fn handle_conversation_editor_event(
263 &mut self,
264 _: ViewHandle<ConversationEditor>,
265 event: &ConversationEditorEvent,
266 cx: &mut ViewContext<Self>,
267 ) {
268 match event {
269 ConversationEditorEvent::TabContentChanged => cx.notify(),
270 }
271 }
272
273 fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
274 if let Some(api_key) = self
275 .api_key_editor
276 .as_ref()
277 .map(|editor| editor.read(cx).text(cx))
278 {
279 if !api_key.is_empty() {
280 cx.platform()
281 .write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
282 .log_err();
283 *self.api_key.borrow_mut() = Some(api_key);
284 self.api_key_editor.take();
285 cx.focus_self();
286 cx.notify();
287 }
288 } else {
289 cx.propagate_action();
290 }
291 }
292
293 fn reset_api_key(&mut self, _: &ResetKey, cx: &mut ViewContext<Self>) {
294 cx.platform().delete_credentials(OPENAI_API_URL).log_err();
295 self.api_key.take();
296 self.api_key_editor = Some(build_api_key_editor(cx));
297 cx.focus_self();
298 cx.notify();
299 }
300
301 fn toggle_zoom(&mut self, _: &workspace::ToggleZoom, cx: &mut ViewContext<Self>) {
302 if self.zoomed {
303 cx.emit(AssistantPanelEvent::ZoomOut)
304 } else {
305 cx.emit(AssistantPanelEvent::ZoomIn)
306 }
307 }
308
309 fn deploy(&mut self, action: &search::buffer_search::Deploy, cx: &mut ViewContext<Self>) {
310 if let Some(search_bar) = self.toolbar.read(cx).item_of_type::<BufferSearchBar>() {
311 if search_bar.update(cx, |search_bar, cx| search_bar.show(action.focus, true, cx)) {
312 return;
313 }
314 }
315 cx.propagate_action();
316 }
317
318 fn handle_editor_cancel(&mut self, _: &editor::Cancel, cx: &mut ViewContext<Self>) {
319 if let Some(search_bar) = self.toolbar.read(cx).item_of_type::<BufferSearchBar>() {
320 if !search_bar.read(cx).is_dismissed() {
321 search_bar.update(cx, |search_bar, cx| {
322 search_bar.dismiss(&Default::default(), cx)
323 });
324 return;
325 }
326 }
327 cx.propagate_action();
328 }
329
330 fn select_next_match(&mut self, _: &search::SelectNextMatch, cx: &mut ViewContext<Self>) {
331 if let Some(search_bar) = self.toolbar.read(cx).item_of_type::<BufferSearchBar>() {
332 search_bar.update(cx, |bar, cx| bar.select_match(Direction::Next, cx));
333 }
334 }
335
336 fn select_prev_match(&mut self, _: &search::SelectPrevMatch, cx: &mut ViewContext<Self>) {
337 if let Some(search_bar) = self.toolbar.read(cx).item_of_type::<BufferSearchBar>() {
338 search_bar.update(cx, |bar, cx| bar.select_match(Direction::Prev, cx));
339 }
340 }
341
342 fn active_editor(&self) -> Option<&ViewHandle<ConversationEditor>> {
343 self.editors.get(self.active_editor_index?)
344 }
345
346 fn render_hamburger_button(cx: &mut ViewContext<Self>) -> impl Element<Self> {
347 enum ListConversations {}
348 let theme = theme::current(cx);
349 MouseEventHandler::<ListConversations, _>::new(0, cx, |state, _| {
350 let style = theme.assistant.hamburger_button.style_for(state);
351 Svg::for_style(style.icon.clone())
352 .contained()
353 .with_style(style.container)
354 })
355 .with_cursor_style(CursorStyle::PointingHand)
356 .on_click(MouseButton::Left, |_, this: &mut Self, cx| {
357 if this.active_editor().is_some() {
358 this.set_active_editor_index(None, cx);
359 } else {
360 this.set_active_editor_index(this.prev_active_editor_index, cx);
361 }
362 })
363 }
364
365 fn render_editor_tools(&self, cx: &mut ViewContext<Self>) -> Vec<AnyElement<Self>> {
366 if self.active_editor().is_some() {
367 vec![
368 Self::render_split_button(cx).into_any(),
369 Self::render_quote_button(cx).into_any(),
370 Self::render_assist_button(cx).into_any(),
371 ]
372 } else {
373 Default::default()
374 }
375 }
376
377 fn render_split_button(cx: &mut ViewContext<Self>) -> impl Element<Self> {
378 let theme = theme::current(cx);
379 let tooltip_style = theme::current(cx).tooltip.clone();
380 MouseEventHandler::<Split, _>::new(0, cx, |state, _| {
381 let style = theme.assistant.split_button.style_for(state);
382 Svg::for_style(style.icon.clone())
383 .contained()
384 .with_style(style.container)
385 })
386 .with_cursor_style(CursorStyle::PointingHand)
387 .on_click(MouseButton::Left, |_, this: &mut Self, cx| {
388 if let Some(active_editor) = this.active_editor() {
389 active_editor.update(cx, |editor, cx| editor.split(&Default::default(), cx));
390 }
391 })
392 .with_tooltip::<Split>(
393 1,
394 "Split Message".into(),
395 Some(Box::new(Split)),
396 tooltip_style,
397 cx,
398 )
399 }
400
401 fn render_assist_button(cx: &mut ViewContext<Self>) -> impl Element<Self> {
402 let theme = theme::current(cx);
403 let tooltip_style = theme::current(cx).tooltip.clone();
404 MouseEventHandler::<Assist, _>::new(0, cx, |state, _| {
405 let style = theme.assistant.assist_button.style_for(state);
406 Svg::for_style(style.icon.clone())
407 .contained()
408 .with_style(style.container)
409 })
410 .with_cursor_style(CursorStyle::PointingHand)
411 .on_click(MouseButton::Left, |_, this: &mut Self, cx| {
412 if let Some(active_editor) = this.active_editor() {
413 active_editor.update(cx, |editor, cx| editor.assist(&Default::default(), cx));
414 }
415 })
416 .with_tooltip::<Assist>(
417 1,
418 "Assist".into(),
419 Some(Box::new(Assist)),
420 tooltip_style,
421 cx,
422 )
423 }
424
425 fn render_quote_button(cx: &mut ViewContext<Self>) -> impl Element<Self> {
426 let theme = theme::current(cx);
427 let tooltip_style = theme::current(cx).tooltip.clone();
428 MouseEventHandler::<QuoteSelection, _>::new(0, cx, |state, _| {
429 let style = theme.assistant.quote_button.style_for(state);
430 Svg::for_style(style.icon.clone())
431 .contained()
432 .with_style(style.container)
433 })
434 .with_cursor_style(CursorStyle::PointingHand)
435 .on_click(MouseButton::Left, |_, this: &mut Self, cx| {
436 if let Some(workspace) = this.workspace.upgrade(cx) {
437 cx.window_context().defer(move |cx| {
438 workspace.update(cx, |workspace, cx| {
439 ConversationEditor::quote_selection(workspace, &Default::default(), cx)
440 });
441 });
442 }
443 })
444 .with_tooltip::<QuoteSelection>(
445 1,
446 "Assist".into(),
447 Some(Box::new(QuoteSelection)),
448 tooltip_style,
449 cx,
450 )
451 }
452
453 fn render_plus_button(cx: &mut ViewContext<Self>) -> impl Element<Self> {
454 enum AddConversation {}
455 let theme = theme::current(cx);
456 MouseEventHandler::<AddConversation, _>::new(0, cx, |state, _| {
457 let style = theme.assistant.plus_button.style_for(state);
458 Svg::for_style(style.icon.clone())
459 .contained()
460 .with_style(style.container)
461 })
462 .with_cursor_style(CursorStyle::PointingHand)
463 .on_click(MouseButton::Left, |_, this: &mut Self, cx| {
464 this.new_conversation(cx);
465 })
466 }
467
468 fn render_zoom_button(&self, cx: &mut ViewContext<Self>) -> impl Element<Self> {
469 enum ToggleZoomButton {}
470
471 let theme = theme::current(cx);
472 let style = if self.zoomed {
473 &theme.assistant.zoom_out_button
474 } else {
475 &theme.assistant.zoom_in_button
476 };
477
478 MouseEventHandler::<ToggleZoomButton, _>::new(0, cx, |state, _| {
479 let style = style.style_for(state);
480 Svg::for_style(style.icon.clone())
481 .contained()
482 .with_style(style.container)
483 })
484 .with_cursor_style(CursorStyle::PointingHand)
485 .on_click(MouseButton::Left, |_, this, cx| {
486 this.toggle_zoom(&ToggleZoom, cx);
487 })
488 }
489
490 fn render_saved_conversation(
491 &mut self,
492 index: usize,
493 cx: &mut ViewContext<Self>,
494 ) -> impl Element<Self> {
495 let conversation = &self.saved_conversations[index];
496 let path = conversation.path.clone();
497 MouseEventHandler::<SavedConversationMetadata, _>::new(index, cx, move |state, cx| {
498 let style = &theme::current(cx).assistant.saved_conversation;
499 Flex::row()
500 .with_child(
501 Label::new(
502 conversation.mtime.format("%F %I:%M%p").to_string(),
503 style.saved_at.text.clone(),
504 )
505 .aligned()
506 .contained()
507 .with_style(style.saved_at.container),
508 )
509 .with_child(
510 Label::new(conversation.title.clone(), style.title.text.clone())
511 .aligned()
512 .contained()
513 .with_style(style.title.container),
514 )
515 .contained()
516 .with_style(*style.container.style_for(state))
517 })
518 .with_cursor_style(CursorStyle::PointingHand)
519 .on_click(MouseButton::Left, move |_, this, cx| {
520 this.open_conversation(path.clone(), cx)
521 .detach_and_log_err(cx)
522 })
523 }
524
525 fn open_conversation(&mut self, path: PathBuf, cx: &mut ViewContext<Self>) -> Task<Result<()>> {
526 if let Some(ix) = self.editor_index_for_path(&path, cx) {
527 self.set_active_editor_index(Some(ix), cx);
528 return Task::ready(Ok(()));
529 }
530
531 let fs = self.fs.clone();
532 let api_key = self.api_key.clone();
533 let languages = self.languages.clone();
534 cx.spawn(|this, mut cx| async move {
535 let saved_conversation = fs.load(&path).await?;
536 let saved_conversation = serde_json::from_str(&saved_conversation)?;
537 let conversation = cx.add_model(|cx| {
538 Conversation::deserialize(saved_conversation, path.clone(), api_key, languages, cx)
539 });
540 this.update(&mut cx, |this, cx| {
541 // If, by the time we've loaded the conversation, the user has already opened
542 // the same conversation, we don't want to open it again.
543 if let Some(ix) = this.editor_index_for_path(&path, cx) {
544 this.set_active_editor_index(Some(ix), cx);
545 } else {
546 let editor = cx
547 .add_view(|cx| ConversationEditor::for_conversation(conversation, fs, cx));
548 this.add_conversation(editor, cx);
549 }
550 })?;
551 Ok(())
552 })
553 }
554
555 fn editor_index_for_path(&self, path: &Path, cx: &AppContext) -> Option<usize> {
556 self.editors
557 .iter()
558 .position(|editor| editor.read(cx).conversation.read(cx).path.as_deref() == Some(path))
559 }
560}
561
562fn build_api_key_editor(cx: &mut ViewContext<AssistantPanel>) -> ViewHandle<Editor> {
563 cx.add_view(|cx| {
564 let mut editor = Editor::single_line(
565 Some(Arc::new(|theme| theme.assistant.api_key_editor.clone())),
566 cx,
567 );
568 editor.set_placeholder_text("sk-000000000000000000000000000000000000000000000000", cx);
569 editor
570 })
571}
572
573impl Entity for AssistantPanel {
574 type Event = AssistantPanelEvent;
575}
576
577impl View for AssistantPanel {
578 fn ui_name() -> &'static str {
579 "AssistantPanel"
580 }
581
582 fn render(&mut self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
583 let theme = &theme::current(cx);
584 let style = &theme.assistant;
585 if let Some(api_key_editor) = self.api_key_editor.as_ref() {
586 Flex::column()
587 .with_child(
588 Text::new(
589 "Paste your OpenAI API key and press Enter to use the assistant",
590 style.api_key_prompt.text.clone(),
591 )
592 .aligned(),
593 )
594 .with_child(
595 ChildView::new(api_key_editor, cx)
596 .contained()
597 .with_style(style.api_key_editor.container)
598 .aligned(),
599 )
600 .contained()
601 .with_style(style.api_key_prompt.container)
602 .aligned()
603 .into_any()
604 } else {
605 let title = self.active_editor().map(|editor| {
606 Label::new(editor.read(cx).title(cx), style.title.text.clone())
607 .contained()
608 .with_style(style.title.container)
609 .aligned()
610 .left()
611 .flex(1., false)
612 });
613
614 Flex::column()
615 .with_child(
616 Flex::row()
617 .with_child(Self::render_hamburger_button(cx).aligned())
618 .with_children(title)
619 .with_children(
620 self.render_editor_tools(cx)
621 .into_iter()
622 .map(|tool| tool.aligned().flex_float()),
623 )
624 .with_child(Self::render_plus_button(cx).aligned().flex_float())
625 .with_child(self.render_zoom_button(cx).aligned())
626 .contained()
627 .with_style(theme.workspace.tab_bar.container)
628 .expanded()
629 .constrained()
630 .with_height(theme.workspace.tab_bar.height),
631 )
632 .with_children(if self.toolbar.read(cx).hidden() {
633 None
634 } else {
635 Some(ChildView::new(&self.toolbar, cx).expanded())
636 })
637 .with_child(if let Some(editor) = self.active_editor() {
638 ChildView::new(editor, cx).flex(1., true).into_any()
639 } else {
640 UniformList::new(
641 self.saved_conversations_list_state.clone(),
642 self.saved_conversations.len(),
643 cx,
644 |this, range, items, cx| {
645 for ix in range {
646 items.push(this.render_saved_conversation(ix, cx).into_any());
647 }
648 },
649 )
650 .flex(1., true)
651 .into_any()
652 })
653 .into_any()
654 }
655 }
656
657 fn focus_in(&mut self, _: gpui::AnyViewHandle, cx: &mut ViewContext<Self>) {
658 self.has_focus = true;
659 self.toolbar
660 .update(cx, |toolbar, cx| toolbar.focus_changed(true, cx));
661 if cx.is_self_focused() {
662 if let Some(editor) = self.active_editor() {
663 cx.focus(editor);
664 } else if let Some(api_key_editor) = self.api_key_editor.as_ref() {
665 cx.focus(api_key_editor);
666 }
667 }
668 }
669
670 fn focus_out(&mut self, _: gpui::AnyViewHandle, cx: &mut ViewContext<Self>) {
671 self.has_focus = false;
672 self.toolbar
673 .update(cx, |toolbar, cx| toolbar.focus_changed(false, cx));
674 }
675}
676
677impl Panel for AssistantPanel {
678 fn position(&self, cx: &WindowContext) -> DockPosition {
679 match settings::get::<AssistantSettings>(cx).dock {
680 AssistantDockPosition::Left => DockPosition::Left,
681 AssistantDockPosition::Bottom => DockPosition::Bottom,
682 AssistantDockPosition::Right => DockPosition::Right,
683 }
684 }
685
686 fn position_is_valid(&self, _: DockPosition) -> bool {
687 true
688 }
689
690 fn set_position(&mut self, position: DockPosition, cx: &mut ViewContext<Self>) {
691 settings::update_settings_file::<AssistantSettings>(self.fs.clone(), cx, move |settings| {
692 let dock = match position {
693 DockPosition::Left => AssistantDockPosition::Left,
694 DockPosition::Bottom => AssistantDockPosition::Bottom,
695 DockPosition::Right => AssistantDockPosition::Right,
696 };
697 settings.dock = Some(dock);
698 });
699 }
700
701 fn size(&self, cx: &WindowContext) -> f32 {
702 let settings = settings::get::<AssistantSettings>(cx);
703 match self.position(cx) {
704 DockPosition::Left | DockPosition::Right => {
705 self.width.unwrap_or_else(|| settings.default_width)
706 }
707 DockPosition::Bottom => self.height.unwrap_or_else(|| settings.default_height),
708 }
709 }
710
711 fn set_size(&mut self, size: f32, cx: &mut ViewContext<Self>) {
712 match self.position(cx) {
713 DockPosition::Left | DockPosition::Right => self.width = Some(size),
714 DockPosition::Bottom => self.height = Some(size),
715 }
716 cx.notify();
717 }
718
719 fn should_zoom_in_on_event(event: &AssistantPanelEvent) -> bool {
720 matches!(event, AssistantPanelEvent::ZoomIn)
721 }
722
723 fn should_zoom_out_on_event(event: &AssistantPanelEvent) -> bool {
724 matches!(event, AssistantPanelEvent::ZoomOut)
725 }
726
727 fn is_zoomed(&self, _: &WindowContext) -> bool {
728 self.zoomed
729 }
730
731 fn set_zoomed(&mut self, zoomed: bool, cx: &mut ViewContext<Self>) {
732 self.zoomed = zoomed;
733 cx.notify();
734 }
735
736 fn set_active(&mut self, active: bool, cx: &mut ViewContext<Self>) {
737 if active {
738 if self.api_key.borrow().is_none() && !self.has_read_credentials {
739 self.has_read_credentials = true;
740 let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
741 Some(api_key)
742 } else if let Some((_, api_key)) = cx
743 .platform()
744 .read_credentials(OPENAI_API_URL)
745 .log_err()
746 .flatten()
747 {
748 String::from_utf8(api_key).log_err()
749 } else {
750 None
751 };
752 if let Some(api_key) = api_key {
753 *self.api_key.borrow_mut() = Some(api_key);
754 } else if self.api_key_editor.is_none() {
755 self.api_key_editor = Some(build_api_key_editor(cx));
756 cx.notify();
757 }
758 }
759
760 if self.editors.is_empty() {
761 self.new_conversation(cx);
762 }
763 }
764 }
765
766 fn icon_path(&self) -> &'static str {
767 "icons/robot_14.svg"
768 }
769
770 fn icon_tooltip(&self) -> (String, Option<Box<dyn Action>>) {
771 ("Assistant Panel".into(), Some(Box::new(ToggleFocus)))
772 }
773
774 fn should_change_position_on_event(event: &Self::Event) -> bool {
775 matches!(event, AssistantPanelEvent::DockPositionChanged)
776 }
777
778 fn should_activate_on_event(_: &Self::Event) -> bool {
779 false
780 }
781
782 fn should_close_on_event(event: &AssistantPanelEvent) -> bool {
783 matches!(event, AssistantPanelEvent::Close)
784 }
785
786 fn has_focus(&self, _: &WindowContext) -> bool {
787 self.has_focus
788 }
789
790 fn is_focus_event(event: &Self::Event) -> bool {
791 matches!(event, AssistantPanelEvent::Focus)
792 }
793}
794
795enum ConversationEvent {
796 MessagesEdited,
797 SummaryChanged,
798 StreamedCompletion,
799}
800
801#[derive(Default)]
802struct Summary {
803 text: String,
804 done: bool,
805}
806
807struct Conversation {
808 buffer: ModelHandle<Buffer>,
809 message_anchors: Vec<MessageAnchor>,
810 messages_metadata: HashMap<MessageId, MessageMetadata>,
811 next_message_id: MessageId,
812 summary: Option<Summary>,
813 pending_summary: Task<Option<()>>,
814 completion_count: usize,
815 pending_completions: Vec<PendingCompletion>,
816 model: String,
817 token_count: Option<usize>,
818 max_token_count: usize,
819 pending_token_count: Task<Option<()>>,
820 api_key: Rc<RefCell<Option<String>>>,
821 pending_save: Task<Result<()>>,
822 path: Option<PathBuf>,
823 _subscriptions: Vec<Subscription>,
824}
825
826impl Entity for Conversation {
827 type Event = ConversationEvent;
828}
829
830impl Conversation {
831 fn new(
832 api_key: Rc<RefCell<Option<String>>>,
833 language_registry: Arc<LanguageRegistry>,
834 cx: &mut ModelContext<Self>,
835 ) -> Self {
836 let model = "gpt-3.5-turbo-0613";
837 let markdown = language_registry.language_for_name("Markdown");
838 let buffer = cx.add_model(|cx| {
839 let mut buffer = Buffer::new(0, "", cx);
840 buffer.set_language_registry(language_registry);
841 cx.spawn_weak(|buffer, mut cx| async move {
842 let markdown = markdown.await?;
843 let buffer = buffer
844 .upgrade(&cx)
845 .ok_or_else(|| anyhow!("buffer was dropped"))?;
846 buffer.update(&mut cx, |buffer, cx| {
847 buffer.set_language(Some(markdown), cx)
848 });
849 anyhow::Ok(())
850 })
851 .detach_and_log_err(cx);
852 buffer
853 });
854
855 let mut this = Self {
856 message_anchors: Default::default(),
857 messages_metadata: Default::default(),
858 next_message_id: Default::default(),
859 summary: None,
860 pending_summary: Task::ready(None),
861 completion_count: Default::default(),
862 pending_completions: Default::default(),
863 token_count: None,
864 max_token_count: tiktoken_rs::model::get_context_size(model),
865 pending_token_count: Task::ready(None),
866 model: model.into(),
867 _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
868 pending_save: Task::ready(Ok(())),
869 path: None,
870 api_key,
871 buffer,
872 };
873 let message = MessageAnchor {
874 id: MessageId(post_inc(&mut this.next_message_id.0)),
875 start: language::Anchor::MIN,
876 };
877 this.message_anchors.push(message.clone());
878 this.messages_metadata.insert(
879 message.id,
880 MessageMetadata {
881 role: Role::User,
882 sent_at: Local::now(),
883 status: MessageStatus::Done,
884 },
885 );
886
887 this.count_remaining_tokens(cx);
888 this
889 }
890
891 fn serialize(&self, cx: &AppContext) -> SavedConversation {
892 SavedConversation {
893 zed: "conversation".into(),
894 version: SavedConversation::VERSION.into(),
895 text: self.buffer.read(cx).text(),
896 message_metadata: self.messages_metadata.clone(),
897 messages: self
898 .messages(cx)
899 .map(|message| SavedMessage {
900 id: message.id,
901 start: message.offset_range.start,
902 })
903 .collect(),
904 summary: self
905 .summary
906 .as_ref()
907 .map(|summary| summary.text.clone())
908 .unwrap_or_default(),
909 model: self.model.clone(),
910 }
911 }
912
913 fn deserialize(
914 saved_conversation: SavedConversation,
915 path: PathBuf,
916 api_key: Rc<RefCell<Option<String>>>,
917 language_registry: Arc<LanguageRegistry>,
918 cx: &mut ModelContext<Self>,
919 ) -> Self {
920 let model = saved_conversation.model;
921 let markdown = language_registry.language_for_name("Markdown");
922 let mut message_anchors = Vec::new();
923 let mut next_message_id = MessageId(0);
924 let buffer = cx.add_model(|cx| {
925 let mut buffer = Buffer::new(0, saved_conversation.text, cx);
926 for message in saved_conversation.messages {
927 message_anchors.push(MessageAnchor {
928 id: message.id,
929 start: buffer.anchor_before(message.start),
930 });
931 next_message_id = cmp::max(next_message_id, MessageId(message.id.0 + 1));
932 }
933 buffer.set_language_registry(language_registry);
934 cx.spawn_weak(|buffer, mut cx| async move {
935 let markdown = markdown.await?;
936 let buffer = buffer
937 .upgrade(&cx)
938 .ok_or_else(|| anyhow!("buffer was dropped"))?;
939 buffer.update(&mut cx, |buffer, cx| {
940 buffer.set_language(Some(markdown), cx)
941 });
942 anyhow::Ok(())
943 })
944 .detach_and_log_err(cx);
945 buffer
946 });
947
948 let mut this = Self {
949 message_anchors,
950 messages_metadata: saved_conversation.message_metadata,
951 next_message_id,
952 summary: Some(Summary {
953 text: saved_conversation.summary,
954 done: true,
955 }),
956 pending_summary: Task::ready(None),
957 completion_count: Default::default(),
958 pending_completions: Default::default(),
959 token_count: None,
960 max_token_count: tiktoken_rs::model::get_context_size(&model),
961 pending_token_count: Task::ready(None),
962 model,
963 _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
964 pending_save: Task::ready(Ok(())),
965 path: Some(path),
966 api_key,
967 buffer,
968 };
969 this.count_remaining_tokens(cx);
970 this
971 }
972
973 fn handle_buffer_event(
974 &mut self,
975 _: ModelHandle<Buffer>,
976 event: &language::Event,
977 cx: &mut ModelContext<Self>,
978 ) {
979 match event {
980 language::Event::Edited => {
981 self.count_remaining_tokens(cx);
982 cx.emit(ConversationEvent::MessagesEdited);
983 }
984 _ => {}
985 }
986 }
987
988 fn count_remaining_tokens(&mut self, cx: &mut ModelContext<Self>) {
989 let messages = self
990 .messages(cx)
991 .into_iter()
992 .filter_map(|message| {
993 Some(tiktoken_rs::ChatCompletionRequestMessage {
994 role: match message.role {
995 Role::User => "user".into(),
996 Role::Assistant => "assistant".into(),
997 Role::System => "system".into(),
998 },
999 content: self
1000 .buffer
1001 .read(cx)
1002 .text_for_range(message.offset_range)
1003 .collect(),
1004 name: None,
1005 })
1006 })
1007 .collect::<Vec<_>>();
1008 let model = self.model.clone();
1009 self.pending_token_count = cx.spawn_weak(|this, mut cx| {
1010 async move {
1011 cx.background().timer(Duration::from_millis(200)).await;
1012 let token_count = cx
1013 .background()
1014 .spawn(async move { tiktoken_rs::num_tokens_from_messages(&model, &messages) })
1015 .await?;
1016
1017 this.upgrade(&cx)
1018 .ok_or_else(|| anyhow!("conversation was dropped"))?
1019 .update(&mut cx, |this, cx| {
1020 this.max_token_count = tiktoken_rs::model::get_context_size(&this.model);
1021 this.token_count = Some(token_count);
1022 cx.notify()
1023 });
1024 anyhow::Ok(())
1025 }
1026 .log_err()
1027 });
1028 }
1029
1030 fn remaining_tokens(&self) -> Option<isize> {
1031 Some(self.max_token_count as isize - self.token_count? as isize)
1032 }
1033
1034 fn set_model(&mut self, model: String, cx: &mut ModelContext<Self>) {
1035 self.model = model;
1036 self.count_remaining_tokens(cx);
1037 cx.notify();
1038 }
1039
1040 fn assist(
1041 &mut self,
1042 selected_messages: HashSet<MessageId>,
1043 cx: &mut ModelContext<Self>,
1044 ) -> Vec<MessageAnchor> {
1045 let mut user_messages = Vec::new();
1046 let mut tasks = Vec::new();
1047
1048 let last_message_id = self.message_anchors.iter().rev().find_map(|message| {
1049 message
1050 .start
1051 .is_valid(self.buffer.read(cx))
1052 .then_some(message.id)
1053 });
1054
1055 for selected_message_id in selected_messages {
1056 let selected_message_role =
1057 if let Some(metadata) = self.messages_metadata.get(&selected_message_id) {
1058 metadata.role
1059 } else {
1060 continue;
1061 };
1062
1063 if selected_message_role == Role::Assistant {
1064 if let Some(user_message) = self.insert_message_after(
1065 selected_message_id,
1066 Role::User,
1067 MessageStatus::Done,
1068 cx,
1069 ) {
1070 user_messages.push(user_message);
1071 } else {
1072 continue;
1073 }
1074 } else {
1075 let request = OpenAIRequest {
1076 model: self.model.clone(),
1077 messages: self
1078 .messages(cx)
1079 .filter(|message| matches!(message.status, MessageStatus::Done))
1080 .flat_map(|message| {
1081 let mut system_message = None;
1082 if message.id == selected_message_id {
1083 system_message = Some(RequestMessage {
1084 role: Role::System,
1085 content: concat!(
1086 "Treat the following messages as additional knowledge you have learned about, ",
1087 "but act as if they were not part of this conversation. That is, treat them ",
1088 "as if the user didn't see them and couldn't possibly inquire about them."
1089 ).into()
1090 });
1091 }
1092
1093 Some(message.to_open_ai_message(self.buffer.read(cx))).into_iter().chain(system_message)
1094 })
1095 .chain(Some(RequestMessage {
1096 role: Role::System,
1097 content: format!(
1098 "Direct your reply to message with id {}. Do not include a [Message X] header.",
1099 selected_message_id.0
1100 ),
1101 }))
1102 .collect(),
1103 stream: true,
1104 };
1105
1106 let Some(api_key) = self.api_key.borrow().clone() else { continue };
1107 let stream = stream_completion(api_key, cx.background().clone(), request);
1108 let assistant_message = self
1109 .insert_message_after(
1110 selected_message_id,
1111 Role::Assistant,
1112 MessageStatus::Pending,
1113 cx,
1114 )
1115 .unwrap();
1116
1117 // Queue up the user's next reply
1118 if Some(selected_message_id) == last_message_id {
1119 let user_message = self
1120 .insert_message_after(
1121 assistant_message.id,
1122 Role::User,
1123 MessageStatus::Done,
1124 cx,
1125 )
1126 .unwrap();
1127 user_messages.push(user_message);
1128 }
1129
1130 tasks.push(cx.spawn_weak({
1131 |this, mut cx| async move {
1132 let assistant_message_id = assistant_message.id;
1133 let stream_completion = async {
1134 let mut messages = stream.await?;
1135
1136 while let Some(message) = messages.next().await {
1137 let mut message = message?;
1138 if let Some(choice) = message.choices.pop() {
1139 this.upgrade(&cx)
1140 .ok_or_else(|| anyhow!("conversation was dropped"))?
1141 .update(&mut cx, |this, cx| {
1142 let text: Arc<str> = choice.delta.content?.into();
1143 let message_ix = this.message_anchors.iter().position(
1144 |message| message.id == assistant_message_id,
1145 )?;
1146 this.buffer.update(cx, |buffer, cx| {
1147 let offset = this.message_anchors[message_ix + 1..]
1148 .iter()
1149 .find(|message| message.start.is_valid(buffer))
1150 .map_or(buffer.len(), |message| {
1151 message
1152 .start
1153 .to_offset(buffer)
1154 .saturating_sub(1)
1155 });
1156 buffer.edit([(offset..offset, text)], None, cx);
1157 });
1158 cx.emit(ConversationEvent::StreamedCompletion);
1159
1160 Some(())
1161 });
1162 }
1163 smol::future::yield_now().await;
1164 }
1165
1166 this.upgrade(&cx)
1167 .ok_or_else(|| anyhow!("conversation was dropped"))?
1168 .update(&mut cx, |this, cx| {
1169 this.pending_completions.retain(|completion| {
1170 completion.id != this.completion_count
1171 });
1172 this.summarize(cx);
1173 });
1174
1175 anyhow::Ok(())
1176 };
1177
1178 let result = stream_completion.await;
1179 if let Some(this) = this.upgrade(&cx) {
1180 this.update(&mut cx, |this, cx| {
1181 if let Some(metadata) =
1182 this.messages_metadata.get_mut(&assistant_message.id)
1183 {
1184 match result {
1185 Ok(_) => {
1186 metadata.status = MessageStatus::Done;
1187 }
1188 Err(error) => {
1189 metadata.status = MessageStatus::Error(
1190 error.to_string().trim().into(),
1191 );
1192 }
1193 }
1194 cx.notify();
1195 }
1196 });
1197 }
1198 }
1199 }));
1200 }
1201 }
1202
1203 if !tasks.is_empty() {
1204 self.pending_completions.push(PendingCompletion {
1205 id: post_inc(&mut self.completion_count),
1206 _tasks: tasks,
1207 });
1208 }
1209
1210 user_messages
1211 }
1212
1213 fn cancel_last_assist(&mut self) -> bool {
1214 self.pending_completions.pop().is_some()
1215 }
1216
1217 fn cycle_message_roles(&mut self, ids: HashSet<MessageId>, cx: &mut ModelContext<Self>) {
1218 for id in ids {
1219 if let Some(metadata) = self.messages_metadata.get_mut(&id) {
1220 metadata.role.cycle();
1221 cx.emit(ConversationEvent::MessagesEdited);
1222 cx.notify();
1223 }
1224 }
1225 }
1226
1227 fn insert_message_after(
1228 &mut self,
1229 message_id: MessageId,
1230 role: Role,
1231 status: MessageStatus,
1232 cx: &mut ModelContext<Self>,
1233 ) -> Option<MessageAnchor> {
1234 if let Some(prev_message_ix) = self
1235 .message_anchors
1236 .iter()
1237 .position(|message| message.id == message_id)
1238 {
1239 // Find the next valid message after the one we were given.
1240 let mut next_message_ix = prev_message_ix + 1;
1241 while let Some(next_message) = self.message_anchors.get(next_message_ix) {
1242 if next_message.start.is_valid(self.buffer.read(cx)) {
1243 break;
1244 }
1245 next_message_ix += 1;
1246 }
1247
1248 let start = self.buffer.update(cx, |buffer, cx| {
1249 let offset = self
1250 .message_anchors
1251 .get(next_message_ix)
1252 .map_or(buffer.len(), |message| message.start.to_offset(buffer) - 1);
1253 buffer.edit([(offset..offset, "\n")], None, cx);
1254 buffer.anchor_before(offset + 1)
1255 });
1256 let message = MessageAnchor {
1257 id: MessageId(post_inc(&mut self.next_message_id.0)),
1258 start,
1259 };
1260 self.message_anchors
1261 .insert(next_message_ix, message.clone());
1262 self.messages_metadata.insert(
1263 message.id,
1264 MessageMetadata {
1265 role,
1266 sent_at: Local::now(),
1267 status,
1268 },
1269 );
1270 cx.emit(ConversationEvent::MessagesEdited);
1271 Some(message)
1272 } else {
1273 None
1274 }
1275 }
1276
1277 fn split_message(
1278 &mut self,
1279 range: Range<usize>,
1280 cx: &mut ModelContext<Self>,
1281 ) -> (Option<MessageAnchor>, Option<MessageAnchor>) {
1282 let start_message = self.message_for_offset(range.start, cx);
1283 let end_message = self.message_for_offset(range.end, cx);
1284 if let Some((start_message, end_message)) = start_message.zip(end_message) {
1285 // Prevent splitting when range spans multiple messages.
1286 if start_message.id != end_message.id {
1287 return (None, None);
1288 }
1289
1290 let message = start_message;
1291 let role = message.role;
1292 let mut edited_buffer = false;
1293
1294 let mut suffix_start = None;
1295 if range.start > message.offset_range.start && range.end < message.offset_range.end - 1
1296 {
1297 if self.buffer.read(cx).chars_at(range.end).next() == Some('\n') {
1298 suffix_start = Some(range.end + 1);
1299 } else if self.buffer.read(cx).reversed_chars_at(range.end).next() == Some('\n') {
1300 suffix_start = Some(range.end);
1301 }
1302 }
1303
1304 let suffix = if let Some(suffix_start) = suffix_start {
1305 MessageAnchor {
1306 id: MessageId(post_inc(&mut self.next_message_id.0)),
1307 start: self.buffer.read(cx).anchor_before(suffix_start),
1308 }
1309 } else {
1310 self.buffer.update(cx, |buffer, cx| {
1311 buffer.edit([(range.end..range.end, "\n")], None, cx);
1312 });
1313 edited_buffer = true;
1314 MessageAnchor {
1315 id: MessageId(post_inc(&mut self.next_message_id.0)),
1316 start: self.buffer.read(cx).anchor_before(range.end + 1),
1317 }
1318 };
1319
1320 self.message_anchors
1321 .insert(message.index_range.end + 1, suffix.clone());
1322 self.messages_metadata.insert(
1323 suffix.id,
1324 MessageMetadata {
1325 role,
1326 sent_at: Local::now(),
1327 status: MessageStatus::Done,
1328 },
1329 );
1330
1331 let new_messages =
1332 if range.start == range.end || range.start == message.offset_range.start {
1333 (None, Some(suffix))
1334 } else {
1335 let mut prefix_end = None;
1336 if range.start > message.offset_range.start
1337 && range.end < message.offset_range.end - 1
1338 {
1339 if self.buffer.read(cx).chars_at(range.start).next() == Some('\n') {
1340 prefix_end = Some(range.start + 1);
1341 } else if self.buffer.read(cx).reversed_chars_at(range.start).next()
1342 == Some('\n')
1343 {
1344 prefix_end = Some(range.start);
1345 }
1346 }
1347
1348 let selection = if let Some(prefix_end) = prefix_end {
1349 cx.emit(ConversationEvent::MessagesEdited);
1350 MessageAnchor {
1351 id: MessageId(post_inc(&mut self.next_message_id.0)),
1352 start: self.buffer.read(cx).anchor_before(prefix_end),
1353 }
1354 } else {
1355 self.buffer.update(cx, |buffer, cx| {
1356 buffer.edit([(range.start..range.start, "\n")], None, cx)
1357 });
1358 edited_buffer = true;
1359 MessageAnchor {
1360 id: MessageId(post_inc(&mut self.next_message_id.0)),
1361 start: self.buffer.read(cx).anchor_before(range.end + 1),
1362 }
1363 };
1364
1365 self.message_anchors
1366 .insert(message.index_range.end + 1, selection.clone());
1367 self.messages_metadata.insert(
1368 selection.id,
1369 MessageMetadata {
1370 role,
1371 sent_at: Local::now(),
1372 status: MessageStatus::Done,
1373 },
1374 );
1375 (Some(selection), Some(suffix))
1376 };
1377
1378 if !edited_buffer {
1379 cx.emit(ConversationEvent::MessagesEdited);
1380 }
1381 new_messages
1382 } else {
1383 (None, None)
1384 }
1385 }
1386
1387 fn summarize(&mut self, cx: &mut ModelContext<Self>) {
1388 if self.message_anchors.len() >= 2 && self.summary.is_none() {
1389 let api_key = self.api_key.borrow().clone();
1390 if let Some(api_key) = api_key {
1391 let messages = self
1392 .messages(cx)
1393 .take(2)
1394 .map(|message| message.to_open_ai_message(self.buffer.read(cx)))
1395 .chain(Some(RequestMessage {
1396 role: Role::User,
1397 content:
1398 "Summarize the conversation into a short title without punctuation"
1399 .into(),
1400 }));
1401 let request = OpenAIRequest {
1402 model: self.model.clone(),
1403 messages: messages.collect(),
1404 stream: true,
1405 };
1406
1407 let stream = stream_completion(api_key, cx.background().clone(), request);
1408 self.pending_summary = cx.spawn(|this, mut cx| {
1409 async move {
1410 let mut messages = stream.await?;
1411
1412 while let Some(message) = messages.next().await {
1413 let mut message = message?;
1414 if let Some(choice) = message.choices.pop() {
1415 let text = choice.delta.content.unwrap_or_default();
1416 this.update(&mut cx, |this, cx| {
1417 this.summary
1418 .get_or_insert(Default::default())
1419 .text
1420 .push_str(&text);
1421 cx.emit(ConversationEvent::SummaryChanged);
1422 });
1423 }
1424 }
1425
1426 this.update(&mut cx, |this, cx| {
1427 if let Some(summary) = this.summary.as_mut() {
1428 summary.done = true;
1429 cx.emit(ConversationEvent::SummaryChanged);
1430 }
1431 });
1432
1433 anyhow::Ok(())
1434 }
1435 .log_err()
1436 });
1437 }
1438 }
1439 }
1440
1441 fn message_for_offset(&self, offset: usize, cx: &AppContext) -> Option<Message> {
1442 self.messages_for_offsets([offset], cx).pop()
1443 }
1444
1445 fn messages_for_offsets(
1446 &self,
1447 offsets: impl IntoIterator<Item = usize>,
1448 cx: &AppContext,
1449 ) -> Vec<Message> {
1450 let mut result = Vec::new();
1451
1452 let mut messages = self.messages(cx).peekable();
1453 let mut offsets = offsets.into_iter().peekable();
1454 let mut current_message = messages.next();
1455 while let Some(offset) = offsets.next() {
1456 // Locate the message that contains the offset.
1457 while current_message.as_ref().map_or(false, |message| {
1458 !message.offset_range.contains(&offset) && messages.peek().is_some()
1459 }) {
1460 current_message = messages.next();
1461 }
1462 let Some(message) = current_message.as_ref() else { break };
1463
1464 // Skip offsets that are in the same message.
1465 while offsets.peek().map_or(false, |offset| {
1466 message.offset_range.contains(offset) || messages.peek().is_none()
1467 }) {
1468 offsets.next();
1469 }
1470
1471 result.push(message.clone());
1472 }
1473 result
1474 }
1475
1476 fn messages<'a>(&'a self, cx: &'a AppContext) -> impl 'a + Iterator<Item = Message> {
1477 let buffer = self.buffer.read(cx);
1478 let mut message_anchors = self.message_anchors.iter().enumerate().peekable();
1479 iter::from_fn(move || {
1480 while let Some((start_ix, message_anchor)) = message_anchors.next() {
1481 let metadata = self.messages_metadata.get(&message_anchor.id)?;
1482 let message_start = message_anchor.start.to_offset(buffer);
1483 let mut message_end = None;
1484 let mut end_ix = start_ix;
1485 while let Some((_, next_message)) = message_anchors.peek() {
1486 if next_message.start.is_valid(buffer) {
1487 message_end = Some(next_message.start);
1488 break;
1489 } else {
1490 end_ix += 1;
1491 message_anchors.next();
1492 }
1493 }
1494 let message_end = message_end
1495 .unwrap_or(language::Anchor::MAX)
1496 .to_offset(buffer);
1497 return Some(Message {
1498 index_range: start_ix..end_ix,
1499 offset_range: message_start..message_end,
1500 id: message_anchor.id,
1501 anchor: message_anchor.start,
1502 role: metadata.role,
1503 sent_at: metadata.sent_at,
1504 status: metadata.status.clone(),
1505 });
1506 }
1507 None
1508 })
1509 }
1510
1511 fn save(
1512 &mut self,
1513 debounce: Option<Duration>,
1514 fs: Arc<dyn Fs>,
1515 cx: &mut ModelContext<Conversation>,
1516 ) {
1517 self.pending_save = cx.spawn(|this, mut cx| async move {
1518 if let Some(debounce) = debounce {
1519 cx.background().timer(debounce).await;
1520 }
1521
1522 let (old_path, summary) = this.read_with(&cx, |this, _| {
1523 let path = this.path.clone();
1524 let summary = if let Some(summary) = this.summary.as_ref() {
1525 if summary.done {
1526 Some(summary.text.clone())
1527 } else {
1528 None
1529 }
1530 } else {
1531 None
1532 };
1533 (path, summary)
1534 });
1535
1536 if let Some(summary) = summary {
1537 let conversation = this.read_with(&cx, |this, cx| this.serialize(cx));
1538 let path = if let Some(old_path) = old_path {
1539 old_path
1540 } else {
1541 let mut discriminant = 1;
1542 let mut new_path;
1543 loop {
1544 new_path = CONVERSATIONS_DIR.join(&format!(
1545 "{} - {}.zed.json",
1546 summary.trim(),
1547 discriminant
1548 ));
1549 if fs.is_file(&new_path).await {
1550 discriminant += 1;
1551 } else {
1552 break;
1553 }
1554 }
1555 new_path
1556 };
1557
1558 fs.create_dir(CONVERSATIONS_DIR.as_ref()).await?;
1559 fs.atomic_write(path.clone(), serde_json::to_string(&conversation).unwrap())
1560 .await?;
1561 this.update(&mut cx, |this, _| this.path = Some(path));
1562 }
1563
1564 Ok(())
1565 });
1566 }
1567}
1568
1569struct PendingCompletion {
1570 id: usize,
1571 _tasks: Vec<Task<()>>,
1572}
1573
1574enum ConversationEditorEvent {
1575 TabContentChanged,
1576}
1577
1578#[derive(Copy, Clone, Debug, PartialEq)]
1579struct ScrollPosition {
1580 offset_before_cursor: Vector2F,
1581 cursor: Anchor,
1582}
1583
1584struct ConversationEditor {
1585 conversation: ModelHandle<Conversation>,
1586 fs: Arc<dyn Fs>,
1587 editor: ViewHandle<Editor>,
1588 blocks: HashSet<BlockId>,
1589 scroll_position: Option<ScrollPosition>,
1590 _subscriptions: Vec<Subscription>,
1591}
1592
1593impl ConversationEditor {
1594 fn new(
1595 api_key: Rc<RefCell<Option<String>>>,
1596 language_registry: Arc<LanguageRegistry>,
1597 fs: Arc<dyn Fs>,
1598 cx: &mut ViewContext<Self>,
1599 ) -> Self {
1600 let conversation = cx.add_model(|cx| Conversation::new(api_key, language_registry, cx));
1601 Self::for_conversation(conversation, fs, cx)
1602 }
1603
1604 fn for_conversation(
1605 conversation: ModelHandle<Conversation>,
1606 fs: Arc<dyn Fs>,
1607 cx: &mut ViewContext<Self>,
1608 ) -> Self {
1609 let editor = cx.add_view(|cx| {
1610 let mut editor = Editor::for_buffer(conversation.read(cx).buffer.clone(), None, cx);
1611 editor.set_soft_wrap_mode(SoftWrap::EditorWidth, cx);
1612 editor.set_show_gutter(false, cx);
1613 editor
1614 });
1615
1616 let _subscriptions = vec![
1617 cx.observe(&conversation, |_, _, cx| cx.notify()),
1618 cx.subscribe(&conversation, Self::handle_conversation_event),
1619 cx.subscribe(&editor, Self::handle_editor_event),
1620 ];
1621
1622 let mut this = Self {
1623 conversation,
1624 editor,
1625 blocks: Default::default(),
1626 scroll_position: None,
1627 fs,
1628 _subscriptions,
1629 };
1630 this.update_message_headers(cx);
1631 this
1632 }
1633
1634 fn assist(&mut self, _: &Assist, cx: &mut ViewContext<Self>) {
1635 let cursors = self.cursors(cx);
1636
1637 let user_messages = self.conversation.update(cx, |conversation, cx| {
1638 let selected_messages = conversation
1639 .messages_for_offsets(cursors, cx)
1640 .into_iter()
1641 .map(|message| message.id)
1642 .collect();
1643 conversation.assist(selected_messages, cx)
1644 });
1645 let new_selections = user_messages
1646 .iter()
1647 .map(|message| {
1648 let cursor = message
1649 .start
1650 .to_offset(self.conversation.read(cx).buffer.read(cx));
1651 cursor..cursor
1652 })
1653 .collect::<Vec<_>>();
1654 if !new_selections.is_empty() {
1655 self.editor.update(cx, |editor, cx| {
1656 editor.change_selections(
1657 Some(Autoscroll::Strategy(AutoscrollStrategy::Fit)),
1658 cx,
1659 |selections| selections.select_ranges(new_selections),
1660 );
1661 });
1662 }
1663 }
1664
1665 fn cancel_last_assist(&mut self, _: &editor::Cancel, cx: &mut ViewContext<Self>) {
1666 if !self
1667 .conversation
1668 .update(cx, |conversation, _| conversation.cancel_last_assist())
1669 {
1670 cx.propagate_action();
1671 }
1672 }
1673
1674 fn cycle_message_role(&mut self, _: &CycleMessageRole, cx: &mut ViewContext<Self>) {
1675 let cursors = self.cursors(cx);
1676 self.conversation.update(cx, |conversation, cx| {
1677 let messages = conversation
1678 .messages_for_offsets(cursors, cx)
1679 .into_iter()
1680 .map(|message| message.id)
1681 .collect();
1682 conversation.cycle_message_roles(messages, cx)
1683 });
1684 }
1685
1686 fn cursors(&self, cx: &AppContext) -> Vec<usize> {
1687 let selections = self.editor.read(cx).selections.all::<usize>(cx);
1688 selections
1689 .into_iter()
1690 .map(|selection| selection.head())
1691 .collect()
1692 }
1693
1694 fn handle_conversation_event(
1695 &mut self,
1696 _: ModelHandle<Conversation>,
1697 event: &ConversationEvent,
1698 cx: &mut ViewContext<Self>,
1699 ) {
1700 match event {
1701 ConversationEvent::MessagesEdited => {
1702 self.update_message_headers(cx);
1703 self.conversation.update(cx, |conversation, cx| {
1704 conversation.save(Some(Duration::from_millis(500)), self.fs.clone(), cx);
1705 });
1706 }
1707 ConversationEvent::SummaryChanged => {
1708 cx.emit(ConversationEditorEvent::TabContentChanged);
1709 self.conversation.update(cx, |conversation, cx| {
1710 conversation.save(None, self.fs.clone(), cx);
1711 });
1712 }
1713 ConversationEvent::StreamedCompletion => {
1714 self.editor.update(cx, |editor, cx| {
1715 if let Some(scroll_position) = self.scroll_position {
1716 let snapshot = editor.snapshot(cx);
1717 let cursor_point = scroll_position.cursor.to_display_point(&snapshot);
1718 let scroll_top =
1719 cursor_point.row() as f32 - scroll_position.offset_before_cursor.y();
1720 editor.set_scroll_position(
1721 vec2f(scroll_position.offset_before_cursor.x(), scroll_top),
1722 cx,
1723 );
1724 }
1725 });
1726 }
1727 }
1728 }
1729
1730 fn handle_editor_event(
1731 &mut self,
1732 _: ViewHandle<Editor>,
1733 event: &editor::Event,
1734 cx: &mut ViewContext<Self>,
1735 ) {
1736 match event {
1737 editor::Event::ScrollPositionChanged { autoscroll, .. } => {
1738 let cursor_scroll_position = self.cursor_scroll_position(cx);
1739 if *autoscroll {
1740 self.scroll_position = cursor_scroll_position;
1741 } else if self.scroll_position != cursor_scroll_position {
1742 self.scroll_position = None;
1743 }
1744 }
1745 editor::Event::SelectionsChanged { .. } => {
1746 self.scroll_position = self.cursor_scroll_position(cx);
1747 }
1748 _ => {}
1749 }
1750 }
1751
1752 fn cursor_scroll_position(&self, cx: &mut ViewContext<Self>) -> Option<ScrollPosition> {
1753 self.editor.update(cx, |editor, cx| {
1754 let snapshot = editor.snapshot(cx);
1755 let cursor = editor.selections.newest_anchor().head();
1756 let cursor_row = cursor.to_display_point(&snapshot.display_snapshot).row() as f32;
1757 let scroll_position = editor
1758 .scroll_manager
1759 .anchor()
1760 .scroll_position(&snapshot.display_snapshot);
1761
1762 let scroll_bottom = scroll_position.y() + editor.visible_line_count().unwrap_or(0.);
1763 if (scroll_position.y()..scroll_bottom).contains(&cursor_row) {
1764 Some(ScrollPosition {
1765 cursor,
1766 offset_before_cursor: vec2f(
1767 scroll_position.x(),
1768 cursor_row - scroll_position.y(),
1769 ),
1770 })
1771 } else {
1772 None
1773 }
1774 })
1775 }
1776
1777 fn update_message_headers(&mut self, cx: &mut ViewContext<Self>) {
1778 self.editor.update(cx, |editor, cx| {
1779 let buffer = editor.buffer().read(cx).snapshot(cx);
1780 let excerpt_id = *buffer.as_singleton().unwrap().0;
1781 let old_blocks = std::mem::take(&mut self.blocks);
1782 let new_blocks = self
1783 .conversation
1784 .read(cx)
1785 .messages(cx)
1786 .map(|message| BlockProperties {
1787 position: buffer.anchor_in_excerpt(excerpt_id, message.anchor),
1788 height: 2,
1789 style: BlockStyle::Sticky,
1790 render: Arc::new({
1791 let conversation = self.conversation.clone();
1792 // let metadata = message.metadata.clone();
1793 // let message = message.clone();
1794 move |cx| {
1795 enum Sender {}
1796 enum ErrorTooltip {}
1797
1798 let theme = theme::current(cx);
1799 let style = &theme.assistant;
1800 let message_id = message.id;
1801 let sender = MouseEventHandler::<Sender, _>::new(
1802 message_id.0,
1803 cx,
1804 |state, _| match message.role {
1805 Role::User => {
1806 let style = style.user_sender.style_for(state);
1807 Label::new("You", style.text.clone())
1808 .contained()
1809 .with_style(style.container)
1810 }
1811 Role::Assistant => {
1812 let style = style.assistant_sender.style_for(state);
1813 Label::new("Assistant", style.text.clone())
1814 .contained()
1815 .with_style(style.container)
1816 }
1817 Role::System => {
1818 let style = style.system_sender.style_for(state);
1819 Label::new("System", style.text.clone())
1820 .contained()
1821 .with_style(style.container)
1822 }
1823 },
1824 )
1825 .with_cursor_style(CursorStyle::PointingHand)
1826 .on_down(MouseButton::Left, {
1827 let conversation = conversation.clone();
1828 move |_, _, cx| {
1829 conversation.update(cx, |conversation, cx| {
1830 conversation.cycle_message_roles(
1831 HashSet::from_iter(Some(message_id)),
1832 cx,
1833 )
1834 })
1835 }
1836 });
1837
1838 Flex::row()
1839 .with_child(sender.aligned())
1840 .with_child(
1841 Label::new(
1842 message.sent_at.format("%I:%M%P").to_string(),
1843 style.sent_at.text.clone(),
1844 )
1845 .contained()
1846 .with_style(style.sent_at.container)
1847 .aligned(),
1848 )
1849 .with_children(
1850 if let MessageStatus::Error(error) = &message.status {
1851 Some(
1852 Svg::new("icons/circle_x_mark_12.svg")
1853 .with_color(style.error_icon.color)
1854 .constrained()
1855 .with_width(style.error_icon.width)
1856 .contained()
1857 .with_style(style.error_icon.container)
1858 .with_tooltip::<ErrorTooltip>(
1859 message_id.0,
1860 error.to_string(),
1861 None,
1862 theme.tooltip.clone(),
1863 cx,
1864 )
1865 .aligned(),
1866 )
1867 } else {
1868 None
1869 },
1870 )
1871 .aligned()
1872 .left()
1873 .contained()
1874 .with_style(style.message_header)
1875 .into_any()
1876 }
1877 }),
1878 disposition: BlockDisposition::Above,
1879 })
1880 .collect::<Vec<_>>();
1881
1882 editor.remove_blocks(old_blocks, None, cx);
1883 let ids = editor.insert_blocks(new_blocks, None, cx);
1884 self.blocks = HashSet::from_iter(ids);
1885 });
1886 }
1887
1888 fn quote_selection(
1889 workspace: &mut Workspace,
1890 _: &QuoteSelection,
1891 cx: &mut ViewContext<Workspace>,
1892 ) {
1893 let Some(panel) = workspace.panel::<AssistantPanel>(cx) else {
1894 return;
1895 };
1896 let Some(editor) = workspace.active_item(cx).and_then(|item| item.downcast::<Editor>()) else {
1897 return;
1898 };
1899
1900 let text = editor.read_with(cx, |editor, cx| {
1901 let range = editor.selections.newest::<usize>(cx).range();
1902 let buffer = editor.buffer().read(cx).snapshot(cx);
1903 let start_language = buffer.language_at(range.start);
1904 let end_language = buffer.language_at(range.end);
1905 let language_name = if start_language == end_language {
1906 start_language.map(|language| language.name())
1907 } else {
1908 None
1909 };
1910 let language_name = language_name.as_deref().unwrap_or("").to_lowercase();
1911
1912 let selected_text = buffer.text_for_range(range).collect::<String>();
1913 if selected_text.is_empty() {
1914 None
1915 } else {
1916 Some(if language_name == "markdown" {
1917 selected_text
1918 .lines()
1919 .map(|line| format!("> {}", line))
1920 .collect::<Vec<_>>()
1921 .join("\n")
1922 } else {
1923 format!("```{language_name}\n{selected_text}\n```")
1924 })
1925 }
1926 });
1927
1928 // Activate the panel
1929 if !panel.read(cx).has_focus(cx) {
1930 workspace.toggle_panel_focus::<AssistantPanel>(cx);
1931 }
1932
1933 if let Some(text) = text {
1934 panel.update(cx, |panel, cx| {
1935 let conversation = panel
1936 .active_editor()
1937 .cloned()
1938 .unwrap_or_else(|| panel.new_conversation(cx));
1939 conversation.update(cx, |conversation, cx| {
1940 conversation
1941 .editor
1942 .update(cx, |editor, cx| editor.insert(&text, cx))
1943 });
1944 });
1945 }
1946 }
1947
1948 fn copy(&mut self, _: &editor::Copy, cx: &mut ViewContext<Self>) {
1949 let editor = self.editor.read(cx);
1950 let conversation = self.conversation.read(cx);
1951 if editor.selections.count() == 1 {
1952 let selection = editor.selections.newest::<usize>(cx);
1953 let mut copied_text = String::new();
1954 let mut spanned_messages = 0;
1955 for message in conversation.messages(cx) {
1956 if message.offset_range.start >= selection.range().end {
1957 break;
1958 } else if message.offset_range.end >= selection.range().start {
1959 let range = cmp::max(message.offset_range.start, selection.range().start)
1960 ..cmp::min(message.offset_range.end, selection.range().end);
1961 if !range.is_empty() {
1962 spanned_messages += 1;
1963 write!(&mut copied_text, "## {}\n\n", message.role).unwrap();
1964 for chunk in conversation.buffer.read(cx).text_for_range(range) {
1965 copied_text.push_str(&chunk);
1966 }
1967 copied_text.push('\n');
1968 }
1969 }
1970 }
1971
1972 if spanned_messages > 1 {
1973 cx.platform()
1974 .write_to_clipboard(ClipboardItem::new(copied_text));
1975 return;
1976 }
1977 }
1978
1979 cx.propagate_action();
1980 }
1981
1982 fn split(&mut self, _: &Split, cx: &mut ViewContext<Self>) {
1983 self.conversation.update(cx, |conversation, cx| {
1984 let selections = self.editor.read(cx).selections.disjoint_anchors();
1985 for selection in selections.into_iter() {
1986 let buffer = self.editor.read(cx).buffer().read(cx).snapshot(cx);
1987 let range = selection
1988 .map(|endpoint| endpoint.to_offset(&buffer))
1989 .range();
1990 conversation.split_message(range, cx);
1991 }
1992 });
1993 }
1994
1995 fn save(&mut self, _: &Save, cx: &mut ViewContext<Self>) {
1996 self.conversation.update(cx, |conversation, cx| {
1997 conversation.save(None, self.fs.clone(), cx)
1998 });
1999 }
2000
2001 fn cycle_model(&mut self, cx: &mut ViewContext<Self>) {
2002 self.conversation.update(cx, |conversation, cx| {
2003 let new_model = match conversation.model.as_str() {
2004 "gpt-4-0613" => "gpt-3.5-turbo-0613",
2005 _ => "gpt-4-0613",
2006 };
2007 conversation.set_model(new_model.into(), cx);
2008 });
2009 }
2010
2011 fn title(&self, cx: &AppContext) -> String {
2012 self.conversation
2013 .read(cx)
2014 .summary
2015 .as_ref()
2016 .map(|summary| summary.text.clone())
2017 .unwrap_or_else(|| "New Conversation".into())
2018 }
2019
2020 fn render_current_model(
2021 &self,
2022 style: &AssistantStyle,
2023 cx: &mut ViewContext<Self>,
2024 ) -> impl Element<Self> {
2025 enum Model {}
2026
2027 MouseEventHandler::<Model, _>::new(0, cx, |state, cx| {
2028 let style = style.model.style_for(state);
2029 Label::new(self.conversation.read(cx).model.clone(), style.text.clone())
2030 .contained()
2031 .with_style(style.container)
2032 })
2033 .with_cursor_style(CursorStyle::PointingHand)
2034 .on_click(MouseButton::Left, |_, this, cx| this.cycle_model(cx))
2035 }
2036
2037 fn render_remaining_tokens(
2038 &self,
2039 style: &AssistantStyle,
2040 cx: &mut ViewContext<Self>,
2041 ) -> Option<impl Element<Self>> {
2042 let remaining_tokens = self.conversation.read(cx).remaining_tokens()?;
2043 let remaining_tokens_style = if remaining_tokens <= 0 {
2044 &style.no_remaining_tokens
2045 } else {
2046 &style.remaining_tokens
2047 };
2048 Some(
2049 Label::new(
2050 remaining_tokens.to_string(),
2051 remaining_tokens_style.text.clone(),
2052 )
2053 .contained()
2054 .with_style(remaining_tokens_style.container),
2055 )
2056 }
2057}
2058
2059impl Entity for ConversationEditor {
2060 type Event = ConversationEditorEvent;
2061}
2062
2063impl View for ConversationEditor {
2064 fn ui_name() -> &'static str {
2065 "ConversationEditor"
2066 }
2067
2068 fn render(&mut self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
2069 let theme = &theme::current(cx).assistant;
2070 Stack::new()
2071 .with_child(
2072 ChildView::new(&self.editor, cx)
2073 .contained()
2074 .with_style(theme.container),
2075 )
2076 .with_child(
2077 Flex::row()
2078 .with_child(self.render_current_model(theme, cx))
2079 .with_children(self.render_remaining_tokens(theme, cx))
2080 .aligned()
2081 .top()
2082 .right(),
2083 )
2084 .into_any()
2085 }
2086
2087 fn focus_in(&mut self, _: gpui::AnyViewHandle, cx: &mut ViewContext<Self>) {
2088 if cx.is_self_focused() {
2089 cx.focus(&self.editor);
2090 }
2091 }
2092}
2093
2094#[derive(Clone, Debug)]
2095struct MessageAnchor {
2096 id: MessageId,
2097 start: language::Anchor,
2098}
2099
2100#[derive(Clone, Debug)]
2101pub struct Message {
2102 offset_range: Range<usize>,
2103 index_range: Range<usize>,
2104 id: MessageId,
2105 anchor: language::Anchor,
2106 role: Role,
2107 sent_at: DateTime<Local>,
2108 status: MessageStatus,
2109}
2110
2111impl Message {
2112 fn to_open_ai_message(&self, buffer: &Buffer) -> RequestMessage {
2113 let mut content = format!("[Message {}]\n", self.id.0).to_string();
2114 content.extend(buffer.text_for_range(self.offset_range.clone()));
2115 RequestMessage {
2116 role: self.role,
2117 content: content.trim_end().into(),
2118 }
2119 }
2120}
2121
2122async fn stream_completion(
2123 api_key: String,
2124 executor: Arc<Background>,
2125 mut request: OpenAIRequest,
2126) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
2127 request.stream = true;
2128
2129 let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
2130
2131 let json_data = serde_json::to_string(&request)?;
2132 let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
2133 .header("Content-Type", "application/json")
2134 .header("Authorization", format!("Bearer {}", api_key))
2135 .body(json_data)?
2136 .send_async()
2137 .await?;
2138
2139 let status = response.status();
2140 if status == StatusCode::OK {
2141 executor
2142 .spawn(async move {
2143 let mut lines = BufReader::new(response.body_mut()).lines();
2144
2145 fn parse_line(
2146 line: Result<String, io::Error>,
2147 ) -> Result<Option<OpenAIResponseStreamEvent>> {
2148 if let Some(data) = line?.strip_prefix("data: ") {
2149 let event = serde_json::from_str(&data)?;
2150 Ok(Some(event))
2151 } else {
2152 Ok(None)
2153 }
2154 }
2155
2156 while let Some(line) = lines.next().await {
2157 if let Some(event) = parse_line(line).transpose() {
2158 let done = event.as_ref().map_or(false, |event| {
2159 event
2160 .choices
2161 .last()
2162 .map_or(false, |choice| choice.finish_reason.is_some())
2163 });
2164 if tx.unbounded_send(event).is_err() {
2165 break;
2166 }
2167
2168 if done {
2169 break;
2170 }
2171 }
2172 }
2173
2174 anyhow::Ok(())
2175 })
2176 .detach();
2177
2178 Ok(rx)
2179 } else {
2180 let mut body = String::new();
2181 response.body_mut().read_to_string(&mut body).await?;
2182
2183 #[derive(Deserialize)]
2184 struct OpenAIResponse {
2185 error: OpenAIError,
2186 }
2187
2188 #[derive(Deserialize)]
2189 struct OpenAIError {
2190 message: String,
2191 }
2192
2193 match serde_json::from_str::<OpenAIResponse>(&body) {
2194 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
2195 "Failed to connect to OpenAI API: {}",
2196 response.error.message,
2197 )),
2198
2199 _ => Err(anyhow!(
2200 "Failed to connect to OpenAI API: {} {}",
2201 response.status(),
2202 body,
2203 )),
2204 }
2205 }
2206}
2207
2208#[cfg(test)]
2209mod tests {
2210 use super::*;
2211 use crate::MessageId;
2212 use gpui::AppContext;
2213
2214 #[gpui::test]
2215 fn test_inserting_and_removing_messages(cx: &mut AppContext) {
2216 let registry = Arc::new(LanguageRegistry::test());
2217 let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
2218 let buffer = conversation.read(cx).buffer.clone();
2219
2220 let message_1 = conversation.read(cx).message_anchors[0].clone();
2221 assert_eq!(
2222 messages(&conversation, cx),
2223 vec![(message_1.id, Role::User, 0..0)]
2224 );
2225
2226 let message_2 = conversation.update(cx, |conversation, cx| {
2227 conversation
2228 .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
2229 .unwrap()
2230 });
2231 assert_eq!(
2232 messages(&conversation, cx),
2233 vec![
2234 (message_1.id, Role::User, 0..1),
2235 (message_2.id, Role::Assistant, 1..1)
2236 ]
2237 );
2238
2239 buffer.update(cx, |buffer, cx| {
2240 buffer.edit([(0..0, "1"), (1..1, "2")], None, cx)
2241 });
2242 assert_eq!(
2243 messages(&conversation, cx),
2244 vec![
2245 (message_1.id, Role::User, 0..2),
2246 (message_2.id, Role::Assistant, 2..3)
2247 ]
2248 );
2249
2250 let message_3 = conversation.update(cx, |conversation, cx| {
2251 conversation
2252 .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
2253 .unwrap()
2254 });
2255 assert_eq!(
2256 messages(&conversation, cx),
2257 vec![
2258 (message_1.id, Role::User, 0..2),
2259 (message_2.id, Role::Assistant, 2..4),
2260 (message_3.id, Role::User, 4..4)
2261 ]
2262 );
2263
2264 let message_4 = conversation.update(cx, |conversation, cx| {
2265 conversation
2266 .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
2267 .unwrap()
2268 });
2269 assert_eq!(
2270 messages(&conversation, cx),
2271 vec![
2272 (message_1.id, Role::User, 0..2),
2273 (message_2.id, Role::Assistant, 2..4),
2274 (message_4.id, Role::User, 4..5),
2275 (message_3.id, Role::User, 5..5),
2276 ]
2277 );
2278
2279 buffer.update(cx, |buffer, cx| {
2280 buffer.edit([(4..4, "C"), (5..5, "D")], None, cx)
2281 });
2282 assert_eq!(
2283 messages(&conversation, cx),
2284 vec![
2285 (message_1.id, Role::User, 0..2),
2286 (message_2.id, Role::Assistant, 2..4),
2287 (message_4.id, Role::User, 4..6),
2288 (message_3.id, Role::User, 6..7),
2289 ]
2290 );
2291
2292 // Deleting across message boundaries merges the messages.
2293 buffer.update(cx, |buffer, cx| buffer.edit([(1..4, "")], None, cx));
2294 assert_eq!(
2295 messages(&conversation, cx),
2296 vec![
2297 (message_1.id, Role::User, 0..3),
2298 (message_3.id, Role::User, 3..4),
2299 ]
2300 );
2301
2302 // Undoing the deletion should also undo the merge.
2303 buffer.update(cx, |buffer, cx| buffer.undo(cx));
2304 assert_eq!(
2305 messages(&conversation, cx),
2306 vec![
2307 (message_1.id, Role::User, 0..2),
2308 (message_2.id, Role::Assistant, 2..4),
2309 (message_4.id, Role::User, 4..6),
2310 (message_3.id, Role::User, 6..7),
2311 ]
2312 );
2313
2314 // Redoing the deletion should also redo the merge.
2315 buffer.update(cx, |buffer, cx| buffer.redo(cx));
2316 assert_eq!(
2317 messages(&conversation, cx),
2318 vec![
2319 (message_1.id, Role::User, 0..3),
2320 (message_3.id, Role::User, 3..4),
2321 ]
2322 );
2323
2324 // Ensure we can still insert after a merged message.
2325 let message_5 = conversation.update(cx, |conversation, cx| {
2326 conversation
2327 .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
2328 .unwrap()
2329 });
2330 assert_eq!(
2331 messages(&conversation, cx),
2332 vec![
2333 (message_1.id, Role::User, 0..3),
2334 (message_5.id, Role::System, 3..4),
2335 (message_3.id, Role::User, 4..5)
2336 ]
2337 );
2338 }
2339
2340 #[gpui::test]
2341 fn test_message_splitting(cx: &mut AppContext) {
2342 let registry = Arc::new(LanguageRegistry::test());
2343 let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
2344 let buffer = conversation.read(cx).buffer.clone();
2345
2346 let message_1 = conversation.read(cx).message_anchors[0].clone();
2347 assert_eq!(
2348 messages(&conversation, cx),
2349 vec![(message_1.id, Role::User, 0..0)]
2350 );
2351
2352 buffer.update(cx, |buffer, cx| {
2353 buffer.edit([(0..0, "aaa\nbbb\nccc\nddd\n")], None, cx)
2354 });
2355
2356 let (_, message_2) =
2357 conversation.update(cx, |conversation, cx| conversation.split_message(3..3, cx));
2358 let message_2 = message_2.unwrap();
2359
2360 // We recycle newlines in the middle of a split message
2361 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\nddd\n");
2362 assert_eq!(
2363 messages(&conversation, cx),
2364 vec![
2365 (message_1.id, Role::User, 0..4),
2366 (message_2.id, Role::User, 4..16),
2367 ]
2368 );
2369
2370 let (_, message_3) =
2371 conversation.update(cx, |conversation, cx| conversation.split_message(3..3, cx));
2372 let message_3 = message_3.unwrap();
2373
2374 // We don't recycle newlines at the end of a split message
2375 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
2376 assert_eq!(
2377 messages(&conversation, cx),
2378 vec![
2379 (message_1.id, Role::User, 0..4),
2380 (message_3.id, Role::User, 4..5),
2381 (message_2.id, Role::User, 5..17),
2382 ]
2383 );
2384
2385 let (_, message_4) =
2386 conversation.update(cx, |conversation, cx| conversation.split_message(9..9, cx));
2387 let message_4 = message_4.unwrap();
2388 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
2389 assert_eq!(
2390 messages(&conversation, cx),
2391 vec![
2392 (message_1.id, Role::User, 0..4),
2393 (message_3.id, Role::User, 4..5),
2394 (message_2.id, Role::User, 5..9),
2395 (message_4.id, Role::User, 9..17),
2396 ]
2397 );
2398
2399 let (_, message_5) =
2400 conversation.update(cx, |conversation, cx| conversation.split_message(9..9, cx));
2401 let message_5 = message_5.unwrap();
2402 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\nddd\n");
2403 assert_eq!(
2404 messages(&conversation, cx),
2405 vec![
2406 (message_1.id, Role::User, 0..4),
2407 (message_3.id, Role::User, 4..5),
2408 (message_2.id, Role::User, 5..9),
2409 (message_4.id, Role::User, 9..10),
2410 (message_5.id, Role::User, 10..18),
2411 ]
2412 );
2413
2414 let (message_6, message_7) = conversation.update(cx, |conversation, cx| {
2415 conversation.split_message(14..16, cx)
2416 });
2417 let message_6 = message_6.unwrap();
2418 let message_7 = message_7.unwrap();
2419 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\ndd\nd\n");
2420 assert_eq!(
2421 messages(&conversation, cx),
2422 vec![
2423 (message_1.id, Role::User, 0..4),
2424 (message_3.id, Role::User, 4..5),
2425 (message_2.id, Role::User, 5..9),
2426 (message_4.id, Role::User, 9..10),
2427 (message_5.id, Role::User, 10..14),
2428 (message_6.id, Role::User, 14..17),
2429 (message_7.id, Role::User, 17..19),
2430 ]
2431 );
2432 }
2433
2434 #[gpui::test]
2435 fn test_messages_for_offsets(cx: &mut AppContext) {
2436 let registry = Arc::new(LanguageRegistry::test());
2437 let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
2438 let buffer = conversation.read(cx).buffer.clone();
2439
2440 let message_1 = conversation.read(cx).message_anchors[0].clone();
2441 assert_eq!(
2442 messages(&conversation, cx),
2443 vec![(message_1.id, Role::User, 0..0)]
2444 );
2445
2446 buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
2447 let message_2 = conversation
2448 .update(cx, |conversation, cx| {
2449 conversation.insert_message_after(message_1.id, Role::User, MessageStatus::Done, cx)
2450 })
2451 .unwrap();
2452 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx));
2453
2454 let message_3 = conversation
2455 .update(cx, |conversation, cx| {
2456 conversation.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
2457 })
2458 .unwrap();
2459 buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx));
2460
2461 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc");
2462 assert_eq!(
2463 messages(&conversation, cx),
2464 vec![
2465 (message_1.id, Role::User, 0..4),
2466 (message_2.id, Role::User, 4..8),
2467 (message_3.id, Role::User, 8..11)
2468 ]
2469 );
2470
2471 assert_eq!(
2472 message_ids_for_offsets(&conversation, &[0, 4, 9], cx),
2473 [message_1.id, message_2.id, message_3.id]
2474 );
2475 assert_eq!(
2476 message_ids_for_offsets(&conversation, &[0, 1, 11], cx),
2477 [message_1.id, message_3.id]
2478 );
2479
2480 let message_4 = conversation
2481 .update(cx, |conversation, cx| {
2482 conversation.insert_message_after(message_3.id, Role::User, MessageStatus::Done, cx)
2483 })
2484 .unwrap();
2485 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\n");
2486 assert_eq!(
2487 messages(&conversation, cx),
2488 vec![
2489 (message_1.id, Role::User, 0..4),
2490 (message_2.id, Role::User, 4..8),
2491 (message_3.id, Role::User, 8..12),
2492 (message_4.id, Role::User, 12..12)
2493 ]
2494 );
2495 assert_eq!(
2496 message_ids_for_offsets(&conversation, &[0, 4, 8, 12], cx),
2497 [message_1.id, message_2.id, message_3.id, message_4.id]
2498 );
2499
2500 fn message_ids_for_offsets(
2501 conversation: &ModelHandle<Conversation>,
2502 offsets: &[usize],
2503 cx: &AppContext,
2504 ) -> Vec<MessageId> {
2505 conversation
2506 .read(cx)
2507 .messages_for_offsets(offsets.iter().copied(), cx)
2508 .into_iter()
2509 .map(|message| message.id)
2510 .collect()
2511 }
2512 }
2513
2514 #[gpui::test]
2515 fn test_serialization(cx: &mut AppContext) {
2516 let registry = Arc::new(LanguageRegistry::test());
2517 let conversation =
2518 cx.add_model(|cx| Conversation::new(Default::default(), registry.clone(), cx));
2519 let buffer = conversation.read(cx).buffer.clone();
2520 let message_0 = conversation.read(cx).message_anchors[0].id;
2521 let message_1 = conversation.update(cx, |conversation, cx| {
2522 conversation
2523 .insert_message_after(message_0, Role::Assistant, MessageStatus::Done, cx)
2524 .unwrap()
2525 });
2526 let message_2 = conversation.update(cx, |conversation, cx| {
2527 conversation
2528 .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
2529 .unwrap()
2530 });
2531 buffer.update(cx, |buffer, cx| {
2532 buffer.edit([(0..0, "a"), (1..1, "b\nc")], None, cx);
2533 buffer.finalize_last_transaction();
2534 });
2535 let _message_3 = conversation.update(cx, |conversation, cx| {
2536 conversation
2537 .insert_message_after(message_2.id, Role::System, MessageStatus::Done, cx)
2538 .unwrap()
2539 });
2540 buffer.update(cx, |buffer, cx| buffer.undo(cx));
2541 assert_eq!(buffer.read(cx).text(), "a\nb\nc\n");
2542 assert_eq!(
2543 messages(&conversation, cx),
2544 [
2545 (message_0, Role::User, 0..2),
2546 (message_1.id, Role::Assistant, 2..6),
2547 (message_2.id, Role::System, 6..6),
2548 ]
2549 );
2550
2551 let deserialized_conversation = cx.add_model(|cx| {
2552 Conversation::deserialize(
2553 conversation.read(cx).serialize(cx),
2554 Default::default(),
2555 Default::default(),
2556 registry.clone(),
2557 cx,
2558 )
2559 });
2560 let deserialized_buffer = deserialized_conversation.read(cx).buffer.clone();
2561 assert_eq!(deserialized_buffer.read(cx).text(), "a\nb\nc\n");
2562 assert_eq!(
2563 messages(&deserialized_conversation, cx),
2564 [
2565 (message_0, Role::User, 0..2),
2566 (message_1.id, Role::Assistant, 2..6),
2567 (message_2.id, Role::System, 6..6),
2568 ]
2569 );
2570 }
2571
2572 fn messages(
2573 conversation: &ModelHandle<Conversation>,
2574 cx: &AppContext,
2575 ) -> Vec<(MessageId, Role, Range<usize>)> {
2576 conversation
2577 .read(cx)
2578 .messages(cx)
2579 .map(|message| (message.id, message.role, message.offset_range))
2580 .collect()
2581 }
2582}