1use crate::{
2 assistant_settings::{AssistantDockPosition, AssistantSettings},
3 MessageId, MessageMetadata, MessageStatus, OpenAIRequest, OpenAIResponseStreamEvent,
4 RequestMessage, Role, SavedConversation, SavedConversationMetadata, SavedMessage,
5};
6use anyhow::{anyhow, Result};
7use chrono::{DateTime, Local};
8use collections::{HashMap, HashSet};
9use editor::{
10 display_map::{BlockDisposition, BlockId, BlockProperties, BlockStyle, ToDisplayPoint},
11 scroll::autoscroll::{Autoscroll, AutoscrollStrategy},
12 Anchor, Editor, ToOffset,
13};
14use fs::Fs;
15use futures::{io::BufReader, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
16use gpui::{
17 actions,
18 elements::*,
19 executor::Background,
20 geometry::vector::{vec2f, Vector2F},
21 platform::{CursorStyle, MouseButton},
22 Action, AppContext, AsyncAppContext, ClipboardItem, Entity, ModelContext, ModelHandle,
23 Subscription, Task, View, ViewContext, ViewHandle, WeakViewHandle, WindowContext,
24};
25use isahc::{http::StatusCode, Request, RequestExt};
26use language::{language_settings::SoftWrap, Buffer, LanguageRegistry, ToOffset as _};
27use search::BufferSearchBar;
28use serde::Deserialize;
29use settings::SettingsStore;
30use std::{
31 cell::RefCell,
32 cmp, env,
33 fmt::Write,
34 io, iter,
35 ops::Range,
36 path::{Path, PathBuf},
37 rc::Rc,
38 sync::Arc,
39 time::Duration,
40};
41use theme::AssistantStyle;
42use util::{channel::ReleaseChannel, paths::CONVERSATIONS_DIR, post_inc, ResultExt, TryFutureExt};
43use workspace::{
44 dock::{DockPosition, Panel},
45 searchable::Direction,
46 Save, ToggleZoom, Toolbar, Workspace,
47};
48
49const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
50
51actions!(
52 assistant,
53 [
54 NewConversation,
55 Assist,
56 Split,
57 CycleMessageRole,
58 QuoteSelection,
59 ToggleFocus,
60 ResetKey,
61 ]
62);
63
64pub fn init(cx: &mut AppContext) {
65 if *util::channel::RELEASE_CHANNEL == ReleaseChannel::Stable {
66 cx.update_default_global::<collections::CommandPaletteFilter, _, _>(move |filter, _cx| {
67 filter.filtered_namespaces.insert("assistant");
68 });
69 }
70
71 settings::register::<AssistantSettings>(cx);
72 cx.add_action(
73 |this: &mut AssistantPanel,
74 _: &workspace::NewFile,
75 cx: &mut ViewContext<AssistantPanel>| {
76 this.new_conversation(cx);
77 },
78 );
79 cx.add_action(ConversationEditor::assist);
80 cx.capture_action(ConversationEditor::cancel_last_assist);
81 cx.capture_action(ConversationEditor::save);
82 cx.add_action(ConversationEditor::quote_selection);
83 cx.capture_action(ConversationEditor::copy);
84 cx.add_action(ConversationEditor::split);
85 cx.capture_action(ConversationEditor::cycle_message_role);
86 cx.add_action(AssistantPanel::save_api_key);
87 cx.add_action(AssistantPanel::reset_api_key);
88 cx.add_action(AssistantPanel::toggle_zoom);
89 cx.add_action(AssistantPanel::deploy);
90 cx.add_action(AssistantPanel::select_next_match);
91 cx.add_action(AssistantPanel::select_prev_match);
92 cx.add_action(AssistantPanel::handle_editor_cancel);
93 cx.add_action(
94 |workspace: &mut Workspace, _: &ToggleFocus, cx: &mut ViewContext<Workspace>| {
95 workspace.toggle_panel_focus::<AssistantPanel>(cx);
96 },
97 );
98}
99
100#[derive(Debug)]
101pub enum AssistantPanelEvent {
102 ZoomIn,
103 ZoomOut,
104 Focus,
105 Close,
106 DockPositionChanged,
107}
108
109pub struct AssistantPanel {
110 workspace: WeakViewHandle<Workspace>,
111 width: Option<f32>,
112 height: Option<f32>,
113 active_editor_index: Option<usize>,
114 prev_active_editor_index: Option<usize>,
115 editors: Vec<ViewHandle<ConversationEditor>>,
116 saved_conversations: Vec<SavedConversationMetadata>,
117 saved_conversations_list_state: UniformListState,
118 zoomed: bool,
119 has_focus: bool,
120 toolbar: ViewHandle<Toolbar>,
121 api_key: Rc<RefCell<Option<String>>>,
122 api_key_editor: Option<ViewHandle<Editor>>,
123 has_read_credentials: bool,
124 languages: Arc<LanguageRegistry>,
125 fs: Arc<dyn Fs>,
126 subscriptions: Vec<Subscription>,
127 _watch_saved_conversations: Task<Result<()>>,
128}
129
130impl AssistantPanel {
131 pub fn load(
132 workspace: WeakViewHandle<Workspace>,
133 cx: AsyncAppContext,
134 ) -> Task<Result<ViewHandle<Self>>> {
135 cx.spawn(|mut cx| async move {
136 let fs = workspace.read_with(&cx, |workspace, _| workspace.app_state().fs.clone())?;
137 let saved_conversations = SavedConversationMetadata::list(fs.clone())
138 .await
139 .log_err()
140 .unwrap_or_default();
141
142 // TODO: deserialize state.
143 let workspace_handle = workspace.clone();
144 workspace.update(&mut cx, |workspace, cx| {
145 cx.add_view::<Self, _>(|cx| {
146 const CONVERSATION_WATCH_DURATION: Duration = Duration::from_millis(100);
147 let _watch_saved_conversations = cx.spawn(move |this, mut cx| async move {
148 let mut events = fs
149 .watch(&CONVERSATIONS_DIR, CONVERSATION_WATCH_DURATION)
150 .await;
151 while events.next().await.is_some() {
152 let saved_conversations = SavedConversationMetadata::list(fs.clone())
153 .await
154 .log_err()
155 .unwrap_or_default();
156 this.update(&mut cx, |this, _| {
157 this.saved_conversations = saved_conversations
158 })
159 .ok();
160 }
161
162 anyhow::Ok(())
163 });
164
165 let toolbar = cx.add_view(|cx| {
166 let mut toolbar = Toolbar::new(None);
167 toolbar.set_can_navigate(false, cx);
168 toolbar.add_item(cx.add_view(|cx| BufferSearchBar::new(cx)), cx);
169 toolbar
170 });
171 let mut this = Self {
172 workspace: workspace_handle,
173 active_editor_index: Default::default(),
174 prev_active_editor_index: Default::default(),
175 editors: Default::default(),
176 saved_conversations,
177 saved_conversations_list_state: Default::default(),
178 zoomed: false,
179 has_focus: false,
180 toolbar,
181 api_key: Rc::new(RefCell::new(None)),
182 api_key_editor: None,
183 has_read_credentials: false,
184 languages: workspace.app_state().languages.clone(),
185 fs: workspace.app_state().fs.clone(),
186 width: None,
187 height: None,
188 subscriptions: Default::default(),
189 _watch_saved_conversations,
190 };
191
192 let mut old_dock_position = this.position(cx);
193 this.subscriptions =
194 vec![cx.observe_global::<SettingsStore, _>(move |this, cx| {
195 let new_dock_position = this.position(cx);
196 if new_dock_position != old_dock_position {
197 old_dock_position = new_dock_position;
198 cx.emit(AssistantPanelEvent::DockPositionChanged);
199 }
200 })];
201
202 this
203 })
204 })
205 })
206 }
207
208 fn new_conversation(&mut self, cx: &mut ViewContext<Self>) -> ViewHandle<ConversationEditor> {
209 let editor = cx.add_view(|cx| {
210 ConversationEditor::new(
211 self.api_key.clone(),
212 self.languages.clone(),
213 self.fs.clone(),
214 cx,
215 )
216 });
217 self.add_conversation(editor.clone(), cx);
218 editor
219 }
220
221 fn add_conversation(
222 &mut self,
223 editor: ViewHandle<ConversationEditor>,
224 cx: &mut ViewContext<Self>,
225 ) {
226 self.subscriptions
227 .push(cx.subscribe(&editor, Self::handle_conversation_editor_event));
228
229 let conversation = editor.read(cx).conversation.clone();
230 self.subscriptions
231 .push(cx.observe(&conversation, |_, _, cx| cx.notify()));
232
233 let index = self.editors.len();
234 self.editors.push(editor);
235 self.set_active_editor_index(Some(index), cx);
236 }
237
238 fn set_active_editor_index(&mut self, index: Option<usize>, cx: &mut ViewContext<Self>) {
239 self.prev_active_editor_index = self.active_editor_index;
240 self.active_editor_index = index;
241 if let Some(editor) = self.active_editor() {
242 let editor = editor.read(cx).editor.clone();
243 self.toolbar.update(cx, |toolbar, cx| {
244 toolbar.set_active_item(Some(&editor), cx);
245 });
246 if self.has_focus(cx) {
247 cx.focus(&editor);
248 }
249 } else {
250 self.toolbar.update(cx, |toolbar, cx| {
251 toolbar.set_active_item(None, cx);
252 });
253 }
254
255 cx.notify();
256 }
257
258 fn handle_conversation_editor_event(
259 &mut self,
260 _: ViewHandle<ConversationEditor>,
261 event: &ConversationEditorEvent,
262 cx: &mut ViewContext<Self>,
263 ) {
264 match event {
265 ConversationEditorEvent::TabContentChanged => cx.notify(),
266 }
267 }
268
269 fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
270 if let Some(api_key) = self
271 .api_key_editor
272 .as_ref()
273 .map(|editor| editor.read(cx).text(cx))
274 {
275 if !api_key.is_empty() {
276 cx.platform()
277 .write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
278 .log_err();
279 *self.api_key.borrow_mut() = Some(api_key);
280 self.api_key_editor.take();
281 cx.focus_self();
282 cx.notify();
283 }
284 } else {
285 cx.propagate_action();
286 }
287 }
288
289 fn reset_api_key(&mut self, _: &ResetKey, cx: &mut ViewContext<Self>) {
290 cx.platform().delete_credentials(OPENAI_API_URL).log_err();
291 self.api_key.take();
292 self.api_key_editor = Some(build_api_key_editor(cx));
293 cx.focus_self();
294 cx.notify();
295 }
296
297 fn toggle_zoom(&mut self, _: &workspace::ToggleZoom, cx: &mut ViewContext<Self>) {
298 if self.zoomed {
299 cx.emit(AssistantPanelEvent::ZoomOut)
300 } else {
301 cx.emit(AssistantPanelEvent::ZoomIn)
302 }
303 }
304
305 fn deploy(&mut self, action: &search::buffer_search::Deploy, cx: &mut ViewContext<Self>) {
306 if let Some(search_bar) = self.toolbar.read(cx).item_of_type::<BufferSearchBar>() {
307 if search_bar.update(cx, |search_bar, cx| search_bar.show(action.focus, true, cx)) {
308 return;
309 }
310 }
311 cx.propagate_action();
312 }
313
314 fn handle_editor_cancel(&mut self, _: &editor::Cancel, cx: &mut ViewContext<Self>) {
315 if let Some(search_bar) = self.toolbar.read(cx).item_of_type::<BufferSearchBar>() {
316 if !search_bar.read(cx).is_dismissed() {
317 search_bar.update(cx, |search_bar, cx| {
318 search_bar.dismiss(&Default::default(), cx)
319 });
320 return;
321 }
322 }
323 cx.propagate_action();
324 }
325
326 fn select_next_match(&mut self, _: &search::SelectNextMatch, cx: &mut ViewContext<Self>) {
327 if let Some(search_bar) = self.toolbar.read(cx).item_of_type::<BufferSearchBar>() {
328 search_bar.update(cx, |bar, cx| bar.select_match(Direction::Next, cx));
329 }
330 }
331
332 fn select_prev_match(&mut self, _: &search::SelectPrevMatch, cx: &mut ViewContext<Self>) {
333 if let Some(search_bar) = self.toolbar.read(cx).item_of_type::<BufferSearchBar>() {
334 search_bar.update(cx, |bar, cx| bar.select_match(Direction::Prev, cx));
335 }
336 }
337
338 fn active_editor(&self) -> Option<&ViewHandle<ConversationEditor>> {
339 self.editors.get(self.active_editor_index?)
340 }
341
342 fn render_hamburger_button(cx: &mut ViewContext<Self>) -> impl Element<Self> {
343 enum History {}
344 let theme = theme::current(cx);
345 let tooltip_style = theme::current(cx).tooltip.clone();
346 MouseEventHandler::<History, _>::new(0, cx, |state, _| {
347 let style = theme.assistant.hamburger_button.style_for(state);
348 Svg::for_style(style.icon.clone())
349 .contained()
350 .with_style(style.container)
351 })
352 .with_cursor_style(CursorStyle::PointingHand)
353 .on_click(MouseButton::Left, |_, this: &mut Self, cx| {
354 if this.active_editor().is_some() {
355 this.set_active_editor_index(None, cx);
356 } else {
357 this.set_active_editor_index(this.prev_active_editor_index, cx);
358 }
359 })
360 .with_tooltip::<History>(1, "History".into(), None, tooltip_style, cx)
361 }
362
363 fn render_editor_tools(&self, cx: &mut ViewContext<Self>) -> Vec<AnyElement<Self>> {
364 if self.active_editor().is_some() {
365 vec![
366 Self::render_split_button(cx).into_any(),
367 Self::render_quote_button(cx).into_any(),
368 Self::render_assist_button(cx).into_any(),
369 ]
370 } else {
371 Default::default()
372 }
373 }
374
375 fn render_split_button(cx: &mut ViewContext<Self>) -> impl Element<Self> {
376 let theme = theme::current(cx);
377 let tooltip_style = theme::current(cx).tooltip.clone();
378 MouseEventHandler::<Split, _>::new(0, cx, |state, _| {
379 let style = theme.assistant.split_button.style_for(state);
380 Svg::for_style(style.icon.clone())
381 .contained()
382 .with_style(style.container)
383 })
384 .with_cursor_style(CursorStyle::PointingHand)
385 .on_click(MouseButton::Left, |_, this: &mut Self, cx| {
386 if let Some(active_editor) = this.active_editor() {
387 active_editor.update(cx, |editor, cx| editor.split(&Default::default(), cx));
388 }
389 })
390 .with_tooltip::<Split>(
391 1,
392 "Split Message".into(),
393 Some(Box::new(Split)),
394 tooltip_style,
395 cx,
396 )
397 }
398
399 fn render_assist_button(cx: &mut ViewContext<Self>) -> impl Element<Self> {
400 let theme = theme::current(cx);
401 let tooltip_style = theme::current(cx).tooltip.clone();
402 MouseEventHandler::<Assist, _>::new(0, cx, |state, _| {
403 let style = theme.assistant.assist_button.style_for(state);
404 Svg::for_style(style.icon.clone())
405 .contained()
406 .with_style(style.container)
407 })
408 .with_cursor_style(CursorStyle::PointingHand)
409 .on_click(MouseButton::Left, |_, this: &mut Self, cx| {
410 if let Some(active_editor) = this.active_editor() {
411 active_editor.update(cx, |editor, cx| editor.assist(&Default::default(), cx));
412 }
413 })
414 .with_tooltip::<Assist>(
415 1,
416 "Assist".into(),
417 Some(Box::new(Assist)),
418 tooltip_style,
419 cx,
420 )
421 }
422
423 fn render_quote_button(cx: &mut ViewContext<Self>) -> impl Element<Self> {
424 let theme = theme::current(cx);
425 let tooltip_style = theme::current(cx).tooltip.clone();
426 MouseEventHandler::<QuoteSelection, _>::new(0, cx, |state, _| {
427 let style = theme.assistant.quote_button.style_for(state);
428 Svg::for_style(style.icon.clone())
429 .contained()
430 .with_style(style.container)
431 })
432 .with_cursor_style(CursorStyle::PointingHand)
433 .on_click(MouseButton::Left, |_, this: &mut Self, cx| {
434 if let Some(workspace) = this.workspace.upgrade(cx) {
435 cx.window_context().defer(move |cx| {
436 workspace.update(cx, |workspace, cx| {
437 ConversationEditor::quote_selection(workspace, &Default::default(), cx)
438 });
439 });
440 }
441 })
442 .with_tooltip::<QuoteSelection>(
443 1,
444 "Quote Selection".into(),
445 Some(Box::new(QuoteSelection)),
446 tooltip_style,
447 cx,
448 )
449 }
450
451 fn render_plus_button(cx: &mut ViewContext<Self>) -> impl Element<Self> {
452 let theme = theme::current(cx);
453 let tooltip_style = theme::current(cx).tooltip.clone();
454 MouseEventHandler::<NewConversation, _>::new(0, cx, |state, _| {
455 let style = theme.assistant.plus_button.style_for(state);
456 Svg::for_style(style.icon.clone())
457 .contained()
458 .with_style(style.container)
459 })
460 .with_cursor_style(CursorStyle::PointingHand)
461 .on_click(MouseButton::Left, |_, this: &mut Self, cx| {
462 this.new_conversation(cx);
463 })
464 .with_tooltip::<NewConversation>(
465 1,
466 "New Conversation".into(),
467 Some(Box::new(NewConversation)),
468 tooltip_style,
469 cx,
470 )
471 }
472
473 fn render_zoom_button(&self, cx: &mut ViewContext<Self>) -> impl Element<Self> {
474 enum ToggleZoomButton {}
475
476 let theme = theme::current(cx);
477 let tooltip_style = theme::current(cx).tooltip.clone();
478 let style = if self.zoomed {
479 &theme.assistant.zoom_out_button
480 } else {
481 &theme.assistant.zoom_in_button
482 };
483
484 MouseEventHandler::<ToggleZoomButton, _>::new(0, cx, |state, _| {
485 let style = style.style_for(state);
486 Svg::for_style(style.icon.clone())
487 .contained()
488 .with_style(style.container)
489 })
490 .with_cursor_style(CursorStyle::PointingHand)
491 .on_click(MouseButton::Left, |_, this, cx| {
492 this.toggle_zoom(&ToggleZoom, cx);
493 })
494 .with_tooltip::<ToggleZoom>(
495 0,
496 if self.zoomed {
497 "Zoom Out".into()
498 } else {
499 "Zoom In".into()
500 },
501 Some(Box::new(ToggleZoom)),
502 tooltip_style,
503 cx,
504 )
505 }
506
507 fn render_saved_conversation(
508 &mut self,
509 index: usize,
510 cx: &mut ViewContext<Self>,
511 ) -> impl Element<Self> {
512 let conversation = &self.saved_conversations[index];
513 let path = conversation.path.clone();
514 MouseEventHandler::<SavedConversationMetadata, _>::new(index, cx, move |state, cx| {
515 let style = &theme::current(cx).assistant.saved_conversation;
516 Flex::row()
517 .with_child(
518 Label::new(
519 conversation.mtime.format("%F %I:%M%p").to_string(),
520 style.saved_at.text.clone(),
521 )
522 .aligned()
523 .contained()
524 .with_style(style.saved_at.container),
525 )
526 .with_child(
527 Label::new(conversation.title.clone(), style.title.text.clone())
528 .aligned()
529 .contained()
530 .with_style(style.title.container),
531 )
532 .contained()
533 .with_style(*style.container.style_for(state))
534 })
535 .with_cursor_style(CursorStyle::PointingHand)
536 .on_click(MouseButton::Left, move |_, this, cx| {
537 this.open_conversation(path.clone(), cx)
538 .detach_and_log_err(cx)
539 })
540 }
541
542 fn open_conversation(&mut self, path: PathBuf, cx: &mut ViewContext<Self>) -> Task<Result<()>> {
543 if let Some(ix) = self.editor_index_for_path(&path, cx) {
544 self.set_active_editor_index(Some(ix), cx);
545 return Task::ready(Ok(()));
546 }
547
548 let fs = self.fs.clone();
549 let api_key = self.api_key.clone();
550 let languages = self.languages.clone();
551 cx.spawn(|this, mut cx| async move {
552 let saved_conversation = fs.load(&path).await?;
553 let saved_conversation = serde_json::from_str(&saved_conversation)?;
554 let conversation = cx.add_model(|cx| {
555 Conversation::deserialize(saved_conversation, path.clone(), api_key, languages, cx)
556 });
557 this.update(&mut cx, |this, cx| {
558 // If, by the time we've loaded the conversation, the user has already opened
559 // the same conversation, we don't want to open it again.
560 if let Some(ix) = this.editor_index_for_path(&path, cx) {
561 this.set_active_editor_index(Some(ix), cx);
562 } else {
563 let editor = cx
564 .add_view(|cx| ConversationEditor::for_conversation(conversation, fs, cx));
565 this.add_conversation(editor, cx);
566 }
567 })?;
568 Ok(())
569 })
570 }
571
572 fn editor_index_for_path(&self, path: &Path, cx: &AppContext) -> Option<usize> {
573 self.editors
574 .iter()
575 .position(|editor| editor.read(cx).conversation.read(cx).path.as_deref() == Some(path))
576 }
577}
578
579fn build_api_key_editor(cx: &mut ViewContext<AssistantPanel>) -> ViewHandle<Editor> {
580 cx.add_view(|cx| {
581 let mut editor = Editor::single_line(
582 Some(Arc::new(|theme| theme.assistant.api_key_editor.clone())),
583 cx,
584 );
585 editor.set_placeholder_text("sk-000000000000000000000000000000000000000000000000", cx);
586 editor
587 })
588}
589
590impl Entity for AssistantPanel {
591 type Event = AssistantPanelEvent;
592}
593
594impl View for AssistantPanel {
595 fn ui_name() -> &'static str {
596 "AssistantPanel"
597 }
598
599 fn render(&mut self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
600 let theme = &theme::current(cx);
601 let style = &theme.assistant;
602 if let Some(api_key_editor) = self.api_key_editor.as_ref() {
603 Flex::column()
604 .with_child(
605 Text::new(
606 "Paste your OpenAI API key and press Enter to use the assistant",
607 style.api_key_prompt.text.clone(),
608 )
609 .aligned(),
610 )
611 .with_child(
612 ChildView::new(api_key_editor, cx)
613 .contained()
614 .with_style(style.api_key_editor.container)
615 .aligned(),
616 )
617 .contained()
618 .with_style(style.api_key_prompt.container)
619 .aligned()
620 .into_any()
621 } else {
622 let title = self.active_editor().map(|editor| {
623 Label::new(editor.read(cx).title(cx), style.title.text.clone())
624 .contained()
625 .with_style(style.title.container)
626 .aligned()
627 .left()
628 .flex(1., false)
629 });
630 let mut header = Flex::row()
631 .with_child(Self::render_hamburger_button(cx).aligned())
632 .with_children(title);
633 if self.has_focus {
634 header.add_children(
635 self.render_editor_tools(cx)
636 .into_iter()
637 .map(|tool| tool.aligned().flex_float()),
638 );
639 header.add_child(Self::render_plus_button(cx).aligned().flex_float());
640 header.add_child(self.render_zoom_button(cx).aligned());
641 }
642
643 Flex::column()
644 .with_child(
645 header
646 .contained()
647 .with_style(theme.workspace.tab_bar.container)
648 .expanded()
649 .constrained()
650 .with_height(theme.workspace.tab_bar.height),
651 )
652 .with_children(if self.toolbar.read(cx).hidden() {
653 None
654 } else {
655 Some(ChildView::new(&self.toolbar, cx).expanded())
656 })
657 .with_child(if let Some(editor) = self.active_editor() {
658 ChildView::new(editor, cx).flex(1., true).into_any()
659 } else {
660 UniformList::new(
661 self.saved_conversations_list_state.clone(),
662 self.saved_conversations.len(),
663 cx,
664 |this, range, items, cx| {
665 for ix in range {
666 items.push(this.render_saved_conversation(ix, cx).into_any());
667 }
668 },
669 )
670 .flex(1., true)
671 .into_any()
672 })
673 .into_any()
674 }
675 }
676
677 fn focus_in(&mut self, _: gpui::AnyViewHandle, cx: &mut ViewContext<Self>) {
678 self.has_focus = true;
679 self.toolbar
680 .update(cx, |toolbar, cx| toolbar.focus_changed(true, cx));
681 cx.notify();
682 if cx.is_self_focused() {
683 if let Some(editor) = self.active_editor() {
684 cx.focus(editor);
685 } else if let Some(api_key_editor) = self.api_key_editor.as_ref() {
686 cx.focus(api_key_editor);
687 }
688 }
689 }
690
691 fn focus_out(&mut self, _: gpui::AnyViewHandle, cx: &mut ViewContext<Self>) {
692 self.has_focus = false;
693 self.toolbar
694 .update(cx, |toolbar, cx| toolbar.focus_changed(false, cx));
695 cx.notify();
696 }
697}
698
699impl Panel for AssistantPanel {
700 fn position(&self, cx: &WindowContext) -> DockPosition {
701 match settings::get::<AssistantSettings>(cx).dock {
702 AssistantDockPosition::Left => DockPosition::Left,
703 AssistantDockPosition::Bottom => DockPosition::Bottom,
704 AssistantDockPosition::Right => DockPosition::Right,
705 }
706 }
707
708 fn position_is_valid(&self, _: DockPosition) -> bool {
709 true
710 }
711
712 fn set_position(&mut self, position: DockPosition, cx: &mut ViewContext<Self>) {
713 settings::update_settings_file::<AssistantSettings>(self.fs.clone(), cx, move |settings| {
714 let dock = match position {
715 DockPosition::Left => AssistantDockPosition::Left,
716 DockPosition::Bottom => AssistantDockPosition::Bottom,
717 DockPosition::Right => AssistantDockPosition::Right,
718 };
719 settings.dock = Some(dock);
720 });
721 }
722
723 fn size(&self, cx: &WindowContext) -> f32 {
724 let settings = settings::get::<AssistantSettings>(cx);
725 match self.position(cx) {
726 DockPosition::Left | DockPosition::Right => {
727 self.width.unwrap_or_else(|| settings.default_width)
728 }
729 DockPosition::Bottom => self.height.unwrap_or_else(|| settings.default_height),
730 }
731 }
732
733 fn set_size(&mut self, size: f32, cx: &mut ViewContext<Self>) {
734 match self.position(cx) {
735 DockPosition::Left | DockPosition::Right => self.width = Some(size),
736 DockPosition::Bottom => self.height = Some(size),
737 }
738 cx.notify();
739 }
740
741 fn should_zoom_in_on_event(event: &AssistantPanelEvent) -> bool {
742 matches!(event, AssistantPanelEvent::ZoomIn)
743 }
744
745 fn should_zoom_out_on_event(event: &AssistantPanelEvent) -> bool {
746 matches!(event, AssistantPanelEvent::ZoomOut)
747 }
748
749 fn is_zoomed(&self, _: &WindowContext) -> bool {
750 self.zoomed
751 }
752
753 fn set_zoomed(&mut self, zoomed: bool, cx: &mut ViewContext<Self>) {
754 self.zoomed = zoomed;
755 cx.notify();
756 }
757
758 fn set_active(&mut self, active: bool, cx: &mut ViewContext<Self>) {
759 if active {
760 if self.api_key.borrow().is_none() && !self.has_read_credentials {
761 self.has_read_credentials = true;
762 let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
763 Some(api_key)
764 } else if let Some((_, api_key)) = cx
765 .platform()
766 .read_credentials(OPENAI_API_URL)
767 .log_err()
768 .flatten()
769 {
770 String::from_utf8(api_key).log_err()
771 } else {
772 None
773 };
774 if let Some(api_key) = api_key {
775 *self.api_key.borrow_mut() = Some(api_key);
776 } else if self.api_key_editor.is_none() {
777 self.api_key_editor = Some(build_api_key_editor(cx));
778 cx.notify();
779 }
780 }
781
782 if self.editors.is_empty() {
783 self.new_conversation(cx);
784 }
785 }
786 }
787
788 fn icon_path(&self) -> &'static str {
789 "icons/robot_14.svg"
790 }
791
792 fn icon_tooltip(&self) -> (String, Option<Box<dyn Action>>) {
793 ("Assistant Panel".into(), Some(Box::new(ToggleFocus)))
794 }
795
796 fn should_change_position_on_event(event: &Self::Event) -> bool {
797 matches!(event, AssistantPanelEvent::DockPositionChanged)
798 }
799
800 fn should_activate_on_event(_: &Self::Event) -> bool {
801 false
802 }
803
804 fn should_close_on_event(event: &AssistantPanelEvent) -> bool {
805 matches!(event, AssistantPanelEvent::Close)
806 }
807
808 fn has_focus(&self, _: &WindowContext) -> bool {
809 self.has_focus
810 }
811
812 fn is_focus_event(event: &Self::Event) -> bool {
813 matches!(event, AssistantPanelEvent::Focus)
814 }
815}
816
817enum ConversationEvent {
818 MessagesEdited,
819 SummaryChanged,
820 StreamedCompletion,
821}
822
823#[derive(Default)]
824struct Summary {
825 text: String,
826 done: bool,
827}
828
829struct Conversation {
830 buffer: ModelHandle<Buffer>,
831 message_anchors: Vec<MessageAnchor>,
832 messages_metadata: HashMap<MessageId, MessageMetadata>,
833 next_message_id: MessageId,
834 summary: Option<Summary>,
835 pending_summary: Task<Option<()>>,
836 completion_count: usize,
837 pending_completions: Vec<PendingCompletion>,
838 model: String,
839 token_count: Option<usize>,
840 max_token_count: usize,
841 pending_token_count: Task<Option<()>>,
842 api_key: Rc<RefCell<Option<String>>>,
843 pending_save: Task<Result<()>>,
844 path: Option<PathBuf>,
845 _subscriptions: Vec<Subscription>,
846}
847
848impl Entity for Conversation {
849 type Event = ConversationEvent;
850}
851
852impl Conversation {
853 fn new(
854 api_key: Rc<RefCell<Option<String>>>,
855 language_registry: Arc<LanguageRegistry>,
856 cx: &mut ModelContext<Self>,
857 ) -> Self {
858 let model = "gpt-3.5-turbo-0613";
859 let markdown = language_registry.language_for_name("Markdown");
860 let buffer = cx.add_model(|cx| {
861 let mut buffer = Buffer::new(0, "", cx);
862 buffer.set_language_registry(language_registry);
863 cx.spawn_weak(|buffer, mut cx| async move {
864 let markdown = markdown.await?;
865 let buffer = buffer
866 .upgrade(&cx)
867 .ok_or_else(|| anyhow!("buffer was dropped"))?;
868 buffer.update(&mut cx, |buffer, cx| {
869 buffer.set_language(Some(markdown), cx)
870 });
871 anyhow::Ok(())
872 })
873 .detach_and_log_err(cx);
874 buffer
875 });
876
877 let mut this = Self {
878 message_anchors: Default::default(),
879 messages_metadata: Default::default(),
880 next_message_id: Default::default(),
881 summary: None,
882 pending_summary: Task::ready(None),
883 completion_count: Default::default(),
884 pending_completions: Default::default(),
885 token_count: None,
886 max_token_count: tiktoken_rs::model::get_context_size(model),
887 pending_token_count: Task::ready(None),
888 model: model.into(),
889 _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
890 pending_save: Task::ready(Ok(())),
891 path: None,
892 api_key,
893 buffer,
894 };
895 let message = MessageAnchor {
896 id: MessageId(post_inc(&mut this.next_message_id.0)),
897 start: language::Anchor::MIN,
898 };
899 this.message_anchors.push(message.clone());
900 this.messages_metadata.insert(
901 message.id,
902 MessageMetadata {
903 role: Role::User,
904 sent_at: Local::now(),
905 status: MessageStatus::Done,
906 },
907 );
908
909 this.count_remaining_tokens(cx);
910 this
911 }
912
913 fn serialize(&self, cx: &AppContext) -> SavedConversation {
914 SavedConversation {
915 zed: "conversation".into(),
916 version: SavedConversation::VERSION.into(),
917 text: self.buffer.read(cx).text(),
918 message_metadata: self.messages_metadata.clone(),
919 messages: self
920 .messages(cx)
921 .map(|message| SavedMessage {
922 id: message.id,
923 start: message.offset_range.start,
924 })
925 .collect(),
926 summary: self
927 .summary
928 .as_ref()
929 .map(|summary| summary.text.clone())
930 .unwrap_or_default(),
931 model: self.model.clone(),
932 }
933 }
934
935 fn deserialize(
936 saved_conversation: SavedConversation,
937 path: PathBuf,
938 api_key: Rc<RefCell<Option<String>>>,
939 language_registry: Arc<LanguageRegistry>,
940 cx: &mut ModelContext<Self>,
941 ) -> Self {
942 let model = saved_conversation.model;
943 let markdown = language_registry.language_for_name("Markdown");
944 let mut message_anchors = Vec::new();
945 let mut next_message_id = MessageId(0);
946 let buffer = cx.add_model(|cx| {
947 let mut buffer = Buffer::new(0, saved_conversation.text, cx);
948 for message in saved_conversation.messages {
949 message_anchors.push(MessageAnchor {
950 id: message.id,
951 start: buffer.anchor_before(message.start),
952 });
953 next_message_id = cmp::max(next_message_id, MessageId(message.id.0 + 1));
954 }
955 buffer.set_language_registry(language_registry);
956 cx.spawn_weak(|buffer, mut cx| async move {
957 let markdown = markdown.await?;
958 let buffer = buffer
959 .upgrade(&cx)
960 .ok_or_else(|| anyhow!("buffer was dropped"))?;
961 buffer.update(&mut cx, |buffer, cx| {
962 buffer.set_language(Some(markdown), cx)
963 });
964 anyhow::Ok(())
965 })
966 .detach_and_log_err(cx);
967 buffer
968 });
969
970 let mut this = Self {
971 message_anchors,
972 messages_metadata: saved_conversation.message_metadata,
973 next_message_id,
974 summary: Some(Summary {
975 text: saved_conversation.summary,
976 done: true,
977 }),
978 pending_summary: Task::ready(None),
979 completion_count: Default::default(),
980 pending_completions: Default::default(),
981 token_count: None,
982 max_token_count: tiktoken_rs::model::get_context_size(&model),
983 pending_token_count: Task::ready(None),
984 model,
985 _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
986 pending_save: Task::ready(Ok(())),
987 path: Some(path),
988 api_key,
989 buffer,
990 };
991 this.count_remaining_tokens(cx);
992 this
993 }
994
995 fn handle_buffer_event(
996 &mut self,
997 _: ModelHandle<Buffer>,
998 event: &language::Event,
999 cx: &mut ModelContext<Self>,
1000 ) {
1001 match event {
1002 language::Event::Edited => {
1003 self.count_remaining_tokens(cx);
1004 cx.emit(ConversationEvent::MessagesEdited);
1005 }
1006 _ => {}
1007 }
1008 }
1009
1010 fn count_remaining_tokens(&mut self, cx: &mut ModelContext<Self>) {
1011 let messages = self
1012 .messages(cx)
1013 .into_iter()
1014 .filter_map(|message| {
1015 Some(tiktoken_rs::ChatCompletionRequestMessage {
1016 role: match message.role {
1017 Role::User => "user".into(),
1018 Role::Assistant => "assistant".into(),
1019 Role::System => "system".into(),
1020 },
1021 content: self
1022 .buffer
1023 .read(cx)
1024 .text_for_range(message.offset_range)
1025 .collect(),
1026 name: None,
1027 })
1028 })
1029 .collect::<Vec<_>>();
1030 let model = self.model.clone();
1031 self.pending_token_count = cx.spawn_weak(|this, mut cx| {
1032 async move {
1033 cx.background().timer(Duration::from_millis(200)).await;
1034 let token_count = cx
1035 .background()
1036 .spawn(async move { tiktoken_rs::num_tokens_from_messages(&model, &messages) })
1037 .await?;
1038
1039 this.upgrade(&cx)
1040 .ok_or_else(|| anyhow!("conversation was dropped"))?
1041 .update(&mut cx, |this, cx| {
1042 this.max_token_count = tiktoken_rs::model::get_context_size(&this.model);
1043 this.token_count = Some(token_count);
1044 cx.notify()
1045 });
1046 anyhow::Ok(())
1047 }
1048 .log_err()
1049 });
1050 }
1051
1052 fn remaining_tokens(&self) -> Option<isize> {
1053 Some(self.max_token_count as isize - self.token_count? as isize)
1054 }
1055
1056 fn set_model(&mut self, model: String, cx: &mut ModelContext<Self>) {
1057 self.model = model;
1058 self.count_remaining_tokens(cx);
1059 cx.notify();
1060 }
1061
1062 fn assist(
1063 &mut self,
1064 selected_messages: HashSet<MessageId>,
1065 cx: &mut ModelContext<Self>,
1066 ) -> Vec<MessageAnchor> {
1067 let mut user_messages = Vec::new();
1068 let mut tasks = Vec::new();
1069
1070 let last_message_id = self.message_anchors.iter().rev().find_map(|message| {
1071 message
1072 .start
1073 .is_valid(self.buffer.read(cx))
1074 .then_some(message.id)
1075 });
1076
1077 for selected_message_id in selected_messages {
1078 let selected_message_role =
1079 if let Some(metadata) = self.messages_metadata.get(&selected_message_id) {
1080 metadata.role
1081 } else {
1082 continue;
1083 };
1084
1085 if selected_message_role == Role::Assistant {
1086 if let Some(user_message) = self.insert_message_after(
1087 selected_message_id,
1088 Role::User,
1089 MessageStatus::Done,
1090 cx,
1091 ) {
1092 user_messages.push(user_message);
1093 } else {
1094 continue;
1095 }
1096 } else {
1097 let request = OpenAIRequest {
1098 model: self.model.clone(),
1099 messages: self
1100 .messages(cx)
1101 .filter(|message| matches!(message.status, MessageStatus::Done))
1102 .flat_map(|message| {
1103 let mut system_message = None;
1104 if message.id == selected_message_id {
1105 system_message = Some(RequestMessage {
1106 role: Role::System,
1107 content: concat!(
1108 "Treat the following messages as additional knowledge you have learned about, ",
1109 "but act as if they were not part of this conversation. That is, treat them ",
1110 "as if the user didn't see them and couldn't possibly inquire about them."
1111 ).into()
1112 });
1113 }
1114
1115 Some(message.to_open_ai_message(self.buffer.read(cx))).into_iter().chain(system_message)
1116 })
1117 .chain(Some(RequestMessage {
1118 role: Role::System,
1119 content: format!(
1120 "Direct your reply to message with id {}. Do not include a [Message X] header.",
1121 selected_message_id.0
1122 ),
1123 }))
1124 .collect(),
1125 stream: true,
1126 };
1127
1128 let Some(api_key) = self.api_key.borrow().clone() else { continue };
1129 let stream = stream_completion(api_key, cx.background().clone(), request);
1130 let assistant_message = self
1131 .insert_message_after(
1132 selected_message_id,
1133 Role::Assistant,
1134 MessageStatus::Pending,
1135 cx,
1136 )
1137 .unwrap();
1138
1139 // Queue up the user's next reply
1140 if Some(selected_message_id) == last_message_id {
1141 let user_message = self
1142 .insert_message_after(
1143 assistant_message.id,
1144 Role::User,
1145 MessageStatus::Done,
1146 cx,
1147 )
1148 .unwrap();
1149 user_messages.push(user_message);
1150 }
1151
1152 tasks.push(cx.spawn_weak({
1153 |this, mut cx| async move {
1154 let assistant_message_id = assistant_message.id;
1155 let stream_completion = async {
1156 let mut messages = stream.await?;
1157
1158 while let Some(message) = messages.next().await {
1159 let mut message = message?;
1160 if let Some(choice) = message.choices.pop() {
1161 this.upgrade(&cx)
1162 .ok_or_else(|| anyhow!("conversation was dropped"))?
1163 .update(&mut cx, |this, cx| {
1164 let text: Arc<str> = choice.delta.content?.into();
1165 let message_ix = this.message_anchors.iter().position(
1166 |message| message.id == assistant_message_id,
1167 )?;
1168 this.buffer.update(cx, |buffer, cx| {
1169 let offset = this.message_anchors[message_ix + 1..]
1170 .iter()
1171 .find(|message| message.start.is_valid(buffer))
1172 .map_or(buffer.len(), |message| {
1173 message
1174 .start
1175 .to_offset(buffer)
1176 .saturating_sub(1)
1177 });
1178 buffer.edit([(offset..offset, text)], None, cx);
1179 });
1180 cx.emit(ConversationEvent::StreamedCompletion);
1181
1182 Some(())
1183 });
1184 }
1185 smol::future::yield_now().await;
1186 }
1187
1188 this.upgrade(&cx)
1189 .ok_or_else(|| anyhow!("conversation was dropped"))?
1190 .update(&mut cx, |this, cx| {
1191 this.pending_completions.retain(|completion| {
1192 completion.id != this.completion_count
1193 });
1194 this.summarize(cx);
1195 });
1196
1197 anyhow::Ok(())
1198 };
1199
1200 let result = stream_completion.await;
1201 if let Some(this) = this.upgrade(&cx) {
1202 this.update(&mut cx, |this, cx| {
1203 if let Some(metadata) =
1204 this.messages_metadata.get_mut(&assistant_message.id)
1205 {
1206 match result {
1207 Ok(_) => {
1208 metadata.status = MessageStatus::Done;
1209 }
1210 Err(error) => {
1211 metadata.status = MessageStatus::Error(
1212 error.to_string().trim().into(),
1213 );
1214 }
1215 }
1216 cx.notify();
1217 }
1218 });
1219 }
1220 }
1221 }));
1222 }
1223 }
1224
1225 if !tasks.is_empty() {
1226 self.pending_completions.push(PendingCompletion {
1227 id: post_inc(&mut self.completion_count),
1228 _tasks: tasks,
1229 });
1230 }
1231
1232 user_messages
1233 }
1234
1235 fn cancel_last_assist(&mut self) -> bool {
1236 self.pending_completions.pop().is_some()
1237 }
1238
1239 fn cycle_message_roles(&mut self, ids: HashSet<MessageId>, cx: &mut ModelContext<Self>) {
1240 for id in ids {
1241 if let Some(metadata) = self.messages_metadata.get_mut(&id) {
1242 metadata.role.cycle();
1243 cx.emit(ConversationEvent::MessagesEdited);
1244 cx.notify();
1245 }
1246 }
1247 }
1248
1249 fn insert_message_after(
1250 &mut self,
1251 message_id: MessageId,
1252 role: Role,
1253 status: MessageStatus,
1254 cx: &mut ModelContext<Self>,
1255 ) -> Option<MessageAnchor> {
1256 if let Some(prev_message_ix) = self
1257 .message_anchors
1258 .iter()
1259 .position(|message| message.id == message_id)
1260 {
1261 // Find the next valid message after the one we were given.
1262 let mut next_message_ix = prev_message_ix + 1;
1263 while let Some(next_message) = self.message_anchors.get(next_message_ix) {
1264 if next_message.start.is_valid(self.buffer.read(cx)) {
1265 break;
1266 }
1267 next_message_ix += 1;
1268 }
1269
1270 let start = self.buffer.update(cx, |buffer, cx| {
1271 let offset = self
1272 .message_anchors
1273 .get(next_message_ix)
1274 .map_or(buffer.len(), |message| message.start.to_offset(buffer) - 1);
1275 buffer.edit([(offset..offset, "\n")], None, cx);
1276 buffer.anchor_before(offset + 1)
1277 });
1278 let message = MessageAnchor {
1279 id: MessageId(post_inc(&mut self.next_message_id.0)),
1280 start,
1281 };
1282 self.message_anchors
1283 .insert(next_message_ix, message.clone());
1284 self.messages_metadata.insert(
1285 message.id,
1286 MessageMetadata {
1287 role,
1288 sent_at: Local::now(),
1289 status,
1290 },
1291 );
1292 cx.emit(ConversationEvent::MessagesEdited);
1293 Some(message)
1294 } else {
1295 None
1296 }
1297 }
1298
1299 fn split_message(
1300 &mut self,
1301 range: Range<usize>,
1302 cx: &mut ModelContext<Self>,
1303 ) -> (Option<MessageAnchor>, Option<MessageAnchor>) {
1304 let start_message = self.message_for_offset(range.start, cx);
1305 let end_message = self.message_for_offset(range.end, cx);
1306 if let Some((start_message, end_message)) = start_message.zip(end_message) {
1307 // Prevent splitting when range spans multiple messages.
1308 if start_message.id != end_message.id {
1309 return (None, None);
1310 }
1311
1312 let message = start_message;
1313 let role = message.role;
1314 let mut edited_buffer = false;
1315
1316 let mut suffix_start = None;
1317 if range.start > message.offset_range.start && range.end < message.offset_range.end - 1
1318 {
1319 if self.buffer.read(cx).chars_at(range.end).next() == Some('\n') {
1320 suffix_start = Some(range.end + 1);
1321 } else if self.buffer.read(cx).reversed_chars_at(range.end).next() == Some('\n') {
1322 suffix_start = Some(range.end);
1323 }
1324 }
1325
1326 let suffix = if let Some(suffix_start) = suffix_start {
1327 MessageAnchor {
1328 id: MessageId(post_inc(&mut self.next_message_id.0)),
1329 start: self.buffer.read(cx).anchor_before(suffix_start),
1330 }
1331 } else {
1332 self.buffer.update(cx, |buffer, cx| {
1333 buffer.edit([(range.end..range.end, "\n")], None, cx);
1334 });
1335 edited_buffer = true;
1336 MessageAnchor {
1337 id: MessageId(post_inc(&mut self.next_message_id.0)),
1338 start: self.buffer.read(cx).anchor_before(range.end + 1),
1339 }
1340 };
1341
1342 self.message_anchors
1343 .insert(message.index_range.end + 1, suffix.clone());
1344 self.messages_metadata.insert(
1345 suffix.id,
1346 MessageMetadata {
1347 role,
1348 sent_at: Local::now(),
1349 status: MessageStatus::Done,
1350 },
1351 );
1352
1353 let new_messages =
1354 if range.start == range.end || range.start == message.offset_range.start {
1355 (None, Some(suffix))
1356 } else {
1357 let mut prefix_end = None;
1358 if range.start > message.offset_range.start
1359 && range.end < message.offset_range.end - 1
1360 {
1361 if self.buffer.read(cx).chars_at(range.start).next() == Some('\n') {
1362 prefix_end = Some(range.start + 1);
1363 } else if self.buffer.read(cx).reversed_chars_at(range.start).next()
1364 == Some('\n')
1365 {
1366 prefix_end = Some(range.start);
1367 }
1368 }
1369
1370 let selection = if let Some(prefix_end) = prefix_end {
1371 cx.emit(ConversationEvent::MessagesEdited);
1372 MessageAnchor {
1373 id: MessageId(post_inc(&mut self.next_message_id.0)),
1374 start: self.buffer.read(cx).anchor_before(prefix_end),
1375 }
1376 } else {
1377 self.buffer.update(cx, |buffer, cx| {
1378 buffer.edit([(range.start..range.start, "\n")], None, cx)
1379 });
1380 edited_buffer = true;
1381 MessageAnchor {
1382 id: MessageId(post_inc(&mut self.next_message_id.0)),
1383 start: self.buffer.read(cx).anchor_before(range.end + 1),
1384 }
1385 };
1386
1387 self.message_anchors
1388 .insert(message.index_range.end + 1, selection.clone());
1389 self.messages_metadata.insert(
1390 selection.id,
1391 MessageMetadata {
1392 role,
1393 sent_at: Local::now(),
1394 status: MessageStatus::Done,
1395 },
1396 );
1397 (Some(selection), Some(suffix))
1398 };
1399
1400 if !edited_buffer {
1401 cx.emit(ConversationEvent::MessagesEdited);
1402 }
1403 new_messages
1404 } else {
1405 (None, None)
1406 }
1407 }
1408
1409 fn summarize(&mut self, cx: &mut ModelContext<Self>) {
1410 if self.message_anchors.len() >= 2 && self.summary.is_none() {
1411 let api_key = self.api_key.borrow().clone();
1412 if let Some(api_key) = api_key {
1413 let messages = self
1414 .messages(cx)
1415 .take(2)
1416 .map(|message| message.to_open_ai_message(self.buffer.read(cx)))
1417 .chain(Some(RequestMessage {
1418 role: Role::User,
1419 content:
1420 "Summarize the conversation into a short title without punctuation"
1421 .into(),
1422 }));
1423 let request = OpenAIRequest {
1424 model: self.model.clone(),
1425 messages: messages.collect(),
1426 stream: true,
1427 };
1428
1429 let stream = stream_completion(api_key, cx.background().clone(), request);
1430 self.pending_summary = cx.spawn(|this, mut cx| {
1431 async move {
1432 let mut messages = stream.await?;
1433
1434 while let Some(message) = messages.next().await {
1435 let mut message = message?;
1436 if let Some(choice) = message.choices.pop() {
1437 let text = choice.delta.content.unwrap_or_default();
1438 this.update(&mut cx, |this, cx| {
1439 this.summary
1440 .get_or_insert(Default::default())
1441 .text
1442 .push_str(&text);
1443 cx.emit(ConversationEvent::SummaryChanged);
1444 });
1445 }
1446 }
1447
1448 this.update(&mut cx, |this, cx| {
1449 if let Some(summary) = this.summary.as_mut() {
1450 summary.done = true;
1451 cx.emit(ConversationEvent::SummaryChanged);
1452 }
1453 });
1454
1455 anyhow::Ok(())
1456 }
1457 .log_err()
1458 });
1459 }
1460 }
1461 }
1462
1463 fn message_for_offset(&self, offset: usize, cx: &AppContext) -> Option<Message> {
1464 self.messages_for_offsets([offset], cx).pop()
1465 }
1466
1467 fn messages_for_offsets(
1468 &self,
1469 offsets: impl IntoIterator<Item = usize>,
1470 cx: &AppContext,
1471 ) -> Vec<Message> {
1472 let mut result = Vec::new();
1473
1474 let mut messages = self.messages(cx).peekable();
1475 let mut offsets = offsets.into_iter().peekable();
1476 let mut current_message = messages.next();
1477 while let Some(offset) = offsets.next() {
1478 // Locate the message that contains the offset.
1479 while current_message.as_ref().map_or(false, |message| {
1480 !message.offset_range.contains(&offset) && messages.peek().is_some()
1481 }) {
1482 current_message = messages.next();
1483 }
1484 let Some(message) = current_message.as_ref() else { break };
1485
1486 // Skip offsets that are in the same message.
1487 while offsets.peek().map_or(false, |offset| {
1488 message.offset_range.contains(offset) || messages.peek().is_none()
1489 }) {
1490 offsets.next();
1491 }
1492
1493 result.push(message.clone());
1494 }
1495 result
1496 }
1497
1498 fn messages<'a>(&'a self, cx: &'a AppContext) -> impl 'a + Iterator<Item = Message> {
1499 let buffer = self.buffer.read(cx);
1500 let mut message_anchors = self.message_anchors.iter().enumerate().peekable();
1501 iter::from_fn(move || {
1502 while let Some((start_ix, message_anchor)) = message_anchors.next() {
1503 let metadata = self.messages_metadata.get(&message_anchor.id)?;
1504 let message_start = message_anchor.start.to_offset(buffer);
1505 let mut message_end = None;
1506 let mut end_ix = start_ix;
1507 while let Some((_, next_message)) = message_anchors.peek() {
1508 if next_message.start.is_valid(buffer) {
1509 message_end = Some(next_message.start);
1510 break;
1511 } else {
1512 end_ix += 1;
1513 message_anchors.next();
1514 }
1515 }
1516 let message_end = message_end
1517 .unwrap_or(language::Anchor::MAX)
1518 .to_offset(buffer);
1519 return Some(Message {
1520 index_range: start_ix..end_ix,
1521 offset_range: message_start..message_end,
1522 id: message_anchor.id,
1523 anchor: message_anchor.start,
1524 role: metadata.role,
1525 sent_at: metadata.sent_at,
1526 status: metadata.status.clone(),
1527 });
1528 }
1529 None
1530 })
1531 }
1532
1533 fn save(
1534 &mut self,
1535 debounce: Option<Duration>,
1536 fs: Arc<dyn Fs>,
1537 cx: &mut ModelContext<Conversation>,
1538 ) {
1539 self.pending_save = cx.spawn(|this, mut cx| async move {
1540 if let Some(debounce) = debounce {
1541 cx.background().timer(debounce).await;
1542 }
1543
1544 let (old_path, summary) = this.read_with(&cx, |this, _| {
1545 let path = this.path.clone();
1546 let summary = if let Some(summary) = this.summary.as_ref() {
1547 if summary.done {
1548 Some(summary.text.clone())
1549 } else {
1550 None
1551 }
1552 } else {
1553 None
1554 };
1555 (path, summary)
1556 });
1557
1558 if let Some(summary) = summary {
1559 let conversation = this.read_with(&cx, |this, cx| this.serialize(cx));
1560 let path = if let Some(old_path) = old_path {
1561 old_path
1562 } else {
1563 let mut discriminant = 1;
1564 let mut new_path;
1565 loop {
1566 new_path = CONVERSATIONS_DIR.join(&format!(
1567 "{} - {}.zed.json",
1568 summary.trim(),
1569 discriminant
1570 ));
1571 if fs.is_file(&new_path).await {
1572 discriminant += 1;
1573 } else {
1574 break;
1575 }
1576 }
1577 new_path
1578 };
1579
1580 fs.create_dir(CONVERSATIONS_DIR.as_ref()).await?;
1581 fs.atomic_write(path.clone(), serde_json::to_string(&conversation).unwrap())
1582 .await?;
1583 this.update(&mut cx, |this, _| this.path = Some(path));
1584 }
1585
1586 Ok(())
1587 });
1588 }
1589}
1590
1591struct PendingCompletion {
1592 id: usize,
1593 _tasks: Vec<Task<()>>,
1594}
1595
1596enum ConversationEditorEvent {
1597 TabContentChanged,
1598}
1599
1600#[derive(Copy, Clone, Debug, PartialEq)]
1601struct ScrollPosition {
1602 offset_before_cursor: Vector2F,
1603 cursor: Anchor,
1604}
1605
1606struct ConversationEditor {
1607 conversation: ModelHandle<Conversation>,
1608 fs: Arc<dyn Fs>,
1609 editor: ViewHandle<Editor>,
1610 blocks: HashSet<BlockId>,
1611 scroll_position: Option<ScrollPosition>,
1612 _subscriptions: Vec<Subscription>,
1613}
1614
1615impl ConversationEditor {
1616 fn new(
1617 api_key: Rc<RefCell<Option<String>>>,
1618 language_registry: Arc<LanguageRegistry>,
1619 fs: Arc<dyn Fs>,
1620 cx: &mut ViewContext<Self>,
1621 ) -> Self {
1622 let conversation = cx.add_model(|cx| Conversation::new(api_key, language_registry, cx));
1623 Self::for_conversation(conversation, fs, cx)
1624 }
1625
1626 fn for_conversation(
1627 conversation: ModelHandle<Conversation>,
1628 fs: Arc<dyn Fs>,
1629 cx: &mut ViewContext<Self>,
1630 ) -> Self {
1631 let editor = cx.add_view(|cx| {
1632 let mut editor = Editor::for_buffer(conversation.read(cx).buffer.clone(), None, cx);
1633 editor.set_soft_wrap_mode(SoftWrap::EditorWidth, cx);
1634 editor.set_show_gutter(false, cx);
1635 editor
1636 });
1637
1638 let _subscriptions = vec![
1639 cx.observe(&conversation, |_, _, cx| cx.notify()),
1640 cx.subscribe(&conversation, Self::handle_conversation_event),
1641 cx.subscribe(&editor, Self::handle_editor_event),
1642 ];
1643
1644 let mut this = Self {
1645 conversation,
1646 editor,
1647 blocks: Default::default(),
1648 scroll_position: None,
1649 fs,
1650 _subscriptions,
1651 };
1652 this.update_message_headers(cx);
1653 this
1654 }
1655
1656 fn assist(&mut self, _: &Assist, cx: &mut ViewContext<Self>) {
1657 let cursors = self.cursors(cx);
1658
1659 let user_messages = self.conversation.update(cx, |conversation, cx| {
1660 let selected_messages = conversation
1661 .messages_for_offsets(cursors, cx)
1662 .into_iter()
1663 .map(|message| message.id)
1664 .collect();
1665 conversation.assist(selected_messages, cx)
1666 });
1667 let new_selections = user_messages
1668 .iter()
1669 .map(|message| {
1670 let cursor = message
1671 .start
1672 .to_offset(self.conversation.read(cx).buffer.read(cx));
1673 cursor..cursor
1674 })
1675 .collect::<Vec<_>>();
1676 if !new_selections.is_empty() {
1677 self.editor.update(cx, |editor, cx| {
1678 editor.change_selections(
1679 Some(Autoscroll::Strategy(AutoscrollStrategy::Fit)),
1680 cx,
1681 |selections| selections.select_ranges(new_selections),
1682 );
1683 });
1684 // Avoid scrolling to the new cursor position so the assistant's output is stable.
1685 cx.defer(|this, _| this.scroll_position = None);
1686 }
1687 }
1688
1689 fn cancel_last_assist(&mut self, _: &editor::Cancel, cx: &mut ViewContext<Self>) {
1690 if !self
1691 .conversation
1692 .update(cx, |conversation, _| conversation.cancel_last_assist())
1693 {
1694 cx.propagate_action();
1695 }
1696 }
1697
1698 fn cycle_message_role(&mut self, _: &CycleMessageRole, cx: &mut ViewContext<Self>) {
1699 let cursors = self.cursors(cx);
1700 self.conversation.update(cx, |conversation, cx| {
1701 let messages = conversation
1702 .messages_for_offsets(cursors, cx)
1703 .into_iter()
1704 .map(|message| message.id)
1705 .collect();
1706 conversation.cycle_message_roles(messages, cx)
1707 });
1708 }
1709
1710 fn cursors(&self, cx: &AppContext) -> Vec<usize> {
1711 let selections = self.editor.read(cx).selections.all::<usize>(cx);
1712 selections
1713 .into_iter()
1714 .map(|selection| selection.head())
1715 .collect()
1716 }
1717
1718 fn handle_conversation_event(
1719 &mut self,
1720 _: ModelHandle<Conversation>,
1721 event: &ConversationEvent,
1722 cx: &mut ViewContext<Self>,
1723 ) {
1724 match event {
1725 ConversationEvent::MessagesEdited => {
1726 self.update_message_headers(cx);
1727 self.conversation.update(cx, |conversation, cx| {
1728 conversation.save(Some(Duration::from_millis(500)), self.fs.clone(), cx);
1729 });
1730 }
1731 ConversationEvent::SummaryChanged => {
1732 cx.emit(ConversationEditorEvent::TabContentChanged);
1733 self.conversation.update(cx, |conversation, cx| {
1734 conversation.save(None, self.fs.clone(), cx);
1735 });
1736 }
1737 ConversationEvent::StreamedCompletion => {
1738 self.editor.update(cx, |editor, cx| {
1739 if let Some(scroll_position) = self.scroll_position {
1740 let snapshot = editor.snapshot(cx);
1741 let cursor_point = scroll_position.cursor.to_display_point(&snapshot);
1742 let scroll_top =
1743 cursor_point.row() as f32 - scroll_position.offset_before_cursor.y();
1744 editor.set_scroll_position(
1745 vec2f(scroll_position.offset_before_cursor.x(), scroll_top),
1746 cx,
1747 );
1748 }
1749 });
1750 }
1751 }
1752 }
1753
1754 fn handle_editor_event(
1755 &mut self,
1756 _: ViewHandle<Editor>,
1757 event: &editor::Event,
1758 cx: &mut ViewContext<Self>,
1759 ) {
1760 match event {
1761 editor::Event::ScrollPositionChanged { autoscroll, .. } => {
1762 let cursor_scroll_position = self.cursor_scroll_position(cx);
1763 if *autoscroll {
1764 self.scroll_position = cursor_scroll_position;
1765 } else if self.scroll_position != cursor_scroll_position {
1766 self.scroll_position = None;
1767 }
1768 }
1769 editor::Event::SelectionsChanged { .. } => {
1770 self.scroll_position = self.cursor_scroll_position(cx);
1771 }
1772 _ => {}
1773 }
1774 }
1775
1776 fn cursor_scroll_position(&self, cx: &mut ViewContext<Self>) -> Option<ScrollPosition> {
1777 self.editor.update(cx, |editor, cx| {
1778 let snapshot = editor.snapshot(cx);
1779 let cursor = editor.selections.newest_anchor().head();
1780 let cursor_row = cursor.to_display_point(&snapshot.display_snapshot).row() as f32;
1781 let scroll_position = editor
1782 .scroll_manager
1783 .anchor()
1784 .scroll_position(&snapshot.display_snapshot);
1785
1786 let scroll_bottom = scroll_position.y() + editor.visible_line_count().unwrap_or(0.);
1787 if (scroll_position.y()..scroll_bottom).contains(&cursor_row) {
1788 Some(ScrollPosition {
1789 cursor,
1790 offset_before_cursor: vec2f(
1791 scroll_position.x(),
1792 cursor_row - scroll_position.y(),
1793 ),
1794 })
1795 } else {
1796 None
1797 }
1798 })
1799 }
1800
1801 fn update_message_headers(&mut self, cx: &mut ViewContext<Self>) {
1802 self.editor.update(cx, |editor, cx| {
1803 let buffer = editor.buffer().read(cx).snapshot(cx);
1804 let excerpt_id = *buffer.as_singleton().unwrap().0;
1805 let old_blocks = std::mem::take(&mut self.blocks);
1806 let new_blocks = self
1807 .conversation
1808 .read(cx)
1809 .messages(cx)
1810 .map(|message| BlockProperties {
1811 position: buffer.anchor_in_excerpt(excerpt_id, message.anchor),
1812 height: 2,
1813 style: BlockStyle::Sticky,
1814 render: Arc::new({
1815 let conversation = self.conversation.clone();
1816 // let metadata = message.metadata.clone();
1817 // let message = message.clone();
1818 move |cx| {
1819 enum Sender {}
1820 enum ErrorTooltip {}
1821
1822 let theme = theme::current(cx);
1823 let style = &theme.assistant;
1824 let message_id = message.id;
1825 let sender = MouseEventHandler::<Sender, _>::new(
1826 message_id.0,
1827 cx,
1828 |state, _| match message.role {
1829 Role::User => {
1830 let style = style.user_sender.style_for(state);
1831 Label::new("You", style.text.clone())
1832 .contained()
1833 .with_style(style.container)
1834 }
1835 Role::Assistant => {
1836 let style = style.assistant_sender.style_for(state);
1837 Label::new("Assistant", style.text.clone())
1838 .contained()
1839 .with_style(style.container)
1840 }
1841 Role::System => {
1842 let style = style.system_sender.style_for(state);
1843 Label::new("System", style.text.clone())
1844 .contained()
1845 .with_style(style.container)
1846 }
1847 },
1848 )
1849 .with_cursor_style(CursorStyle::PointingHand)
1850 .on_down(MouseButton::Left, {
1851 let conversation = conversation.clone();
1852 move |_, _, cx| {
1853 conversation.update(cx, |conversation, cx| {
1854 conversation.cycle_message_roles(
1855 HashSet::from_iter(Some(message_id)),
1856 cx,
1857 )
1858 })
1859 }
1860 });
1861
1862 Flex::row()
1863 .with_child(sender.aligned())
1864 .with_child(
1865 Label::new(
1866 message.sent_at.format("%I:%M%P").to_string(),
1867 style.sent_at.text.clone(),
1868 )
1869 .contained()
1870 .with_style(style.sent_at.container)
1871 .aligned(),
1872 )
1873 .with_children(
1874 if let MessageStatus::Error(error) = &message.status {
1875 Some(
1876 Svg::new("icons/circle_x_mark_12.svg")
1877 .with_color(style.error_icon.color)
1878 .constrained()
1879 .with_width(style.error_icon.width)
1880 .contained()
1881 .with_style(style.error_icon.container)
1882 .with_tooltip::<ErrorTooltip>(
1883 message_id.0,
1884 error.to_string(),
1885 None,
1886 theme.tooltip.clone(),
1887 cx,
1888 )
1889 .aligned(),
1890 )
1891 } else {
1892 None
1893 },
1894 )
1895 .aligned()
1896 .left()
1897 .contained()
1898 .with_style(style.message_header)
1899 .into_any()
1900 }
1901 }),
1902 disposition: BlockDisposition::Above,
1903 })
1904 .collect::<Vec<_>>();
1905
1906 editor.remove_blocks(old_blocks, None, cx);
1907 let ids = editor.insert_blocks(new_blocks, None, cx);
1908 self.blocks = HashSet::from_iter(ids);
1909 });
1910 }
1911
1912 fn quote_selection(
1913 workspace: &mut Workspace,
1914 _: &QuoteSelection,
1915 cx: &mut ViewContext<Workspace>,
1916 ) {
1917 let Some(panel) = workspace.panel::<AssistantPanel>(cx) else {
1918 return;
1919 };
1920 let Some(editor) = workspace.active_item(cx).and_then(|item| item.downcast::<Editor>()) else {
1921 return;
1922 };
1923
1924 let text = editor.read_with(cx, |editor, cx| {
1925 let range = editor.selections.newest::<usize>(cx).range();
1926 let buffer = editor.buffer().read(cx).snapshot(cx);
1927 let start_language = buffer.language_at(range.start);
1928 let end_language = buffer.language_at(range.end);
1929 let language_name = if start_language == end_language {
1930 start_language.map(|language| language.name())
1931 } else {
1932 None
1933 };
1934 let language_name = language_name.as_deref().unwrap_or("").to_lowercase();
1935
1936 let selected_text = buffer.text_for_range(range).collect::<String>();
1937 if selected_text.is_empty() {
1938 None
1939 } else {
1940 Some(if language_name == "markdown" {
1941 selected_text
1942 .lines()
1943 .map(|line| format!("> {}", line))
1944 .collect::<Vec<_>>()
1945 .join("\n")
1946 } else {
1947 format!("```{language_name}\n{selected_text}\n```")
1948 })
1949 }
1950 });
1951
1952 // Activate the panel
1953 if !panel.read(cx).has_focus(cx) {
1954 workspace.toggle_panel_focus::<AssistantPanel>(cx);
1955 }
1956
1957 if let Some(text) = text {
1958 panel.update(cx, |panel, cx| {
1959 let conversation = panel
1960 .active_editor()
1961 .cloned()
1962 .unwrap_or_else(|| panel.new_conversation(cx));
1963 conversation.update(cx, |conversation, cx| {
1964 conversation
1965 .editor
1966 .update(cx, |editor, cx| editor.insert(&text, cx))
1967 });
1968 });
1969 }
1970 }
1971
1972 fn copy(&mut self, _: &editor::Copy, cx: &mut ViewContext<Self>) {
1973 let editor = self.editor.read(cx);
1974 let conversation = self.conversation.read(cx);
1975 if editor.selections.count() == 1 {
1976 let selection = editor.selections.newest::<usize>(cx);
1977 let mut copied_text = String::new();
1978 let mut spanned_messages = 0;
1979 for message in conversation.messages(cx) {
1980 if message.offset_range.start >= selection.range().end {
1981 break;
1982 } else if message.offset_range.end >= selection.range().start {
1983 let range = cmp::max(message.offset_range.start, selection.range().start)
1984 ..cmp::min(message.offset_range.end, selection.range().end);
1985 if !range.is_empty() {
1986 spanned_messages += 1;
1987 write!(&mut copied_text, "## {}\n\n", message.role).unwrap();
1988 for chunk in conversation.buffer.read(cx).text_for_range(range) {
1989 copied_text.push_str(&chunk);
1990 }
1991 copied_text.push('\n');
1992 }
1993 }
1994 }
1995
1996 if spanned_messages > 1 {
1997 cx.platform()
1998 .write_to_clipboard(ClipboardItem::new(copied_text));
1999 return;
2000 }
2001 }
2002
2003 cx.propagate_action();
2004 }
2005
2006 fn split(&mut self, _: &Split, cx: &mut ViewContext<Self>) {
2007 self.conversation.update(cx, |conversation, cx| {
2008 let selections = self.editor.read(cx).selections.disjoint_anchors();
2009 for selection in selections.into_iter() {
2010 let buffer = self.editor.read(cx).buffer().read(cx).snapshot(cx);
2011 let range = selection
2012 .map(|endpoint| endpoint.to_offset(&buffer))
2013 .range();
2014 conversation.split_message(range, cx);
2015 }
2016 });
2017 }
2018
2019 fn save(&mut self, _: &Save, cx: &mut ViewContext<Self>) {
2020 self.conversation.update(cx, |conversation, cx| {
2021 conversation.save(None, self.fs.clone(), cx)
2022 });
2023 }
2024
2025 fn cycle_model(&mut self, cx: &mut ViewContext<Self>) {
2026 self.conversation.update(cx, |conversation, cx| {
2027 let new_model = match conversation.model.as_str() {
2028 "gpt-4-0613" => "gpt-3.5-turbo-0613",
2029 _ => "gpt-4-0613",
2030 };
2031 conversation.set_model(new_model.into(), cx);
2032 });
2033 }
2034
2035 fn title(&self, cx: &AppContext) -> String {
2036 self.conversation
2037 .read(cx)
2038 .summary
2039 .as_ref()
2040 .map(|summary| summary.text.clone())
2041 .unwrap_or_else(|| "New Conversation".into())
2042 }
2043
2044 fn render_current_model(
2045 &self,
2046 style: &AssistantStyle,
2047 cx: &mut ViewContext<Self>,
2048 ) -> impl Element<Self> {
2049 enum Model {}
2050
2051 MouseEventHandler::<Model, _>::new(0, cx, |state, cx| {
2052 let style = style.model.style_for(state);
2053 Label::new(self.conversation.read(cx).model.clone(), style.text.clone())
2054 .contained()
2055 .with_style(style.container)
2056 })
2057 .with_cursor_style(CursorStyle::PointingHand)
2058 .on_click(MouseButton::Left, |_, this, cx| this.cycle_model(cx))
2059 }
2060
2061 fn render_remaining_tokens(
2062 &self,
2063 style: &AssistantStyle,
2064 cx: &mut ViewContext<Self>,
2065 ) -> Option<impl Element<Self>> {
2066 let remaining_tokens = self.conversation.read(cx).remaining_tokens()?;
2067 let remaining_tokens_style = if remaining_tokens <= 0 {
2068 &style.no_remaining_tokens
2069 } else {
2070 &style.remaining_tokens
2071 };
2072 Some(
2073 Label::new(
2074 remaining_tokens.to_string(),
2075 remaining_tokens_style.text.clone(),
2076 )
2077 .contained()
2078 .with_style(remaining_tokens_style.container),
2079 )
2080 }
2081}
2082
2083impl Entity for ConversationEditor {
2084 type Event = ConversationEditorEvent;
2085}
2086
2087impl View for ConversationEditor {
2088 fn ui_name() -> &'static str {
2089 "ConversationEditor"
2090 }
2091
2092 fn render(&mut self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
2093 let theme = &theme::current(cx).assistant;
2094 Stack::new()
2095 .with_child(
2096 ChildView::new(&self.editor, cx)
2097 .contained()
2098 .with_style(theme.container),
2099 )
2100 .with_child(
2101 Flex::row()
2102 .with_child(self.render_current_model(theme, cx))
2103 .with_children(self.render_remaining_tokens(theme, cx))
2104 .aligned()
2105 .top()
2106 .right(),
2107 )
2108 .into_any()
2109 }
2110
2111 fn focus_in(&mut self, _: gpui::AnyViewHandle, cx: &mut ViewContext<Self>) {
2112 if cx.is_self_focused() {
2113 cx.focus(&self.editor);
2114 }
2115 }
2116}
2117
2118#[derive(Clone, Debug)]
2119struct MessageAnchor {
2120 id: MessageId,
2121 start: language::Anchor,
2122}
2123
2124#[derive(Clone, Debug)]
2125pub struct Message {
2126 offset_range: Range<usize>,
2127 index_range: Range<usize>,
2128 id: MessageId,
2129 anchor: language::Anchor,
2130 role: Role,
2131 sent_at: DateTime<Local>,
2132 status: MessageStatus,
2133}
2134
2135impl Message {
2136 fn to_open_ai_message(&self, buffer: &Buffer) -> RequestMessage {
2137 let mut content = format!("[Message {}]\n", self.id.0).to_string();
2138 content.extend(buffer.text_for_range(self.offset_range.clone()));
2139 RequestMessage {
2140 role: self.role,
2141 content: content.trim_end().into(),
2142 }
2143 }
2144}
2145
2146async fn stream_completion(
2147 api_key: String,
2148 executor: Arc<Background>,
2149 mut request: OpenAIRequest,
2150) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
2151 request.stream = true;
2152
2153 let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
2154
2155 let json_data = serde_json::to_string(&request)?;
2156 let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
2157 .header("Content-Type", "application/json")
2158 .header("Authorization", format!("Bearer {}", api_key))
2159 .body(json_data)?
2160 .send_async()
2161 .await?;
2162
2163 let status = response.status();
2164 if status == StatusCode::OK {
2165 executor
2166 .spawn(async move {
2167 let mut lines = BufReader::new(response.body_mut()).lines();
2168
2169 fn parse_line(
2170 line: Result<String, io::Error>,
2171 ) -> Result<Option<OpenAIResponseStreamEvent>> {
2172 if let Some(data) = line?.strip_prefix("data: ") {
2173 let event = serde_json::from_str(&data)?;
2174 Ok(Some(event))
2175 } else {
2176 Ok(None)
2177 }
2178 }
2179
2180 while let Some(line) = lines.next().await {
2181 if let Some(event) = parse_line(line).transpose() {
2182 let done = event.as_ref().map_or(false, |event| {
2183 event
2184 .choices
2185 .last()
2186 .map_or(false, |choice| choice.finish_reason.is_some())
2187 });
2188 if tx.unbounded_send(event).is_err() {
2189 break;
2190 }
2191
2192 if done {
2193 break;
2194 }
2195 }
2196 }
2197
2198 anyhow::Ok(())
2199 })
2200 .detach();
2201
2202 Ok(rx)
2203 } else {
2204 let mut body = String::new();
2205 response.body_mut().read_to_string(&mut body).await?;
2206
2207 #[derive(Deserialize)]
2208 struct OpenAIResponse {
2209 error: OpenAIError,
2210 }
2211
2212 #[derive(Deserialize)]
2213 struct OpenAIError {
2214 message: String,
2215 }
2216
2217 match serde_json::from_str::<OpenAIResponse>(&body) {
2218 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
2219 "Failed to connect to OpenAI API: {}",
2220 response.error.message,
2221 )),
2222
2223 _ => Err(anyhow!(
2224 "Failed to connect to OpenAI API: {} {}",
2225 response.status(),
2226 body,
2227 )),
2228 }
2229 }
2230}
2231
2232#[cfg(test)]
2233mod tests {
2234 use super::*;
2235 use crate::MessageId;
2236 use gpui::AppContext;
2237
2238 #[gpui::test]
2239 fn test_inserting_and_removing_messages(cx: &mut AppContext) {
2240 let registry = Arc::new(LanguageRegistry::test());
2241 let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
2242 let buffer = conversation.read(cx).buffer.clone();
2243
2244 let message_1 = conversation.read(cx).message_anchors[0].clone();
2245 assert_eq!(
2246 messages(&conversation, cx),
2247 vec![(message_1.id, Role::User, 0..0)]
2248 );
2249
2250 let message_2 = conversation.update(cx, |conversation, cx| {
2251 conversation
2252 .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
2253 .unwrap()
2254 });
2255 assert_eq!(
2256 messages(&conversation, cx),
2257 vec![
2258 (message_1.id, Role::User, 0..1),
2259 (message_2.id, Role::Assistant, 1..1)
2260 ]
2261 );
2262
2263 buffer.update(cx, |buffer, cx| {
2264 buffer.edit([(0..0, "1"), (1..1, "2")], None, cx)
2265 });
2266 assert_eq!(
2267 messages(&conversation, cx),
2268 vec![
2269 (message_1.id, Role::User, 0..2),
2270 (message_2.id, Role::Assistant, 2..3)
2271 ]
2272 );
2273
2274 let message_3 = conversation.update(cx, |conversation, cx| {
2275 conversation
2276 .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
2277 .unwrap()
2278 });
2279 assert_eq!(
2280 messages(&conversation, cx),
2281 vec![
2282 (message_1.id, Role::User, 0..2),
2283 (message_2.id, Role::Assistant, 2..4),
2284 (message_3.id, Role::User, 4..4)
2285 ]
2286 );
2287
2288 let message_4 = conversation.update(cx, |conversation, cx| {
2289 conversation
2290 .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
2291 .unwrap()
2292 });
2293 assert_eq!(
2294 messages(&conversation, cx),
2295 vec![
2296 (message_1.id, Role::User, 0..2),
2297 (message_2.id, Role::Assistant, 2..4),
2298 (message_4.id, Role::User, 4..5),
2299 (message_3.id, Role::User, 5..5),
2300 ]
2301 );
2302
2303 buffer.update(cx, |buffer, cx| {
2304 buffer.edit([(4..4, "C"), (5..5, "D")], None, cx)
2305 });
2306 assert_eq!(
2307 messages(&conversation, cx),
2308 vec![
2309 (message_1.id, Role::User, 0..2),
2310 (message_2.id, Role::Assistant, 2..4),
2311 (message_4.id, Role::User, 4..6),
2312 (message_3.id, Role::User, 6..7),
2313 ]
2314 );
2315
2316 // Deleting across message boundaries merges the messages.
2317 buffer.update(cx, |buffer, cx| buffer.edit([(1..4, "")], None, cx));
2318 assert_eq!(
2319 messages(&conversation, cx),
2320 vec![
2321 (message_1.id, Role::User, 0..3),
2322 (message_3.id, Role::User, 3..4),
2323 ]
2324 );
2325
2326 // Undoing the deletion should also undo the merge.
2327 buffer.update(cx, |buffer, cx| buffer.undo(cx));
2328 assert_eq!(
2329 messages(&conversation, cx),
2330 vec![
2331 (message_1.id, Role::User, 0..2),
2332 (message_2.id, Role::Assistant, 2..4),
2333 (message_4.id, Role::User, 4..6),
2334 (message_3.id, Role::User, 6..7),
2335 ]
2336 );
2337
2338 // Redoing the deletion should also redo the merge.
2339 buffer.update(cx, |buffer, cx| buffer.redo(cx));
2340 assert_eq!(
2341 messages(&conversation, cx),
2342 vec![
2343 (message_1.id, Role::User, 0..3),
2344 (message_3.id, Role::User, 3..4),
2345 ]
2346 );
2347
2348 // Ensure we can still insert after a merged message.
2349 let message_5 = conversation.update(cx, |conversation, cx| {
2350 conversation
2351 .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
2352 .unwrap()
2353 });
2354 assert_eq!(
2355 messages(&conversation, cx),
2356 vec![
2357 (message_1.id, Role::User, 0..3),
2358 (message_5.id, Role::System, 3..4),
2359 (message_3.id, Role::User, 4..5)
2360 ]
2361 );
2362 }
2363
2364 #[gpui::test]
2365 fn test_message_splitting(cx: &mut AppContext) {
2366 let registry = Arc::new(LanguageRegistry::test());
2367 let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
2368 let buffer = conversation.read(cx).buffer.clone();
2369
2370 let message_1 = conversation.read(cx).message_anchors[0].clone();
2371 assert_eq!(
2372 messages(&conversation, cx),
2373 vec![(message_1.id, Role::User, 0..0)]
2374 );
2375
2376 buffer.update(cx, |buffer, cx| {
2377 buffer.edit([(0..0, "aaa\nbbb\nccc\nddd\n")], None, cx)
2378 });
2379
2380 let (_, message_2) =
2381 conversation.update(cx, |conversation, cx| conversation.split_message(3..3, cx));
2382 let message_2 = message_2.unwrap();
2383
2384 // We recycle newlines in the middle of a split message
2385 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\nddd\n");
2386 assert_eq!(
2387 messages(&conversation, cx),
2388 vec![
2389 (message_1.id, Role::User, 0..4),
2390 (message_2.id, Role::User, 4..16),
2391 ]
2392 );
2393
2394 let (_, message_3) =
2395 conversation.update(cx, |conversation, cx| conversation.split_message(3..3, cx));
2396 let message_3 = message_3.unwrap();
2397
2398 // We don't recycle newlines at the end of a split message
2399 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
2400 assert_eq!(
2401 messages(&conversation, cx),
2402 vec![
2403 (message_1.id, Role::User, 0..4),
2404 (message_3.id, Role::User, 4..5),
2405 (message_2.id, Role::User, 5..17),
2406 ]
2407 );
2408
2409 let (_, message_4) =
2410 conversation.update(cx, |conversation, cx| conversation.split_message(9..9, cx));
2411 let message_4 = message_4.unwrap();
2412 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
2413 assert_eq!(
2414 messages(&conversation, cx),
2415 vec![
2416 (message_1.id, Role::User, 0..4),
2417 (message_3.id, Role::User, 4..5),
2418 (message_2.id, Role::User, 5..9),
2419 (message_4.id, Role::User, 9..17),
2420 ]
2421 );
2422
2423 let (_, message_5) =
2424 conversation.update(cx, |conversation, cx| conversation.split_message(9..9, cx));
2425 let message_5 = message_5.unwrap();
2426 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\nddd\n");
2427 assert_eq!(
2428 messages(&conversation, cx),
2429 vec![
2430 (message_1.id, Role::User, 0..4),
2431 (message_3.id, Role::User, 4..5),
2432 (message_2.id, Role::User, 5..9),
2433 (message_4.id, Role::User, 9..10),
2434 (message_5.id, Role::User, 10..18),
2435 ]
2436 );
2437
2438 let (message_6, message_7) = conversation.update(cx, |conversation, cx| {
2439 conversation.split_message(14..16, cx)
2440 });
2441 let message_6 = message_6.unwrap();
2442 let message_7 = message_7.unwrap();
2443 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\ndd\nd\n");
2444 assert_eq!(
2445 messages(&conversation, cx),
2446 vec![
2447 (message_1.id, Role::User, 0..4),
2448 (message_3.id, Role::User, 4..5),
2449 (message_2.id, Role::User, 5..9),
2450 (message_4.id, Role::User, 9..10),
2451 (message_5.id, Role::User, 10..14),
2452 (message_6.id, Role::User, 14..17),
2453 (message_7.id, Role::User, 17..19),
2454 ]
2455 );
2456 }
2457
2458 #[gpui::test]
2459 fn test_messages_for_offsets(cx: &mut AppContext) {
2460 let registry = Arc::new(LanguageRegistry::test());
2461 let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
2462 let buffer = conversation.read(cx).buffer.clone();
2463
2464 let message_1 = conversation.read(cx).message_anchors[0].clone();
2465 assert_eq!(
2466 messages(&conversation, cx),
2467 vec![(message_1.id, Role::User, 0..0)]
2468 );
2469
2470 buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
2471 let message_2 = conversation
2472 .update(cx, |conversation, cx| {
2473 conversation.insert_message_after(message_1.id, Role::User, MessageStatus::Done, cx)
2474 })
2475 .unwrap();
2476 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx));
2477
2478 let message_3 = conversation
2479 .update(cx, |conversation, cx| {
2480 conversation.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
2481 })
2482 .unwrap();
2483 buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx));
2484
2485 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc");
2486 assert_eq!(
2487 messages(&conversation, cx),
2488 vec![
2489 (message_1.id, Role::User, 0..4),
2490 (message_2.id, Role::User, 4..8),
2491 (message_3.id, Role::User, 8..11)
2492 ]
2493 );
2494
2495 assert_eq!(
2496 message_ids_for_offsets(&conversation, &[0, 4, 9], cx),
2497 [message_1.id, message_2.id, message_3.id]
2498 );
2499 assert_eq!(
2500 message_ids_for_offsets(&conversation, &[0, 1, 11], cx),
2501 [message_1.id, message_3.id]
2502 );
2503
2504 let message_4 = conversation
2505 .update(cx, |conversation, cx| {
2506 conversation.insert_message_after(message_3.id, Role::User, MessageStatus::Done, cx)
2507 })
2508 .unwrap();
2509 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\n");
2510 assert_eq!(
2511 messages(&conversation, cx),
2512 vec![
2513 (message_1.id, Role::User, 0..4),
2514 (message_2.id, Role::User, 4..8),
2515 (message_3.id, Role::User, 8..12),
2516 (message_4.id, Role::User, 12..12)
2517 ]
2518 );
2519 assert_eq!(
2520 message_ids_for_offsets(&conversation, &[0, 4, 8, 12], cx),
2521 [message_1.id, message_2.id, message_3.id, message_4.id]
2522 );
2523
2524 fn message_ids_for_offsets(
2525 conversation: &ModelHandle<Conversation>,
2526 offsets: &[usize],
2527 cx: &AppContext,
2528 ) -> Vec<MessageId> {
2529 conversation
2530 .read(cx)
2531 .messages_for_offsets(offsets.iter().copied(), cx)
2532 .into_iter()
2533 .map(|message| message.id)
2534 .collect()
2535 }
2536 }
2537
2538 #[gpui::test]
2539 fn test_serialization(cx: &mut AppContext) {
2540 let registry = Arc::new(LanguageRegistry::test());
2541 let conversation =
2542 cx.add_model(|cx| Conversation::new(Default::default(), registry.clone(), cx));
2543 let buffer = conversation.read(cx).buffer.clone();
2544 let message_0 = conversation.read(cx).message_anchors[0].id;
2545 let message_1 = conversation.update(cx, |conversation, cx| {
2546 conversation
2547 .insert_message_after(message_0, Role::Assistant, MessageStatus::Done, cx)
2548 .unwrap()
2549 });
2550 let message_2 = conversation.update(cx, |conversation, cx| {
2551 conversation
2552 .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
2553 .unwrap()
2554 });
2555 buffer.update(cx, |buffer, cx| {
2556 buffer.edit([(0..0, "a"), (1..1, "b\nc")], None, cx);
2557 buffer.finalize_last_transaction();
2558 });
2559 let _message_3 = conversation.update(cx, |conversation, cx| {
2560 conversation
2561 .insert_message_after(message_2.id, Role::System, MessageStatus::Done, cx)
2562 .unwrap()
2563 });
2564 buffer.update(cx, |buffer, cx| buffer.undo(cx));
2565 assert_eq!(buffer.read(cx).text(), "a\nb\nc\n");
2566 assert_eq!(
2567 messages(&conversation, cx),
2568 [
2569 (message_0, Role::User, 0..2),
2570 (message_1.id, Role::Assistant, 2..6),
2571 (message_2.id, Role::System, 6..6),
2572 ]
2573 );
2574
2575 let deserialized_conversation = cx.add_model(|cx| {
2576 Conversation::deserialize(
2577 conversation.read(cx).serialize(cx),
2578 Default::default(),
2579 Default::default(),
2580 registry.clone(),
2581 cx,
2582 )
2583 });
2584 let deserialized_buffer = deserialized_conversation.read(cx).buffer.clone();
2585 assert_eq!(deserialized_buffer.read(cx).text(), "a\nb\nc\n");
2586 assert_eq!(
2587 messages(&deserialized_conversation, cx),
2588 [
2589 (message_0, Role::User, 0..2),
2590 (message_1.id, Role::Assistant, 2..6),
2591 (message_2.id, Role::System, 6..6),
2592 ]
2593 );
2594 }
2595
2596 fn messages(
2597 conversation: &ModelHandle<Conversation>,
2598 cx: &AppContext,
2599 ) -> Vec<(MessageId, Role, Range<usize>)> {
2600 conversation
2601 .read(cx)
2602 .messages(cx)
2603 .map(|message| (message.id, message.role, message.offset_range))
2604 .collect()
2605 }
2606}