1use crate::{
2 assistant_settings::{AssistantDockPosition, AssistantSettings},
3 MessageId, MessageMetadata, MessageStatus, OpenAIRequest, OpenAIResponseStreamEvent,
4 RequestMessage, Role, SavedConversation, SavedConversationMetadata, SavedMessage,
5};
6use anyhow::{anyhow, Result};
7use chrono::{DateTime, Local};
8use collections::{HashMap, HashSet};
9use editor::{
10 display_map::{BlockDisposition, BlockId, BlockProperties, BlockStyle, ToDisplayPoint},
11 scroll::autoscroll::{Autoscroll, AutoscrollStrategy},
12 Anchor, Editor, ToOffset,
13};
14use fs::Fs;
15use futures::{io::BufReader, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
16use gpui::{
17 actions,
18 elements::*,
19 executor::Background,
20 geometry::vector::{vec2f, Vector2F},
21 platform::{CursorStyle, MouseButton},
22 Action, AppContext, AsyncAppContext, ClipboardItem, Entity, ModelContext, ModelHandle,
23 Subscription, Task, View, ViewContext, ViewHandle, WeakViewHandle, WindowContext,
24};
25use isahc::{http::StatusCode, Request, RequestExt};
26use language::{language_settings::SoftWrap, Buffer, LanguageRegistry, ToOffset as _};
27use search::BufferSearchBar;
28use serde::Deserialize;
29use settings::SettingsStore;
30use std::{
31 cell::RefCell,
32 cmp, env,
33 fmt::Write,
34 io, iter,
35 ops::Range,
36 path::{Path, PathBuf},
37 rc::Rc,
38 sync::Arc,
39 time::Duration,
40};
41use theme::AssistantStyle;
42use util::{channel::ReleaseChannel, paths::CONVERSATIONS_DIR, post_inc, ResultExt, TryFutureExt};
43use workspace::{
44 dock::{DockPosition, Panel},
45 searchable::Direction,
46 Save, ToggleZoom, Toolbar, Workspace,
47};
48
49const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
50
51actions!(
52 assistant,
53 [
54 NewContext,
55 Assist,
56 Split,
57 CycleMessageRole,
58 QuoteSelection,
59 ToggleFocus,
60 ResetKey,
61 ]
62);
63
64pub fn init(cx: &mut AppContext) {
65 if *util::channel::RELEASE_CHANNEL == ReleaseChannel::Stable {
66 cx.update_default_global::<collections::CommandPaletteFilter, _, _>(move |filter, _cx| {
67 filter.filtered_namespaces.insert("assistant");
68 });
69 }
70
71 settings::register::<AssistantSettings>(cx);
72 cx.add_action(
73 |workspace: &mut Workspace, _: &NewContext, cx: &mut ViewContext<Workspace>| {
74 if let Some(this) = workspace.panel::<AssistantPanel>(cx) {
75 this.update(cx, |this, cx| {
76 this.new_conversation(cx);
77 })
78 }
79
80 workspace.focus_panel::<AssistantPanel>(cx);
81 },
82 );
83 cx.add_action(ConversationEditor::assist);
84 cx.capture_action(ConversationEditor::cancel_last_assist);
85 cx.capture_action(ConversationEditor::save);
86 cx.add_action(ConversationEditor::quote_selection);
87 cx.capture_action(ConversationEditor::copy);
88 cx.add_action(ConversationEditor::split);
89 cx.capture_action(ConversationEditor::cycle_message_role);
90 cx.add_action(AssistantPanel::save_api_key);
91 cx.add_action(AssistantPanel::reset_api_key);
92 cx.add_action(AssistantPanel::toggle_zoom);
93 cx.add_action(AssistantPanel::deploy);
94 cx.add_action(AssistantPanel::select_next_match);
95 cx.add_action(AssistantPanel::select_prev_match);
96 cx.add_action(AssistantPanel::handle_editor_cancel);
97 cx.add_action(
98 |workspace: &mut Workspace, _: &ToggleFocus, cx: &mut ViewContext<Workspace>| {
99 workspace.toggle_panel_focus::<AssistantPanel>(cx);
100 },
101 );
102}
103
104#[derive(Debug)]
105pub enum AssistantPanelEvent {
106 ZoomIn,
107 ZoomOut,
108 Focus,
109 Close,
110 DockPositionChanged,
111}
112
113pub struct AssistantPanel {
114 workspace: WeakViewHandle<Workspace>,
115 width: Option<f32>,
116 height: Option<f32>,
117 active_editor_index: Option<usize>,
118 prev_active_editor_index: Option<usize>,
119 editors: Vec<ViewHandle<ConversationEditor>>,
120 saved_conversations: Vec<SavedConversationMetadata>,
121 saved_conversations_list_state: UniformListState,
122 zoomed: bool,
123 has_focus: bool,
124 toolbar: ViewHandle<Toolbar>,
125 api_key: Rc<RefCell<Option<String>>>,
126 api_key_editor: Option<ViewHandle<Editor>>,
127 has_read_credentials: bool,
128 languages: Arc<LanguageRegistry>,
129 fs: Arc<dyn Fs>,
130 subscriptions: Vec<Subscription>,
131 _watch_saved_conversations: Task<Result<()>>,
132}
133
134impl AssistantPanel {
135 pub fn load(
136 workspace: WeakViewHandle<Workspace>,
137 cx: AsyncAppContext,
138 ) -> Task<Result<ViewHandle<Self>>> {
139 cx.spawn(|mut cx| async move {
140 let fs = workspace.read_with(&cx, |workspace, _| workspace.app_state().fs.clone())?;
141 let saved_conversations = SavedConversationMetadata::list(fs.clone())
142 .await
143 .log_err()
144 .unwrap_or_default();
145
146 // TODO: deserialize state.
147 let workspace_handle = workspace.clone();
148 workspace.update(&mut cx, |workspace, cx| {
149 cx.add_view::<Self, _>(|cx| {
150 const CONVERSATION_WATCH_DURATION: Duration = Duration::from_millis(100);
151 let _watch_saved_conversations = cx.spawn(move |this, mut cx| async move {
152 let mut events = fs
153 .watch(&CONVERSATIONS_DIR, CONVERSATION_WATCH_DURATION)
154 .await;
155 while events.next().await.is_some() {
156 let saved_conversations = SavedConversationMetadata::list(fs.clone())
157 .await
158 .log_err()
159 .unwrap_or_default();
160 this.update(&mut cx, |this, _| {
161 this.saved_conversations = saved_conversations
162 })
163 .ok();
164 }
165
166 anyhow::Ok(())
167 });
168
169 let toolbar = cx.add_view(|cx| {
170 let mut toolbar = Toolbar::new(None);
171 toolbar.set_can_navigate(false, cx);
172 toolbar.add_item(cx.add_view(|cx| BufferSearchBar::new(cx)), cx);
173 toolbar
174 });
175 let mut this = Self {
176 workspace: workspace_handle,
177 active_editor_index: Default::default(),
178 prev_active_editor_index: Default::default(),
179 editors: Default::default(),
180 saved_conversations,
181 saved_conversations_list_state: Default::default(),
182 zoomed: false,
183 has_focus: false,
184 toolbar,
185 api_key: Rc::new(RefCell::new(None)),
186 api_key_editor: None,
187 has_read_credentials: false,
188 languages: workspace.app_state().languages.clone(),
189 fs: workspace.app_state().fs.clone(),
190 width: None,
191 height: None,
192 subscriptions: Default::default(),
193 _watch_saved_conversations,
194 };
195
196 let mut old_dock_position = this.position(cx);
197 this.subscriptions =
198 vec![cx.observe_global::<SettingsStore, _>(move |this, cx| {
199 let new_dock_position = this.position(cx);
200 if new_dock_position != old_dock_position {
201 old_dock_position = new_dock_position;
202 cx.emit(AssistantPanelEvent::DockPositionChanged);
203 }
204 })];
205
206 this
207 })
208 })
209 })
210 }
211
212 fn new_conversation(&mut self, cx: &mut ViewContext<Self>) -> ViewHandle<ConversationEditor> {
213 let editor = cx.add_view(|cx| {
214 ConversationEditor::new(
215 self.api_key.clone(),
216 self.languages.clone(),
217 self.fs.clone(),
218 cx,
219 )
220 });
221 self.add_conversation(editor.clone(), cx);
222 editor
223 }
224
225 fn add_conversation(
226 &mut self,
227 editor: ViewHandle<ConversationEditor>,
228 cx: &mut ViewContext<Self>,
229 ) {
230 self.subscriptions
231 .push(cx.subscribe(&editor, Self::handle_conversation_editor_event));
232
233 let conversation = editor.read(cx).conversation.clone();
234 self.subscriptions
235 .push(cx.observe(&conversation, |_, _, cx| cx.notify()));
236
237 let index = self.editors.len();
238 self.editors.push(editor);
239 self.set_active_editor_index(Some(index), cx);
240 }
241
242 fn set_active_editor_index(&mut self, index: Option<usize>, cx: &mut ViewContext<Self>) {
243 self.prev_active_editor_index = self.active_editor_index;
244 self.active_editor_index = index;
245 if let Some(editor) = self.active_editor() {
246 let editor = editor.read(cx).editor.clone();
247 self.toolbar.update(cx, |toolbar, cx| {
248 toolbar.set_active_item(Some(&editor), cx);
249 });
250 if self.has_focus(cx) {
251 cx.focus(&editor);
252 }
253 } else {
254 self.toolbar.update(cx, |toolbar, cx| {
255 toolbar.set_active_item(None, cx);
256 });
257 }
258
259 cx.notify();
260 }
261
262 fn handle_conversation_editor_event(
263 &mut self,
264 _: ViewHandle<ConversationEditor>,
265 event: &ConversationEditorEvent,
266 cx: &mut ViewContext<Self>,
267 ) {
268 match event {
269 ConversationEditorEvent::TabContentChanged => cx.notify(),
270 }
271 }
272
273 fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
274 if let Some(api_key) = self
275 .api_key_editor
276 .as_ref()
277 .map(|editor| editor.read(cx).text(cx))
278 {
279 if !api_key.is_empty() {
280 cx.platform()
281 .write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
282 .log_err();
283 *self.api_key.borrow_mut() = Some(api_key);
284 self.api_key_editor.take();
285 cx.focus_self();
286 cx.notify();
287 }
288 } else {
289 cx.propagate_action();
290 }
291 }
292
293 fn reset_api_key(&mut self, _: &ResetKey, cx: &mut ViewContext<Self>) {
294 cx.platform().delete_credentials(OPENAI_API_URL).log_err();
295 self.api_key.take();
296 self.api_key_editor = Some(build_api_key_editor(cx));
297 cx.focus_self();
298 cx.notify();
299 }
300
301 fn toggle_zoom(&mut self, _: &workspace::ToggleZoom, cx: &mut ViewContext<Self>) {
302 if self.zoomed {
303 cx.emit(AssistantPanelEvent::ZoomOut)
304 } else {
305 cx.emit(AssistantPanelEvent::ZoomIn)
306 }
307 }
308
309 fn deploy(&mut self, action: &search::buffer_search::Deploy, cx: &mut ViewContext<Self>) {
310 if let Some(search_bar) = self.toolbar.read(cx).item_of_type::<BufferSearchBar>() {
311 if search_bar.update(cx, |search_bar, cx| search_bar.show(action.focus, true, cx)) {
312 return;
313 }
314 }
315 cx.propagate_action();
316 }
317
318 fn handle_editor_cancel(&mut self, _: &editor::Cancel, cx: &mut ViewContext<Self>) {
319 if let Some(search_bar) = self.toolbar.read(cx).item_of_type::<BufferSearchBar>() {
320 if !search_bar.read(cx).is_dismissed() {
321 search_bar.update(cx, |search_bar, cx| {
322 search_bar.dismiss(&Default::default(), cx)
323 });
324 return;
325 }
326 }
327 cx.propagate_action();
328 }
329
330 fn select_next_match(&mut self, _: &search::SelectNextMatch, cx: &mut ViewContext<Self>) {
331 if let Some(search_bar) = self.toolbar.read(cx).item_of_type::<BufferSearchBar>() {
332 search_bar.update(cx, |bar, cx| bar.select_match(Direction::Next, cx));
333 }
334 }
335
336 fn select_prev_match(&mut self, _: &search::SelectPrevMatch, cx: &mut ViewContext<Self>) {
337 if let Some(search_bar) = self.toolbar.read(cx).item_of_type::<BufferSearchBar>() {
338 search_bar.update(cx, |bar, cx| bar.select_match(Direction::Prev, cx));
339 }
340 }
341
342 fn active_editor(&self) -> Option<&ViewHandle<ConversationEditor>> {
343 self.editors.get(self.active_editor_index?)
344 }
345
346 fn render_hamburger_button(cx: &mut ViewContext<Self>) -> impl Element<Self> {
347 enum ListConversations {}
348 let theme = theme::current(cx);
349 MouseEventHandler::<ListConversations, _>::new(0, cx, |state, _| {
350 let style = theme.assistant.hamburger_button.style_for(state);
351 Svg::for_style(style.icon.clone())
352 .contained()
353 .with_style(style.container)
354 })
355 .with_cursor_style(CursorStyle::PointingHand)
356 .on_click(MouseButton::Left, |_, this: &mut Self, cx| {
357 if this.active_editor().is_some() {
358 this.set_active_editor_index(None, cx);
359 } else {
360 this.set_active_editor_index(this.prev_active_editor_index, cx);
361 }
362 })
363 }
364
365 fn render_editor_tools(&self, cx: &mut ViewContext<Self>) -> Vec<AnyElement<Self>> {
366 if self.active_editor().is_some() {
367 vec![
368 Self::render_split_button(cx).into_any(),
369 Self::render_quote_button(cx).into_any(),
370 Self::render_assist_button(cx).into_any(),
371 ]
372 } else {
373 Default::default()
374 }
375 }
376
377 fn render_split_button(cx: &mut ViewContext<Self>) -> impl Element<Self> {
378 let theme = theme::current(cx);
379 let tooltip_style = theme::current(cx).tooltip.clone();
380 MouseEventHandler::<Split, _>::new(0, cx, |state, _| {
381 let style = theme.assistant.split_button.style_for(state);
382 Svg::for_style(style.icon.clone())
383 .contained()
384 .with_style(style.container)
385 })
386 .with_cursor_style(CursorStyle::PointingHand)
387 .on_click(MouseButton::Left, |_, this: &mut Self, cx| {
388 if let Some(active_editor) = this.active_editor() {
389 active_editor.update(cx, |editor, cx| editor.split(&Default::default(), cx));
390 }
391 })
392 .with_tooltip::<Split>(
393 1,
394 "Split Message".into(),
395 Some(Box::new(Split)),
396 tooltip_style,
397 cx,
398 )
399 }
400
401 fn render_assist_button(cx: &mut ViewContext<Self>) -> impl Element<Self> {
402 let theme = theme::current(cx);
403 let tooltip_style = theme::current(cx).tooltip.clone();
404 MouseEventHandler::<Assist, _>::new(0, cx, |state, _| {
405 let style = theme.assistant.assist_button.style_for(state);
406 Svg::for_style(style.icon.clone())
407 .contained()
408 .with_style(style.container)
409 })
410 .with_cursor_style(CursorStyle::PointingHand)
411 .on_click(MouseButton::Left, |_, this: &mut Self, cx| {
412 if let Some(active_editor) = this.active_editor() {
413 active_editor.update(cx, |editor, cx| editor.assist(&Default::default(), cx));
414 }
415 })
416 .with_tooltip::<Assist>(
417 1,
418 "Assist".into(),
419 Some(Box::new(Assist)),
420 tooltip_style,
421 cx,
422 )
423 }
424
425 fn render_quote_button(cx: &mut ViewContext<Self>) -> impl Element<Self> {
426 let theme = theme::current(cx);
427 let tooltip_style = theme::current(cx).tooltip.clone();
428 MouseEventHandler::<QuoteSelection, _>::new(0, cx, |state, _| {
429 let style = theme.assistant.quote_button.style_for(state);
430 Svg::for_style(style.icon.clone())
431 .contained()
432 .with_style(style.container)
433 })
434 .with_cursor_style(CursorStyle::PointingHand)
435 .on_click(MouseButton::Left, |_, this: &mut Self, cx| {
436 if let Some(workspace) = this.workspace.upgrade(cx) {
437 cx.window_context().defer(move |cx| {
438 workspace.update(cx, |workspace, cx| {
439 ConversationEditor::quote_selection(workspace, &Default::default(), cx)
440 });
441 });
442 }
443 })
444 .with_tooltip::<QuoteSelection>(
445 1,
446 "Assist".into(),
447 Some(Box::new(QuoteSelection)),
448 tooltip_style,
449 cx,
450 )
451 }
452
453 fn render_plus_button(cx: &mut ViewContext<Self>) -> impl Element<Self> {
454 enum AddConversation {}
455 let theme = theme::current(cx);
456 MouseEventHandler::<AddConversation, _>::new(0, cx, |state, _| {
457 let style = theme.assistant.plus_button.style_for(state);
458 Svg::for_style(style.icon.clone())
459 .contained()
460 .with_style(style.container)
461 })
462 .with_cursor_style(CursorStyle::PointingHand)
463 .on_click(MouseButton::Left, |_, this: &mut Self, cx| {
464 this.new_conversation(cx);
465 })
466 }
467
468 fn render_zoom_button(&self, cx: &mut ViewContext<Self>) -> impl Element<Self> {
469 enum ToggleZoomButton {}
470
471 let theme = theme::current(cx);
472 let style = if self.zoomed {
473 &theme.assistant.zoom_out_button
474 } else {
475 &theme.assistant.zoom_in_button
476 };
477
478 MouseEventHandler::<ToggleZoomButton, _>::new(0, cx, |state, _| {
479 let style = style.style_for(state);
480 Svg::for_style(style.icon.clone())
481 .contained()
482 .with_style(style.container)
483 })
484 .with_cursor_style(CursorStyle::PointingHand)
485 .on_click(MouseButton::Left, |_, this, cx| {
486 this.toggle_zoom(&ToggleZoom, cx);
487 })
488 }
489
490 fn render_saved_conversation(
491 &mut self,
492 index: usize,
493 cx: &mut ViewContext<Self>,
494 ) -> impl Element<Self> {
495 let conversation = &self.saved_conversations[index];
496 let path = conversation.path.clone();
497 MouseEventHandler::<SavedConversationMetadata, _>::new(index, cx, move |state, cx| {
498 let style = &theme::current(cx).assistant.saved_conversation;
499 Flex::row()
500 .with_child(
501 Label::new(
502 conversation.mtime.format("%F %I:%M%p").to_string(),
503 style.saved_at.text.clone(),
504 )
505 .aligned()
506 .contained()
507 .with_style(style.saved_at.container),
508 )
509 .with_child(
510 Label::new(conversation.title.clone(), style.title.text.clone())
511 .aligned()
512 .contained()
513 .with_style(style.title.container),
514 )
515 .contained()
516 .with_style(*style.container.style_for(state))
517 })
518 .with_cursor_style(CursorStyle::PointingHand)
519 .on_click(MouseButton::Left, move |_, this, cx| {
520 this.open_conversation(path.clone(), cx)
521 .detach_and_log_err(cx)
522 })
523 }
524
525 fn open_conversation(&mut self, path: PathBuf, cx: &mut ViewContext<Self>) -> Task<Result<()>> {
526 if let Some(ix) = self.editor_index_for_path(&path, cx) {
527 self.set_active_editor_index(Some(ix), cx);
528 return Task::ready(Ok(()));
529 }
530
531 let fs = self.fs.clone();
532 let api_key = self.api_key.clone();
533 let languages = self.languages.clone();
534 cx.spawn(|this, mut cx| async move {
535 let saved_conversation = fs.load(&path).await?;
536 let saved_conversation = serde_json::from_str(&saved_conversation)?;
537 let conversation = cx.add_model(|cx| {
538 Conversation::deserialize(saved_conversation, path.clone(), api_key, languages, cx)
539 });
540 this.update(&mut cx, |this, cx| {
541 // If, by the time we've loaded the conversation, the user has already opened
542 // the same conversation, we don't want to open it again.
543 if let Some(ix) = this.editor_index_for_path(&path, cx) {
544 this.set_active_editor_index(Some(ix), cx);
545 } else {
546 let editor = cx
547 .add_view(|cx| ConversationEditor::for_conversation(conversation, fs, cx));
548 this.add_conversation(editor, cx);
549 }
550 })?;
551 Ok(())
552 })
553 }
554
555 fn editor_index_for_path(&self, path: &Path, cx: &AppContext) -> Option<usize> {
556 self.editors
557 .iter()
558 .position(|editor| editor.read(cx).conversation.read(cx).path.as_deref() == Some(path))
559 }
560}
561
562fn build_api_key_editor(cx: &mut ViewContext<AssistantPanel>) -> ViewHandle<Editor> {
563 cx.add_view(|cx| {
564 let mut editor = Editor::single_line(
565 Some(Arc::new(|theme| theme.assistant.api_key_editor.clone())),
566 cx,
567 );
568 editor.set_placeholder_text("sk-000000000000000000000000000000000000000000000000", cx);
569 editor
570 })
571}
572
573impl Entity for AssistantPanel {
574 type Event = AssistantPanelEvent;
575}
576
577impl View for AssistantPanel {
578 fn ui_name() -> &'static str {
579 "AssistantPanel"
580 }
581
582 fn render(&mut self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
583 let theme = &theme::current(cx);
584 let style = &theme.assistant;
585 if let Some(api_key_editor) = self.api_key_editor.as_ref() {
586 Flex::column()
587 .with_child(
588 Text::new(
589 "Paste your OpenAI API key and press Enter to use the assistant",
590 style.api_key_prompt.text.clone(),
591 )
592 .aligned(),
593 )
594 .with_child(
595 ChildView::new(api_key_editor, cx)
596 .contained()
597 .with_style(style.api_key_editor.container)
598 .aligned(),
599 )
600 .contained()
601 .with_style(style.api_key_prompt.container)
602 .aligned()
603 .into_any()
604 } else {
605 let title = self.active_editor().map(|editor| {
606 Label::new(editor.read(cx).title(cx), style.title.text.clone())
607 .contained()
608 .with_style(style.title.container)
609 .aligned()
610 .left()
611 .flex(1., false)
612 });
613
614 Flex::column()
615 .with_child(
616 Flex::row()
617 .with_child(Self::render_hamburger_button(cx).aligned())
618 .with_children(title)
619 .with_children(
620 self.render_editor_tools(cx)
621 .into_iter()
622 .map(|tool| tool.aligned().flex_float()),
623 )
624 .with_child(Self::render_plus_button(cx).aligned().flex_float())
625 .with_child(self.render_zoom_button(cx).aligned())
626 .contained()
627 .with_style(theme.workspace.tab_bar.container)
628 .expanded()
629 .constrained()
630 .with_height(theme.workspace.tab_bar.height),
631 )
632 .with_children(if self.toolbar.read(cx).hidden() {
633 None
634 } else {
635 Some(ChildView::new(&self.toolbar, cx).expanded())
636 })
637 .with_child(if let Some(editor) = self.active_editor() {
638 ChildView::new(editor, cx).flex(1., true).into_any()
639 } else {
640 UniformList::new(
641 self.saved_conversations_list_state.clone(),
642 self.saved_conversations.len(),
643 cx,
644 |this, range, items, cx| {
645 for ix in range {
646 items.push(this.render_saved_conversation(ix, cx).into_any());
647 }
648 },
649 )
650 .flex(1., true)
651 .into_any()
652 })
653 .into_any()
654 }
655 }
656
657 fn focus_in(&mut self, _: gpui::AnyViewHandle, cx: &mut ViewContext<Self>) {
658 self.has_focus = true;
659 self.toolbar
660 .update(cx, |toolbar, cx| toolbar.focus_changed(true, cx));
661 if cx.is_self_focused() {
662 if let Some(editor) = self.active_editor() {
663 cx.focus(editor);
664 } else if let Some(api_key_editor) = self.api_key_editor.as_ref() {
665 cx.focus(api_key_editor);
666 }
667 }
668 }
669
670 fn focus_out(&mut self, _: gpui::AnyViewHandle, cx: &mut ViewContext<Self>) {
671 self.has_focus = false;
672 self.toolbar
673 .update(cx, |toolbar, cx| toolbar.focus_changed(false, cx));
674 }
675}
676
677impl Panel for AssistantPanel {
678 fn position(&self, cx: &WindowContext) -> DockPosition {
679 match settings::get::<AssistantSettings>(cx).dock {
680 AssistantDockPosition::Left => DockPosition::Left,
681 AssistantDockPosition::Bottom => DockPosition::Bottom,
682 AssistantDockPosition::Right => DockPosition::Right,
683 }
684 }
685
686 fn position_is_valid(&self, _: DockPosition) -> bool {
687 true
688 }
689
690 fn set_position(&mut self, position: DockPosition, cx: &mut ViewContext<Self>) {
691 settings::update_settings_file::<AssistantSettings>(self.fs.clone(), cx, move |settings| {
692 let dock = match position {
693 DockPosition::Left => AssistantDockPosition::Left,
694 DockPosition::Bottom => AssistantDockPosition::Bottom,
695 DockPosition::Right => AssistantDockPosition::Right,
696 };
697 settings.dock = Some(dock);
698 });
699 }
700
701 fn size(&self, cx: &WindowContext) -> f32 {
702 let settings = settings::get::<AssistantSettings>(cx);
703 match self.position(cx) {
704 DockPosition::Left | DockPosition::Right => {
705 self.width.unwrap_or_else(|| settings.default_width)
706 }
707 DockPosition::Bottom => self.height.unwrap_or_else(|| settings.default_height),
708 }
709 }
710
711 fn set_size(&mut self, size: f32, cx: &mut ViewContext<Self>) {
712 match self.position(cx) {
713 DockPosition::Left | DockPosition::Right => self.width = Some(size),
714 DockPosition::Bottom => self.height = Some(size),
715 }
716 cx.notify();
717 }
718
719 fn should_zoom_in_on_event(event: &AssistantPanelEvent) -> bool {
720 matches!(event, AssistantPanelEvent::ZoomIn)
721 }
722
723 fn should_zoom_out_on_event(event: &AssistantPanelEvent) -> bool {
724 matches!(event, AssistantPanelEvent::ZoomOut)
725 }
726
727 fn is_zoomed(&self, _: &WindowContext) -> bool {
728 self.zoomed
729 }
730
731 fn set_zoomed(&mut self, zoomed: bool, cx: &mut ViewContext<Self>) {
732 self.zoomed = zoomed;
733 cx.notify();
734 }
735
736 fn set_active(&mut self, active: bool, cx: &mut ViewContext<Self>) {
737 if active {
738 if self.api_key.borrow().is_none() && !self.has_read_credentials {
739 self.has_read_credentials = true;
740 let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
741 Some(api_key)
742 } else if let Some((_, api_key)) = cx
743 .platform()
744 .read_credentials(OPENAI_API_URL)
745 .log_err()
746 .flatten()
747 {
748 String::from_utf8(api_key).log_err()
749 } else {
750 None
751 };
752 if let Some(api_key) = api_key {
753 *self.api_key.borrow_mut() = Some(api_key);
754 } else if self.api_key_editor.is_none() {
755 self.api_key_editor = Some(build_api_key_editor(cx));
756 cx.notify();
757 }
758 }
759
760 if self.editors.is_empty() {
761 self.new_conversation(cx);
762 }
763 }
764 }
765
766 fn icon_path(&self) -> &'static str {
767 "icons/robot_14.svg"
768 }
769
770 fn icon_tooltip(&self) -> (String, Option<Box<dyn Action>>) {
771 ("Assistant Panel".into(), Some(Box::new(ToggleFocus)))
772 }
773
774 fn should_change_position_on_event(event: &Self::Event) -> bool {
775 matches!(event, AssistantPanelEvent::DockPositionChanged)
776 }
777
778 fn should_activate_on_event(_: &Self::Event) -> bool {
779 false
780 }
781
782 fn should_close_on_event(event: &AssistantPanelEvent) -> bool {
783 matches!(event, AssistantPanelEvent::Close)
784 }
785
786 fn has_focus(&self, _: &WindowContext) -> bool {
787 self.has_focus
788 }
789
790 fn is_focus_event(event: &Self::Event) -> bool {
791 matches!(event, AssistantPanelEvent::Focus)
792 }
793}
794
795enum ConversationEvent {
796 MessagesEdited,
797 SummaryChanged,
798 StreamedCompletion,
799}
800
801#[derive(Default)]
802struct Summary {
803 text: String,
804 done: bool,
805}
806
807struct Conversation {
808 buffer: ModelHandle<Buffer>,
809 message_anchors: Vec<MessageAnchor>,
810 messages_metadata: HashMap<MessageId, MessageMetadata>,
811 next_message_id: MessageId,
812 summary: Option<Summary>,
813 pending_summary: Task<Option<()>>,
814 completion_count: usize,
815 pending_completions: Vec<PendingCompletion>,
816 model: String,
817 token_count: Option<usize>,
818 max_token_count: usize,
819 pending_token_count: Task<Option<()>>,
820 api_key: Rc<RefCell<Option<String>>>,
821 pending_save: Task<Result<()>>,
822 path: Option<PathBuf>,
823 _subscriptions: Vec<Subscription>,
824}
825
826impl Entity for Conversation {
827 type Event = ConversationEvent;
828}
829
830impl Conversation {
831 fn new(
832 api_key: Rc<RefCell<Option<String>>>,
833 language_registry: Arc<LanguageRegistry>,
834 cx: &mut ModelContext<Self>,
835 ) -> Self {
836 let model = "gpt-3.5-turbo-0613";
837 let markdown = language_registry.language_for_name("Markdown");
838 let buffer = cx.add_model(|cx| {
839 let mut buffer = Buffer::new(0, "", cx);
840 buffer.set_language_registry(language_registry);
841 cx.spawn_weak(|buffer, mut cx| async move {
842 let markdown = markdown.await?;
843 let buffer = buffer
844 .upgrade(&cx)
845 .ok_or_else(|| anyhow!("buffer was dropped"))?;
846 buffer.update(&mut cx, |buffer, cx| {
847 buffer.set_language(Some(markdown), cx)
848 });
849 anyhow::Ok(())
850 })
851 .detach_and_log_err(cx);
852 buffer
853 });
854
855 let mut this = Self {
856 message_anchors: Default::default(),
857 messages_metadata: Default::default(),
858 next_message_id: Default::default(),
859 summary: None,
860 pending_summary: Task::ready(None),
861 completion_count: Default::default(),
862 pending_completions: Default::default(),
863 token_count: None,
864 max_token_count: tiktoken_rs::model::get_context_size(model),
865 pending_token_count: Task::ready(None),
866 model: model.into(),
867 _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
868 pending_save: Task::ready(Ok(())),
869 path: None,
870 api_key,
871 buffer,
872 };
873 let message = MessageAnchor {
874 id: MessageId(post_inc(&mut this.next_message_id.0)),
875 start: language::Anchor::MIN,
876 };
877 this.message_anchors.push(message.clone());
878 this.messages_metadata.insert(
879 message.id,
880 MessageMetadata {
881 role: Role::User,
882 sent_at: Local::now(),
883 status: MessageStatus::Done,
884 },
885 );
886
887 this.count_remaining_tokens(cx);
888 this
889 }
890
891 fn serialize(&self, cx: &AppContext) -> SavedConversation {
892 SavedConversation {
893 zed: "conversation".into(),
894 version: SavedConversation::VERSION.into(),
895 text: self.buffer.read(cx).text(),
896 message_metadata: self.messages_metadata.clone(),
897 messages: self
898 .messages(cx)
899 .map(|message| SavedMessage {
900 id: message.id,
901 start: message.offset_range.start,
902 })
903 .collect(),
904 summary: self
905 .summary
906 .as_ref()
907 .map(|summary| summary.text.clone())
908 .unwrap_or_default(),
909 model: self.model.clone(),
910 }
911 }
912
913 fn deserialize(
914 saved_conversation: SavedConversation,
915 path: PathBuf,
916 api_key: Rc<RefCell<Option<String>>>,
917 language_registry: Arc<LanguageRegistry>,
918 cx: &mut ModelContext<Self>,
919 ) -> Self {
920 let model = saved_conversation.model;
921 let markdown = language_registry.language_for_name("Markdown");
922 let mut message_anchors = Vec::new();
923 let mut next_message_id = MessageId(0);
924 let buffer = cx.add_model(|cx| {
925 let mut buffer = Buffer::new(0, saved_conversation.text, cx);
926 for message in saved_conversation.messages {
927 message_anchors.push(MessageAnchor {
928 id: message.id,
929 start: buffer.anchor_before(message.start),
930 });
931 next_message_id = cmp::max(next_message_id, MessageId(message.id.0 + 1));
932 }
933 buffer.set_language_registry(language_registry);
934 cx.spawn_weak(|buffer, mut cx| async move {
935 let markdown = markdown.await?;
936 let buffer = buffer
937 .upgrade(&cx)
938 .ok_or_else(|| anyhow!("buffer was dropped"))?;
939 buffer.update(&mut cx, |buffer, cx| {
940 buffer.set_language(Some(markdown), cx)
941 });
942 anyhow::Ok(())
943 })
944 .detach_and_log_err(cx);
945 buffer
946 });
947
948 let mut this = Self {
949 message_anchors,
950 messages_metadata: saved_conversation.message_metadata,
951 next_message_id,
952 summary: Some(Summary {
953 text: saved_conversation.summary,
954 done: true,
955 }),
956 pending_summary: Task::ready(None),
957 completion_count: Default::default(),
958 pending_completions: Default::default(),
959 token_count: None,
960 max_token_count: tiktoken_rs::model::get_context_size(&model),
961 pending_token_count: Task::ready(None),
962 model,
963 _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
964 pending_save: Task::ready(Ok(())),
965 path: Some(path),
966 api_key,
967 buffer,
968 };
969 this.count_remaining_tokens(cx);
970 this
971 }
972
973 fn handle_buffer_event(
974 &mut self,
975 _: ModelHandle<Buffer>,
976 event: &language::Event,
977 cx: &mut ModelContext<Self>,
978 ) {
979 match event {
980 language::Event::Edited => {
981 self.count_remaining_tokens(cx);
982 cx.emit(ConversationEvent::MessagesEdited);
983 }
984 _ => {}
985 }
986 }
987
988 fn count_remaining_tokens(&mut self, cx: &mut ModelContext<Self>) {
989 let messages = self
990 .messages(cx)
991 .into_iter()
992 .filter_map(|message| {
993 Some(tiktoken_rs::ChatCompletionRequestMessage {
994 role: match message.role {
995 Role::User => "user".into(),
996 Role::Assistant => "assistant".into(),
997 Role::System => "system".into(),
998 },
999 content: self
1000 .buffer
1001 .read(cx)
1002 .text_for_range(message.offset_range)
1003 .collect(),
1004 name: None,
1005 })
1006 })
1007 .collect::<Vec<_>>();
1008 let model = self.model.clone();
1009 self.pending_token_count = cx.spawn_weak(|this, mut cx| {
1010 async move {
1011 cx.background().timer(Duration::from_millis(200)).await;
1012 let token_count = cx
1013 .background()
1014 .spawn(async move { tiktoken_rs::num_tokens_from_messages(&model, &messages) })
1015 .await?;
1016
1017 this.upgrade(&cx)
1018 .ok_or_else(|| anyhow!("conversation was dropped"))?
1019 .update(&mut cx, |this, cx| {
1020 this.max_token_count = tiktoken_rs::model::get_context_size(&this.model);
1021 this.token_count = Some(token_count);
1022 cx.notify()
1023 });
1024 anyhow::Ok(())
1025 }
1026 .log_err()
1027 });
1028 }
1029
1030 fn remaining_tokens(&self) -> Option<isize> {
1031 Some(self.max_token_count as isize - self.token_count? as isize)
1032 }
1033
1034 fn set_model(&mut self, model: String, cx: &mut ModelContext<Self>) {
1035 self.model = model;
1036 self.count_remaining_tokens(cx);
1037 cx.notify();
1038 }
1039
1040 fn assist(
1041 &mut self,
1042 selected_messages: HashSet<MessageId>,
1043 cx: &mut ModelContext<Self>,
1044 ) -> Vec<MessageAnchor> {
1045 let mut user_messages = Vec::new();
1046 let mut tasks = Vec::new();
1047
1048 let last_message_id = self.message_anchors.iter().rev().find_map(|message| {
1049 message
1050 .start
1051 .is_valid(self.buffer.read(cx))
1052 .then_some(message.id)
1053 });
1054
1055 for selected_message_id in selected_messages {
1056 let selected_message_role =
1057 if let Some(metadata) = self.messages_metadata.get(&selected_message_id) {
1058 metadata.role
1059 } else {
1060 continue;
1061 };
1062
1063 if selected_message_role == Role::Assistant {
1064 if let Some(user_message) = self.insert_message_after(
1065 selected_message_id,
1066 Role::User,
1067 MessageStatus::Done,
1068 cx,
1069 ) {
1070 user_messages.push(user_message);
1071 } else {
1072 continue;
1073 }
1074 } else {
1075 let request = OpenAIRequest {
1076 model: self.model.clone(),
1077 messages: self
1078 .messages(cx)
1079 .filter(|message| matches!(message.status, MessageStatus::Done))
1080 .flat_map(|message| {
1081 let mut system_message = None;
1082 if message.id == selected_message_id {
1083 system_message = Some(RequestMessage {
1084 role: Role::System,
1085 content: concat!(
1086 "Treat the following messages as additional knowledge you have learned about, ",
1087 "but act as if they were not part of this conversation. That is, treat them ",
1088 "as if the user didn't see them and couldn't possibly inquire about them."
1089 ).into()
1090 });
1091 }
1092
1093 Some(message.to_open_ai_message(self.buffer.read(cx))).into_iter().chain(system_message)
1094 })
1095 .chain(Some(RequestMessage {
1096 role: Role::System,
1097 content: format!(
1098 "Direct your reply to message with id {}. Do not include a [Message X] header.",
1099 selected_message_id.0
1100 ),
1101 }))
1102 .collect(),
1103 stream: true,
1104 };
1105
1106 let Some(api_key) = self.api_key.borrow().clone() else { continue };
1107 let stream = stream_completion(api_key, cx.background().clone(), request);
1108 let assistant_message = self
1109 .insert_message_after(
1110 selected_message_id,
1111 Role::Assistant,
1112 MessageStatus::Pending,
1113 cx,
1114 )
1115 .unwrap();
1116
1117 // Queue up the user's next reply
1118 if Some(selected_message_id) == last_message_id {
1119 let user_message = self
1120 .insert_message_after(
1121 assistant_message.id,
1122 Role::User,
1123 MessageStatus::Done,
1124 cx,
1125 )
1126 .unwrap();
1127 user_messages.push(user_message);
1128 }
1129
1130 tasks.push(cx.spawn_weak({
1131 |this, mut cx| async move {
1132 let assistant_message_id = assistant_message.id;
1133 let stream_completion = async {
1134 let mut messages = stream.await?;
1135
1136 while let Some(message) = messages.next().await {
1137 let mut message = message?;
1138 if let Some(choice) = message.choices.pop() {
1139 this.upgrade(&cx)
1140 .ok_or_else(|| anyhow!("conversation was dropped"))?
1141 .update(&mut cx, |this, cx| {
1142 let text: Arc<str> = choice.delta.content?.into();
1143 let message_ix = this.message_anchors.iter().position(
1144 |message| message.id == assistant_message_id,
1145 )?;
1146 this.buffer.update(cx, |buffer, cx| {
1147 let offset = this.message_anchors[message_ix + 1..]
1148 .iter()
1149 .find(|message| message.start.is_valid(buffer))
1150 .map_or(buffer.len(), |message| {
1151 message
1152 .start
1153 .to_offset(buffer)
1154 .saturating_sub(1)
1155 });
1156 buffer.edit([(offset..offset, text)], None, cx);
1157 });
1158 cx.emit(ConversationEvent::StreamedCompletion);
1159
1160 Some(())
1161 });
1162 }
1163 smol::future::yield_now().await;
1164 }
1165
1166 this.upgrade(&cx)
1167 .ok_or_else(|| anyhow!("conversation was dropped"))?
1168 .update(&mut cx, |this, cx| {
1169 this.pending_completions.retain(|completion| {
1170 completion.id != this.completion_count
1171 });
1172 this.summarize(cx);
1173 });
1174
1175 anyhow::Ok(())
1176 };
1177
1178 let result = stream_completion.await;
1179 if let Some(this) = this.upgrade(&cx) {
1180 this.update(&mut cx, |this, cx| {
1181 if let Some(metadata) =
1182 this.messages_metadata.get_mut(&assistant_message.id)
1183 {
1184 match result {
1185 Ok(_) => {
1186 metadata.status = MessageStatus::Done;
1187 }
1188 Err(error) => {
1189 metadata.status = MessageStatus::Error(
1190 error.to_string().trim().into(),
1191 );
1192 }
1193 }
1194 cx.notify();
1195 }
1196 });
1197 }
1198 }
1199 }));
1200 }
1201 }
1202
1203 if !tasks.is_empty() {
1204 self.pending_completions.push(PendingCompletion {
1205 id: post_inc(&mut self.completion_count),
1206 _tasks: tasks,
1207 });
1208 }
1209
1210 user_messages
1211 }
1212
1213 fn cancel_last_assist(&mut self) -> bool {
1214 self.pending_completions.pop().is_some()
1215 }
1216
1217 fn cycle_message_roles(&mut self, ids: HashSet<MessageId>, cx: &mut ModelContext<Self>) {
1218 for id in ids {
1219 if let Some(metadata) = self.messages_metadata.get_mut(&id) {
1220 metadata.role.cycle();
1221 cx.emit(ConversationEvent::MessagesEdited);
1222 cx.notify();
1223 }
1224 }
1225 }
1226
1227 fn insert_message_after(
1228 &mut self,
1229 message_id: MessageId,
1230 role: Role,
1231 status: MessageStatus,
1232 cx: &mut ModelContext<Self>,
1233 ) -> Option<MessageAnchor> {
1234 if let Some(prev_message_ix) = self
1235 .message_anchors
1236 .iter()
1237 .position(|message| message.id == message_id)
1238 {
1239 // Find the next valid message after the one we were given.
1240 let mut next_message_ix = prev_message_ix + 1;
1241 while let Some(next_message) = self.message_anchors.get(next_message_ix) {
1242 if next_message.start.is_valid(self.buffer.read(cx)) {
1243 break;
1244 }
1245 next_message_ix += 1;
1246 }
1247
1248 let start = self.buffer.update(cx, |buffer, cx| {
1249 let offset = self
1250 .message_anchors
1251 .get(next_message_ix)
1252 .map_or(buffer.len(), |message| message.start.to_offset(buffer) - 1);
1253 buffer.edit([(offset..offset, "\n")], None, cx);
1254 buffer.anchor_before(offset + 1)
1255 });
1256 let message = MessageAnchor {
1257 id: MessageId(post_inc(&mut self.next_message_id.0)),
1258 start,
1259 };
1260 self.message_anchors
1261 .insert(next_message_ix, message.clone());
1262 self.messages_metadata.insert(
1263 message.id,
1264 MessageMetadata {
1265 role,
1266 sent_at: Local::now(),
1267 status,
1268 },
1269 );
1270 cx.emit(ConversationEvent::MessagesEdited);
1271 Some(message)
1272 } else {
1273 None
1274 }
1275 }
1276
1277 fn split_message(
1278 &mut self,
1279 range: Range<usize>,
1280 cx: &mut ModelContext<Self>,
1281 ) -> (Option<MessageAnchor>, Option<MessageAnchor>) {
1282 let start_message = self.message_for_offset(range.start, cx);
1283 let end_message = self.message_for_offset(range.end, cx);
1284 if let Some((start_message, end_message)) = start_message.zip(end_message) {
1285 // Prevent splitting when range spans multiple messages.
1286 if start_message.id != end_message.id {
1287 return (None, None);
1288 }
1289
1290 let message = start_message;
1291 let role = message.role;
1292 let mut edited_buffer = false;
1293
1294 let mut suffix_start = None;
1295 if range.start > message.offset_range.start && range.end < message.offset_range.end - 1
1296 {
1297 if self.buffer.read(cx).chars_at(range.end).next() == Some('\n') {
1298 suffix_start = Some(range.end + 1);
1299 } else if self.buffer.read(cx).reversed_chars_at(range.end).next() == Some('\n') {
1300 suffix_start = Some(range.end);
1301 }
1302 }
1303
1304 let suffix = if let Some(suffix_start) = suffix_start {
1305 MessageAnchor {
1306 id: MessageId(post_inc(&mut self.next_message_id.0)),
1307 start: self.buffer.read(cx).anchor_before(suffix_start),
1308 }
1309 } else {
1310 self.buffer.update(cx, |buffer, cx| {
1311 buffer.edit([(range.end..range.end, "\n")], None, cx);
1312 });
1313 edited_buffer = true;
1314 MessageAnchor {
1315 id: MessageId(post_inc(&mut self.next_message_id.0)),
1316 start: self.buffer.read(cx).anchor_before(range.end + 1),
1317 }
1318 };
1319
1320 self.message_anchors
1321 .insert(message.index_range.end + 1, suffix.clone());
1322 self.messages_metadata.insert(
1323 suffix.id,
1324 MessageMetadata {
1325 role,
1326 sent_at: Local::now(),
1327 status: MessageStatus::Done,
1328 },
1329 );
1330
1331 let new_messages =
1332 if range.start == range.end || range.start == message.offset_range.start {
1333 (None, Some(suffix))
1334 } else {
1335 let mut prefix_end = None;
1336 if range.start > message.offset_range.start
1337 && range.end < message.offset_range.end - 1
1338 {
1339 if self.buffer.read(cx).chars_at(range.start).next() == Some('\n') {
1340 prefix_end = Some(range.start + 1);
1341 } else if self.buffer.read(cx).reversed_chars_at(range.start).next()
1342 == Some('\n')
1343 {
1344 prefix_end = Some(range.start);
1345 }
1346 }
1347
1348 let selection = if let Some(prefix_end) = prefix_end {
1349 cx.emit(ConversationEvent::MessagesEdited);
1350 MessageAnchor {
1351 id: MessageId(post_inc(&mut self.next_message_id.0)),
1352 start: self.buffer.read(cx).anchor_before(prefix_end),
1353 }
1354 } else {
1355 self.buffer.update(cx, |buffer, cx| {
1356 buffer.edit([(range.start..range.start, "\n")], None, cx)
1357 });
1358 edited_buffer = true;
1359 MessageAnchor {
1360 id: MessageId(post_inc(&mut self.next_message_id.0)),
1361 start: self.buffer.read(cx).anchor_before(range.end + 1),
1362 }
1363 };
1364
1365 self.message_anchors
1366 .insert(message.index_range.end + 1, selection.clone());
1367 self.messages_metadata.insert(
1368 selection.id,
1369 MessageMetadata {
1370 role,
1371 sent_at: Local::now(),
1372 status: MessageStatus::Done,
1373 },
1374 );
1375 (Some(selection), Some(suffix))
1376 };
1377
1378 if !edited_buffer {
1379 cx.emit(ConversationEvent::MessagesEdited);
1380 }
1381 new_messages
1382 } else {
1383 (None, None)
1384 }
1385 }
1386
1387 fn summarize(&mut self, cx: &mut ModelContext<Self>) {
1388 if self.message_anchors.len() >= 2 && self.summary.is_none() {
1389 let api_key = self.api_key.borrow().clone();
1390 if let Some(api_key) = api_key {
1391 let messages = self
1392 .messages(cx)
1393 .take(2)
1394 .map(|message| message.to_open_ai_message(self.buffer.read(cx)))
1395 .chain(Some(RequestMessage {
1396 role: Role::User,
1397 content:
1398 "Summarize the conversation into a short title without punctuation"
1399 .into(),
1400 }));
1401 let request = OpenAIRequest {
1402 model: self.model.clone(),
1403 messages: messages.collect(),
1404 stream: true,
1405 };
1406
1407 let stream = stream_completion(api_key, cx.background().clone(), request);
1408 self.pending_summary = cx.spawn(|this, mut cx| {
1409 async move {
1410 let mut messages = stream.await?;
1411
1412 while let Some(message) = messages.next().await {
1413 let mut message = message?;
1414 if let Some(choice) = message.choices.pop() {
1415 let text = choice.delta.content.unwrap_or_default();
1416 this.update(&mut cx, |this, cx| {
1417 this.summary
1418 .get_or_insert(Default::default())
1419 .text
1420 .push_str(&text);
1421 cx.emit(ConversationEvent::SummaryChanged);
1422 });
1423 }
1424 }
1425
1426 this.update(&mut cx, |this, cx| {
1427 if let Some(summary) = this.summary.as_mut() {
1428 summary.done = true;
1429 cx.emit(ConversationEvent::SummaryChanged);
1430 }
1431 });
1432
1433 anyhow::Ok(())
1434 }
1435 .log_err()
1436 });
1437 }
1438 }
1439 }
1440
1441 fn message_for_offset(&self, offset: usize, cx: &AppContext) -> Option<Message> {
1442 self.messages_for_offsets([offset], cx).pop()
1443 }
1444
1445 fn messages_for_offsets(
1446 &self,
1447 offsets: impl IntoIterator<Item = usize>,
1448 cx: &AppContext,
1449 ) -> Vec<Message> {
1450 let mut result = Vec::new();
1451
1452 let mut messages = self.messages(cx).peekable();
1453 let mut offsets = offsets.into_iter().peekable();
1454 let mut current_message = messages.next();
1455 while let Some(offset) = offsets.next() {
1456 // Locate the message that contains the offset.
1457 while current_message.as_ref().map_or(false, |message| {
1458 !message.offset_range.contains(&offset) && messages.peek().is_some()
1459 }) {
1460 current_message = messages.next();
1461 }
1462 let Some(message) = current_message.as_ref() else { break };
1463
1464 // Skip offsets that are in the same message.
1465 while offsets.peek().map_or(false, |offset| {
1466 message.offset_range.contains(offset) || messages.peek().is_none()
1467 }) {
1468 offsets.next();
1469 }
1470
1471 result.push(message.clone());
1472 }
1473 result
1474 }
1475
1476 fn messages<'a>(&'a self, cx: &'a AppContext) -> impl 'a + Iterator<Item = Message> {
1477 let buffer = self.buffer.read(cx);
1478 let mut message_anchors = self.message_anchors.iter().enumerate().peekable();
1479 iter::from_fn(move || {
1480 while let Some((start_ix, message_anchor)) = message_anchors.next() {
1481 let metadata = self.messages_metadata.get(&message_anchor.id)?;
1482 let message_start = message_anchor.start.to_offset(buffer);
1483 let mut message_end = None;
1484 let mut end_ix = start_ix;
1485 while let Some((_, next_message)) = message_anchors.peek() {
1486 if next_message.start.is_valid(buffer) {
1487 message_end = Some(next_message.start);
1488 break;
1489 } else {
1490 end_ix += 1;
1491 message_anchors.next();
1492 }
1493 }
1494 let message_end = message_end
1495 .unwrap_or(language::Anchor::MAX)
1496 .to_offset(buffer);
1497 return Some(Message {
1498 index_range: start_ix..end_ix,
1499 offset_range: message_start..message_end,
1500 id: message_anchor.id,
1501 anchor: message_anchor.start,
1502 role: metadata.role,
1503 sent_at: metadata.sent_at,
1504 status: metadata.status.clone(),
1505 });
1506 }
1507 None
1508 })
1509 }
1510
1511 fn save(
1512 &mut self,
1513 debounce: Option<Duration>,
1514 fs: Arc<dyn Fs>,
1515 cx: &mut ModelContext<Conversation>,
1516 ) {
1517 self.pending_save = cx.spawn(|this, mut cx| async move {
1518 if let Some(debounce) = debounce {
1519 cx.background().timer(debounce).await;
1520 }
1521
1522 let (old_path, summary) = this.read_with(&cx, |this, _| {
1523 let path = this.path.clone();
1524 let summary = if let Some(summary) = this.summary.as_ref() {
1525 if summary.done {
1526 Some(summary.text.clone())
1527 } else {
1528 None
1529 }
1530 } else {
1531 None
1532 };
1533 (path, summary)
1534 });
1535
1536 if let Some(summary) = summary {
1537 let conversation = this.read_with(&cx, |this, cx| this.serialize(cx));
1538 let path = if let Some(old_path) = old_path {
1539 old_path
1540 } else {
1541 let mut discriminant = 1;
1542 let mut new_path;
1543 loop {
1544 new_path = CONVERSATIONS_DIR.join(&format!(
1545 "{} - {}.zed.json",
1546 summary.trim(),
1547 discriminant
1548 ));
1549 if fs.is_file(&new_path).await {
1550 discriminant += 1;
1551 } else {
1552 break;
1553 }
1554 }
1555 new_path
1556 };
1557
1558 fs.create_dir(CONVERSATIONS_DIR.as_ref()).await?;
1559 fs.atomic_write(path.clone(), serde_json::to_string(&conversation).unwrap())
1560 .await?;
1561 this.update(&mut cx, |this, _| this.path = Some(path));
1562 }
1563
1564 Ok(())
1565 });
1566 }
1567}
1568
1569struct PendingCompletion {
1570 id: usize,
1571 _tasks: Vec<Task<()>>,
1572}
1573
1574enum ConversationEditorEvent {
1575 TabContentChanged,
1576}
1577
1578#[derive(Copy, Clone, Debug, PartialEq)]
1579struct ScrollPosition {
1580 offset_before_cursor: Vector2F,
1581 cursor: Anchor,
1582}
1583
1584struct ConversationEditor {
1585 conversation: ModelHandle<Conversation>,
1586 fs: Arc<dyn Fs>,
1587 editor: ViewHandle<Editor>,
1588 blocks: HashSet<BlockId>,
1589 scroll_position: Option<ScrollPosition>,
1590 _subscriptions: Vec<Subscription>,
1591}
1592
1593impl ConversationEditor {
1594 fn new(
1595 api_key: Rc<RefCell<Option<String>>>,
1596 language_registry: Arc<LanguageRegistry>,
1597 fs: Arc<dyn Fs>,
1598 cx: &mut ViewContext<Self>,
1599 ) -> Self {
1600 let conversation = cx.add_model(|cx| Conversation::new(api_key, language_registry, cx));
1601 Self::for_conversation(conversation, fs, cx)
1602 }
1603
1604 fn for_conversation(
1605 conversation: ModelHandle<Conversation>,
1606 fs: Arc<dyn Fs>,
1607 cx: &mut ViewContext<Self>,
1608 ) -> Self {
1609 let editor = cx.add_view(|cx| {
1610 let mut editor = Editor::for_buffer(conversation.read(cx).buffer.clone(), None, cx);
1611 editor.set_soft_wrap_mode(SoftWrap::EditorWidth, cx);
1612 editor.set_show_gutter(false, cx);
1613 editor
1614 });
1615
1616 let _subscriptions = vec![
1617 cx.observe(&conversation, |_, _, cx| cx.notify()),
1618 cx.subscribe(&conversation, Self::handle_conversation_event),
1619 cx.subscribe(&editor, Self::handle_editor_event),
1620 ];
1621
1622 let mut this = Self {
1623 conversation,
1624 editor,
1625 blocks: Default::default(),
1626 scroll_position: None,
1627 fs,
1628 _subscriptions,
1629 };
1630 this.update_message_headers(cx);
1631 this
1632 }
1633
1634 fn assist(&mut self, _: &Assist, cx: &mut ViewContext<Self>) {
1635 let cursors = self.cursors(cx);
1636
1637 let user_messages = self.conversation.update(cx, |conversation, cx| {
1638 let selected_messages = conversation
1639 .messages_for_offsets(cursors, cx)
1640 .into_iter()
1641 .map(|message| message.id)
1642 .collect();
1643 conversation.assist(selected_messages, cx)
1644 });
1645 let new_selections = user_messages
1646 .iter()
1647 .map(|message| {
1648 let cursor = message
1649 .start
1650 .to_offset(self.conversation.read(cx).buffer.read(cx));
1651 cursor..cursor
1652 })
1653 .collect::<Vec<_>>();
1654 if !new_selections.is_empty() {
1655 self.editor.update(cx, |editor, cx| {
1656 editor.change_selections(
1657 Some(Autoscroll::Strategy(AutoscrollStrategy::Fit)),
1658 cx,
1659 |selections| selections.select_ranges(new_selections),
1660 );
1661 });
1662 }
1663 }
1664
1665 fn cancel_last_assist(&mut self, _: &editor::Cancel, cx: &mut ViewContext<Self>) {
1666 if !self
1667 .conversation
1668 .update(cx, |conversation, _| conversation.cancel_last_assist())
1669 {
1670 cx.propagate_action();
1671 }
1672 }
1673
1674 fn cycle_message_role(&mut self, _: &CycleMessageRole, cx: &mut ViewContext<Self>) {
1675 let cursors = self.cursors(cx);
1676 self.conversation.update(cx, |conversation, cx| {
1677 let messages = conversation
1678 .messages_for_offsets(cursors, cx)
1679 .into_iter()
1680 .map(|message| message.id)
1681 .collect();
1682 conversation.cycle_message_roles(messages, cx)
1683 });
1684 }
1685
1686 fn cursors(&self, cx: &AppContext) -> Vec<usize> {
1687 let selections = self.editor.read(cx).selections.all::<usize>(cx);
1688 selections
1689 .into_iter()
1690 .map(|selection| selection.head())
1691 .collect()
1692 }
1693
1694 fn handle_conversation_event(
1695 &mut self,
1696 _: ModelHandle<Conversation>,
1697 event: &ConversationEvent,
1698 cx: &mut ViewContext<Self>,
1699 ) {
1700 match event {
1701 ConversationEvent::MessagesEdited => {
1702 self.update_message_headers(cx);
1703 self.conversation.update(cx, |conversation, cx| {
1704 conversation.save(Some(Duration::from_millis(500)), self.fs.clone(), cx);
1705 });
1706 }
1707 ConversationEvent::SummaryChanged => {
1708 cx.emit(ConversationEditorEvent::TabContentChanged);
1709 self.conversation.update(cx, |conversation, cx| {
1710 conversation.save(None, self.fs.clone(), cx);
1711 });
1712 }
1713 ConversationEvent::StreamedCompletion => {
1714 self.editor.update(cx, |editor, cx| {
1715 if let Some(scroll_position) = self.scroll_position {
1716 let snapshot = editor.snapshot(cx);
1717 let cursor_point = scroll_position.cursor.to_display_point(&snapshot);
1718 let scroll_top =
1719 cursor_point.row() as f32 - scroll_position.offset_before_cursor.y();
1720 editor.set_scroll_position(
1721 vec2f(scroll_position.offset_before_cursor.x(), scroll_top),
1722 cx,
1723 );
1724 }
1725 });
1726 }
1727 }
1728 }
1729
1730 fn handle_editor_event(
1731 &mut self,
1732 _: ViewHandle<Editor>,
1733 event: &editor::Event,
1734 cx: &mut ViewContext<Self>,
1735 ) {
1736 match event {
1737 editor::Event::ScrollPositionChanged { autoscroll, .. } => {
1738 let cursor_scroll_position = self.cursor_scroll_position(cx);
1739 if *autoscroll {
1740 self.scroll_position = cursor_scroll_position;
1741 } else if self.scroll_position != cursor_scroll_position {
1742 self.scroll_position = None;
1743 }
1744 }
1745 editor::Event::SelectionsChanged { .. } => {
1746 self.scroll_position = self.cursor_scroll_position(cx);
1747 }
1748 _ => {}
1749 }
1750 }
1751
1752 fn cursor_scroll_position(&self, cx: &mut ViewContext<Self>) -> Option<ScrollPosition> {
1753 self.editor.update(cx, |editor, cx| {
1754 let snapshot = editor.snapshot(cx);
1755 let cursor = editor.selections.newest_anchor().head();
1756 let cursor_row = cursor.to_display_point(&snapshot.display_snapshot).row() as f32;
1757 let scroll_position = editor
1758 .scroll_manager
1759 .anchor()
1760 .scroll_position(&snapshot.display_snapshot);
1761
1762 let scroll_bottom = scroll_position.y() + editor.visible_line_count().unwrap_or(0.);
1763 if (scroll_position.y()..scroll_bottom).contains(&cursor_row) {
1764 Some(ScrollPosition {
1765 cursor,
1766 offset_before_cursor: vec2f(
1767 scroll_position.x(),
1768 cursor_row - scroll_position.y(),
1769 ),
1770 })
1771 } else {
1772 None
1773 }
1774 })
1775 }
1776
1777 fn update_message_headers(&mut self, cx: &mut ViewContext<Self>) {
1778 self.editor.update(cx, |editor, cx| {
1779 let buffer = editor.buffer().read(cx).snapshot(cx);
1780 let excerpt_id = *buffer.as_singleton().unwrap().0;
1781 let old_blocks = std::mem::take(&mut self.blocks);
1782 let new_blocks = self
1783 .conversation
1784 .read(cx)
1785 .messages(cx)
1786 .map(|message| BlockProperties {
1787 position: buffer.anchor_in_excerpt(excerpt_id, message.anchor),
1788 height: 2,
1789 style: BlockStyle::Sticky,
1790 render: Arc::new({
1791 let conversation = self.conversation.clone();
1792 // let metadata = message.metadata.clone();
1793 // let message = message.clone();
1794 move |cx| {
1795 enum Sender {}
1796 enum ErrorTooltip {}
1797
1798 let theme = theme::current(cx);
1799 let style = &theme.assistant;
1800 let message_id = message.id;
1801 let sender = MouseEventHandler::<Sender, _>::new(
1802 message_id.0,
1803 cx,
1804 |state, _| match message.role {
1805 Role::User => {
1806 let style = style.user_sender.style_for(state);
1807 Label::new("You", style.text.clone())
1808 .contained()
1809 .with_style(style.container)
1810 }
1811 Role::Assistant => {
1812 let style = style.assistant_sender.style_for(state);
1813 Label::new("Assistant", style.text.clone())
1814 .contained()
1815 .with_style(style.container)
1816 }
1817 Role::System => {
1818 let style = style.system_sender.style_for(state);
1819 Label::new("System", style.text.clone())
1820 .contained()
1821 .with_style(style.container)
1822 }
1823 },
1824 )
1825 .with_cursor_style(CursorStyle::PointingHand)
1826 .on_down(MouseButton::Left, {
1827 let conversation = conversation.clone();
1828 move |_, _, cx| {
1829 conversation.update(cx, |conversation, cx| {
1830 conversation.cycle_message_roles(
1831 HashSet::from_iter(Some(message_id)),
1832 cx,
1833 )
1834 })
1835 }
1836 });
1837
1838 Flex::row()
1839 .with_child(sender.aligned())
1840 .with_child(
1841 Label::new(
1842 message.sent_at.format("%I:%M%P").to_string(),
1843 style.sent_at.text.clone(),
1844 )
1845 .contained()
1846 .with_style(style.sent_at.container)
1847 .aligned(),
1848 )
1849 .with_children(
1850 if let MessageStatus::Error(error) = &message.status {
1851 Some(
1852 Svg::new("icons/circle_x_mark_12.svg")
1853 .with_color(style.error_icon.color)
1854 .constrained()
1855 .with_width(style.error_icon.width)
1856 .contained()
1857 .with_style(style.error_icon.container)
1858 .with_tooltip::<ErrorTooltip>(
1859 message_id.0,
1860 error.to_string(),
1861 None,
1862 theme.tooltip.clone(),
1863 cx,
1864 )
1865 .aligned(),
1866 )
1867 } else {
1868 None
1869 },
1870 )
1871 .aligned()
1872 .left()
1873 .contained()
1874 .with_style(style.message_header)
1875 .into_any()
1876 }
1877 }),
1878 disposition: BlockDisposition::Above,
1879 })
1880 .collect::<Vec<_>>();
1881
1882 editor.remove_blocks(old_blocks, None, cx);
1883 let ids = editor.insert_blocks(new_blocks, None, cx);
1884 self.blocks = HashSet::from_iter(ids);
1885 });
1886 }
1887
1888 fn quote_selection(
1889 workspace: &mut Workspace,
1890 _: &QuoteSelection,
1891 cx: &mut ViewContext<Workspace>,
1892 ) {
1893 let Some(panel) = workspace.panel::<AssistantPanel>(cx) else {
1894 return;
1895 };
1896 let Some(editor) = workspace.active_item(cx).and_then(|item| item.downcast::<Editor>()) else {
1897 return;
1898 };
1899
1900 let text = editor.read_with(cx, |editor, cx| {
1901 let range = editor.selections.newest::<usize>(cx).range();
1902 let buffer = editor.buffer().read(cx).snapshot(cx);
1903 let start_language = buffer.language_at(range.start);
1904 let end_language = buffer.language_at(range.end);
1905 let language_name = if start_language == end_language {
1906 start_language.map(|language| language.name())
1907 } else {
1908 None
1909 };
1910 let language_name = language_name.as_deref().unwrap_or("").to_lowercase();
1911
1912 let selected_text = buffer.text_for_range(range).collect::<String>();
1913 if selected_text.is_empty() {
1914 None
1915 } else {
1916 Some(if language_name == "markdown" {
1917 selected_text
1918 .lines()
1919 .map(|line| format!("> {}", line))
1920 .collect::<Vec<_>>()
1921 .join("\n")
1922 } else {
1923 format!("```{language_name}\n{selected_text}\n```")
1924 })
1925 }
1926 });
1927
1928 // Activate the panel
1929 if !panel.read(cx).has_focus(cx) {
1930 workspace.toggle_panel_focus::<AssistantPanel>(cx);
1931 }
1932
1933 if let Some(text) = text {
1934 panel.update(cx, |panel, cx| {
1935 let conversation = panel
1936 .active_editor()
1937 .cloned()
1938 .unwrap_or_else(|| panel.new_conversation(cx));
1939 conversation.update(cx, |conversation, cx| {
1940 conversation
1941 .editor
1942 .update(cx, |editor, cx| editor.insert(&text, cx))
1943 });
1944 });
1945 }
1946 }
1947
1948 fn copy(&mut self, _: &editor::Copy, cx: &mut ViewContext<Self>) {
1949 let editor = self.editor.read(cx);
1950 let conversation = self.conversation.read(cx);
1951 if editor.selections.count() == 1 {
1952 let selection = editor.selections.newest::<usize>(cx);
1953 let mut copied_text = String::new();
1954 let mut spanned_messages = 0;
1955 for message in conversation.messages(cx) {
1956 if message.offset_range.start >= selection.range().end {
1957 break;
1958 } else if message.offset_range.end >= selection.range().start {
1959 let range = cmp::max(message.offset_range.start, selection.range().start)
1960 ..cmp::min(message.offset_range.end, selection.range().end);
1961 if !range.is_empty() {
1962 spanned_messages += 1;
1963 write!(&mut copied_text, "## {}\n\n", message.role).unwrap();
1964 for chunk in conversation.buffer.read(cx).text_for_range(range) {
1965 copied_text.push_str(&chunk);
1966 }
1967 copied_text.push('\n');
1968 }
1969 }
1970 }
1971
1972 if spanned_messages > 1 {
1973 cx.platform()
1974 .write_to_clipboard(ClipboardItem::new(copied_text));
1975 return;
1976 }
1977 }
1978
1979 cx.propagate_action();
1980 }
1981
1982 fn split(&mut self, _: &Split, cx: &mut ViewContext<Self>) {
1983 self.conversation.update(cx, |conversation, cx| {
1984 let selections = self.editor.read(cx).selections.disjoint_anchors();
1985 for selection in selections.into_iter() {
1986 let buffer = self.editor.read(cx).buffer().read(cx).snapshot(cx);
1987 let range = selection
1988 .map(|endpoint| endpoint.to_offset(&buffer))
1989 .range();
1990 conversation.split_message(range, cx);
1991 }
1992 });
1993 }
1994
1995 fn save(&mut self, _: &Save, cx: &mut ViewContext<Self>) {
1996 self.conversation.update(cx, |conversation, cx| {
1997 conversation.save(None, self.fs.clone(), cx)
1998 });
1999 }
2000
2001 fn cycle_model(&mut self, cx: &mut ViewContext<Self>) {
2002 self.conversation.update(cx, |conversation, cx| {
2003 let new_model = match conversation.model.as_str() {
2004 "gpt-4-0613" => "gpt-3.5-turbo-0613",
2005 _ => "gpt-4-0613",
2006 };
2007 conversation.set_model(new_model.into(), cx);
2008 });
2009 }
2010
2011 fn title(&self, cx: &AppContext) -> String {
2012 self.conversation
2013 .read(cx)
2014 .summary
2015 .as_ref()
2016 .map(|summary| summary.text.clone())
2017 .unwrap_or_else(|| "New Conversation".into())
2018 }
2019
2020 fn render_current_model(
2021 &self,
2022 style: &AssistantStyle,
2023 cx: &mut ViewContext<Self>,
2024 ) -> impl Element<Self> {
2025 enum Model {}
2026
2027 MouseEventHandler::<Model, _>::new(0, cx, |state, cx| {
2028 let style = style.model.style_for(state);
2029 Label::new(self.conversation.read(cx).model.clone(), style.text.clone())
2030 .contained()
2031 .with_style(style.container)
2032 })
2033 .with_cursor_style(CursorStyle::PointingHand)
2034 .on_click(MouseButton::Left, |_, this, cx| this.cycle_model(cx))
2035 }
2036
2037 fn render_remaining_tokens(
2038 &self,
2039 style: &AssistantStyle,
2040 cx: &mut ViewContext<Self>,
2041 ) -> Option<impl Element<Self>> {
2042 let remaining_tokens = self.conversation.read(cx).remaining_tokens()?;
2043 let remaining_tokens_style = if remaining_tokens <= 0 {
2044 &style.no_remaining_tokens
2045 } else {
2046 &style.remaining_tokens
2047 };
2048 Some(
2049 Label::new(
2050 remaining_tokens.to_string(),
2051 remaining_tokens_style.text.clone(),
2052 )
2053 .contained()
2054 .with_style(remaining_tokens_style.container),
2055 )
2056 }
2057}
2058
2059impl Entity for ConversationEditor {
2060 type Event = ConversationEditorEvent;
2061}
2062
2063impl View for ConversationEditor {
2064 fn ui_name() -> &'static str {
2065 "ConversationEditor"
2066 }
2067
2068 fn render(&mut self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
2069 let theme = &theme::current(cx).assistant;
2070 Stack::new()
2071 .with_child(
2072 ChildView::new(&self.editor, cx)
2073 .contained()
2074 .with_style(theme.container),
2075 )
2076 .with_child(
2077 Flex::row()
2078 .with_child(self.render_current_model(theme, cx))
2079 .with_children(self.render_remaining_tokens(theme, cx))
2080 .aligned()
2081 .top()
2082 .right(),
2083 )
2084 .into_any()
2085 }
2086
2087 fn focus_in(&mut self, _: gpui::AnyViewHandle, cx: &mut ViewContext<Self>) {
2088 if cx.is_self_focused() {
2089 cx.focus(&self.editor);
2090 }
2091 }
2092}
2093
2094#[derive(Clone, Debug)]
2095struct MessageAnchor {
2096 id: MessageId,
2097 start: language::Anchor,
2098}
2099
2100#[derive(Clone, Debug)]
2101pub struct Message {
2102 offset_range: Range<usize>,
2103 index_range: Range<usize>,
2104 id: MessageId,
2105 anchor: language::Anchor,
2106 role: Role,
2107 sent_at: DateTime<Local>,
2108 status: MessageStatus,
2109}
2110
2111impl Message {
2112 fn to_open_ai_message(&self, buffer: &Buffer) -> RequestMessage {
2113 let mut content = format!("[Message {}]\n", self.id.0).to_string();
2114 content.extend(buffer.text_for_range(self.offset_range.clone()));
2115 RequestMessage {
2116 role: self.role,
2117 content: content.trim_end().into(),
2118 }
2119 }
2120}
2121
2122async fn stream_completion(
2123 api_key: String,
2124 executor: Arc<Background>,
2125 mut request: OpenAIRequest,
2126) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
2127 request.stream = true;
2128
2129 let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
2130
2131 let json_data = serde_json::to_string(&request)?;
2132 let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
2133 .header("Content-Type", "application/json")
2134 .header("Authorization", format!("Bearer {}", api_key))
2135 .body(json_data)?
2136 .send_async()
2137 .await?;
2138
2139 let status = response.status();
2140 if status == StatusCode::OK {
2141 executor
2142 .spawn(async move {
2143 let mut lines = BufReader::new(response.body_mut()).lines();
2144
2145 fn parse_line(
2146 line: Result<String, io::Error>,
2147 ) -> Result<Option<OpenAIResponseStreamEvent>> {
2148 if let Some(data) = line?.strip_prefix("data: ") {
2149 let event = serde_json::from_str(&data)?;
2150 Ok(Some(event))
2151 } else {
2152 Ok(None)
2153 }
2154 }
2155
2156 while let Some(line) = lines.next().await {
2157 if let Some(event) = parse_line(line).transpose() {
2158 let done = event.as_ref().map_or(false, |event| {
2159 event
2160 .choices
2161 .last()
2162 .map_or(false, |choice| choice.finish_reason.is_some())
2163 });
2164 if tx.unbounded_send(event).is_err() {
2165 break;
2166 }
2167
2168 if done {
2169 break;
2170 }
2171 }
2172 }
2173
2174 anyhow::Ok(())
2175 })
2176 .detach();
2177
2178 Ok(rx)
2179 } else {
2180 let mut body = String::new();
2181 response.body_mut().read_to_string(&mut body).await?;
2182
2183 #[derive(Deserialize)]
2184 struct OpenAIResponse {
2185 error: OpenAIError,
2186 }
2187
2188 #[derive(Deserialize)]
2189 struct OpenAIError {
2190 message: String,
2191 }
2192
2193 match serde_json::from_str::<OpenAIResponse>(&body) {
2194 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
2195 "Failed to connect to OpenAI API: {}",
2196 response.error.message,
2197 )),
2198
2199 _ => Err(anyhow!(
2200 "Failed to connect to OpenAI API: {} {}",
2201 response.status(),
2202 body,
2203 )),
2204 }
2205 }
2206}
2207
2208#[cfg(test)]
2209mod tests {
2210 use super::*;
2211 use crate::MessageId;
2212 use fs::FakeFs;
2213 use gpui::{AppContext, TestAppContext};
2214 use project::Project;
2215
2216 fn init_test(cx: &mut TestAppContext) {
2217 cx.foreground().forbid_parking();
2218 cx.update(|cx| {
2219 cx.set_global(SettingsStore::test(cx));
2220 theme::init((), cx);
2221 language::init(cx);
2222 editor::init_settings(cx);
2223 crate::init(cx);
2224 workspace::init_settings(cx);
2225 Project::init_settings(cx);
2226 });
2227 }
2228
2229 #[gpui::test]
2230 async fn test_panel(cx: &mut TestAppContext) {
2231 init_test(cx);
2232
2233 let fs = FakeFs::new(cx.background());
2234 let project = Project::test(fs, [], cx).await;
2235 let (window_id, workspace) = cx.add_window(|cx| Workspace::test_new(project.clone(), cx));
2236 let weak_workspace = workspace.downgrade();
2237
2238 let panel = cx
2239 .spawn(|cx| async move { AssistantPanel::load(weak_workspace, cx).await })
2240 .await
2241 .unwrap();
2242
2243 workspace.update(cx, |workspace, cx| {
2244 workspace.add_panel(panel.clone(), cx);
2245 workspace.toggle_dock(DockPosition::Right, cx);
2246 assert!(workspace.right_dock().read(cx).is_open());
2247 });
2248
2249 panel.update(cx, |_, cx| cx.focus_self());
2250 cx.dispatch_action(window_id, workspace::ToggleZoom);
2251 workspace.read_with(cx, |workspace, cx| {
2252 assert_eq!(workspace.zoomed_view(cx).unwrap(), panel);
2253 })
2254 }
2255
2256 #[gpui::test]
2257 fn test_inserting_and_removing_messages(cx: &mut AppContext) {
2258 let registry = Arc::new(LanguageRegistry::test());
2259 let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
2260 let buffer = conversation.read(cx).buffer.clone();
2261
2262 let message_1 = conversation.read(cx).message_anchors[0].clone();
2263 assert_eq!(
2264 messages(&conversation, cx),
2265 vec![(message_1.id, Role::User, 0..0)]
2266 );
2267
2268 let message_2 = conversation.update(cx, |conversation, cx| {
2269 conversation
2270 .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
2271 .unwrap()
2272 });
2273 assert_eq!(
2274 messages(&conversation, cx),
2275 vec![
2276 (message_1.id, Role::User, 0..1),
2277 (message_2.id, Role::Assistant, 1..1)
2278 ]
2279 );
2280
2281 buffer.update(cx, |buffer, cx| {
2282 buffer.edit([(0..0, "1"), (1..1, "2")], None, cx)
2283 });
2284 assert_eq!(
2285 messages(&conversation, cx),
2286 vec![
2287 (message_1.id, Role::User, 0..2),
2288 (message_2.id, Role::Assistant, 2..3)
2289 ]
2290 );
2291
2292 let message_3 = conversation.update(cx, |conversation, cx| {
2293 conversation
2294 .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
2295 .unwrap()
2296 });
2297 assert_eq!(
2298 messages(&conversation, cx),
2299 vec![
2300 (message_1.id, Role::User, 0..2),
2301 (message_2.id, Role::Assistant, 2..4),
2302 (message_3.id, Role::User, 4..4)
2303 ]
2304 );
2305
2306 let message_4 = conversation.update(cx, |conversation, cx| {
2307 conversation
2308 .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
2309 .unwrap()
2310 });
2311 assert_eq!(
2312 messages(&conversation, cx),
2313 vec![
2314 (message_1.id, Role::User, 0..2),
2315 (message_2.id, Role::Assistant, 2..4),
2316 (message_4.id, Role::User, 4..5),
2317 (message_3.id, Role::User, 5..5),
2318 ]
2319 );
2320
2321 buffer.update(cx, |buffer, cx| {
2322 buffer.edit([(4..4, "C"), (5..5, "D")], None, cx)
2323 });
2324 assert_eq!(
2325 messages(&conversation, cx),
2326 vec![
2327 (message_1.id, Role::User, 0..2),
2328 (message_2.id, Role::Assistant, 2..4),
2329 (message_4.id, Role::User, 4..6),
2330 (message_3.id, Role::User, 6..7),
2331 ]
2332 );
2333
2334 // Deleting across message boundaries merges the messages.
2335 buffer.update(cx, |buffer, cx| buffer.edit([(1..4, "")], None, cx));
2336 assert_eq!(
2337 messages(&conversation, cx),
2338 vec![
2339 (message_1.id, Role::User, 0..3),
2340 (message_3.id, Role::User, 3..4),
2341 ]
2342 );
2343
2344 // Undoing the deletion should also undo the merge.
2345 buffer.update(cx, |buffer, cx| buffer.undo(cx));
2346 assert_eq!(
2347 messages(&conversation, cx),
2348 vec![
2349 (message_1.id, Role::User, 0..2),
2350 (message_2.id, Role::Assistant, 2..4),
2351 (message_4.id, Role::User, 4..6),
2352 (message_3.id, Role::User, 6..7),
2353 ]
2354 );
2355
2356 // Redoing the deletion should also redo the merge.
2357 buffer.update(cx, |buffer, cx| buffer.redo(cx));
2358 assert_eq!(
2359 messages(&conversation, cx),
2360 vec![
2361 (message_1.id, Role::User, 0..3),
2362 (message_3.id, Role::User, 3..4),
2363 ]
2364 );
2365
2366 // Ensure we can still insert after a merged message.
2367 let message_5 = conversation.update(cx, |conversation, cx| {
2368 conversation
2369 .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
2370 .unwrap()
2371 });
2372 assert_eq!(
2373 messages(&conversation, cx),
2374 vec![
2375 (message_1.id, Role::User, 0..3),
2376 (message_5.id, Role::System, 3..4),
2377 (message_3.id, Role::User, 4..5)
2378 ]
2379 );
2380 }
2381
2382 #[gpui::test]
2383 fn test_message_splitting(cx: &mut AppContext) {
2384 let registry = Arc::new(LanguageRegistry::test());
2385 let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
2386 let buffer = conversation.read(cx).buffer.clone();
2387
2388 let message_1 = conversation.read(cx).message_anchors[0].clone();
2389 assert_eq!(
2390 messages(&conversation, cx),
2391 vec![(message_1.id, Role::User, 0..0)]
2392 );
2393
2394 buffer.update(cx, |buffer, cx| {
2395 buffer.edit([(0..0, "aaa\nbbb\nccc\nddd\n")], None, cx)
2396 });
2397
2398 let (_, message_2) =
2399 conversation.update(cx, |conversation, cx| conversation.split_message(3..3, cx));
2400 let message_2 = message_2.unwrap();
2401
2402 // We recycle newlines in the middle of a split message
2403 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\nddd\n");
2404 assert_eq!(
2405 messages(&conversation, cx),
2406 vec![
2407 (message_1.id, Role::User, 0..4),
2408 (message_2.id, Role::User, 4..16),
2409 ]
2410 );
2411
2412 let (_, message_3) =
2413 conversation.update(cx, |conversation, cx| conversation.split_message(3..3, cx));
2414 let message_3 = message_3.unwrap();
2415
2416 // We don't recycle newlines at the end of a split message
2417 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
2418 assert_eq!(
2419 messages(&conversation, cx),
2420 vec![
2421 (message_1.id, Role::User, 0..4),
2422 (message_3.id, Role::User, 4..5),
2423 (message_2.id, Role::User, 5..17),
2424 ]
2425 );
2426
2427 let (_, message_4) =
2428 conversation.update(cx, |conversation, cx| conversation.split_message(9..9, cx));
2429 let message_4 = message_4.unwrap();
2430 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
2431 assert_eq!(
2432 messages(&conversation, cx),
2433 vec![
2434 (message_1.id, Role::User, 0..4),
2435 (message_3.id, Role::User, 4..5),
2436 (message_2.id, Role::User, 5..9),
2437 (message_4.id, Role::User, 9..17),
2438 ]
2439 );
2440
2441 let (_, message_5) =
2442 conversation.update(cx, |conversation, cx| conversation.split_message(9..9, cx));
2443 let message_5 = message_5.unwrap();
2444 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\nddd\n");
2445 assert_eq!(
2446 messages(&conversation, cx),
2447 vec![
2448 (message_1.id, Role::User, 0..4),
2449 (message_3.id, Role::User, 4..5),
2450 (message_2.id, Role::User, 5..9),
2451 (message_4.id, Role::User, 9..10),
2452 (message_5.id, Role::User, 10..18),
2453 ]
2454 );
2455
2456 let (message_6, message_7) = conversation.update(cx, |conversation, cx| {
2457 conversation.split_message(14..16, cx)
2458 });
2459 let message_6 = message_6.unwrap();
2460 let message_7 = message_7.unwrap();
2461 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\ndd\nd\n");
2462 assert_eq!(
2463 messages(&conversation, cx),
2464 vec![
2465 (message_1.id, Role::User, 0..4),
2466 (message_3.id, Role::User, 4..5),
2467 (message_2.id, Role::User, 5..9),
2468 (message_4.id, Role::User, 9..10),
2469 (message_5.id, Role::User, 10..14),
2470 (message_6.id, Role::User, 14..17),
2471 (message_7.id, Role::User, 17..19),
2472 ]
2473 );
2474 }
2475
2476 #[gpui::test]
2477 fn test_messages_for_offsets(cx: &mut AppContext) {
2478 let registry = Arc::new(LanguageRegistry::test());
2479 let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
2480 let buffer = conversation.read(cx).buffer.clone();
2481
2482 let message_1 = conversation.read(cx).message_anchors[0].clone();
2483 assert_eq!(
2484 messages(&conversation, cx),
2485 vec![(message_1.id, Role::User, 0..0)]
2486 );
2487
2488 buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
2489 let message_2 = conversation
2490 .update(cx, |conversation, cx| {
2491 conversation.insert_message_after(message_1.id, Role::User, MessageStatus::Done, cx)
2492 })
2493 .unwrap();
2494 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx));
2495
2496 let message_3 = conversation
2497 .update(cx, |conversation, cx| {
2498 conversation.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
2499 })
2500 .unwrap();
2501 buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx));
2502
2503 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc");
2504 assert_eq!(
2505 messages(&conversation, cx),
2506 vec![
2507 (message_1.id, Role::User, 0..4),
2508 (message_2.id, Role::User, 4..8),
2509 (message_3.id, Role::User, 8..11)
2510 ]
2511 );
2512
2513 assert_eq!(
2514 message_ids_for_offsets(&conversation, &[0, 4, 9], cx),
2515 [message_1.id, message_2.id, message_3.id]
2516 );
2517 assert_eq!(
2518 message_ids_for_offsets(&conversation, &[0, 1, 11], cx),
2519 [message_1.id, message_3.id]
2520 );
2521
2522 let message_4 = conversation
2523 .update(cx, |conversation, cx| {
2524 conversation.insert_message_after(message_3.id, Role::User, MessageStatus::Done, cx)
2525 })
2526 .unwrap();
2527 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\n");
2528 assert_eq!(
2529 messages(&conversation, cx),
2530 vec![
2531 (message_1.id, Role::User, 0..4),
2532 (message_2.id, Role::User, 4..8),
2533 (message_3.id, Role::User, 8..12),
2534 (message_4.id, Role::User, 12..12)
2535 ]
2536 );
2537 assert_eq!(
2538 message_ids_for_offsets(&conversation, &[0, 4, 8, 12], cx),
2539 [message_1.id, message_2.id, message_3.id, message_4.id]
2540 );
2541
2542 fn message_ids_for_offsets(
2543 conversation: &ModelHandle<Conversation>,
2544 offsets: &[usize],
2545 cx: &AppContext,
2546 ) -> Vec<MessageId> {
2547 conversation
2548 .read(cx)
2549 .messages_for_offsets(offsets.iter().copied(), cx)
2550 .into_iter()
2551 .map(|message| message.id)
2552 .collect()
2553 }
2554 }
2555
2556 #[gpui::test]
2557 fn test_serialization(cx: &mut AppContext) {
2558 let registry = Arc::new(LanguageRegistry::test());
2559 let conversation =
2560 cx.add_model(|cx| Conversation::new(Default::default(), registry.clone(), cx));
2561 let buffer = conversation.read(cx).buffer.clone();
2562 let message_0 = conversation.read(cx).message_anchors[0].id;
2563 let message_1 = conversation.update(cx, |conversation, cx| {
2564 conversation
2565 .insert_message_after(message_0, Role::Assistant, MessageStatus::Done, cx)
2566 .unwrap()
2567 });
2568 let message_2 = conversation.update(cx, |conversation, cx| {
2569 conversation
2570 .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
2571 .unwrap()
2572 });
2573 buffer.update(cx, |buffer, cx| {
2574 buffer.edit([(0..0, "a"), (1..1, "b\nc")], None, cx);
2575 buffer.finalize_last_transaction();
2576 });
2577 let _message_3 = conversation.update(cx, |conversation, cx| {
2578 conversation
2579 .insert_message_after(message_2.id, Role::System, MessageStatus::Done, cx)
2580 .unwrap()
2581 });
2582 buffer.update(cx, |buffer, cx| buffer.undo(cx));
2583 assert_eq!(buffer.read(cx).text(), "a\nb\nc\n");
2584 assert_eq!(
2585 messages(&conversation, cx),
2586 [
2587 (message_0, Role::User, 0..2),
2588 (message_1.id, Role::Assistant, 2..6),
2589 (message_2.id, Role::System, 6..6),
2590 ]
2591 );
2592
2593 let deserialized_conversation = cx.add_model(|cx| {
2594 Conversation::deserialize(
2595 conversation.read(cx).serialize(cx),
2596 Default::default(),
2597 Default::default(),
2598 registry.clone(),
2599 cx,
2600 )
2601 });
2602 let deserialized_buffer = deserialized_conversation.read(cx).buffer.clone();
2603 assert_eq!(deserialized_buffer.read(cx).text(), "a\nb\nc\n");
2604 assert_eq!(
2605 messages(&deserialized_conversation, cx),
2606 [
2607 (message_0, Role::User, 0..2),
2608 (message_1.id, Role::Assistant, 2..6),
2609 (message_2.id, Role::System, 6..6),
2610 ]
2611 );
2612 }
2613
2614 fn messages(
2615 conversation: &ModelHandle<Conversation>,
2616 cx: &AppContext,
2617 ) -> Vec<(MessageId, Role, Range<usize>)> {
2618 conversation
2619 .read(cx)
2620 .messages(cx)
2621 .map(|message| (message.id, message.role, message.offset_range))
2622 .collect()
2623 }
2624}