1use crate::{
2 assistant_settings::{AssistantDockPosition, AssistantSettings},
3 MessageId, MessageMetadata, MessageStatus, OpenAIRequest, OpenAIResponseStreamEvent,
4 RequestMessage, Role, SavedConversation, SavedConversationMetadata, SavedMessage,
5};
6use anyhow::{anyhow, Result};
7use chrono::{DateTime, Local};
8use collections::{HashMap, HashSet};
9use editor::{
10 display_map::{BlockDisposition, BlockId, BlockProperties, BlockStyle, ToDisplayPoint},
11 scroll::autoscroll::{Autoscroll, AutoscrollStrategy},
12 Anchor, Editor, ToOffset,
13};
14use fs::Fs;
15use futures::{io::BufReader, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
16use gpui::{
17 actions,
18 elements::*,
19 executor::Background,
20 geometry::vector::{vec2f, Vector2F},
21 platform::{CursorStyle, MouseButton},
22 Action, AppContext, AsyncAppContext, ClipboardItem, Entity, ModelContext, ModelHandle,
23 Subscription, Task, View, ViewContext, ViewHandle, WeakViewHandle, WindowContext,
24};
25use isahc::{http::StatusCode, Request, RequestExt};
26use language::{language_settings::SoftWrap, Buffer, LanguageRegistry, ToOffset as _};
27use search::BufferSearchBar;
28use serde::Deserialize;
29use settings::SettingsStore;
30use std::{
31 cell::RefCell,
32 cmp, env,
33 fmt::Write,
34 io, iter,
35 ops::Range,
36 path::{Path, PathBuf},
37 rc::Rc,
38 sync::Arc,
39 time::Duration,
40};
41use theme::AssistantStyle;
42use util::{channel::ReleaseChannel, paths::CONVERSATIONS_DIR, post_inc, ResultExt, TryFutureExt};
43use workspace::{
44 dock::{DockPosition, Panel},
45 searchable::Direction,
46 Save, ToggleZoom, Toolbar, Workspace,
47};
48
49const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
50
51actions!(
52 assistant,
53 [
54 NewConversation,
55 Assist,
56 Split,
57 CycleMessageRole,
58 QuoteSelection,
59 ToggleFocus,
60 ResetKey,
61 ]
62);
63
64pub fn init(cx: &mut AppContext) {
65 if *util::channel::RELEASE_CHANNEL == ReleaseChannel::Stable {
66 cx.update_default_global::<collections::CommandPaletteFilter, _, _>(move |filter, _cx| {
67 filter.filtered_namespaces.insert("assistant");
68 });
69 }
70
71 settings::register::<AssistantSettings>(cx);
72 cx.add_action(
73 |this: &mut AssistantPanel,
74 _: &workspace::NewFile,
75 cx: &mut ViewContext<AssistantPanel>| {
76 this.new_conversation(cx);
77 },
78 );
79 cx.add_action(ConversationEditor::assist);
80 cx.capture_action(ConversationEditor::cancel_last_assist);
81 cx.capture_action(ConversationEditor::save);
82 cx.add_action(ConversationEditor::quote_selection);
83 cx.capture_action(ConversationEditor::copy);
84 cx.add_action(ConversationEditor::split);
85 cx.capture_action(ConversationEditor::cycle_message_role);
86 cx.add_action(AssistantPanel::save_api_key);
87 cx.add_action(AssistantPanel::reset_api_key);
88 cx.add_action(AssistantPanel::toggle_zoom);
89 cx.add_action(AssistantPanel::deploy);
90 cx.add_action(AssistantPanel::select_next_match);
91 cx.add_action(AssistantPanel::select_prev_match);
92 cx.add_action(AssistantPanel::handle_editor_cancel);
93 cx.add_action(
94 |workspace: &mut Workspace, _: &ToggleFocus, cx: &mut ViewContext<Workspace>| {
95 workspace.toggle_panel_focus::<AssistantPanel>(cx);
96 },
97 );
98}
99
100#[derive(Debug)]
101pub enum AssistantPanelEvent {
102 ZoomIn,
103 ZoomOut,
104 Focus,
105 Close,
106 DockPositionChanged,
107}
108
109pub struct AssistantPanel {
110 workspace: WeakViewHandle<Workspace>,
111 width: Option<f32>,
112 height: Option<f32>,
113 active_editor_index: Option<usize>,
114 prev_active_editor_index: Option<usize>,
115 editors: Vec<ViewHandle<ConversationEditor>>,
116 saved_conversations: Vec<SavedConversationMetadata>,
117 saved_conversations_list_state: UniformListState,
118 zoomed: bool,
119 has_focus: bool,
120 toolbar: ViewHandle<Toolbar>,
121 api_key: Rc<RefCell<Option<String>>>,
122 api_key_editor: Option<ViewHandle<Editor>>,
123 has_read_credentials: bool,
124 languages: Arc<LanguageRegistry>,
125 fs: Arc<dyn Fs>,
126 subscriptions: Vec<Subscription>,
127 _watch_saved_conversations: Task<Result<()>>,
128}
129
130impl AssistantPanel {
131 pub fn load(
132 workspace: WeakViewHandle<Workspace>,
133 cx: AsyncAppContext,
134 ) -> Task<Result<ViewHandle<Self>>> {
135 cx.spawn(|mut cx| async move {
136 let fs = workspace.read_with(&cx, |workspace, _| workspace.app_state().fs.clone())?;
137 let saved_conversations = SavedConversationMetadata::list(fs.clone())
138 .await
139 .log_err()
140 .unwrap_or_default();
141
142 // TODO: deserialize state.
143 let workspace_handle = workspace.clone();
144 workspace.update(&mut cx, |workspace, cx| {
145 cx.add_view::<Self, _>(|cx| {
146 const CONVERSATION_WATCH_DURATION: Duration = Duration::from_millis(100);
147 let _watch_saved_conversations = cx.spawn(move |this, mut cx| async move {
148 let mut events = fs
149 .watch(&CONVERSATIONS_DIR, CONVERSATION_WATCH_DURATION)
150 .await;
151 while events.next().await.is_some() {
152 let saved_conversations = SavedConversationMetadata::list(fs.clone())
153 .await
154 .log_err()
155 .unwrap_or_default();
156 this.update(&mut cx, |this, _| {
157 this.saved_conversations = saved_conversations
158 })
159 .ok();
160 }
161
162 anyhow::Ok(())
163 });
164
165 let toolbar = cx.add_view(|cx| {
166 let mut toolbar = Toolbar::new(None);
167 toolbar.set_can_navigate(false, cx);
168 toolbar.add_item(cx.add_view(|cx| BufferSearchBar::new(cx)), cx);
169 toolbar
170 });
171 let mut this = Self {
172 workspace: workspace_handle,
173 active_editor_index: Default::default(),
174 prev_active_editor_index: Default::default(),
175 editors: Default::default(),
176 saved_conversations,
177 saved_conversations_list_state: Default::default(),
178 zoomed: false,
179 has_focus: false,
180 toolbar,
181 api_key: Rc::new(RefCell::new(None)),
182 api_key_editor: None,
183 has_read_credentials: false,
184 languages: workspace.app_state().languages.clone(),
185 fs: workspace.app_state().fs.clone(),
186 width: None,
187 height: None,
188 subscriptions: Default::default(),
189 _watch_saved_conversations,
190 };
191
192 let mut old_dock_position = this.position(cx);
193 this.subscriptions =
194 vec![cx.observe_global::<SettingsStore, _>(move |this, cx| {
195 let new_dock_position = this.position(cx);
196 if new_dock_position != old_dock_position {
197 old_dock_position = new_dock_position;
198 cx.emit(AssistantPanelEvent::DockPositionChanged);
199 }
200 })];
201
202 this
203 })
204 })
205 })
206 }
207
208 fn new_conversation(&mut self, cx: &mut ViewContext<Self>) -> ViewHandle<ConversationEditor> {
209 let editor = cx.add_view(|cx| {
210 ConversationEditor::new(
211 self.api_key.clone(),
212 self.languages.clone(),
213 self.fs.clone(),
214 cx,
215 )
216 });
217 self.add_conversation(editor.clone(), cx);
218 editor
219 }
220
221 fn add_conversation(
222 &mut self,
223 editor: ViewHandle<ConversationEditor>,
224 cx: &mut ViewContext<Self>,
225 ) {
226 self.subscriptions
227 .push(cx.subscribe(&editor, Self::handle_conversation_editor_event));
228
229 let conversation = editor.read(cx).conversation.clone();
230 self.subscriptions
231 .push(cx.observe(&conversation, |_, _, cx| cx.notify()));
232
233 let index = self.editors.len();
234 self.editors.push(editor);
235 self.set_active_editor_index(Some(index), cx);
236 }
237
238 fn set_active_editor_index(&mut self, index: Option<usize>, cx: &mut ViewContext<Self>) {
239 self.prev_active_editor_index = self.active_editor_index;
240 self.active_editor_index = index;
241 if let Some(editor) = self.active_editor() {
242 let editor = editor.read(cx).editor.clone();
243 self.toolbar.update(cx, |toolbar, cx| {
244 toolbar.set_active_item(Some(&editor), cx);
245 });
246 if self.has_focus(cx) {
247 cx.focus(&editor);
248 }
249 } else {
250 self.toolbar.update(cx, |toolbar, cx| {
251 toolbar.set_active_item(None, cx);
252 });
253 }
254
255 cx.notify();
256 }
257
258 fn handle_conversation_editor_event(
259 &mut self,
260 _: ViewHandle<ConversationEditor>,
261 event: &ConversationEditorEvent,
262 cx: &mut ViewContext<Self>,
263 ) {
264 match event {
265 ConversationEditorEvent::TabContentChanged => cx.notify(),
266 }
267 }
268
269 fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
270 if let Some(api_key) = self
271 .api_key_editor
272 .as_ref()
273 .map(|editor| editor.read(cx).text(cx))
274 {
275 if !api_key.is_empty() {
276 cx.platform()
277 .write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
278 .log_err();
279 *self.api_key.borrow_mut() = Some(api_key);
280 self.api_key_editor.take();
281 cx.focus_self();
282 cx.notify();
283 }
284 } else {
285 cx.propagate_action();
286 }
287 }
288
289 fn reset_api_key(&mut self, _: &ResetKey, cx: &mut ViewContext<Self>) {
290 cx.platform().delete_credentials(OPENAI_API_URL).log_err();
291 self.api_key.take();
292 self.api_key_editor = Some(build_api_key_editor(cx));
293 cx.focus_self();
294 cx.notify();
295 }
296
297 fn toggle_zoom(&mut self, _: &workspace::ToggleZoom, cx: &mut ViewContext<Self>) {
298 if self.zoomed {
299 cx.emit(AssistantPanelEvent::ZoomOut)
300 } else {
301 cx.emit(AssistantPanelEvent::ZoomIn)
302 }
303 }
304
305 fn deploy(&mut self, action: &search::buffer_search::Deploy, cx: &mut ViewContext<Self>) {
306 if let Some(search_bar) = self.toolbar.read(cx).item_of_type::<BufferSearchBar>() {
307 if search_bar.update(cx, |search_bar, cx| search_bar.show(action.focus, true, cx)) {
308 return;
309 }
310 }
311 cx.propagate_action();
312 }
313
314 fn handle_editor_cancel(&mut self, _: &editor::Cancel, cx: &mut ViewContext<Self>) {
315 if let Some(search_bar) = self.toolbar.read(cx).item_of_type::<BufferSearchBar>() {
316 if !search_bar.read(cx).is_dismissed() {
317 search_bar.update(cx, |search_bar, cx| {
318 search_bar.dismiss(&Default::default(), cx)
319 });
320 return;
321 }
322 }
323 cx.propagate_action();
324 }
325
326 fn select_next_match(&mut self, _: &search::SelectNextMatch, cx: &mut ViewContext<Self>) {
327 if let Some(search_bar) = self.toolbar.read(cx).item_of_type::<BufferSearchBar>() {
328 search_bar.update(cx, |bar, cx| bar.select_match(Direction::Next, cx));
329 }
330 }
331
332 fn select_prev_match(&mut self, _: &search::SelectPrevMatch, cx: &mut ViewContext<Self>) {
333 if let Some(search_bar) = self.toolbar.read(cx).item_of_type::<BufferSearchBar>() {
334 search_bar.update(cx, |bar, cx| bar.select_match(Direction::Prev, cx));
335 }
336 }
337
338 fn active_editor(&self) -> Option<&ViewHandle<ConversationEditor>> {
339 self.editors.get(self.active_editor_index?)
340 }
341
342 fn render_hamburger_button(cx: &mut ViewContext<Self>) -> impl Element<Self> {
343 enum History {}
344 let theme = theme::current(cx);
345 let tooltip_style = theme::current(cx).tooltip.clone();
346 MouseEventHandler::<History, _>::new(0, cx, |state, _| {
347 let style = theme.assistant.hamburger_button.style_for(state);
348 Svg::for_style(style.icon.clone())
349 .contained()
350 .with_style(style.container)
351 })
352 .with_cursor_style(CursorStyle::PointingHand)
353 .on_click(MouseButton::Left, |_, this: &mut Self, cx| {
354 if this.active_editor().is_some() {
355 this.set_active_editor_index(None, cx);
356 } else {
357 this.set_active_editor_index(this.prev_active_editor_index, cx);
358 }
359 })
360 .with_tooltip::<History>(1, "History".into(), None, tooltip_style, cx)
361 }
362
363 fn render_editor_tools(&self, cx: &mut ViewContext<Self>) -> Vec<AnyElement<Self>> {
364 if self.active_editor().is_some() {
365 vec![
366 Self::render_split_button(cx).into_any(),
367 Self::render_quote_button(cx).into_any(),
368 Self::render_assist_button(cx).into_any(),
369 ]
370 } else {
371 Default::default()
372 }
373 }
374
375 fn render_split_button(cx: &mut ViewContext<Self>) -> impl Element<Self> {
376 let theme = theme::current(cx);
377 let tooltip_style = theme::current(cx).tooltip.clone();
378 MouseEventHandler::<Split, _>::new(0, cx, |state, _| {
379 let style = theme.assistant.split_button.style_for(state);
380 Svg::for_style(style.icon.clone())
381 .contained()
382 .with_style(style.container)
383 })
384 .with_cursor_style(CursorStyle::PointingHand)
385 .on_click(MouseButton::Left, |_, this: &mut Self, cx| {
386 if let Some(active_editor) = this.active_editor() {
387 active_editor.update(cx, |editor, cx| editor.split(&Default::default(), cx));
388 }
389 })
390 .with_tooltip::<Split>(
391 1,
392 "Split Message".into(),
393 Some(Box::new(Split)),
394 tooltip_style,
395 cx,
396 )
397 }
398
399 fn render_assist_button(cx: &mut ViewContext<Self>) -> impl Element<Self> {
400 let theme = theme::current(cx);
401 let tooltip_style = theme::current(cx).tooltip.clone();
402 MouseEventHandler::<Assist, _>::new(0, cx, |state, _| {
403 let style = theme.assistant.assist_button.style_for(state);
404 Svg::for_style(style.icon.clone())
405 .contained()
406 .with_style(style.container)
407 })
408 .with_cursor_style(CursorStyle::PointingHand)
409 .on_click(MouseButton::Left, |_, this: &mut Self, cx| {
410 if let Some(active_editor) = this.active_editor() {
411 active_editor.update(cx, |editor, cx| editor.assist(&Default::default(), cx));
412 }
413 })
414 .with_tooltip::<Assist>(
415 1,
416 "Assist".into(),
417 Some(Box::new(Assist)),
418 tooltip_style,
419 cx,
420 )
421 }
422
423 fn render_quote_button(cx: &mut ViewContext<Self>) -> impl Element<Self> {
424 let theme = theme::current(cx);
425 let tooltip_style = theme::current(cx).tooltip.clone();
426 MouseEventHandler::<QuoteSelection, _>::new(0, cx, |state, _| {
427 let style = theme.assistant.quote_button.style_for(state);
428 Svg::for_style(style.icon.clone())
429 .contained()
430 .with_style(style.container)
431 })
432 .with_cursor_style(CursorStyle::PointingHand)
433 .on_click(MouseButton::Left, |_, this: &mut Self, cx| {
434 if let Some(workspace) = this.workspace.upgrade(cx) {
435 cx.window_context().defer(move |cx| {
436 workspace.update(cx, |workspace, cx| {
437 ConversationEditor::quote_selection(workspace, &Default::default(), cx)
438 });
439 });
440 }
441 })
442 .with_tooltip::<QuoteSelection>(
443 1,
444 "Quote Selection".into(),
445 Some(Box::new(QuoteSelection)),
446 tooltip_style,
447 cx,
448 )
449 }
450
451 fn render_plus_button(cx: &mut ViewContext<Self>) -> impl Element<Self> {
452 let theme = theme::current(cx);
453 let tooltip_style = theme::current(cx).tooltip.clone();
454 MouseEventHandler::<NewConversation, _>::new(0, cx, |state, _| {
455 let style = theme.assistant.plus_button.style_for(state);
456 Svg::for_style(style.icon.clone())
457 .contained()
458 .with_style(style.container)
459 })
460 .with_cursor_style(CursorStyle::PointingHand)
461 .on_click(MouseButton::Left, |_, this: &mut Self, cx| {
462 this.new_conversation(cx);
463 })
464 .with_tooltip::<NewConversation>(
465 1,
466 "New Conversation".into(),
467 Some(Box::new(NewConversation)),
468 tooltip_style,
469 cx,
470 )
471 }
472
473 fn render_zoom_button(&self, cx: &mut ViewContext<Self>) -> impl Element<Self> {
474 enum ToggleZoomButton {}
475
476 let theme = theme::current(cx);
477 let tooltip_style = theme::current(cx).tooltip.clone();
478 let style = if self.zoomed {
479 &theme.assistant.zoom_out_button
480 } else {
481 &theme.assistant.zoom_in_button
482 };
483
484 MouseEventHandler::<ToggleZoomButton, _>::new(0, cx, |state, _| {
485 let style = style.style_for(state);
486 Svg::for_style(style.icon.clone())
487 .contained()
488 .with_style(style.container)
489 })
490 .with_cursor_style(CursorStyle::PointingHand)
491 .on_click(MouseButton::Left, |_, this, cx| {
492 this.toggle_zoom(&ToggleZoom, cx);
493 })
494 .with_tooltip::<ToggleZoom>(
495 0,
496 if self.zoomed {
497 "Zoom Out".into()
498 } else {
499 "Zoom In".into()
500 },
501 Some(Box::new(ToggleZoom)),
502 tooltip_style,
503 cx,
504 )
505 }
506
507 fn render_saved_conversation(
508 &mut self,
509 index: usize,
510 cx: &mut ViewContext<Self>,
511 ) -> impl Element<Self> {
512 let conversation = &self.saved_conversations[index];
513 let path = conversation.path.clone();
514 MouseEventHandler::<SavedConversationMetadata, _>::new(index, cx, move |state, cx| {
515 let style = &theme::current(cx).assistant.saved_conversation;
516 Flex::row()
517 .with_child(
518 Label::new(
519 conversation.mtime.format("%F %I:%M%p").to_string(),
520 style.saved_at.text.clone(),
521 )
522 .aligned()
523 .contained()
524 .with_style(style.saved_at.container),
525 )
526 .with_child(
527 Label::new(conversation.title.clone(), style.title.text.clone())
528 .aligned()
529 .contained()
530 .with_style(style.title.container),
531 )
532 .contained()
533 .with_style(*style.container.style_for(state))
534 })
535 .with_cursor_style(CursorStyle::PointingHand)
536 .on_click(MouseButton::Left, move |_, this, cx| {
537 this.open_conversation(path.clone(), cx)
538 .detach_and_log_err(cx)
539 })
540 }
541
542 fn open_conversation(&mut self, path: PathBuf, cx: &mut ViewContext<Self>) -> Task<Result<()>> {
543 if let Some(ix) = self.editor_index_for_path(&path, cx) {
544 self.set_active_editor_index(Some(ix), cx);
545 return Task::ready(Ok(()));
546 }
547
548 let fs = self.fs.clone();
549 let api_key = self.api_key.clone();
550 let languages = self.languages.clone();
551 cx.spawn(|this, mut cx| async move {
552 let saved_conversation = fs.load(&path).await?;
553 let saved_conversation = serde_json::from_str(&saved_conversation)?;
554 let conversation = cx.add_model(|cx| {
555 Conversation::deserialize(saved_conversation, path.clone(), api_key, languages, cx)
556 });
557 this.update(&mut cx, |this, cx| {
558 // If, by the time we've loaded the conversation, the user has already opened
559 // the same conversation, we don't want to open it again.
560 if let Some(ix) = this.editor_index_for_path(&path, cx) {
561 this.set_active_editor_index(Some(ix), cx);
562 } else {
563 let editor = cx
564 .add_view(|cx| ConversationEditor::for_conversation(conversation, fs, cx));
565 this.add_conversation(editor, cx);
566 }
567 })?;
568 Ok(())
569 })
570 }
571
572 fn editor_index_for_path(&self, path: &Path, cx: &AppContext) -> Option<usize> {
573 self.editors
574 .iter()
575 .position(|editor| editor.read(cx).conversation.read(cx).path.as_deref() == Some(path))
576 }
577}
578
579fn build_api_key_editor(cx: &mut ViewContext<AssistantPanel>) -> ViewHandle<Editor> {
580 cx.add_view(|cx| {
581 let mut editor = Editor::single_line(
582 Some(Arc::new(|theme| theme.assistant.api_key_editor.clone())),
583 cx,
584 );
585 editor.set_placeholder_text("sk-000000000000000000000000000000000000000000000000", cx);
586 editor
587 })
588}
589
590impl Entity for AssistantPanel {
591 type Event = AssistantPanelEvent;
592}
593
594impl View for AssistantPanel {
595 fn ui_name() -> &'static str {
596 "AssistantPanel"
597 }
598
599 fn render(&mut self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
600 let theme = &theme::current(cx);
601 let style = &theme.assistant;
602 if let Some(api_key_editor) = self.api_key_editor.as_ref() {
603 Flex::column()
604 .with_child(
605 Text::new(
606 "Paste your OpenAI API key and press Enter to use the assistant",
607 style.api_key_prompt.text.clone(),
608 )
609 .aligned(),
610 )
611 .with_child(
612 ChildView::new(api_key_editor, cx)
613 .contained()
614 .with_style(style.api_key_editor.container)
615 .aligned(),
616 )
617 .contained()
618 .with_style(style.api_key_prompt.container)
619 .aligned()
620 .into_any()
621 } else {
622 let title = self.active_editor().map(|editor| {
623 Label::new(editor.read(cx).title(cx), style.title.text.clone())
624 .contained()
625 .with_style(style.title.container)
626 .aligned()
627 .left()
628 .flex(1., false)
629 });
630 let mut header = Flex::row()
631 .with_child(Self::render_hamburger_button(cx).aligned())
632 .with_children(title);
633 if self.has_focus {
634 header.add_children(
635 self.render_editor_tools(cx)
636 .into_iter()
637 .map(|tool| tool.aligned().flex_float()),
638 );
639 header.add_child(Self::render_plus_button(cx).aligned().flex_float());
640 header.add_child(self.render_zoom_button(cx).aligned());
641 }
642
643 Flex::column()
644 .with_child(
645 header
646 .contained()
647 .with_style(theme.workspace.tab_bar.container)
648 .expanded()
649 .constrained()
650 .with_height(theme.workspace.tab_bar.height),
651 )
652 .with_children(if self.toolbar.read(cx).hidden() {
653 None
654 } else {
655 Some(ChildView::new(&self.toolbar, cx).expanded())
656 })
657 .with_child(if let Some(editor) = self.active_editor() {
658 ChildView::new(editor, cx).flex(1., true).into_any()
659 } else {
660 UniformList::new(
661 self.saved_conversations_list_state.clone(),
662 self.saved_conversations.len(),
663 cx,
664 |this, range, items, cx| {
665 for ix in range {
666 items.push(this.render_saved_conversation(ix, cx).into_any());
667 }
668 },
669 )
670 .flex(1., true)
671 .into_any()
672 })
673 .into_any()
674 }
675 }
676
677 fn focus_in(&mut self, _: gpui::AnyViewHandle, cx: &mut ViewContext<Self>) {
678 self.has_focus = true;
679 self.toolbar
680 .update(cx, |toolbar, cx| toolbar.focus_changed(true, cx));
681 cx.notify();
682 if cx.is_self_focused() {
683 if let Some(editor) = self.active_editor() {
684 cx.focus(editor);
685 } else if let Some(api_key_editor) = self.api_key_editor.as_ref() {
686 cx.focus(api_key_editor);
687 }
688 }
689 }
690
691 fn focus_out(&mut self, _: gpui::AnyViewHandle, cx: &mut ViewContext<Self>) {
692 self.has_focus = false;
693 self.toolbar
694 .update(cx, |toolbar, cx| toolbar.focus_changed(false, cx));
695 cx.notify();
696 }
697}
698
699impl Panel for AssistantPanel {
700 fn position(&self, cx: &WindowContext) -> DockPosition {
701 match settings::get::<AssistantSettings>(cx).dock {
702 AssistantDockPosition::Left => DockPosition::Left,
703 AssistantDockPosition::Bottom => DockPosition::Bottom,
704 AssistantDockPosition::Right => DockPosition::Right,
705 }
706 }
707
708 fn position_is_valid(&self, _: DockPosition) -> bool {
709 true
710 }
711
712 fn set_position(&mut self, position: DockPosition, cx: &mut ViewContext<Self>) {
713 settings::update_settings_file::<AssistantSettings>(self.fs.clone(), cx, move |settings| {
714 let dock = match position {
715 DockPosition::Left => AssistantDockPosition::Left,
716 DockPosition::Bottom => AssistantDockPosition::Bottom,
717 DockPosition::Right => AssistantDockPosition::Right,
718 };
719 settings.dock = Some(dock);
720 });
721 }
722
723 fn size(&self, cx: &WindowContext) -> f32 {
724 let settings = settings::get::<AssistantSettings>(cx);
725 match self.position(cx) {
726 DockPosition::Left | DockPosition::Right => {
727 self.width.unwrap_or_else(|| settings.default_width)
728 }
729 DockPosition::Bottom => self.height.unwrap_or_else(|| settings.default_height),
730 }
731 }
732
733 fn set_size(&mut self, size: f32, cx: &mut ViewContext<Self>) {
734 match self.position(cx) {
735 DockPosition::Left | DockPosition::Right => self.width = Some(size),
736 DockPosition::Bottom => self.height = Some(size),
737 }
738 cx.notify();
739 }
740
741 fn should_zoom_in_on_event(event: &AssistantPanelEvent) -> bool {
742 matches!(event, AssistantPanelEvent::ZoomIn)
743 }
744
745 fn should_zoom_out_on_event(event: &AssistantPanelEvent) -> bool {
746 matches!(event, AssistantPanelEvent::ZoomOut)
747 }
748
749 fn is_zoomed(&self, _: &WindowContext) -> bool {
750 self.zoomed
751 }
752
753 fn set_zoomed(&mut self, zoomed: bool, cx: &mut ViewContext<Self>) {
754 self.zoomed = zoomed;
755 cx.notify();
756 }
757
758 fn set_active(&mut self, active: bool, cx: &mut ViewContext<Self>) {
759 if active {
760 if self.api_key.borrow().is_none() && !self.has_read_credentials {
761 self.has_read_credentials = true;
762 let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
763 Some(api_key)
764 } else if let Some((_, api_key)) = cx
765 .platform()
766 .read_credentials(OPENAI_API_URL)
767 .log_err()
768 .flatten()
769 {
770 String::from_utf8(api_key).log_err()
771 } else {
772 None
773 };
774 if let Some(api_key) = api_key {
775 *self.api_key.borrow_mut() = Some(api_key);
776 } else if self.api_key_editor.is_none() {
777 self.api_key_editor = Some(build_api_key_editor(cx));
778 cx.notify();
779 }
780 }
781
782 if self.editors.is_empty() {
783 self.new_conversation(cx);
784 }
785 }
786 }
787
788 fn icon_path(&self) -> &'static str {
789 "icons/robot_14.svg"
790 }
791
792 fn icon_tooltip(&self) -> (String, Option<Box<dyn Action>>) {
793 ("Assistant Panel".into(), Some(Box::new(ToggleFocus)))
794 }
795
796 fn should_change_position_on_event(event: &Self::Event) -> bool {
797 matches!(event, AssistantPanelEvent::DockPositionChanged)
798 }
799
800 fn should_activate_on_event(_: &Self::Event) -> bool {
801 false
802 }
803
804 fn should_close_on_event(event: &AssistantPanelEvent) -> bool {
805 matches!(event, AssistantPanelEvent::Close)
806 }
807
808 fn has_focus(&self, _: &WindowContext) -> bool {
809 self.has_focus
810 }
811
812 fn is_focus_event(event: &Self::Event) -> bool {
813 matches!(event, AssistantPanelEvent::Focus)
814 }
815}
816
817enum ConversationEvent {
818 MessagesEdited,
819 SummaryChanged,
820 StreamedCompletion,
821}
822
823#[derive(Default)]
824struct Summary {
825 text: String,
826 done: bool,
827}
828
829struct Conversation {
830 buffer: ModelHandle<Buffer>,
831 message_anchors: Vec<MessageAnchor>,
832 messages_metadata: HashMap<MessageId, MessageMetadata>,
833 next_message_id: MessageId,
834 summary: Option<Summary>,
835 pending_summary: Task<Option<()>>,
836 completion_count: usize,
837 pending_completions: Vec<PendingCompletion>,
838 model: String,
839 token_count: Option<usize>,
840 max_token_count: usize,
841 pending_token_count: Task<Option<()>>,
842 api_key: Rc<RefCell<Option<String>>>,
843 pending_save: Task<Result<()>>,
844 path: Option<PathBuf>,
845 _subscriptions: Vec<Subscription>,
846}
847
848impl Entity for Conversation {
849 type Event = ConversationEvent;
850}
851
852impl Conversation {
853 fn new(
854 api_key: Rc<RefCell<Option<String>>>,
855 language_registry: Arc<LanguageRegistry>,
856 cx: &mut ModelContext<Self>,
857 ) -> Self {
858 let model = "gpt-3.5-turbo-0613";
859 let markdown = language_registry.language_for_name("Markdown");
860 let buffer = cx.add_model(|cx| {
861 let mut buffer = Buffer::new(0, "", cx);
862 buffer.set_language_registry(language_registry);
863 cx.spawn_weak(|buffer, mut cx| async move {
864 let markdown = markdown.await?;
865 let buffer = buffer
866 .upgrade(&cx)
867 .ok_or_else(|| anyhow!("buffer was dropped"))?;
868 buffer.update(&mut cx, |buffer, cx| {
869 buffer.set_language(Some(markdown), cx)
870 });
871 anyhow::Ok(())
872 })
873 .detach_and_log_err(cx);
874 buffer
875 });
876
877 let mut this = Self {
878 message_anchors: Default::default(),
879 messages_metadata: Default::default(),
880 next_message_id: Default::default(),
881 summary: None,
882 pending_summary: Task::ready(None),
883 completion_count: Default::default(),
884 pending_completions: Default::default(),
885 token_count: None,
886 max_token_count: tiktoken_rs::model::get_context_size(model),
887 pending_token_count: Task::ready(None),
888 model: model.into(),
889 _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
890 pending_save: Task::ready(Ok(())),
891 path: None,
892 api_key,
893 buffer,
894 };
895 let message = MessageAnchor {
896 id: MessageId(post_inc(&mut this.next_message_id.0)),
897 start: language::Anchor::MIN,
898 };
899 this.message_anchors.push(message.clone());
900 this.messages_metadata.insert(
901 message.id,
902 MessageMetadata {
903 role: Role::User,
904 sent_at: Local::now(),
905 status: MessageStatus::Done,
906 },
907 );
908
909 this.count_remaining_tokens(cx);
910 this
911 }
912
913 fn serialize(&self, cx: &AppContext) -> SavedConversation {
914 SavedConversation {
915 zed: "conversation".into(),
916 version: SavedConversation::VERSION.into(),
917 text: self.buffer.read(cx).text(),
918 message_metadata: self.messages_metadata.clone(),
919 messages: self
920 .messages(cx)
921 .map(|message| SavedMessage {
922 id: message.id,
923 start: message.offset_range.start,
924 })
925 .collect(),
926 summary: self
927 .summary
928 .as_ref()
929 .map(|summary| summary.text.clone())
930 .unwrap_or_default(),
931 model: self.model.clone(),
932 }
933 }
934
935 fn deserialize(
936 saved_conversation: SavedConversation,
937 path: PathBuf,
938 api_key: Rc<RefCell<Option<String>>>,
939 language_registry: Arc<LanguageRegistry>,
940 cx: &mut ModelContext<Self>,
941 ) -> Self {
942 let model = saved_conversation.model;
943 let markdown = language_registry.language_for_name("Markdown");
944 let mut message_anchors = Vec::new();
945 let mut next_message_id = MessageId(0);
946 let buffer = cx.add_model(|cx| {
947 let mut buffer = Buffer::new(0, saved_conversation.text, cx);
948 for message in saved_conversation.messages {
949 message_anchors.push(MessageAnchor {
950 id: message.id,
951 start: buffer.anchor_before(message.start),
952 });
953 next_message_id = cmp::max(next_message_id, MessageId(message.id.0 + 1));
954 }
955 buffer.set_language_registry(language_registry);
956 cx.spawn_weak(|buffer, mut cx| async move {
957 let markdown = markdown.await?;
958 let buffer = buffer
959 .upgrade(&cx)
960 .ok_or_else(|| anyhow!("buffer was dropped"))?;
961 buffer.update(&mut cx, |buffer, cx| {
962 buffer.set_language(Some(markdown), cx)
963 });
964 anyhow::Ok(())
965 })
966 .detach_and_log_err(cx);
967 buffer
968 });
969
970 let mut this = Self {
971 message_anchors,
972 messages_metadata: saved_conversation.message_metadata,
973 next_message_id,
974 summary: Some(Summary {
975 text: saved_conversation.summary,
976 done: true,
977 }),
978 pending_summary: Task::ready(None),
979 completion_count: Default::default(),
980 pending_completions: Default::default(),
981 token_count: None,
982 max_token_count: tiktoken_rs::model::get_context_size(&model),
983 pending_token_count: Task::ready(None),
984 model,
985 _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
986 pending_save: Task::ready(Ok(())),
987 path: Some(path),
988 api_key,
989 buffer,
990 };
991 this.count_remaining_tokens(cx);
992 this
993 }
994
995 fn handle_buffer_event(
996 &mut self,
997 _: ModelHandle<Buffer>,
998 event: &language::Event,
999 cx: &mut ModelContext<Self>,
1000 ) {
1001 match event {
1002 language::Event::Edited => {
1003 self.count_remaining_tokens(cx);
1004 cx.emit(ConversationEvent::MessagesEdited);
1005 }
1006 _ => {}
1007 }
1008 }
1009
1010 fn count_remaining_tokens(&mut self, cx: &mut ModelContext<Self>) {
1011 let messages = self
1012 .messages(cx)
1013 .into_iter()
1014 .filter_map(|message| {
1015 Some(tiktoken_rs::ChatCompletionRequestMessage {
1016 role: match message.role {
1017 Role::User => "user".into(),
1018 Role::Assistant => "assistant".into(),
1019 Role::System => "system".into(),
1020 },
1021 content: self
1022 .buffer
1023 .read(cx)
1024 .text_for_range(message.offset_range)
1025 .collect(),
1026 name: None,
1027 })
1028 })
1029 .collect::<Vec<_>>();
1030 let model = self.model.clone();
1031 self.pending_token_count = cx.spawn_weak(|this, mut cx| {
1032 async move {
1033 cx.background().timer(Duration::from_millis(200)).await;
1034 let token_count = cx
1035 .background()
1036 .spawn(async move { tiktoken_rs::num_tokens_from_messages(&model, &messages) })
1037 .await?;
1038
1039 this.upgrade(&cx)
1040 .ok_or_else(|| anyhow!("conversation was dropped"))?
1041 .update(&mut cx, |this, cx| {
1042 this.max_token_count = tiktoken_rs::model::get_context_size(&this.model);
1043 this.token_count = Some(token_count);
1044 cx.notify()
1045 });
1046 anyhow::Ok(())
1047 }
1048 .log_err()
1049 });
1050 }
1051
1052 fn remaining_tokens(&self) -> Option<isize> {
1053 Some(self.max_token_count as isize - self.token_count? as isize)
1054 }
1055
1056 fn set_model(&mut self, model: String, cx: &mut ModelContext<Self>) {
1057 self.model = model;
1058 self.count_remaining_tokens(cx);
1059 cx.notify();
1060 }
1061
1062 fn assist(
1063 &mut self,
1064 selected_messages: HashSet<MessageId>,
1065 cx: &mut ModelContext<Self>,
1066 ) -> Vec<MessageAnchor> {
1067 let mut user_messages = Vec::new();
1068 let mut tasks = Vec::new();
1069
1070 let last_message_id = self.message_anchors.iter().rev().find_map(|message| {
1071 message
1072 .start
1073 .is_valid(self.buffer.read(cx))
1074 .then_some(message.id)
1075 });
1076
1077 for selected_message_id in selected_messages {
1078 let selected_message_role =
1079 if let Some(metadata) = self.messages_metadata.get(&selected_message_id) {
1080 metadata.role
1081 } else {
1082 continue;
1083 };
1084
1085 if selected_message_role == Role::Assistant {
1086 if let Some(user_message) = self.insert_message_after(
1087 selected_message_id,
1088 Role::User,
1089 MessageStatus::Done,
1090 cx,
1091 ) {
1092 user_messages.push(user_message);
1093 } else {
1094 continue;
1095 }
1096 } else {
1097 let request = OpenAIRequest {
1098 model: self.model.clone(),
1099 messages: self
1100 .messages(cx)
1101 .filter(|message| matches!(message.status, MessageStatus::Done))
1102 .flat_map(|message| {
1103 let mut system_message = None;
1104 if message.id == selected_message_id {
1105 system_message = Some(RequestMessage {
1106 role: Role::System,
1107 content: concat!(
1108 "Treat the following messages as additional knowledge you have learned about, ",
1109 "but act as if they were not part of this conversation. That is, treat them ",
1110 "as if the user didn't see them and couldn't possibly inquire about them."
1111 ).into()
1112 });
1113 }
1114
1115 Some(message.to_open_ai_message(self.buffer.read(cx))).into_iter().chain(system_message)
1116 })
1117 .chain(Some(RequestMessage {
1118 role: Role::System,
1119 content: format!(
1120 "Direct your reply to message with id {}. Do not include a [Message X] header.",
1121 selected_message_id.0
1122 ),
1123 }))
1124 .collect(),
1125 stream: true,
1126 };
1127
1128 let Some(api_key) = self.api_key.borrow().clone() else { continue };
1129 let stream = stream_completion(api_key, cx.background().clone(), request);
1130 let assistant_message = self
1131 .insert_message_after(
1132 selected_message_id,
1133 Role::Assistant,
1134 MessageStatus::Pending,
1135 cx,
1136 )
1137 .unwrap();
1138
1139 // Queue up the user's next reply
1140 if Some(selected_message_id) == last_message_id {
1141 let user_message = self
1142 .insert_message_after(
1143 assistant_message.id,
1144 Role::User,
1145 MessageStatus::Done,
1146 cx,
1147 )
1148 .unwrap();
1149 user_messages.push(user_message);
1150 }
1151
1152 tasks.push(cx.spawn_weak({
1153 |this, mut cx| async move {
1154 let assistant_message_id = assistant_message.id;
1155 let stream_completion = async {
1156 let mut messages = stream.await?;
1157
1158 while let Some(message) = messages.next().await {
1159 let mut message = message?;
1160 if let Some(choice) = message.choices.pop() {
1161 this.upgrade(&cx)
1162 .ok_or_else(|| anyhow!("conversation was dropped"))?
1163 .update(&mut cx, |this, cx| {
1164 let text: Arc<str> = choice.delta.content?.into();
1165 let message_ix = this.message_anchors.iter().position(
1166 |message| message.id == assistant_message_id,
1167 )?;
1168 this.buffer.update(cx, |buffer, cx| {
1169 let offset = this.message_anchors[message_ix + 1..]
1170 .iter()
1171 .find(|message| message.start.is_valid(buffer))
1172 .map_or(buffer.len(), |message| {
1173 message
1174 .start
1175 .to_offset(buffer)
1176 .saturating_sub(1)
1177 });
1178 buffer.edit([(offset..offset, text)], None, cx);
1179 });
1180 cx.emit(ConversationEvent::StreamedCompletion);
1181
1182 Some(())
1183 });
1184 }
1185 smol::future::yield_now().await;
1186 }
1187
1188 this.upgrade(&cx)
1189 .ok_or_else(|| anyhow!("conversation was dropped"))?
1190 .update(&mut cx, |this, cx| {
1191 this.pending_completions.retain(|completion| {
1192 completion.id != this.completion_count
1193 });
1194 this.summarize(cx);
1195 });
1196
1197 anyhow::Ok(())
1198 };
1199
1200 let result = stream_completion.await;
1201 if let Some(this) = this.upgrade(&cx) {
1202 this.update(&mut cx, |this, cx| {
1203 if let Some(metadata) =
1204 this.messages_metadata.get_mut(&assistant_message.id)
1205 {
1206 match result {
1207 Ok(_) => {
1208 metadata.status = MessageStatus::Done;
1209 }
1210 Err(error) => {
1211 metadata.status = MessageStatus::Error(
1212 error.to_string().trim().into(),
1213 );
1214 }
1215 }
1216 cx.notify();
1217 }
1218 });
1219 }
1220 }
1221 }));
1222 }
1223 }
1224
1225 if !tasks.is_empty() {
1226 self.pending_completions.push(PendingCompletion {
1227 id: post_inc(&mut self.completion_count),
1228 _tasks: tasks,
1229 });
1230 }
1231
1232 user_messages
1233 }
1234
1235 fn cancel_last_assist(&mut self) -> bool {
1236 self.pending_completions.pop().is_some()
1237 }
1238
1239 fn cycle_message_roles(&mut self, ids: HashSet<MessageId>, cx: &mut ModelContext<Self>) {
1240 for id in ids {
1241 if let Some(metadata) = self.messages_metadata.get_mut(&id) {
1242 metadata.role.cycle();
1243 cx.emit(ConversationEvent::MessagesEdited);
1244 cx.notify();
1245 }
1246 }
1247 }
1248
1249 fn insert_message_after(
1250 &mut self,
1251 message_id: MessageId,
1252 role: Role,
1253 status: MessageStatus,
1254 cx: &mut ModelContext<Self>,
1255 ) -> Option<MessageAnchor> {
1256 if let Some(prev_message_ix) = self
1257 .message_anchors
1258 .iter()
1259 .position(|message| message.id == message_id)
1260 {
1261 // Find the next valid message after the one we were given.
1262 let mut next_message_ix = prev_message_ix + 1;
1263 while let Some(next_message) = self.message_anchors.get(next_message_ix) {
1264 if next_message.start.is_valid(self.buffer.read(cx)) {
1265 break;
1266 }
1267 next_message_ix += 1;
1268 }
1269
1270 let start = self.buffer.update(cx, |buffer, cx| {
1271 let offset = self
1272 .message_anchors
1273 .get(next_message_ix)
1274 .map_or(buffer.len(), |message| message.start.to_offset(buffer) - 1);
1275 buffer.edit([(offset..offset, "\n")], None, cx);
1276 buffer.anchor_before(offset + 1)
1277 });
1278 let message = MessageAnchor {
1279 id: MessageId(post_inc(&mut self.next_message_id.0)),
1280 start,
1281 };
1282 self.message_anchors
1283 .insert(next_message_ix, message.clone());
1284 self.messages_metadata.insert(
1285 message.id,
1286 MessageMetadata {
1287 role,
1288 sent_at: Local::now(),
1289 status,
1290 },
1291 );
1292 cx.emit(ConversationEvent::MessagesEdited);
1293 Some(message)
1294 } else {
1295 None
1296 }
1297 }
1298
1299 fn split_message(
1300 &mut self,
1301 range: Range<usize>,
1302 cx: &mut ModelContext<Self>,
1303 ) -> (Option<MessageAnchor>, Option<MessageAnchor>) {
1304 let start_message = self.message_for_offset(range.start, cx);
1305 let end_message = self.message_for_offset(range.end, cx);
1306 if let Some((start_message, end_message)) = start_message.zip(end_message) {
1307 // Prevent splitting when range spans multiple messages.
1308 if start_message.id != end_message.id {
1309 return (None, None);
1310 }
1311
1312 let message = start_message;
1313 let role = message.role;
1314 let mut edited_buffer = false;
1315
1316 let mut suffix_start = None;
1317 if range.start > message.offset_range.start && range.end < message.offset_range.end - 1
1318 {
1319 if self.buffer.read(cx).chars_at(range.end).next() == Some('\n') {
1320 suffix_start = Some(range.end + 1);
1321 } else if self.buffer.read(cx).reversed_chars_at(range.end).next() == Some('\n') {
1322 suffix_start = Some(range.end);
1323 }
1324 }
1325
1326 let suffix = if let Some(suffix_start) = suffix_start {
1327 MessageAnchor {
1328 id: MessageId(post_inc(&mut self.next_message_id.0)),
1329 start: self.buffer.read(cx).anchor_before(suffix_start),
1330 }
1331 } else {
1332 self.buffer.update(cx, |buffer, cx| {
1333 buffer.edit([(range.end..range.end, "\n")], None, cx);
1334 });
1335 edited_buffer = true;
1336 MessageAnchor {
1337 id: MessageId(post_inc(&mut self.next_message_id.0)),
1338 start: self.buffer.read(cx).anchor_before(range.end + 1),
1339 }
1340 };
1341
1342 self.message_anchors
1343 .insert(message.index_range.end + 1, suffix.clone());
1344 self.messages_metadata.insert(
1345 suffix.id,
1346 MessageMetadata {
1347 role,
1348 sent_at: Local::now(),
1349 status: MessageStatus::Done,
1350 },
1351 );
1352
1353 let new_messages =
1354 if range.start == range.end || range.start == message.offset_range.start {
1355 (None, Some(suffix))
1356 } else {
1357 let mut prefix_end = None;
1358 if range.start > message.offset_range.start
1359 && range.end < message.offset_range.end - 1
1360 {
1361 if self.buffer.read(cx).chars_at(range.start).next() == Some('\n') {
1362 prefix_end = Some(range.start + 1);
1363 } else if self.buffer.read(cx).reversed_chars_at(range.start).next()
1364 == Some('\n')
1365 {
1366 prefix_end = Some(range.start);
1367 }
1368 }
1369
1370 let selection = if let Some(prefix_end) = prefix_end {
1371 cx.emit(ConversationEvent::MessagesEdited);
1372 MessageAnchor {
1373 id: MessageId(post_inc(&mut self.next_message_id.0)),
1374 start: self.buffer.read(cx).anchor_before(prefix_end),
1375 }
1376 } else {
1377 self.buffer.update(cx, |buffer, cx| {
1378 buffer.edit([(range.start..range.start, "\n")], None, cx)
1379 });
1380 edited_buffer = true;
1381 MessageAnchor {
1382 id: MessageId(post_inc(&mut self.next_message_id.0)),
1383 start: self.buffer.read(cx).anchor_before(range.end + 1),
1384 }
1385 };
1386
1387 self.message_anchors
1388 .insert(message.index_range.end + 1, selection.clone());
1389 self.messages_metadata.insert(
1390 selection.id,
1391 MessageMetadata {
1392 role,
1393 sent_at: Local::now(),
1394 status: MessageStatus::Done,
1395 },
1396 );
1397 (Some(selection), Some(suffix))
1398 };
1399
1400 if !edited_buffer {
1401 cx.emit(ConversationEvent::MessagesEdited);
1402 }
1403 new_messages
1404 } else {
1405 (None, None)
1406 }
1407 }
1408
1409 fn summarize(&mut self, cx: &mut ModelContext<Self>) {
1410 if self.message_anchors.len() >= 2 && self.summary.is_none() {
1411 let api_key = self.api_key.borrow().clone();
1412 if let Some(api_key) = api_key {
1413 let messages = self
1414 .messages(cx)
1415 .take(2)
1416 .map(|message| message.to_open_ai_message(self.buffer.read(cx)))
1417 .chain(Some(RequestMessage {
1418 role: Role::User,
1419 content:
1420 "Summarize the conversation into a short title without punctuation"
1421 .into(),
1422 }));
1423 let request = OpenAIRequest {
1424 model: self.model.clone(),
1425 messages: messages.collect(),
1426 stream: true,
1427 };
1428
1429 let stream = stream_completion(api_key, cx.background().clone(), request);
1430 self.pending_summary = cx.spawn(|this, mut cx| {
1431 async move {
1432 let mut messages = stream.await?;
1433
1434 while let Some(message) = messages.next().await {
1435 let mut message = message?;
1436 if let Some(choice) = message.choices.pop() {
1437 let text = choice.delta.content.unwrap_or_default();
1438 this.update(&mut cx, |this, cx| {
1439 this.summary
1440 .get_or_insert(Default::default())
1441 .text
1442 .push_str(&text);
1443 cx.emit(ConversationEvent::SummaryChanged);
1444 });
1445 }
1446 }
1447
1448 this.update(&mut cx, |this, cx| {
1449 if let Some(summary) = this.summary.as_mut() {
1450 summary.done = true;
1451 cx.emit(ConversationEvent::SummaryChanged);
1452 }
1453 });
1454
1455 anyhow::Ok(())
1456 }
1457 .log_err()
1458 });
1459 }
1460 }
1461 }
1462
1463 fn message_for_offset(&self, offset: usize, cx: &AppContext) -> Option<Message> {
1464 self.messages_for_offsets([offset], cx).pop()
1465 }
1466
1467 fn messages_for_offsets(
1468 &self,
1469 offsets: impl IntoIterator<Item = usize>,
1470 cx: &AppContext,
1471 ) -> Vec<Message> {
1472 let mut result = Vec::new();
1473
1474 let mut messages = self.messages(cx).peekable();
1475 let mut offsets = offsets.into_iter().peekable();
1476 let mut current_message = messages.next();
1477 while let Some(offset) = offsets.next() {
1478 // Locate the message that contains the offset.
1479 while current_message.as_ref().map_or(false, |message| {
1480 !message.offset_range.contains(&offset) && messages.peek().is_some()
1481 }) {
1482 current_message = messages.next();
1483 }
1484 let Some(message) = current_message.as_ref() else { break };
1485
1486 // Skip offsets that are in the same message.
1487 while offsets.peek().map_or(false, |offset| {
1488 message.offset_range.contains(offset) || messages.peek().is_none()
1489 }) {
1490 offsets.next();
1491 }
1492
1493 result.push(message.clone());
1494 }
1495 result
1496 }
1497
1498 fn messages<'a>(&'a self, cx: &'a AppContext) -> impl 'a + Iterator<Item = Message> {
1499 let buffer = self.buffer.read(cx);
1500 let mut message_anchors = self.message_anchors.iter().enumerate().peekable();
1501 iter::from_fn(move || {
1502 while let Some((start_ix, message_anchor)) = message_anchors.next() {
1503 let metadata = self.messages_metadata.get(&message_anchor.id)?;
1504 let message_start = message_anchor.start.to_offset(buffer);
1505 let mut message_end = None;
1506 let mut end_ix = start_ix;
1507 while let Some((_, next_message)) = message_anchors.peek() {
1508 if next_message.start.is_valid(buffer) {
1509 message_end = Some(next_message.start);
1510 break;
1511 } else {
1512 end_ix += 1;
1513 message_anchors.next();
1514 }
1515 }
1516 let message_end = message_end
1517 .unwrap_or(language::Anchor::MAX)
1518 .to_offset(buffer);
1519 return Some(Message {
1520 index_range: start_ix..end_ix,
1521 offset_range: message_start..message_end,
1522 id: message_anchor.id,
1523 anchor: message_anchor.start,
1524 role: metadata.role,
1525 sent_at: metadata.sent_at,
1526 status: metadata.status.clone(),
1527 });
1528 }
1529 None
1530 })
1531 }
1532
1533 fn save(
1534 &mut self,
1535 debounce: Option<Duration>,
1536 fs: Arc<dyn Fs>,
1537 cx: &mut ModelContext<Conversation>,
1538 ) {
1539 self.pending_save = cx.spawn(|this, mut cx| async move {
1540 if let Some(debounce) = debounce {
1541 cx.background().timer(debounce).await;
1542 }
1543
1544 let (old_path, summary) = this.read_with(&cx, |this, _| {
1545 let path = this.path.clone();
1546 let summary = if let Some(summary) = this.summary.as_ref() {
1547 if summary.done {
1548 Some(summary.text.clone())
1549 } else {
1550 None
1551 }
1552 } else {
1553 None
1554 };
1555 (path, summary)
1556 });
1557
1558 if let Some(summary) = summary {
1559 let conversation = this.read_with(&cx, |this, cx| this.serialize(cx));
1560 let path = if let Some(old_path) = old_path {
1561 old_path
1562 } else {
1563 let mut discriminant = 1;
1564 let mut new_path;
1565 loop {
1566 new_path = CONVERSATIONS_DIR.join(&format!(
1567 "{} - {}.zed.json",
1568 summary.trim(),
1569 discriminant
1570 ));
1571 if fs.is_file(&new_path).await {
1572 discriminant += 1;
1573 } else {
1574 break;
1575 }
1576 }
1577 new_path
1578 };
1579
1580 fs.create_dir(CONVERSATIONS_DIR.as_ref()).await?;
1581 fs.atomic_write(path.clone(), serde_json::to_string(&conversation).unwrap())
1582 .await?;
1583 this.update(&mut cx, |this, _| this.path = Some(path));
1584 }
1585
1586 Ok(())
1587 });
1588 }
1589}
1590
1591struct PendingCompletion {
1592 id: usize,
1593 _tasks: Vec<Task<()>>,
1594}
1595
1596enum ConversationEditorEvent {
1597 TabContentChanged,
1598}
1599
1600#[derive(Copy, Clone, Debug, PartialEq)]
1601struct ScrollPosition {
1602 offset_before_cursor: Vector2F,
1603 cursor: Anchor,
1604}
1605
1606struct ConversationEditor {
1607 conversation: ModelHandle<Conversation>,
1608 fs: Arc<dyn Fs>,
1609 editor: ViewHandle<Editor>,
1610 blocks: HashSet<BlockId>,
1611 scroll_position: Option<ScrollPosition>,
1612 _subscriptions: Vec<Subscription>,
1613}
1614
1615impl ConversationEditor {
1616 fn new(
1617 api_key: Rc<RefCell<Option<String>>>,
1618 language_registry: Arc<LanguageRegistry>,
1619 fs: Arc<dyn Fs>,
1620 cx: &mut ViewContext<Self>,
1621 ) -> Self {
1622 let conversation = cx.add_model(|cx| Conversation::new(api_key, language_registry, cx));
1623 Self::for_conversation(conversation, fs, cx)
1624 }
1625
1626 fn for_conversation(
1627 conversation: ModelHandle<Conversation>,
1628 fs: Arc<dyn Fs>,
1629 cx: &mut ViewContext<Self>,
1630 ) -> Self {
1631 let editor = cx.add_view(|cx| {
1632 let mut editor = Editor::for_buffer(conversation.read(cx).buffer.clone(), None, cx);
1633 editor.set_soft_wrap_mode(SoftWrap::EditorWidth, cx);
1634 editor.set_show_gutter(false, cx);
1635 editor
1636 });
1637
1638 let _subscriptions = vec![
1639 cx.observe(&conversation, |_, _, cx| cx.notify()),
1640 cx.subscribe(&conversation, Self::handle_conversation_event),
1641 cx.subscribe(&editor, Self::handle_editor_event),
1642 ];
1643
1644 let mut this = Self {
1645 conversation,
1646 editor,
1647 blocks: Default::default(),
1648 scroll_position: None,
1649 fs,
1650 _subscriptions,
1651 };
1652 this.update_message_headers(cx);
1653 this
1654 }
1655
1656 fn assist(&mut self, _: &Assist, cx: &mut ViewContext<Self>) {
1657 let cursors = self.cursors(cx);
1658
1659 let user_messages = self.conversation.update(cx, |conversation, cx| {
1660 let selected_messages = conversation
1661 .messages_for_offsets(cursors, cx)
1662 .into_iter()
1663 .map(|message| message.id)
1664 .collect();
1665 conversation.assist(selected_messages, cx)
1666 });
1667 let new_selections = user_messages
1668 .iter()
1669 .map(|message| {
1670 let cursor = message
1671 .start
1672 .to_offset(self.conversation.read(cx).buffer.read(cx));
1673 cursor..cursor
1674 })
1675 .collect::<Vec<_>>();
1676 if !new_selections.is_empty() {
1677 self.editor.update(cx, |editor, cx| {
1678 editor.change_selections(
1679 Some(Autoscroll::Strategy(AutoscrollStrategy::Fit)),
1680 cx,
1681 |selections| selections.select_ranges(new_selections),
1682 );
1683 });
1684 }
1685 }
1686
1687 fn cancel_last_assist(&mut self, _: &editor::Cancel, cx: &mut ViewContext<Self>) {
1688 if !self
1689 .conversation
1690 .update(cx, |conversation, _| conversation.cancel_last_assist())
1691 {
1692 cx.propagate_action();
1693 }
1694 }
1695
1696 fn cycle_message_role(&mut self, _: &CycleMessageRole, cx: &mut ViewContext<Self>) {
1697 let cursors = self.cursors(cx);
1698 self.conversation.update(cx, |conversation, cx| {
1699 let messages = conversation
1700 .messages_for_offsets(cursors, cx)
1701 .into_iter()
1702 .map(|message| message.id)
1703 .collect();
1704 conversation.cycle_message_roles(messages, cx)
1705 });
1706 }
1707
1708 fn cursors(&self, cx: &AppContext) -> Vec<usize> {
1709 let selections = self.editor.read(cx).selections.all::<usize>(cx);
1710 selections
1711 .into_iter()
1712 .map(|selection| selection.head())
1713 .collect()
1714 }
1715
1716 fn handle_conversation_event(
1717 &mut self,
1718 _: ModelHandle<Conversation>,
1719 event: &ConversationEvent,
1720 cx: &mut ViewContext<Self>,
1721 ) {
1722 match event {
1723 ConversationEvent::MessagesEdited => {
1724 self.update_message_headers(cx);
1725 self.conversation.update(cx, |conversation, cx| {
1726 conversation.save(Some(Duration::from_millis(500)), self.fs.clone(), cx);
1727 });
1728 }
1729 ConversationEvent::SummaryChanged => {
1730 cx.emit(ConversationEditorEvent::TabContentChanged);
1731 self.conversation.update(cx, |conversation, cx| {
1732 conversation.save(None, self.fs.clone(), cx);
1733 });
1734 }
1735 ConversationEvent::StreamedCompletion => {
1736 self.editor.update(cx, |editor, cx| {
1737 if let Some(scroll_position) = self.scroll_position {
1738 let snapshot = editor.snapshot(cx);
1739 let cursor_point = scroll_position.cursor.to_display_point(&snapshot);
1740 let scroll_top =
1741 cursor_point.row() as f32 - scroll_position.offset_before_cursor.y();
1742 editor.set_scroll_position(
1743 vec2f(scroll_position.offset_before_cursor.x(), scroll_top),
1744 cx,
1745 );
1746 }
1747 });
1748 }
1749 }
1750 }
1751
1752 fn handle_editor_event(
1753 &mut self,
1754 _: ViewHandle<Editor>,
1755 event: &editor::Event,
1756 cx: &mut ViewContext<Self>,
1757 ) {
1758 match event {
1759 editor::Event::ScrollPositionChanged { autoscroll, .. } => {
1760 let cursor_scroll_position = self.cursor_scroll_position(cx);
1761 if *autoscroll {
1762 self.scroll_position = cursor_scroll_position;
1763 } else if self.scroll_position != cursor_scroll_position {
1764 self.scroll_position = None;
1765 }
1766 }
1767 editor::Event::SelectionsChanged { .. } => {
1768 self.scroll_position = self.cursor_scroll_position(cx);
1769 }
1770 _ => {}
1771 }
1772 }
1773
1774 fn cursor_scroll_position(&self, cx: &mut ViewContext<Self>) -> Option<ScrollPosition> {
1775 self.editor.update(cx, |editor, cx| {
1776 let snapshot = editor.snapshot(cx);
1777 let cursor = editor.selections.newest_anchor().head();
1778 let cursor_row = cursor.to_display_point(&snapshot.display_snapshot).row() as f32;
1779 let scroll_position = editor
1780 .scroll_manager
1781 .anchor()
1782 .scroll_position(&snapshot.display_snapshot);
1783
1784 let scroll_bottom = scroll_position.y() + editor.visible_line_count().unwrap_or(0.);
1785 if (scroll_position.y()..scroll_bottom).contains(&cursor_row) {
1786 Some(ScrollPosition {
1787 cursor,
1788 offset_before_cursor: vec2f(
1789 scroll_position.x(),
1790 cursor_row - scroll_position.y(),
1791 ),
1792 })
1793 } else {
1794 None
1795 }
1796 })
1797 }
1798
1799 fn update_message_headers(&mut self, cx: &mut ViewContext<Self>) {
1800 self.editor.update(cx, |editor, cx| {
1801 let buffer = editor.buffer().read(cx).snapshot(cx);
1802 let excerpt_id = *buffer.as_singleton().unwrap().0;
1803 let old_blocks = std::mem::take(&mut self.blocks);
1804 let new_blocks = self
1805 .conversation
1806 .read(cx)
1807 .messages(cx)
1808 .map(|message| BlockProperties {
1809 position: buffer.anchor_in_excerpt(excerpt_id, message.anchor),
1810 height: 2,
1811 style: BlockStyle::Sticky,
1812 render: Arc::new({
1813 let conversation = self.conversation.clone();
1814 // let metadata = message.metadata.clone();
1815 // let message = message.clone();
1816 move |cx| {
1817 enum Sender {}
1818 enum ErrorTooltip {}
1819
1820 let theme = theme::current(cx);
1821 let style = &theme.assistant;
1822 let message_id = message.id;
1823 let sender = MouseEventHandler::<Sender, _>::new(
1824 message_id.0,
1825 cx,
1826 |state, _| match message.role {
1827 Role::User => {
1828 let style = style.user_sender.style_for(state);
1829 Label::new("You", style.text.clone())
1830 .contained()
1831 .with_style(style.container)
1832 }
1833 Role::Assistant => {
1834 let style = style.assistant_sender.style_for(state);
1835 Label::new("Assistant", style.text.clone())
1836 .contained()
1837 .with_style(style.container)
1838 }
1839 Role::System => {
1840 let style = style.system_sender.style_for(state);
1841 Label::new("System", style.text.clone())
1842 .contained()
1843 .with_style(style.container)
1844 }
1845 },
1846 )
1847 .with_cursor_style(CursorStyle::PointingHand)
1848 .on_down(MouseButton::Left, {
1849 let conversation = conversation.clone();
1850 move |_, _, cx| {
1851 conversation.update(cx, |conversation, cx| {
1852 conversation.cycle_message_roles(
1853 HashSet::from_iter(Some(message_id)),
1854 cx,
1855 )
1856 })
1857 }
1858 });
1859
1860 Flex::row()
1861 .with_child(sender.aligned())
1862 .with_child(
1863 Label::new(
1864 message.sent_at.format("%I:%M%P").to_string(),
1865 style.sent_at.text.clone(),
1866 )
1867 .contained()
1868 .with_style(style.sent_at.container)
1869 .aligned(),
1870 )
1871 .with_children(
1872 if let MessageStatus::Error(error) = &message.status {
1873 Some(
1874 Svg::new("icons/circle_x_mark_12.svg")
1875 .with_color(style.error_icon.color)
1876 .constrained()
1877 .with_width(style.error_icon.width)
1878 .contained()
1879 .with_style(style.error_icon.container)
1880 .with_tooltip::<ErrorTooltip>(
1881 message_id.0,
1882 error.to_string(),
1883 None,
1884 theme.tooltip.clone(),
1885 cx,
1886 )
1887 .aligned(),
1888 )
1889 } else {
1890 None
1891 },
1892 )
1893 .aligned()
1894 .left()
1895 .contained()
1896 .with_style(style.message_header)
1897 .into_any()
1898 }
1899 }),
1900 disposition: BlockDisposition::Above,
1901 })
1902 .collect::<Vec<_>>();
1903
1904 editor.remove_blocks(old_blocks, None, cx);
1905 let ids = editor.insert_blocks(new_blocks, None, cx);
1906 self.blocks = HashSet::from_iter(ids);
1907 });
1908 }
1909
1910 fn quote_selection(
1911 workspace: &mut Workspace,
1912 _: &QuoteSelection,
1913 cx: &mut ViewContext<Workspace>,
1914 ) {
1915 let Some(panel) = workspace.panel::<AssistantPanel>(cx) else {
1916 return;
1917 };
1918 let Some(editor) = workspace.active_item(cx).and_then(|item| item.downcast::<Editor>()) else {
1919 return;
1920 };
1921
1922 let text = editor.read_with(cx, |editor, cx| {
1923 let range = editor.selections.newest::<usize>(cx).range();
1924 let buffer = editor.buffer().read(cx).snapshot(cx);
1925 let start_language = buffer.language_at(range.start);
1926 let end_language = buffer.language_at(range.end);
1927 let language_name = if start_language == end_language {
1928 start_language.map(|language| language.name())
1929 } else {
1930 None
1931 };
1932 let language_name = language_name.as_deref().unwrap_or("").to_lowercase();
1933
1934 let selected_text = buffer.text_for_range(range).collect::<String>();
1935 if selected_text.is_empty() {
1936 None
1937 } else {
1938 Some(if language_name == "markdown" {
1939 selected_text
1940 .lines()
1941 .map(|line| format!("> {}", line))
1942 .collect::<Vec<_>>()
1943 .join("\n")
1944 } else {
1945 format!("```{language_name}\n{selected_text}\n```")
1946 })
1947 }
1948 });
1949
1950 // Activate the panel
1951 if !panel.read(cx).has_focus(cx) {
1952 workspace.toggle_panel_focus::<AssistantPanel>(cx);
1953 }
1954
1955 if let Some(text) = text {
1956 panel.update(cx, |panel, cx| {
1957 let conversation = panel
1958 .active_editor()
1959 .cloned()
1960 .unwrap_or_else(|| panel.new_conversation(cx));
1961 conversation.update(cx, |conversation, cx| {
1962 conversation
1963 .editor
1964 .update(cx, |editor, cx| editor.insert(&text, cx))
1965 });
1966 });
1967 }
1968 }
1969
1970 fn copy(&mut self, _: &editor::Copy, cx: &mut ViewContext<Self>) {
1971 let editor = self.editor.read(cx);
1972 let conversation = self.conversation.read(cx);
1973 if editor.selections.count() == 1 {
1974 let selection = editor.selections.newest::<usize>(cx);
1975 let mut copied_text = String::new();
1976 let mut spanned_messages = 0;
1977 for message in conversation.messages(cx) {
1978 if message.offset_range.start >= selection.range().end {
1979 break;
1980 } else if message.offset_range.end >= selection.range().start {
1981 let range = cmp::max(message.offset_range.start, selection.range().start)
1982 ..cmp::min(message.offset_range.end, selection.range().end);
1983 if !range.is_empty() {
1984 spanned_messages += 1;
1985 write!(&mut copied_text, "## {}\n\n", message.role).unwrap();
1986 for chunk in conversation.buffer.read(cx).text_for_range(range) {
1987 copied_text.push_str(&chunk);
1988 }
1989 copied_text.push('\n');
1990 }
1991 }
1992 }
1993
1994 if spanned_messages > 1 {
1995 cx.platform()
1996 .write_to_clipboard(ClipboardItem::new(copied_text));
1997 return;
1998 }
1999 }
2000
2001 cx.propagate_action();
2002 }
2003
2004 fn split(&mut self, _: &Split, cx: &mut ViewContext<Self>) {
2005 self.conversation.update(cx, |conversation, cx| {
2006 let selections = self.editor.read(cx).selections.disjoint_anchors();
2007 for selection in selections.into_iter() {
2008 let buffer = self.editor.read(cx).buffer().read(cx).snapshot(cx);
2009 let range = selection
2010 .map(|endpoint| endpoint.to_offset(&buffer))
2011 .range();
2012 conversation.split_message(range, cx);
2013 }
2014 });
2015 }
2016
2017 fn save(&mut self, _: &Save, cx: &mut ViewContext<Self>) {
2018 self.conversation.update(cx, |conversation, cx| {
2019 conversation.save(None, self.fs.clone(), cx)
2020 });
2021 }
2022
2023 fn cycle_model(&mut self, cx: &mut ViewContext<Self>) {
2024 self.conversation.update(cx, |conversation, cx| {
2025 let new_model = match conversation.model.as_str() {
2026 "gpt-4-0613" => "gpt-3.5-turbo-0613",
2027 _ => "gpt-4-0613",
2028 };
2029 conversation.set_model(new_model.into(), cx);
2030 });
2031 }
2032
2033 fn title(&self, cx: &AppContext) -> String {
2034 self.conversation
2035 .read(cx)
2036 .summary
2037 .as_ref()
2038 .map(|summary| summary.text.clone())
2039 .unwrap_or_else(|| "New Conversation".into())
2040 }
2041
2042 fn render_current_model(
2043 &self,
2044 style: &AssistantStyle,
2045 cx: &mut ViewContext<Self>,
2046 ) -> impl Element<Self> {
2047 enum Model {}
2048
2049 MouseEventHandler::<Model, _>::new(0, cx, |state, cx| {
2050 let style = style.model.style_for(state);
2051 Label::new(self.conversation.read(cx).model.clone(), style.text.clone())
2052 .contained()
2053 .with_style(style.container)
2054 })
2055 .with_cursor_style(CursorStyle::PointingHand)
2056 .on_click(MouseButton::Left, |_, this, cx| this.cycle_model(cx))
2057 }
2058
2059 fn render_remaining_tokens(
2060 &self,
2061 style: &AssistantStyle,
2062 cx: &mut ViewContext<Self>,
2063 ) -> Option<impl Element<Self>> {
2064 let remaining_tokens = self.conversation.read(cx).remaining_tokens()?;
2065 let remaining_tokens_style = if remaining_tokens <= 0 {
2066 &style.no_remaining_tokens
2067 } else {
2068 &style.remaining_tokens
2069 };
2070 Some(
2071 Label::new(
2072 remaining_tokens.to_string(),
2073 remaining_tokens_style.text.clone(),
2074 )
2075 .contained()
2076 .with_style(remaining_tokens_style.container),
2077 )
2078 }
2079}
2080
2081impl Entity for ConversationEditor {
2082 type Event = ConversationEditorEvent;
2083}
2084
2085impl View for ConversationEditor {
2086 fn ui_name() -> &'static str {
2087 "ConversationEditor"
2088 }
2089
2090 fn render(&mut self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
2091 let theme = &theme::current(cx).assistant;
2092 Stack::new()
2093 .with_child(
2094 ChildView::new(&self.editor, cx)
2095 .contained()
2096 .with_style(theme.container),
2097 )
2098 .with_child(
2099 Flex::row()
2100 .with_child(self.render_current_model(theme, cx))
2101 .with_children(self.render_remaining_tokens(theme, cx))
2102 .aligned()
2103 .top()
2104 .right(),
2105 )
2106 .into_any()
2107 }
2108
2109 fn focus_in(&mut self, _: gpui::AnyViewHandle, cx: &mut ViewContext<Self>) {
2110 if cx.is_self_focused() {
2111 cx.focus(&self.editor);
2112 }
2113 }
2114}
2115
2116#[derive(Clone, Debug)]
2117struct MessageAnchor {
2118 id: MessageId,
2119 start: language::Anchor,
2120}
2121
2122#[derive(Clone, Debug)]
2123pub struct Message {
2124 offset_range: Range<usize>,
2125 index_range: Range<usize>,
2126 id: MessageId,
2127 anchor: language::Anchor,
2128 role: Role,
2129 sent_at: DateTime<Local>,
2130 status: MessageStatus,
2131}
2132
2133impl Message {
2134 fn to_open_ai_message(&self, buffer: &Buffer) -> RequestMessage {
2135 let mut content = format!("[Message {}]\n", self.id.0).to_string();
2136 content.extend(buffer.text_for_range(self.offset_range.clone()));
2137 RequestMessage {
2138 role: self.role,
2139 content: content.trim_end().into(),
2140 }
2141 }
2142}
2143
2144async fn stream_completion(
2145 api_key: String,
2146 executor: Arc<Background>,
2147 mut request: OpenAIRequest,
2148) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
2149 request.stream = true;
2150
2151 let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
2152
2153 let json_data = serde_json::to_string(&request)?;
2154 let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
2155 .header("Content-Type", "application/json")
2156 .header("Authorization", format!("Bearer {}", api_key))
2157 .body(json_data)?
2158 .send_async()
2159 .await?;
2160
2161 let status = response.status();
2162 if status == StatusCode::OK {
2163 executor
2164 .spawn(async move {
2165 let mut lines = BufReader::new(response.body_mut()).lines();
2166
2167 fn parse_line(
2168 line: Result<String, io::Error>,
2169 ) -> Result<Option<OpenAIResponseStreamEvent>> {
2170 if let Some(data) = line?.strip_prefix("data: ") {
2171 let event = serde_json::from_str(&data)?;
2172 Ok(Some(event))
2173 } else {
2174 Ok(None)
2175 }
2176 }
2177
2178 while let Some(line) = lines.next().await {
2179 if let Some(event) = parse_line(line).transpose() {
2180 let done = event.as_ref().map_or(false, |event| {
2181 event
2182 .choices
2183 .last()
2184 .map_or(false, |choice| choice.finish_reason.is_some())
2185 });
2186 if tx.unbounded_send(event).is_err() {
2187 break;
2188 }
2189
2190 if done {
2191 break;
2192 }
2193 }
2194 }
2195
2196 anyhow::Ok(())
2197 })
2198 .detach();
2199
2200 Ok(rx)
2201 } else {
2202 let mut body = String::new();
2203 response.body_mut().read_to_string(&mut body).await?;
2204
2205 #[derive(Deserialize)]
2206 struct OpenAIResponse {
2207 error: OpenAIError,
2208 }
2209
2210 #[derive(Deserialize)]
2211 struct OpenAIError {
2212 message: String,
2213 }
2214
2215 match serde_json::from_str::<OpenAIResponse>(&body) {
2216 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
2217 "Failed to connect to OpenAI API: {}",
2218 response.error.message,
2219 )),
2220
2221 _ => Err(anyhow!(
2222 "Failed to connect to OpenAI API: {} {}",
2223 response.status(),
2224 body,
2225 )),
2226 }
2227 }
2228}
2229
2230#[cfg(test)]
2231mod tests {
2232 use super::*;
2233 use crate::MessageId;
2234 use gpui::AppContext;
2235
2236 #[gpui::test]
2237 fn test_inserting_and_removing_messages(cx: &mut AppContext) {
2238 let registry = Arc::new(LanguageRegistry::test());
2239 let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
2240 let buffer = conversation.read(cx).buffer.clone();
2241
2242 let message_1 = conversation.read(cx).message_anchors[0].clone();
2243 assert_eq!(
2244 messages(&conversation, cx),
2245 vec![(message_1.id, Role::User, 0..0)]
2246 );
2247
2248 let message_2 = conversation.update(cx, |conversation, cx| {
2249 conversation
2250 .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
2251 .unwrap()
2252 });
2253 assert_eq!(
2254 messages(&conversation, cx),
2255 vec![
2256 (message_1.id, Role::User, 0..1),
2257 (message_2.id, Role::Assistant, 1..1)
2258 ]
2259 );
2260
2261 buffer.update(cx, |buffer, cx| {
2262 buffer.edit([(0..0, "1"), (1..1, "2")], None, cx)
2263 });
2264 assert_eq!(
2265 messages(&conversation, cx),
2266 vec![
2267 (message_1.id, Role::User, 0..2),
2268 (message_2.id, Role::Assistant, 2..3)
2269 ]
2270 );
2271
2272 let message_3 = conversation.update(cx, |conversation, cx| {
2273 conversation
2274 .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
2275 .unwrap()
2276 });
2277 assert_eq!(
2278 messages(&conversation, cx),
2279 vec![
2280 (message_1.id, Role::User, 0..2),
2281 (message_2.id, Role::Assistant, 2..4),
2282 (message_3.id, Role::User, 4..4)
2283 ]
2284 );
2285
2286 let message_4 = conversation.update(cx, |conversation, cx| {
2287 conversation
2288 .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
2289 .unwrap()
2290 });
2291 assert_eq!(
2292 messages(&conversation, cx),
2293 vec![
2294 (message_1.id, Role::User, 0..2),
2295 (message_2.id, Role::Assistant, 2..4),
2296 (message_4.id, Role::User, 4..5),
2297 (message_3.id, Role::User, 5..5),
2298 ]
2299 );
2300
2301 buffer.update(cx, |buffer, cx| {
2302 buffer.edit([(4..4, "C"), (5..5, "D")], None, cx)
2303 });
2304 assert_eq!(
2305 messages(&conversation, cx),
2306 vec![
2307 (message_1.id, Role::User, 0..2),
2308 (message_2.id, Role::Assistant, 2..4),
2309 (message_4.id, Role::User, 4..6),
2310 (message_3.id, Role::User, 6..7),
2311 ]
2312 );
2313
2314 // Deleting across message boundaries merges the messages.
2315 buffer.update(cx, |buffer, cx| buffer.edit([(1..4, "")], None, cx));
2316 assert_eq!(
2317 messages(&conversation, cx),
2318 vec![
2319 (message_1.id, Role::User, 0..3),
2320 (message_3.id, Role::User, 3..4),
2321 ]
2322 );
2323
2324 // Undoing the deletion should also undo the merge.
2325 buffer.update(cx, |buffer, cx| buffer.undo(cx));
2326 assert_eq!(
2327 messages(&conversation, cx),
2328 vec![
2329 (message_1.id, Role::User, 0..2),
2330 (message_2.id, Role::Assistant, 2..4),
2331 (message_4.id, Role::User, 4..6),
2332 (message_3.id, Role::User, 6..7),
2333 ]
2334 );
2335
2336 // Redoing the deletion should also redo the merge.
2337 buffer.update(cx, |buffer, cx| buffer.redo(cx));
2338 assert_eq!(
2339 messages(&conversation, cx),
2340 vec![
2341 (message_1.id, Role::User, 0..3),
2342 (message_3.id, Role::User, 3..4),
2343 ]
2344 );
2345
2346 // Ensure we can still insert after a merged message.
2347 let message_5 = conversation.update(cx, |conversation, cx| {
2348 conversation
2349 .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
2350 .unwrap()
2351 });
2352 assert_eq!(
2353 messages(&conversation, cx),
2354 vec![
2355 (message_1.id, Role::User, 0..3),
2356 (message_5.id, Role::System, 3..4),
2357 (message_3.id, Role::User, 4..5)
2358 ]
2359 );
2360 }
2361
2362 #[gpui::test]
2363 fn test_message_splitting(cx: &mut AppContext) {
2364 let registry = Arc::new(LanguageRegistry::test());
2365 let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
2366 let buffer = conversation.read(cx).buffer.clone();
2367
2368 let message_1 = conversation.read(cx).message_anchors[0].clone();
2369 assert_eq!(
2370 messages(&conversation, cx),
2371 vec![(message_1.id, Role::User, 0..0)]
2372 );
2373
2374 buffer.update(cx, |buffer, cx| {
2375 buffer.edit([(0..0, "aaa\nbbb\nccc\nddd\n")], None, cx)
2376 });
2377
2378 let (_, message_2) =
2379 conversation.update(cx, |conversation, cx| conversation.split_message(3..3, cx));
2380 let message_2 = message_2.unwrap();
2381
2382 // We recycle newlines in the middle of a split message
2383 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\nddd\n");
2384 assert_eq!(
2385 messages(&conversation, cx),
2386 vec![
2387 (message_1.id, Role::User, 0..4),
2388 (message_2.id, Role::User, 4..16),
2389 ]
2390 );
2391
2392 let (_, message_3) =
2393 conversation.update(cx, |conversation, cx| conversation.split_message(3..3, cx));
2394 let message_3 = message_3.unwrap();
2395
2396 // We don't recycle newlines at the end of a split message
2397 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
2398 assert_eq!(
2399 messages(&conversation, cx),
2400 vec![
2401 (message_1.id, Role::User, 0..4),
2402 (message_3.id, Role::User, 4..5),
2403 (message_2.id, Role::User, 5..17),
2404 ]
2405 );
2406
2407 let (_, message_4) =
2408 conversation.update(cx, |conversation, cx| conversation.split_message(9..9, cx));
2409 let message_4 = message_4.unwrap();
2410 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
2411 assert_eq!(
2412 messages(&conversation, cx),
2413 vec![
2414 (message_1.id, Role::User, 0..4),
2415 (message_3.id, Role::User, 4..5),
2416 (message_2.id, Role::User, 5..9),
2417 (message_4.id, Role::User, 9..17),
2418 ]
2419 );
2420
2421 let (_, message_5) =
2422 conversation.update(cx, |conversation, cx| conversation.split_message(9..9, cx));
2423 let message_5 = message_5.unwrap();
2424 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\nddd\n");
2425 assert_eq!(
2426 messages(&conversation, cx),
2427 vec![
2428 (message_1.id, Role::User, 0..4),
2429 (message_3.id, Role::User, 4..5),
2430 (message_2.id, Role::User, 5..9),
2431 (message_4.id, Role::User, 9..10),
2432 (message_5.id, Role::User, 10..18),
2433 ]
2434 );
2435
2436 let (message_6, message_7) = conversation.update(cx, |conversation, cx| {
2437 conversation.split_message(14..16, cx)
2438 });
2439 let message_6 = message_6.unwrap();
2440 let message_7 = message_7.unwrap();
2441 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\ndd\nd\n");
2442 assert_eq!(
2443 messages(&conversation, cx),
2444 vec![
2445 (message_1.id, Role::User, 0..4),
2446 (message_3.id, Role::User, 4..5),
2447 (message_2.id, Role::User, 5..9),
2448 (message_4.id, Role::User, 9..10),
2449 (message_5.id, Role::User, 10..14),
2450 (message_6.id, Role::User, 14..17),
2451 (message_7.id, Role::User, 17..19),
2452 ]
2453 );
2454 }
2455
2456 #[gpui::test]
2457 fn test_messages_for_offsets(cx: &mut AppContext) {
2458 let registry = Arc::new(LanguageRegistry::test());
2459 let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
2460 let buffer = conversation.read(cx).buffer.clone();
2461
2462 let message_1 = conversation.read(cx).message_anchors[0].clone();
2463 assert_eq!(
2464 messages(&conversation, cx),
2465 vec![(message_1.id, Role::User, 0..0)]
2466 );
2467
2468 buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
2469 let message_2 = conversation
2470 .update(cx, |conversation, cx| {
2471 conversation.insert_message_after(message_1.id, Role::User, MessageStatus::Done, cx)
2472 })
2473 .unwrap();
2474 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx));
2475
2476 let message_3 = conversation
2477 .update(cx, |conversation, cx| {
2478 conversation.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
2479 })
2480 .unwrap();
2481 buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx));
2482
2483 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc");
2484 assert_eq!(
2485 messages(&conversation, cx),
2486 vec![
2487 (message_1.id, Role::User, 0..4),
2488 (message_2.id, Role::User, 4..8),
2489 (message_3.id, Role::User, 8..11)
2490 ]
2491 );
2492
2493 assert_eq!(
2494 message_ids_for_offsets(&conversation, &[0, 4, 9], cx),
2495 [message_1.id, message_2.id, message_3.id]
2496 );
2497 assert_eq!(
2498 message_ids_for_offsets(&conversation, &[0, 1, 11], cx),
2499 [message_1.id, message_3.id]
2500 );
2501
2502 let message_4 = conversation
2503 .update(cx, |conversation, cx| {
2504 conversation.insert_message_after(message_3.id, Role::User, MessageStatus::Done, cx)
2505 })
2506 .unwrap();
2507 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\n");
2508 assert_eq!(
2509 messages(&conversation, cx),
2510 vec![
2511 (message_1.id, Role::User, 0..4),
2512 (message_2.id, Role::User, 4..8),
2513 (message_3.id, Role::User, 8..12),
2514 (message_4.id, Role::User, 12..12)
2515 ]
2516 );
2517 assert_eq!(
2518 message_ids_for_offsets(&conversation, &[0, 4, 8, 12], cx),
2519 [message_1.id, message_2.id, message_3.id, message_4.id]
2520 );
2521
2522 fn message_ids_for_offsets(
2523 conversation: &ModelHandle<Conversation>,
2524 offsets: &[usize],
2525 cx: &AppContext,
2526 ) -> Vec<MessageId> {
2527 conversation
2528 .read(cx)
2529 .messages_for_offsets(offsets.iter().copied(), cx)
2530 .into_iter()
2531 .map(|message| message.id)
2532 .collect()
2533 }
2534 }
2535
2536 #[gpui::test]
2537 fn test_serialization(cx: &mut AppContext) {
2538 let registry = Arc::new(LanguageRegistry::test());
2539 let conversation =
2540 cx.add_model(|cx| Conversation::new(Default::default(), registry.clone(), cx));
2541 let buffer = conversation.read(cx).buffer.clone();
2542 let message_0 = conversation.read(cx).message_anchors[0].id;
2543 let message_1 = conversation.update(cx, |conversation, cx| {
2544 conversation
2545 .insert_message_after(message_0, Role::Assistant, MessageStatus::Done, cx)
2546 .unwrap()
2547 });
2548 let message_2 = conversation.update(cx, |conversation, cx| {
2549 conversation
2550 .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
2551 .unwrap()
2552 });
2553 buffer.update(cx, |buffer, cx| {
2554 buffer.edit([(0..0, "a"), (1..1, "b\nc")], None, cx);
2555 buffer.finalize_last_transaction();
2556 });
2557 let _message_3 = conversation.update(cx, |conversation, cx| {
2558 conversation
2559 .insert_message_after(message_2.id, Role::System, MessageStatus::Done, cx)
2560 .unwrap()
2561 });
2562 buffer.update(cx, |buffer, cx| buffer.undo(cx));
2563 assert_eq!(buffer.read(cx).text(), "a\nb\nc\n");
2564 assert_eq!(
2565 messages(&conversation, cx),
2566 [
2567 (message_0, Role::User, 0..2),
2568 (message_1.id, Role::Assistant, 2..6),
2569 (message_2.id, Role::System, 6..6),
2570 ]
2571 );
2572
2573 let deserialized_conversation = cx.add_model(|cx| {
2574 Conversation::deserialize(
2575 conversation.read(cx).serialize(cx),
2576 Default::default(),
2577 Default::default(),
2578 registry.clone(),
2579 cx,
2580 )
2581 });
2582 let deserialized_buffer = deserialized_conversation.read(cx).buffer.clone();
2583 assert_eq!(deserialized_buffer.read(cx).text(), "a\nb\nc\n");
2584 assert_eq!(
2585 messages(&deserialized_conversation, cx),
2586 [
2587 (message_0, Role::User, 0..2),
2588 (message_1.id, Role::Assistant, 2..6),
2589 (message_2.id, Role::System, 6..6),
2590 ]
2591 );
2592 }
2593
2594 fn messages(
2595 conversation: &ModelHandle<Conversation>,
2596 cx: &AppContext,
2597 ) -> Vec<(MessageId, Role, Range<usize>)> {
2598 conversation
2599 .read(cx)
2600 .messages(cx)
2601 .map(|message| (message.id, message.role, message.offset_range))
2602 .collect()
2603 }
2604}