1use crate::{
2 assistant_settings::{AssistantDockPosition, AssistantSettings},
3 MessageId, MessageMetadata, MessageStatus, OpenAIRequest, OpenAIResponseStreamEvent,
4 RequestMessage, Role, SavedConversation, SavedConversationMetadata, SavedMessage,
5};
6use anyhow::{anyhow, Result};
7use chrono::{DateTime, Local};
8use collections::{HashMap, HashSet};
9use editor::{
10 display_map::{BlockDisposition, BlockId, BlockProperties, BlockStyle, ToDisplayPoint},
11 scroll::autoscroll::{Autoscroll, AutoscrollStrategy},
12 Anchor, Editor, ToOffset,
13};
14use fs::Fs;
15use futures::{io::BufReader, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
16use gpui::{
17 actions,
18 elements::*,
19 executor::Background,
20 geometry::vector::{vec2f, Vector2F},
21 platform::{CursorStyle, MouseButton},
22 Action, AppContext, AsyncAppContext, ClipboardItem, Entity, ModelContext, ModelHandle,
23 Subscription, Task, View, ViewContext, ViewHandle, WeakViewHandle, WindowContext,
24};
25use isahc::{http::StatusCode, Request, RequestExt};
26use language::{language_settings::SoftWrap, Buffer, LanguageRegistry, ToOffset as _};
27use search::BufferSearchBar;
28use serde::Deserialize;
29use settings::SettingsStore;
30use std::{
31 cell::RefCell,
32 cmp, env,
33 fmt::Write,
34 io, iter,
35 ops::Range,
36 path::{Path, PathBuf},
37 rc::Rc,
38 sync::Arc,
39 time::Duration,
40};
41use theme::AssistantStyle;
42use util::{paths::CONVERSATIONS_DIR, post_inc, ResultExt, TryFutureExt};
43use workspace::{
44 dock::{DockPosition, Panel},
45 searchable::Direction,
46 Save, ToggleZoom, Toolbar, Workspace,
47};
48
49const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
50
51actions!(
52 assistant,
53 [
54 NewConversation,
55 Assist,
56 Split,
57 CycleMessageRole,
58 QuoteSelection,
59 ToggleFocus,
60 ResetKey,
61 ]
62);
63
64pub fn init(cx: &mut AppContext) {
65 settings::register::<AssistantSettings>(cx);
66 cx.add_action(
67 |this: &mut AssistantPanel,
68 _: &workspace::NewFile,
69 cx: &mut ViewContext<AssistantPanel>| {
70 this.new_conversation(cx);
71 },
72 );
73 cx.add_action(ConversationEditor::assist);
74 cx.capture_action(ConversationEditor::cancel_last_assist);
75 cx.capture_action(ConversationEditor::save);
76 cx.add_action(ConversationEditor::quote_selection);
77 cx.capture_action(ConversationEditor::copy);
78 cx.add_action(ConversationEditor::split);
79 cx.capture_action(ConversationEditor::cycle_message_role);
80 cx.add_action(AssistantPanel::save_api_key);
81 cx.add_action(AssistantPanel::reset_api_key);
82 cx.add_action(AssistantPanel::toggle_zoom);
83 cx.add_action(AssistantPanel::deploy);
84 cx.add_action(AssistantPanel::select_next_match);
85 cx.add_action(AssistantPanel::select_prev_match);
86 cx.add_action(AssistantPanel::handle_editor_cancel);
87 cx.add_action(
88 |workspace: &mut Workspace, _: &ToggleFocus, cx: &mut ViewContext<Workspace>| {
89 workspace.toggle_panel_focus::<AssistantPanel>(cx);
90 },
91 );
92}
93
94#[derive(Debug)]
95pub enum AssistantPanelEvent {
96 ZoomIn,
97 ZoomOut,
98 Focus,
99 Close,
100 DockPositionChanged,
101}
102
103pub struct AssistantPanel {
104 workspace: WeakViewHandle<Workspace>,
105 width: Option<f32>,
106 height: Option<f32>,
107 active_editor_index: Option<usize>,
108 prev_active_editor_index: Option<usize>,
109 editors: Vec<ViewHandle<ConversationEditor>>,
110 saved_conversations: Vec<SavedConversationMetadata>,
111 saved_conversations_list_state: UniformListState,
112 zoomed: bool,
113 has_focus: bool,
114 toolbar: ViewHandle<Toolbar>,
115 api_key: Rc<RefCell<Option<String>>>,
116 api_key_editor: Option<ViewHandle<Editor>>,
117 has_read_credentials: bool,
118 languages: Arc<LanguageRegistry>,
119 fs: Arc<dyn Fs>,
120 subscriptions: Vec<Subscription>,
121 _watch_saved_conversations: Task<Result<()>>,
122}
123
124impl AssistantPanel {
125 pub fn load(
126 workspace: WeakViewHandle<Workspace>,
127 cx: AsyncAppContext,
128 ) -> Task<Result<ViewHandle<Self>>> {
129 cx.spawn(|mut cx| async move {
130 let fs = workspace.read_with(&cx, |workspace, _| workspace.app_state().fs.clone())?;
131 let saved_conversations = SavedConversationMetadata::list(fs.clone())
132 .await
133 .log_err()
134 .unwrap_or_default();
135
136 // TODO: deserialize state.
137 let workspace_handle = workspace.clone();
138 workspace.update(&mut cx, |workspace, cx| {
139 cx.add_view::<Self, _>(|cx| {
140 const CONVERSATION_WATCH_DURATION: Duration = Duration::from_millis(100);
141 let _watch_saved_conversations = cx.spawn(move |this, mut cx| async move {
142 let mut events = fs
143 .watch(&CONVERSATIONS_DIR, CONVERSATION_WATCH_DURATION)
144 .await;
145 while events.next().await.is_some() {
146 let saved_conversations = SavedConversationMetadata::list(fs.clone())
147 .await
148 .log_err()
149 .unwrap_or_default();
150 this.update(&mut cx, |this, _| {
151 this.saved_conversations = saved_conversations
152 })
153 .ok();
154 }
155
156 anyhow::Ok(())
157 });
158
159 let toolbar = cx.add_view(|cx| {
160 let mut toolbar = Toolbar::new(None);
161 toolbar.set_can_navigate(false, cx);
162 toolbar.add_item(cx.add_view(|cx| BufferSearchBar::new(cx)), cx);
163 toolbar
164 });
165 let mut this = Self {
166 workspace: workspace_handle,
167 active_editor_index: Default::default(),
168 prev_active_editor_index: Default::default(),
169 editors: Default::default(),
170 saved_conversations,
171 saved_conversations_list_state: Default::default(),
172 zoomed: false,
173 has_focus: false,
174 toolbar,
175 api_key: Rc::new(RefCell::new(None)),
176 api_key_editor: None,
177 has_read_credentials: false,
178 languages: workspace.app_state().languages.clone(),
179 fs: workspace.app_state().fs.clone(),
180 width: None,
181 height: None,
182 subscriptions: Default::default(),
183 _watch_saved_conversations,
184 };
185
186 let mut old_dock_position = this.position(cx);
187 this.subscriptions =
188 vec![cx.observe_global::<SettingsStore, _>(move |this, cx| {
189 let new_dock_position = this.position(cx);
190 if new_dock_position != old_dock_position {
191 old_dock_position = new_dock_position;
192 cx.emit(AssistantPanelEvent::DockPositionChanged);
193 }
194 })];
195
196 this
197 })
198 })
199 })
200 }
201
202 fn new_conversation(&mut self, cx: &mut ViewContext<Self>) -> ViewHandle<ConversationEditor> {
203 let editor = cx.add_view(|cx| {
204 ConversationEditor::new(
205 self.api_key.clone(),
206 self.languages.clone(),
207 self.fs.clone(),
208 cx,
209 )
210 });
211 self.add_conversation(editor.clone(), cx);
212 editor
213 }
214
215 fn add_conversation(
216 &mut self,
217 editor: ViewHandle<ConversationEditor>,
218 cx: &mut ViewContext<Self>,
219 ) {
220 self.subscriptions
221 .push(cx.subscribe(&editor, Self::handle_conversation_editor_event));
222
223 let conversation = editor.read(cx).conversation.clone();
224 self.subscriptions
225 .push(cx.observe(&conversation, |_, _, cx| cx.notify()));
226
227 let index = self.editors.len();
228 self.editors.push(editor);
229 self.set_active_editor_index(Some(index), cx);
230 }
231
232 fn set_active_editor_index(&mut self, index: Option<usize>, cx: &mut ViewContext<Self>) {
233 self.prev_active_editor_index = self.active_editor_index;
234 self.active_editor_index = index;
235 if let Some(editor) = self.active_editor() {
236 let editor = editor.read(cx).editor.clone();
237 self.toolbar.update(cx, |toolbar, cx| {
238 toolbar.set_active_item(Some(&editor), cx);
239 });
240 if self.has_focus(cx) {
241 cx.focus(&editor);
242 }
243 } else {
244 self.toolbar.update(cx, |toolbar, cx| {
245 toolbar.set_active_item(None, cx);
246 });
247 }
248
249 cx.notify();
250 }
251
252 fn handle_conversation_editor_event(
253 &mut self,
254 _: ViewHandle<ConversationEditor>,
255 event: &ConversationEditorEvent,
256 cx: &mut ViewContext<Self>,
257 ) {
258 match event {
259 ConversationEditorEvent::TabContentChanged => cx.notify(),
260 }
261 }
262
263 fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
264 if let Some(api_key) = self
265 .api_key_editor
266 .as_ref()
267 .map(|editor| editor.read(cx).text(cx))
268 {
269 if !api_key.is_empty() {
270 cx.platform()
271 .write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
272 .log_err();
273 *self.api_key.borrow_mut() = Some(api_key);
274 self.api_key_editor.take();
275 cx.focus_self();
276 cx.notify();
277 }
278 } else {
279 cx.propagate_action();
280 }
281 }
282
283 fn reset_api_key(&mut self, _: &ResetKey, cx: &mut ViewContext<Self>) {
284 cx.platform().delete_credentials(OPENAI_API_URL).log_err();
285 self.api_key.take();
286 self.api_key_editor = Some(build_api_key_editor(cx));
287 cx.focus_self();
288 cx.notify();
289 }
290
291 fn toggle_zoom(&mut self, _: &workspace::ToggleZoom, cx: &mut ViewContext<Self>) {
292 if self.zoomed {
293 cx.emit(AssistantPanelEvent::ZoomOut)
294 } else {
295 cx.emit(AssistantPanelEvent::ZoomIn)
296 }
297 }
298
299 fn deploy(&mut self, action: &search::buffer_search::Deploy, cx: &mut ViewContext<Self>) {
300 if let Some(search_bar) = self.toolbar.read(cx).item_of_type::<BufferSearchBar>() {
301 if search_bar.update(cx, |search_bar, cx| search_bar.show(action.focus, true, cx)) {
302 return;
303 }
304 }
305 cx.propagate_action();
306 }
307
308 fn handle_editor_cancel(&mut self, _: &editor::Cancel, cx: &mut ViewContext<Self>) {
309 if let Some(search_bar) = self.toolbar.read(cx).item_of_type::<BufferSearchBar>() {
310 if !search_bar.read(cx).is_dismissed() {
311 search_bar.update(cx, |search_bar, cx| {
312 search_bar.dismiss(&Default::default(), cx)
313 });
314 return;
315 }
316 }
317 cx.propagate_action();
318 }
319
320 fn select_next_match(&mut self, _: &search::SelectNextMatch, cx: &mut ViewContext<Self>) {
321 if let Some(search_bar) = self.toolbar.read(cx).item_of_type::<BufferSearchBar>() {
322 search_bar.update(cx, |bar, cx| bar.select_match(Direction::Next, cx));
323 }
324 }
325
326 fn select_prev_match(&mut self, _: &search::SelectPrevMatch, cx: &mut ViewContext<Self>) {
327 if let Some(search_bar) = self.toolbar.read(cx).item_of_type::<BufferSearchBar>() {
328 search_bar.update(cx, |bar, cx| bar.select_match(Direction::Prev, cx));
329 }
330 }
331
332 fn active_editor(&self) -> Option<&ViewHandle<ConversationEditor>> {
333 self.editors.get(self.active_editor_index?)
334 }
335
336 fn render_hamburger_button(cx: &mut ViewContext<Self>) -> impl Element<Self> {
337 enum History {}
338 let theme = theme::current(cx);
339 let tooltip_style = theme::current(cx).tooltip.clone();
340 MouseEventHandler::<History, _>::new(0, cx, |state, _| {
341 let style = theme.assistant.hamburger_button.style_for(state);
342 Svg::for_style(style.icon.clone())
343 .contained()
344 .with_style(style.container)
345 })
346 .with_cursor_style(CursorStyle::PointingHand)
347 .on_click(MouseButton::Left, |_, this: &mut Self, cx| {
348 if this.active_editor().is_some() {
349 this.set_active_editor_index(None, cx);
350 } else {
351 this.set_active_editor_index(this.prev_active_editor_index, cx);
352 }
353 })
354 .with_tooltip::<History>(1, "History".into(), None, tooltip_style, cx)
355 }
356
357 fn render_editor_tools(&self, cx: &mut ViewContext<Self>) -> Vec<AnyElement<Self>> {
358 if self.active_editor().is_some() {
359 vec![
360 Self::render_split_button(cx).into_any(),
361 Self::render_quote_button(cx).into_any(),
362 Self::render_assist_button(cx).into_any(),
363 ]
364 } else {
365 Default::default()
366 }
367 }
368
369 fn render_split_button(cx: &mut ViewContext<Self>) -> impl Element<Self> {
370 let theme = theme::current(cx);
371 let tooltip_style = theme::current(cx).tooltip.clone();
372 MouseEventHandler::<Split, _>::new(0, cx, |state, _| {
373 let style = theme.assistant.split_button.style_for(state);
374 Svg::for_style(style.icon.clone())
375 .contained()
376 .with_style(style.container)
377 })
378 .with_cursor_style(CursorStyle::PointingHand)
379 .on_click(MouseButton::Left, |_, this: &mut Self, cx| {
380 if let Some(active_editor) = this.active_editor() {
381 active_editor.update(cx, |editor, cx| editor.split(&Default::default(), cx));
382 }
383 })
384 .with_tooltip::<Split>(
385 1,
386 "Split Message".into(),
387 Some(Box::new(Split)),
388 tooltip_style,
389 cx,
390 )
391 }
392
393 fn render_assist_button(cx: &mut ViewContext<Self>) -> impl Element<Self> {
394 let theme = theme::current(cx);
395 let tooltip_style = theme::current(cx).tooltip.clone();
396 MouseEventHandler::<Assist, _>::new(0, cx, |state, _| {
397 let style = theme.assistant.assist_button.style_for(state);
398 Svg::for_style(style.icon.clone())
399 .contained()
400 .with_style(style.container)
401 })
402 .with_cursor_style(CursorStyle::PointingHand)
403 .on_click(MouseButton::Left, |_, this: &mut Self, cx| {
404 if let Some(active_editor) = this.active_editor() {
405 active_editor.update(cx, |editor, cx| editor.assist(&Default::default(), cx));
406 }
407 })
408 .with_tooltip::<Assist>(
409 1,
410 "Assist".into(),
411 Some(Box::new(Assist)),
412 tooltip_style,
413 cx,
414 )
415 }
416
417 fn render_quote_button(cx: &mut ViewContext<Self>) -> impl Element<Self> {
418 let theme = theme::current(cx);
419 let tooltip_style = theme::current(cx).tooltip.clone();
420 MouseEventHandler::<QuoteSelection, _>::new(0, cx, |state, _| {
421 let style = theme.assistant.quote_button.style_for(state);
422 Svg::for_style(style.icon.clone())
423 .contained()
424 .with_style(style.container)
425 })
426 .with_cursor_style(CursorStyle::PointingHand)
427 .on_click(MouseButton::Left, |_, this: &mut Self, cx| {
428 if let Some(workspace) = this.workspace.upgrade(cx) {
429 cx.window_context().defer(move |cx| {
430 workspace.update(cx, |workspace, cx| {
431 ConversationEditor::quote_selection(workspace, &Default::default(), cx)
432 });
433 });
434 }
435 })
436 .with_tooltip::<QuoteSelection>(
437 1,
438 "Quote Selection".into(),
439 Some(Box::new(QuoteSelection)),
440 tooltip_style,
441 cx,
442 )
443 }
444
445 fn render_plus_button(cx: &mut ViewContext<Self>) -> impl Element<Self> {
446 let theme = theme::current(cx);
447 let tooltip_style = theme::current(cx).tooltip.clone();
448 MouseEventHandler::<NewConversation, _>::new(0, cx, |state, _| {
449 let style = theme.assistant.plus_button.style_for(state);
450 Svg::for_style(style.icon.clone())
451 .contained()
452 .with_style(style.container)
453 })
454 .with_cursor_style(CursorStyle::PointingHand)
455 .on_click(MouseButton::Left, |_, this: &mut Self, cx| {
456 this.new_conversation(cx);
457 })
458 .with_tooltip::<NewConversation>(
459 1,
460 "New Conversation".into(),
461 Some(Box::new(NewConversation)),
462 tooltip_style,
463 cx,
464 )
465 }
466
467 fn render_zoom_button(&self, cx: &mut ViewContext<Self>) -> impl Element<Self> {
468 enum ToggleZoomButton {}
469
470 let theme = theme::current(cx);
471 let tooltip_style = theme::current(cx).tooltip.clone();
472 let style = if self.zoomed {
473 &theme.assistant.zoom_out_button
474 } else {
475 &theme.assistant.zoom_in_button
476 };
477
478 MouseEventHandler::<ToggleZoomButton, _>::new(0, cx, |state, _| {
479 let style = style.style_for(state);
480 Svg::for_style(style.icon.clone())
481 .contained()
482 .with_style(style.container)
483 })
484 .with_cursor_style(CursorStyle::PointingHand)
485 .on_click(MouseButton::Left, |_, this, cx| {
486 this.toggle_zoom(&ToggleZoom, cx);
487 })
488 .with_tooltip::<ToggleZoom>(
489 0,
490 if self.zoomed {
491 "Zoom Out".into()
492 } else {
493 "Zoom In".into()
494 },
495 Some(Box::new(ToggleZoom)),
496 tooltip_style,
497 cx,
498 )
499 }
500
501 fn render_saved_conversation(
502 &mut self,
503 index: usize,
504 cx: &mut ViewContext<Self>,
505 ) -> impl Element<Self> {
506 let conversation = &self.saved_conversations[index];
507 let path = conversation.path.clone();
508 MouseEventHandler::<SavedConversationMetadata, _>::new(index, cx, move |state, cx| {
509 let style = &theme::current(cx).assistant.saved_conversation;
510 Flex::row()
511 .with_child(
512 Label::new(
513 conversation.mtime.format("%F %I:%M%p").to_string(),
514 style.saved_at.text.clone(),
515 )
516 .aligned()
517 .contained()
518 .with_style(style.saved_at.container),
519 )
520 .with_child(
521 Label::new(conversation.title.clone(), style.title.text.clone())
522 .aligned()
523 .contained()
524 .with_style(style.title.container),
525 )
526 .contained()
527 .with_style(*style.container.style_for(state))
528 })
529 .with_cursor_style(CursorStyle::PointingHand)
530 .on_click(MouseButton::Left, move |_, this, cx| {
531 this.open_conversation(path.clone(), cx)
532 .detach_and_log_err(cx)
533 })
534 }
535
536 fn open_conversation(&mut self, path: PathBuf, cx: &mut ViewContext<Self>) -> Task<Result<()>> {
537 if let Some(ix) = self.editor_index_for_path(&path, cx) {
538 self.set_active_editor_index(Some(ix), cx);
539 return Task::ready(Ok(()));
540 }
541
542 let fs = self.fs.clone();
543 let api_key = self.api_key.clone();
544 let languages = self.languages.clone();
545 cx.spawn(|this, mut cx| async move {
546 let saved_conversation = fs.load(&path).await?;
547 let saved_conversation = serde_json::from_str(&saved_conversation)?;
548 let conversation = cx.add_model(|cx| {
549 Conversation::deserialize(saved_conversation, path.clone(), api_key, languages, cx)
550 });
551 this.update(&mut cx, |this, cx| {
552 // If, by the time we've loaded the conversation, the user has already opened
553 // the same conversation, we don't want to open it again.
554 if let Some(ix) = this.editor_index_for_path(&path, cx) {
555 this.set_active_editor_index(Some(ix), cx);
556 } else {
557 let editor = cx
558 .add_view(|cx| ConversationEditor::for_conversation(conversation, fs, cx));
559 this.add_conversation(editor, cx);
560 }
561 })?;
562 Ok(())
563 })
564 }
565
566 fn editor_index_for_path(&self, path: &Path, cx: &AppContext) -> Option<usize> {
567 self.editors
568 .iter()
569 .position(|editor| editor.read(cx).conversation.read(cx).path.as_deref() == Some(path))
570 }
571}
572
573fn build_api_key_editor(cx: &mut ViewContext<AssistantPanel>) -> ViewHandle<Editor> {
574 cx.add_view(|cx| {
575 let mut editor = Editor::single_line(
576 Some(Arc::new(|theme| theme.assistant.api_key_editor.clone())),
577 cx,
578 );
579 editor.set_placeholder_text("sk-000000000000000000000000000000000000000000000000", cx);
580 editor
581 })
582}
583
584impl Entity for AssistantPanel {
585 type Event = AssistantPanelEvent;
586}
587
588impl View for AssistantPanel {
589 fn ui_name() -> &'static str {
590 "AssistantPanel"
591 }
592
593 fn render(&mut self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
594 let theme = &theme::current(cx);
595 let style = &theme.assistant;
596 if let Some(api_key_editor) = self.api_key_editor.as_ref() {
597 Flex::column()
598 .with_child(
599 Text::new(
600 "Paste your OpenAI API key and press Enter to use the assistant",
601 style.api_key_prompt.text.clone(),
602 )
603 .aligned(),
604 )
605 .with_child(
606 ChildView::new(api_key_editor, cx)
607 .contained()
608 .with_style(style.api_key_editor.container)
609 .aligned(),
610 )
611 .contained()
612 .with_style(style.api_key_prompt.container)
613 .aligned()
614 .into_any()
615 } else {
616 let title = self.active_editor().map(|editor| {
617 Label::new(editor.read(cx).title(cx), style.title.text.clone())
618 .contained()
619 .with_style(style.title.container)
620 .aligned()
621 .left()
622 .flex(1., false)
623 });
624 let mut header = Flex::row()
625 .with_child(Self::render_hamburger_button(cx).aligned())
626 .with_children(title);
627 if self.has_focus {
628 header.add_children(
629 self.render_editor_tools(cx)
630 .into_iter()
631 .map(|tool| tool.aligned().flex_float()),
632 );
633 header.add_child(Self::render_plus_button(cx).aligned().flex_float());
634 header.add_child(self.render_zoom_button(cx).aligned());
635 }
636
637 Flex::column()
638 .with_child(
639 header
640 .contained()
641 .with_style(theme.workspace.tab_bar.container)
642 .expanded()
643 .constrained()
644 .with_height(theme.workspace.tab_bar.height),
645 )
646 .with_children(if self.toolbar.read(cx).hidden() {
647 None
648 } else {
649 Some(ChildView::new(&self.toolbar, cx).expanded())
650 })
651 .with_child(if let Some(editor) = self.active_editor() {
652 ChildView::new(editor, cx).flex(1., true).into_any()
653 } else {
654 UniformList::new(
655 self.saved_conversations_list_state.clone(),
656 self.saved_conversations.len(),
657 cx,
658 |this, range, items, cx| {
659 for ix in range {
660 items.push(this.render_saved_conversation(ix, cx).into_any());
661 }
662 },
663 )
664 .flex(1., true)
665 .into_any()
666 })
667 .into_any()
668 }
669 }
670
671 fn focus_in(&mut self, _: gpui::AnyViewHandle, cx: &mut ViewContext<Self>) {
672 self.has_focus = true;
673 self.toolbar
674 .update(cx, |toolbar, cx| toolbar.focus_changed(true, cx));
675 cx.notify();
676 if cx.is_self_focused() {
677 if let Some(editor) = self.active_editor() {
678 cx.focus(editor);
679 } else if let Some(api_key_editor) = self.api_key_editor.as_ref() {
680 cx.focus(api_key_editor);
681 }
682 }
683 }
684
685 fn focus_out(&mut self, _: gpui::AnyViewHandle, cx: &mut ViewContext<Self>) {
686 self.has_focus = false;
687 self.toolbar
688 .update(cx, |toolbar, cx| toolbar.focus_changed(false, cx));
689 cx.notify();
690 }
691}
692
693impl Panel for AssistantPanel {
694 fn position(&self, cx: &WindowContext) -> DockPosition {
695 match settings::get::<AssistantSettings>(cx).dock {
696 AssistantDockPosition::Left => DockPosition::Left,
697 AssistantDockPosition::Bottom => DockPosition::Bottom,
698 AssistantDockPosition::Right => DockPosition::Right,
699 }
700 }
701
702 fn position_is_valid(&self, _: DockPosition) -> bool {
703 true
704 }
705
706 fn set_position(&mut self, position: DockPosition, cx: &mut ViewContext<Self>) {
707 settings::update_settings_file::<AssistantSettings>(self.fs.clone(), cx, move |settings| {
708 let dock = match position {
709 DockPosition::Left => AssistantDockPosition::Left,
710 DockPosition::Bottom => AssistantDockPosition::Bottom,
711 DockPosition::Right => AssistantDockPosition::Right,
712 };
713 settings.dock = Some(dock);
714 });
715 }
716
717 fn size(&self, cx: &WindowContext) -> f32 {
718 let settings = settings::get::<AssistantSettings>(cx);
719 match self.position(cx) {
720 DockPosition::Left | DockPosition::Right => {
721 self.width.unwrap_or_else(|| settings.default_width)
722 }
723 DockPosition::Bottom => self.height.unwrap_or_else(|| settings.default_height),
724 }
725 }
726
727 fn set_size(&mut self, size: f32, cx: &mut ViewContext<Self>) {
728 match self.position(cx) {
729 DockPosition::Left | DockPosition::Right => self.width = Some(size),
730 DockPosition::Bottom => self.height = Some(size),
731 }
732 cx.notify();
733 }
734
735 fn should_zoom_in_on_event(event: &AssistantPanelEvent) -> bool {
736 matches!(event, AssistantPanelEvent::ZoomIn)
737 }
738
739 fn should_zoom_out_on_event(event: &AssistantPanelEvent) -> bool {
740 matches!(event, AssistantPanelEvent::ZoomOut)
741 }
742
743 fn is_zoomed(&self, _: &WindowContext) -> bool {
744 self.zoomed
745 }
746
747 fn set_zoomed(&mut self, zoomed: bool, cx: &mut ViewContext<Self>) {
748 self.zoomed = zoomed;
749 cx.notify();
750 }
751
752 fn set_active(&mut self, active: bool, cx: &mut ViewContext<Self>) {
753 if active {
754 if self.api_key.borrow().is_none() && !self.has_read_credentials {
755 self.has_read_credentials = true;
756 let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
757 Some(api_key)
758 } else if let Some((_, api_key)) = cx
759 .platform()
760 .read_credentials(OPENAI_API_URL)
761 .log_err()
762 .flatten()
763 {
764 String::from_utf8(api_key).log_err()
765 } else {
766 None
767 };
768 if let Some(api_key) = api_key {
769 *self.api_key.borrow_mut() = Some(api_key);
770 } else if self.api_key_editor.is_none() {
771 self.api_key_editor = Some(build_api_key_editor(cx));
772 cx.notify();
773 }
774 }
775
776 if self.editors.is_empty() {
777 self.new_conversation(cx);
778 }
779 }
780 }
781
782 fn icon_path(&self) -> &'static str {
783 "icons/robot_14.svg"
784 }
785
786 fn icon_tooltip(&self) -> (String, Option<Box<dyn Action>>) {
787 ("Assistant Panel".into(), Some(Box::new(ToggleFocus)))
788 }
789
790 fn should_change_position_on_event(event: &Self::Event) -> bool {
791 matches!(event, AssistantPanelEvent::DockPositionChanged)
792 }
793
794 fn should_activate_on_event(_: &Self::Event) -> bool {
795 false
796 }
797
798 fn should_close_on_event(event: &AssistantPanelEvent) -> bool {
799 matches!(event, AssistantPanelEvent::Close)
800 }
801
802 fn has_focus(&self, _: &WindowContext) -> bool {
803 self.has_focus
804 }
805
806 fn is_focus_event(event: &Self::Event) -> bool {
807 matches!(event, AssistantPanelEvent::Focus)
808 }
809}
810
811enum ConversationEvent {
812 MessagesEdited,
813 SummaryChanged,
814 StreamedCompletion,
815}
816
817#[derive(Default)]
818struct Summary {
819 text: String,
820 done: bool,
821}
822
823struct Conversation {
824 buffer: ModelHandle<Buffer>,
825 message_anchors: Vec<MessageAnchor>,
826 messages_metadata: HashMap<MessageId, MessageMetadata>,
827 next_message_id: MessageId,
828 summary: Option<Summary>,
829 pending_summary: Task<Option<()>>,
830 completion_count: usize,
831 pending_completions: Vec<PendingCompletion>,
832 model: String,
833 token_count: Option<usize>,
834 max_token_count: usize,
835 pending_token_count: Task<Option<()>>,
836 api_key: Rc<RefCell<Option<String>>>,
837 pending_save: Task<Result<()>>,
838 path: Option<PathBuf>,
839 _subscriptions: Vec<Subscription>,
840}
841
842impl Entity for Conversation {
843 type Event = ConversationEvent;
844}
845
846impl Conversation {
847 fn new(
848 api_key: Rc<RefCell<Option<String>>>,
849 language_registry: Arc<LanguageRegistry>,
850 cx: &mut ModelContext<Self>,
851 ) -> Self {
852 let model = "gpt-3.5-turbo-0613";
853 let markdown = language_registry.language_for_name("Markdown");
854 let buffer = cx.add_model(|cx| {
855 let mut buffer = Buffer::new(0, "", cx);
856 buffer.set_language_registry(language_registry);
857 cx.spawn_weak(|buffer, mut cx| async move {
858 let markdown = markdown.await?;
859 let buffer = buffer
860 .upgrade(&cx)
861 .ok_or_else(|| anyhow!("buffer was dropped"))?;
862 buffer.update(&mut cx, |buffer, cx| {
863 buffer.set_language(Some(markdown), cx)
864 });
865 anyhow::Ok(())
866 })
867 .detach_and_log_err(cx);
868 buffer
869 });
870
871 let mut this = Self {
872 message_anchors: Default::default(),
873 messages_metadata: Default::default(),
874 next_message_id: Default::default(),
875 summary: None,
876 pending_summary: Task::ready(None),
877 completion_count: Default::default(),
878 pending_completions: Default::default(),
879 token_count: None,
880 max_token_count: tiktoken_rs::model::get_context_size(model),
881 pending_token_count: Task::ready(None),
882 model: model.into(),
883 _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
884 pending_save: Task::ready(Ok(())),
885 path: None,
886 api_key,
887 buffer,
888 };
889 let message = MessageAnchor {
890 id: MessageId(post_inc(&mut this.next_message_id.0)),
891 start: language::Anchor::MIN,
892 };
893 this.message_anchors.push(message.clone());
894 this.messages_metadata.insert(
895 message.id,
896 MessageMetadata {
897 role: Role::User,
898 sent_at: Local::now(),
899 status: MessageStatus::Done,
900 },
901 );
902
903 this.count_remaining_tokens(cx);
904 this
905 }
906
907 fn serialize(&self, cx: &AppContext) -> SavedConversation {
908 SavedConversation {
909 zed: "conversation".into(),
910 version: SavedConversation::VERSION.into(),
911 text: self.buffer.read(cx).text(),
912 message_metadata: self.messages_metadata.clone(),
913 messages: self
914 .messages(cx)
915 .map(|message| SavedMessage {
916 id: message.id,
917 start: message.offset_range.start,
918 })
919 .collect(),
920 summary: self
921 .summary
922 .as_ref()
923 .map(|summary| summary.text.clone())
924 .unwrap_or_default(),
925 model: self.model.clone(),
926 }
927 }
928
929 fn deserialize(
930 saved_conversation: SavedConversation,
931 path: PathBuf,
932 api_key: Rc<RefCell<Option<String>>>,
933 language_registry: Arc<LanguageRegistry>,
934 cx: &mut ModelContext<Self>,
935 ) -> Self {
936 let model = saved_conversation.model;
937 let markdown = language_registry.language_for_name("Markdown");
938 let mut message_anchors = Vec::new();
939 let mut next_message_id = MessageId(0);
940 let buffer = cx.add_model(|cx| {
941 let mut buffer = Buffer::new(0, saved_conversation.text, cx);
942 for message in saved_conversation.messages {
943 message_anchors.push(MessageAnchor {
944 id: message.id,
945 start: buffer.anchor_before(message.start),
946 });
947 next_message_id = cmp::max(next_message_id, MessageId(message.id.0 + 1));
948 }
949 buffer.set_language_registry(language_registry);
950 cx.spawn_weak(|buffer, mut cx| async move {
951 let markdown = markdown.await?;
952 let buffer = buffer
953 .upgrade(&cx)
954 .ok_or_else(|| anyhow!("buffer was dropped"))?;
955 buffer.update(&mut cx, |buffer, cx| {
956 buffer.set_language(Some(markdown), cx)
957 });
958 anyhow::Ok(())
959 })
960 .detach_and_log_err(cx);
961 buffer
962 });
963
964 let mut this = Self {
965 message_anchors,
966 messages_metadata: saved_conversation.message_metadata,
967 next_message_id,
968 summary: Some(Summary {
969 text: saved_conversation.summary,
970 done: true,
971 }),
972 pending_summary: Task::ready(None),
973 completion_count: Default::default(),
974 pending_completions: Default::default(),
975 token_count: None,
976 max_token_count: tiktoken_rs::model::get_context_size(&model),
977 pending_token_count: Task::ready(None),
978 model,
979 _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
980 pending_save: Task::ready(Ok(())),
981 path: Some(path),
982 api_key,
983 buffer,
984 };
985 this.count_remaining_tokens(cx);
986 this
987 }
988
989 fn handle_buffer_event(
990 &mut self,
991 _: ModelHandle<Buffer>,
992 event: &language::Event,
993 cx: &mut ModelContext<Self>,
994 ) {
995 match event {
996 language::Event::Edited => {
997 self.count_remaining_tokens(cx);
998 cx.emit(ConversationEvent::MessagesEdited);
999 }
1000 _ => {}
1001 }
1002 }
1003
1004 fn count_remaining_tokens(&mut self, cx: &mut ModelContext<Self>) {
1005 let messages = self
1006 .messages(cx)
1007 .into_iter()
1008 .filter_map(|message| {
1009 Some(tiktoken_rs::ChatCompletionRequestMessage {
1010 role: match message.role {
1011 Role::User => "user".into(),
1012 Role::Assistant => "assistant".into(),
1013 Role::System => "system".into(),
1014 },
1015 content: self
1016 .buffer
1017 .read(cx)
1018 .text_for_range(message.offset_range)
1019 .collect(),
1020 name: None,
1021 })
1022 })
1023 .collect::<Vec<_>>();
1024 let model = self.model.clone();
1025 self.pending_token_count = cx.spawn_weak(|this, mut cx| {
1026 async move {
1027 cx.background().timer(Duration::from_millis(200)).await;
1028 let token_count = cx
1029 .background()
1030 .spawn(async move { tiktoken_rs::num_tokens_from_messages(&model, &messages) })
1031 .await?;
1032
1033 this.upgrade(&cx)
1034 .ok_or_else(|| anyhow!("conversation was dropped"))?
1035 .update(&mut cx, |this, cx| {
1036 this.max_token_count = tiktoken_rs::model::get_context_size(&this.model);
1037 this.token_count = Some(token_count);
1038 cx.notify()
1039 });
1040 anyhow::Ok(())
1041 }
1042 .log_err()
1043 });
1044 }
1045
1046 fn remaining_tokens(&self) -> Option<isize> {
1047 Some(self.max_token_count as isize - self.token_count? as isize)
1048 }
1049
1050 fn set_model(&mut self, model: String, cx: &mut ModelContext<Self>) {
1051 self.model = model;
1052 self.count_remaining_tokens(cx);
1053 cx.notify();
1054 }
1055
1056 fn assist(
1057 &mut self,
1058 selected_messages: HashSet<MessageId>,
1059 cx: &mut ModelContext<Self>,
1060 ) -> Vec<MessageAnchor> {
1061 let mut user_messages = Vec::new();
1062 let mut tasks = Vec::new();
1063
1064 let last_message_id = self.message_anchors.iter().rev().find_map(|message| {
1065 message
1066 .start
1067 .is_valid(self.buffer.read(cx))
1068 .then_some(message.id)
1069 });
1070
1071 for selected_message_id in selected_messages {
1072 let selected_message_role =
1073 if let Some(metadata) = self.messages_metadata.get(&selected_message_id) {
1074 metadata.role
1075 } else {
1076 continue;
1077 };
1078
1079 if selected_message_role == Role::Assistant {
1080 if let Some(user_message) = self.insert_message_after(
1081 selected_message_id,
1082 Role::User,
1083 MessageStatus::Done,
1084 cx,
1085 ) {
1086 user_messages.push(user_message);
1087 } else {
1088 continue;
1089 }
1090 } else {
1091 let request = OpenAIRequest {
1092 model: self.model.clone(),
1093 messages: self
1094 .messages(cx)
1095 .filter(|message| matches!(message.status, MessageStatus::Done))
1096 .flat_map(|message| {
1097 let mut system_message = None;
1098 if message.id == selected_message_id {
1099 system_message = Some(RequestMessage {
1100 role: Role::System,
1101 content: concat!(
1102 "Treat the following messages as additional knowledge you have learned about, ",
1103 "but act as if they were not part of this conversation. That is, treat them ",
1104 "as if the user didn't see them and couldn't possibly inquire about them."
1105 ).into()
1106 });
1107 }
1108
1109 Some(message.to_open_ai_message(self.buffer.read(cx))).into_iter().chain(system_message)
1110 })
1111 .chain(Some(RequestMessage {
1112 role: Role::System,
1113 content: format!(
1114 "Direct your reply to message with id {}. Do not include a [Message X] header.",
1115 selected_message_id.0
1116 ),
1117 }))
1118 .collect(),
1119 stream: true,
1120 };
1121
1122 let Some(api_key) = self.api_key.borrow().clone() else { continue };
1123 let stream = stream_completion(api_key, cx.background().clone(), request);
1124 let assistant_message = self
1125 .insert_message_after(
1126 selected_message_id,
1127 Role::Assistant,
1128 MessageStatus::Pending,
1129 cx,
1130 )
1131 .unwrap();
1132
1133 // Queue up the user's next reply
1134 if Some(selected_message_id) == last_message_id {
1135 let user_message = self
1136 .insert_message_after(
1137 assistant_message.id,
1138 Role::User,
1139 MessageStatus::Done,
1140 cx,
1141 )
1142 .unwrap();
1143 user_messages.push(user_message);
1144 }
1145
1146 tasks.push(cx.spawn_weak({
1147 |this, mut cx| async move {
1148 let assistant_message_id = assistant_message.id;
1149 let stream_completion = async {
1150 let mut messages = stream.await?;
1151
1152 while let Some(message) = messages.next().await {
1153 let mut message = message?;
1154 if let Some(choice) = message.choices.pop() {
1155 this.upgrade(&cx)
1156 .ok_or_else(|| anyhow!("conversation was dropped"))?
1157 .update(&mut cx, |this, cx| {
1158 let text: Arc<str> = choice.delta.content?.into();
1159 let message_ix = this.message_anchors.iter().position(
1160 |message| message.id == assistant_message_id,
1161 )?;
1162 this.buffer.update(cx, |buffer, cx| {
1163 let offset = this.message_anchors[message_ix + 1..]
1164 .iter()
1165 .find(|message| message.start.is_valid(buffer))
1166 .map_or(buffer.len(), |message| {
1167 message
1168 .start
1169 .to_offset(buffer)
1170 .saturating_sub(1)
1171 });
1172 buffer.edit([(offset..offset, text)], None, cx);
1173 });
1174 cx.emit(ConversationEvent::StreamedCompletion);
1175
1176 Some(())
1177 });
1178 }
1179 smol::future::yield_now().await;
1180 }
1181
1182 this.upgrade(&cx)
1183 .ok_or_else(|| anyhow!("conversation was dropped"))?
1184 .update(&mut cx, |this, cx| {
1185 this.pending_completions.retain(|completion| {
1186 completion.id != this.completion_count
1187 });
1188 this.summarize(cx);
1189 });
1190
1191 anyhow::Ok(())
1192 };
1193
1194 let result = stream_completion.await;
1195 if let Some(this) = this.upgrade(&cx) {
1196 this.update(&mut cx, |this, cx| {
1197 if let Some(metadata) =
1198 this.messages_metadata.get_mut(&assistant_message.id)
1199 {
1200 match result {
1201 Ok(_) => {
1202 metadata.status = MessageStatus::Done;
1203 }
1204 Err(error) => {
1205 metadata.status = MessageStatus::Error(
1206 error.to_string().trim().into(),
1207 );
1208 }
1209 }
1210 cx.notify();
1211 }
1212 });
1213 }
1214 }
1215 }));
1216 }
1217 }
1218
1219 if !tasks.is_empty() {
1220 self.pending_completions.push(PendingCompletion {
1221 id: post_inc(&mut self.completion_count),
1222 _tasks: tasks,
1223 });
1224 }
1225
1226 user_messages
1227 }
1228
1229 fn cancel_last_assist(&mut self) -> bool {
1230 self.pending_completions.pop().is_some()
1231 }
1232
1233 fn cycle_message_roles(&mut self, ids: HashSet<MessageId>, cx: &mut ModelContext<Self>) {
1234 for id in ids {
1235 if let Some(metadata) = self.messages_metadata.get_mut(&id) {
1236 metadata.role.cycle();
1237 cx.emit(ConversationEvent::MessagesEdited);
1238 cx.notify();
1239 }
1240 }
1241 }
1242
1243 fn insert_message_after(
1244 &mut self,
1245 message_id: MessageId,
1246 role: Role,
1247 status: MessageStatus,
1248 cx: &mut ModelContext<Self>,
1249 ) -> Option<MessageAnchor> {
1250 if let Some(prev_message_ix) = self
1251 .message_anchors
1252 .iter()
1253 .position(|message| message.id == message_id)
1254 {
1255 // Find the next valid message after the one we were given.
1256 let mut next_message_ix = prev_message_ix + 1;
1257 while let Some(next_message) = self.message_anchors.get(next_message_ix) {
1258 if next_message.start.is_valid(self.buffer.read(cx)) {
1259 break;
1260 }
1261 next_message_ix += 1;
1262 }
1263
1264 let start = self.buffer.update(cx, |buffer, cx| {
1265 let offset = self
1266 .message_anchors
1267 .get(next_message_ix)
1268 .map_or(buffer.len(), |message| message.start.to_offset(buffer) - 1);
1269 buffer.edit([(offset..offset, "\n")], None, cx);
1270 buffer.anchor_before(offset + 1)
1271 });
1272 let message = MessageAnchor {
1273 id: MessageId(post_inc(&mut self.next_message_id.0)),
1274 start,
1275 };
1276 self.message_anchors
1277 .insert(next_message_ix, message.clone());
1278 self.messages_metadata.insert(
1279 message.id,
1280 MessageMetadata {
1281 role,
1282 sent_at: Local::now(),
1283 status,
1284 },
1285 );
1286 cx.emit(ConversationEvent::MessagesEdited);
1287 Some(message)
1288 } else {
1289 None
1290 }
1291 }
1292
1293 fn split_message(
1294 &mut self,
1295 range: Range<usize>,
1296 cx: &mut ModelContext<Self>,
1297 ) -> (Option<MessageAnchor>, Option<MessageAnchor>) {
1298 let start_message = self.message_for_offset(range.start, cx);
1299 let end_message = self.message_for_offset(range.end, cx);
1300 if let Some((start_message, end_message)) = start_message.zip(end_message) {
1301 // Prevent splitting when range spans multiple messages.
1302 if start_message.id != end_message.id {
1303 return (None, None);
1304 }
1305
1306 let message = start_message;
1307 let role = message.role;
1308 let mut edited_buffer = false;
1309
1310 let mut suffix_start = None;
1311 if range.start > message.offset_range.start && range.end < message.offset_range.end - 1
1312 {
1313 if self.buffer.read(cx).chars_at(range.end).next() == Some('\n') {
1314 suffix_start = Some(range.end + 1);
1315 } else if self.buffer.read(cx).reversed_chars_at(range.end).next() == Some('\n') {
1316 suffix_start = Some(range.end);
1317 }
1318 }
1319
1320 let suffix = if let Some(suffix_start) = suffix_start {
1321 MessageAnchor {
1322 id: MessageId(post_inc(&mut self.next_message_id.0)),
1323 start: self.buffer.read(cx).anchor_before(suffix_start),
1324 }
1325 } else {
1326 self.buffer.update(cx, |buffer, cx| {
1327 buffer.edit([(range.end..range.end, "\n")], None, cx);
1328 });
1329 edited_buffer = true;
1330 MessageAnchor {
1331 id: MessageId(post_inc(&mut self.next_message_id.0)),
1332 start: self.buffer.read(cx).anchor_before(range.end + 1),
1333 }
1334 };
1335
1336 self.message_anchors
1337 .insert(message.index_range.end + 1, suffix.clone());
1338 self.messages_metadata.insert(
1339 suffix.id,
1340 MessageMetadata {
1341 role,
1342 sent_at: Local::now(),
1343 status: MessageStatus::Done,
1344 },
1345 );
1346
1347 let new_messages =
1348 if range.start == range.end || range.start == message.offset_range.start {
1349 (None, Some(suffix))
1350 } else {
1351 let mut prefix_end = None;
1352 if range.start > message.offset_range.start
1353 && range.end < message.offset_range.end - 1
1354 {
1355 if self.buffer.read(cx).chars_at(range.start).next() == Some('\n') {
1356 prefix_end = Some(range.start + 1);
1357 } else if self.buffer.read(cx).reversed_chars_at(range.start).next()
1358 == Some('\n')
1359 {
1360 prefix_end = Some(range.start);
1361 }
1362 }
1363
1364 let selection = if let Some(prefix_end) = prefix_end {
1365 cx.emit(ConversationEvent::MessagesEdited);
1366 MessageAnchor {
1367 id: MessageId(post_inc(&mut self.next_message_id.0)),
1368 start: self.buffer.read(cx).anchor_before(prefix_end),
1369 }
1370 } else {
1371 self.buffer.update(cx, |buffer, cx| {
1372 buffer.edit([(range.start..range.start, "\n")], None, cx)
1373 });
1374 edited_buffer = true;
1375 MessageAnchor {
1376 id: MessageId(post_inc(&mut self.next_message_id.0)),
1377 start: self.buffer.read(cx).anchor_before(range.end + 1),
1378 }
1379 };
1380
1381 self.message_anchors
1382 .insert(message.index_range.end + 1, selection.clone());
1383 self.messages_metadata.insert(
1384 selection.id,
1385 MessageMetadata {
1386 role,
1387 sent_at: Local::now(),
1388 status: MessageStatus::Done,
1389 },
1390 );
1391 (Some(selection), Some(suffix))
1392 };
1393
1394 if !edited_buffer {
1395 cx.emit(ConversationEvent::MessagesEdited);
1396 }
1397 new_messages
1398 } else {
1399 (None, None)
1400 }
1401 }
1402
1403 fn summarize(&mut self, cx: &mut ModelContext<Self>) {
1404 if self.message_anchors.len() >= 2 && self.summary.is_none() {
1405 let api_key = self.api_key.borrow().clone();
1406 if let Some(api_key) = api_key {
1407 let messages = self
1408 .messages(cx)
1409 .take(2)
1410 .map(|message| message.to_open_ai_message(self.buffer.read(cx)))
1411 .chain(Some(RequestMessage {
1412 role: Role::User,
1413 content:
1414 "Summarize the conversation into a short title without punctuation"
1415 .into(),
1416 }));
1417 let request = OpenAIRequest {
1418 model: self.model.clone(),
1419 messages: messages.collect(),
1420 stream: true,
1421 };
1422
1423 let stream = stream_completion(api_key, cx.background().clone(), request);
1424 self.pending_summary = cx.spawn(|this, mut cx| {
1425 async move {
1426 let mut messages = stream.await?;
1427
1428 while let Some(message) = messages.next().await {
1429 let mut message = message?;
1430 if let Some(choice) = message.choices.pop() {
1431 let text = choice.delta.content.unwrap_or_default();
1432 this.update(&mut cx, |this, cx| {
1433 this.summary
1434 .get_or_insert(Default::default())
1435 .text
1436 .push_str(&text);
1437 cx.emit(ConversationEvent::SummaryChanged);
1438 });
1439 }
1440 }
1441
1442 this.update(&mut cx, |this, cx| {
1443 if let Some(summary) = this.summary.as_mut() {
1444 summary.done = true;
1445 cx.emit(ConversationEvent::SummaryChanged);
1446 }
1447 });
1448
1449 anyhow::Ok(())
1450 }
1451 .log_err()
1452 });
1453 }
1454 }
1455 }
1456
1457 fn message_for_offset(&self, offset: usize, cx: &AppContext) -> Option<Message> {
1458 self.messages_for_offsets([offset], cx).pop()
1459 }
1460
1461 fn messages_for_offsets(
1462 &self,
1463 offsets: impl IntoIterator<Item = usize>,
1464 cx: &AppContext,
1465 ) -> Vec<Message> {
1466 let mut result = Vec::new();
1467
1468 let mut messages = self.messages(cx).peekable();
1469 let mut offsets = offsets.into_iter().peekable();
1470 let mut current_message = messages.next();
1471 while let Some(offset) = offsets.next() {
1472 // Locate the message that contains the offset.
1473 while current_message.as_ref().map_or(false, |message| {
1474 !message.offset_range.contains(&offset) && messages.peek().is_some()
1475 }) {
1476 current_message = messages.next();
1477 }
1478 let Some(message) = current_message.as_ref() else { break };
1479
1480 // Skip offsets that are in the same message.
1481 while offsets.peek().map_or(false, |offset| {
1482 message.offset_range.contains(offset) || messages.peek().is_none()
1483 }) {
1484 offsets.next();
1485 }
1486
1487 result.push(message.clone());
1488 }
1489 result
1490 }
1491
1492 fn messages<'a>(&'a self, cx: &'a AppContext) -> impl 'a + Iterator<Item = Message> {
1493 let buffer = self.buffer.read(cx);
1494 let mut message_anchors = self.message_anchors.iter().enumerate().peekable();
1495 iter::from_fn(move || {
1496 while let Some((start_ix, message_anchor)) = message_anchors.next() {
1497 let metadata = self.messages_metadata.get(&message_anchor.id)?;
1498 let message_start = message_anchor.start.to_offset(buffer);
1499 let mut message_end = None;
1500 let mut end_ix = start_ix;
1501 while let Some((_, next_message)) = message_anchors.peek() {
1502 if next_message.start.is_valid(buffer) {
1503 message_end = Some(next_message.start);
1504 break;
1505 } else {
1506 end_ix += 1;
1507 message_anchors.next();
1508 }
1509 }
1510 let message_end = message_end
1511 .unwrap_or(language::Anchor::MAX)
1512 .to_offset(buffer);
1513 return Some(Message {
1514 index_range: start_ix..end_ix,
1515 offset_range: message_start..message_end,
1516 id: message_anchor.id,
1517 anchor: message_anchor.start,
1518 role: metadata.role,
1519 sent_at: metadata.sent_at,
1520 status: metadata.status.clone(),
1521 });
1522 }
1523 None
1524 })
1525 }
1526
1527 fn save(
1528 &mut self,
1529 debounce: Option<Duration>,
1530 fs: Arc<dyn Fs>,
1531 cx: &mut ModelContext<Conversation>,
1532 ) {
1533 self.pending_save = cx.spawn(|this, mut cx| async move {
1534 if let Some(debounce) = debounce {
1535 cx.background().timer(debounce).await;
1536 }
1537
1538 let (old_path, summary) = this.read_with(&cx, |this, _| {
1539 let path = this.path.clone();
1540 let summary = if let Some(summary) = this.summary.as_ref() {
1541 if summary.done {
1542 Some(summary.text.clone())
1543 } else {
1544 None
1545 }
1546 } else {
1547 None
1548 };
1549 (path, summary)
1550 });
1551
1552 if let Some(summary) = summary {
1553 let conversation = this.read_with(&cx, |this, cx| this.serialize(cx));
1554 let path = if let Some(old_path) = old_path {
1555 old_path
1556 } else {
1557 let mut discriminant = 1;
1558 let mut new_path;
1559 loop {
1560 new_path = CONVERSATIONS_DIR.join(&format!(
1561 "{} - {}.zed.json",
1562 summary.trim(),
1563 discriminant
1564 ));
1565 if fs.is_file(&new_path).await {
1566 discriminant += 1;
1567 } else {
1568 break;
1569 }
1570 }
1571 new_path
1572 };
1573
1574 fs.create_dir(CONVERSATIONS_DIR.as_ref()).await?;
1575 fs.atomic_write(path.clone(), serde_json::to_string(&conversation).unwrap())
1576 .await?;
1577 this.update(&mut cx, |this, _| this.path = Some(path));
1578 }
1579
1580 Ok(())
1581 });
1582 }
1583}
1584
1585struct PendingCompletion {
1586 id: usize,
1587 _tasks: Vec<Task<()>>,
1588}
1589
1590enum ConversationEditorEvent {
1591 TabContentChanged,
1592}
1593
1594#[derive(Copy, Clone, Debug, PartialEq)]
1595struct ScrollPosition {
1596 offset_before_cursor: Vector2F,
1597 cursor: Anchor,
1598}
1599
1600struct ConversationEditor {
1601 conversation: ModelHandle<Conversation>,
1602 fs: Arc<dyn Fs>,
1603 editor: ViewHandle<Editor>,
1604 blocks: HashSet<BlockId>,
1605 scroll_position: Option<ScrollPosition>,
1606 _subscriptions: Vec<Subscription>,
1607}
1608
1609impl ConversationEditor {
1610 fn new(
1611 api_key: Rc<RefCell<Option<String>>>,
1612 language_registry: Arc<LanguageRegistry>,
1613 fs: Arc<dyn Fs>,
1614 cx: &mut ViewContext<Self>,
1615 ) -> Self {
1616 let conversation = cx.add_model(|cx| Conversation::new(api_key, language_registry, cx));
1617 Self::for_conversation(conversation, fs, cx)
1618 }
1619
1620 fn for_conversation(
1621 conversation: ModelHandle<Conversation>,
1622 fs: Arc<dyn Fs>,
1623 cx: &mut ViewContext<Self>,
1624 ) -> Self {
1625 let editor = cx.add_view(|cx| {
1626 let mut editor = Editor::for_buffer(conversation.read(cx).buffer.clone(), None, cx);
1627 editor.set_soft_wrap_mode(SoftWrap::EditorWidth, cx);
1628 editor.set_show_gutter(false, cx);
1629 editor
1630 });
1631
1632 let _subscriptions = vec![
1633 cx.observe(&conversation, |_, _, cx| cx.notify()),
1634 cx.subscribe(&conversation, Self::handle_conversation_event),
1635 cx.subscribe(&editor, Self::handle_editor_event),
1636 ];
1637
1638 let mut this = Self {
1639 conversation,
1640 editor,
1641 blocks: Default::default(),
1642 scroll_position: None,
1643 fs,
1644 _subscriptions,
1645 };
1646 this.update_message_headers(cx);
1647 this
1648 }
1649
1650 fn assist(&mut self, _: &Assist, cx: &mut ViewContext<Self>) {
1651 let cursors = self.cursors(cx);
1652
1653 let user_messages = self.conversation.update(cx, |conversation, cx| {
1654 let selected_messages = conversation
1655 .messages_for_offsets(cursors, cx)
1656 .into_iter()
1657 .map(|message| message.id)
1658 .collect();
1659 conversation.assist(selected_messages, cx)
1660 });
1661 let new_selections = user_messages
1662 .iter()
1663 .map(|message| {
1664 let cursor = message
1665 .start
1666 .to_offset(self.conversation.read(cx).buffer.read(cx));
1667 cursor..cursor
1668 })
1669 .collect::<Vec<_>>();
1670 if !new_selections.is_empty() {
1671 self.editor.update(cx, |editor, cx| {
1672 editor.change_selections(
1673 Some(Autoscroll::Strategy(AutoscrollStrategy::Fit)),
1674 cx,
1675 |selections| selections.select_ranges(new_selections),
1676 );
1677 });
1678 // Avoid scrolling to the new cursor position so the assistant's output is stable.
1679 cx.defer(|this, _| this.scroll_position = None);
1680 }
1681 }
1682
1683 fn cancel_last_assist(&mut self, _: &editor::Cancel, cx: &mut ViewContext<Self>) {
1684 if !self
1685 .conversation
1686 .update(cx, |conversation, _| conversation.cancel_last_assist())
1687 {
1688 cx.propagate_action();
1689 }
1690 }
1691
1692 fn cycle_message_role(&mut self, _: &CycleMessageRole, cx: &mut ViewContext<Self>) {
1693 let cursors = self.cursors(cx);
1694 self.conversation.update(cx, |conversation, cx| {
1695 let messages = conversation
1696 .messages_for_offsets(cursors, cx)
1697 .into_iter()
1698 .map(|message| message.id)
1699 .collect();
1700 conversation.cycle_message_roles(messages, cx)
1701 });
1702 }
1703
1704 fn cursors(&self, cx: &AppContext) -> Vec<usize> {
1705 let selections = self.editor.read(cx).selections.all::<usize>(cx);
1706 selections
1707 .into_iter()
1708 .map(|selection| selection.head())
1709 .collect()
1710 }
1711
1712 fn handle_conversation_event(
1713 &mut self,
1714 _: ModelHandle<Conversation>,
1715 event: &ConversationEvent,
1716 cx: &mut ViewContext<Self>,
1717 ) {
1718 match event {
1719 ConversationEvent::MessagesEdited => {
1720 self.update_message_headers(cx);
1721 self.conversation.update(cx, |conversation, cx| {
1722 conversation.save(Some(Duration::from_millis(500)), self.fs.clone(), cx);
1723 });
1724 }
1725 ConversationEvent::SummaryChanged => {
1726 cx.emit(ConversationEditorEvent::TabContentChanged);
1727 self.conversation.update(cx, |conversation, cx| {
1728 conversation.save(None, self.fs.clone(), cx);
1729 });
1730 }
1731 ConversationEvent::StreamedCompletion => {
1732 self.editor.update(cx, |editor, cx| {
1733 if let Some(scroll_position) = self.scroll_position {
1734 let snapshot = editor.snapshot(cx);
1735 let cursor_point = scroll_position.cursor.to_display_point(&snapshot);
1736 let scroll_top =
1737 cursor_point.row() as f32 - scroll_position.offset_before_cursor.y();
1738 editor.set_scroll_position(
1739 vec2f(scroll_position.offset_before_cursor.x(), scroll_top),
1740 cx,
1741 );
1742 }
1743 });
1744 }
1745 }
1746 }
1747
1748 fn handle_editor_event(
1749 &mut self,
1750 _: ViewHandle<Editor>,
1751 event: &editor::Event,
1752 cx: &mut ViewContext<Self>,
1753 ) {
1754 match event {
1755 editor::Event::ScrollPositionChanged { autoscroll, .. } => {
1756 let cursor_scroll_position = self.cursor_scroll_position(cx);
1757 if *autoscroll {
1758 self.scroll_position = cursor_scroll_position;
1759 } else if self.scroll_position != cursor_scroll_position {
1760 self.scroll_position = None;
1761 }
1762 }
1763 editor::Event::SelectionsChanged { .. } => {
1764 self.scroll_position = self.cursor_scroll_position(cx);
1765 }
1766 _ => {}
1767 }
1768 }
1769
1770 fn cursor_scroll_position(&self, cx: &mut ViewContext<Self>) -> Option<ScrollPosition> {
1771 self.editor.update(cx, |editor, cx| {
1772 let snapshot = editor.snapshot(cx);
1773 let cursor = editor.selections.newest_anchor().head();
1774 let cursor_row = cursor.to_display_point(&snapshot.display_snapshot).row() as f32;
1775 let scroll_position = editor
1776 .scroll_manager
1777 .anchor()
1778 .scroll_position(&snapshot.display_snapshot);
1779
1780 let scroll_bottom = scroll_position.y() + editor.visible_line_count().unwrap_or(0.);
1781 if (scroll_position.y()..scroll_bottom).contains(&cursor_row) {
1782 Some(ScrollPosition {
1783 cursor,
1784 offset_before_cursor: vec2f(
1785 scroll_position.x(),
1786 cursor_row - scroll_position.y(),
1787 ),
1788 })
1789 } else {
1790 None
1791 }
1792 })
1793 }
1794
1795 fn update_message_headers(&mut self, cx: &mut ViewContext<Self>) {
1796 self.editor.update(cx, |editor, cx| {
1797 let buffer = editor.buffer().read(cx).snapshot(cx);
1798 let excerpt_id = *buffer.as_singleton().unwrap().0;
1799 let old_blocks = std::mem::take(&mut self.blocks);
1800 let new_blocks = self
1801 .conversation
1802 .read(cx)
1803 .messages(cx)
1804 .map(|message| BlockProperties {
1805 position: buffer.anchor_in_excerpt(excerpt_id, message.anchor),
1806 height: 2,
1807 style: BlockStyle::Sticky,
1808 render: Arc::new({
1809 let conversation = self.conversation.clone();
1810 // let metadata = message.metadata.clone();
1811 // let message = message.clone();
1812 move |cx| {
1813 enum Sender {}
1814 enum ErrorTooltip {}
1815
1816 let theme = theme::current(cx);
1817 let style = &theme.assistant;
1818 let message_id = message.id;
1819 let sender = MouseEventHandler::<Sender, _>::new(
1820 message_id.0,
1821 cx,
1822 |state, _| match message.role {
1823 Role::User => {
1824 let style = style.user_sender.style_for(state);
1825 Label::new("You", style.text.clone())
1826 .contained()
1827 .with_style(style.container)
1828 }
1829 Role::Assistant => {
1830 let style = style.assistant_sender.style_for(state);
1831 Label::new("Assistant", style.text.clone())
1832 .contained()
1833 .with_style(style.container)
1834 }
1835 Role::System => {
1836 let style = style.system_sender.style_for(state);
1837 Label::new("System", style.text.clone())
1838 .contained()
1839 .with_style(style.container)
1840 }
1841 },
1842 )
1843 .with_cursor_style(CursorStyle::PointingHand)
1844 .on_down(MouseButton::Left, {
1845 let conversation = conversation.clone();
1846 move |_, _, cx| {
1847 conversation.update(cx, |conversation, cx| {
1848 conversation.cycle_message_roles(
1849 HashSet::from_iter(Some(message_id)),
1850 cx,
1851 )
1852 })
1853 }
1854 });
1855
1856 Flex::row()
1857 .with_child(sender.aligned())
1858 .with_child(
1859 Label::new(
1860 message.sent_at.format("%I:%M%P").to_string(),
1861 style.sent_at.text.clone(),
1862 )
1863 .contained()
1864 .with_style(style.sent_at.container)
1865 .aligned(),
1866 )
1867 .with_children(
1868 if let MessageStatus::Error(error) = &message.status {
1869 Some(
1870 Svg::new("icons/circle_x_mark_12.svg")
1871 .with_color(style.error_icon.color)
1872 .constrained()
1873 .with_width(style.error_icon.width)
1874 .contained()
1875 .with_style(style.error_icon.container)
1876 .with_tooltip::<ErrorTooltip>(
1877 message_id.0,
1878 error.to_string(),
1879 None,
1880 theme.tooltip.clone(),
1881 cx,
1882 )
1883 .aligned(),
1884 )
1885 } else {
1886 None
1887 },
1888 )
1889 .aligned()
1890 .left()
1891 .contained()
1892 .with_style(style.message_header)
1893 .into_any()
1894 }
1895 }),
1896 disposition: BlockDisposition::Above,
1897 })
1898 .collect::<Vec<_>>();
1899
1900 editor.remove_blocks(old_blocks, None, cx);
1901 let ids = editor.insert_blocks(new_blocks, None, cx);
1902 self.blocks = HashSet::from_iter(ids);
1903 });
1904 }
1905
1906 fn quote_selection(
1907 workspace: &mut Workspace,
1908 _: &QuoteSelection,
1909 cx: &mut ViewContext<Workspace>,
1910 ) {
1911 let Some(panel) = workspace.panel::<AssistantPanel>(cx) else {
1912 return;
1913 };
1914 let Some(editor) = workspace.active_item(cx).and_then(|item| item.downcast::<Editor>()) else {
1915 return;
1916 };
1917
1918 let text = editor.read_with(cx, |editor, cx| {
1919 let range = editor.selections.newest::<usize>(cx).range();
1920 let buffer = editor.buffer().read(cx).snapshot(cx);
1921 let start_language = buffer.language_at(range.start);
1922 let end_language = buffer.language_at(range.end);
1923 let language_name = if start_language == end_language {
1924 start_language.map(|language| language.name())
1925 } else {
1926 None
1927 };
1928 let language_name = language_name.as_deref().unwrap_or("").to_lowercase();
1929
1930 let selected_text = buffer.text_for_range(range).collect::<String>();
1931 if selected_text.is_empty() {
1932 None
1933 } else {
1934 Some(if language_name == "markdown" {
1935 selected_text
1936 .lines()
1937 .map(|line| format!("> {}", line))
1938 .collect::<Vec<_>>()
1939 .join("\n")
1940 } else {
1941 format!("```{language_name}\n{selected_text}\n```")
1942 })
1943 }
1944 });
1945
1946 // Activate the panel
1947 if !panel.read(cx).has_focus(cx) {
1948 workspace.toggle_panel_focus::<AssistantPanel>(cx);
1949 }
1950
1951 if let Some(text) = text {
1952 panel.update(cx, |panel, cx| {
1953 let conversation = panel
1954 .active_editor()
1955 .cloned()
1956 .unwrap_or_else(|| panel.new_conversation(cx));
1957 conversation.update(cx, |conversation, cx| {
1958 conversation
1959 .editor
1960 .update(cx, |editor, cx| editor.insert(&text, cx))
1961 });
1962 });
1963 }
1964 }
1965
1966 fn copy(&mut self, _: &editor::Copy, cx: &mut ViewContext<Self>) {
1967 let editor = self.editor.read(cx);
1968 let conversation = self.conversation.read(cx);
1969 if editor.selections.count() == 1 {
1970 let selection = editor.selections.newest::<usize>(cx);
1971 let mut copied_text = String::new();
1972 let mut spanned_messages = 0;
1973 for message in conversation.messages(cx) {
1974 if message.offset_range.start >= selection.range().end {
1975 break;
1976 } else if message.offset_range.end >= selection.range().start {
1977 let range = cmp::max(message.offset_range.start, selection.range().start)
1978 ..cmp::min(message.offset_range.end, selection.range().end);
1979 if !range.is_empty() {
1980 spanned_messages += 1;
1981 write!(&mut copied_text, "## {}\n\n", message.role).unwrap();
1982 for chunk in conversation.buffer.read(cx).text_for_range(range) {
1983 copied_text.push_str(&chunk);
1984 }
1985 copied_text.push('\n');
1986 }
1987 }
1988 }
1989
1990 if spanned_messages > 1 {
1991 cx.platform()
1992 .write_to_clipboard(ClipboardItem::new(copied_text));
1993 return;
1994 }
1995 }
1996
1997 cx.propagate_action();
1998 }
1999
2000 fn split(&mut self, _: &Split, cx: &mut ViewContext<Self>) {
2001 self.conversation.update(cx, |conversation, cx| {
2002 let selections = self.editor.read(cx).selections.disjoint_anchors();
2003 for selection in selections.into_iter() {
2004 let buffer = self.editor.read(cx).buffer().read(cx).snapshot(cx);
2005 let range = selection
2006 .map(|endpoint| endpoint.to_offset(&buffer))
2007 .range();
2008 conversation.split_message(range, cx);
2009 }
2010 });
2011 }
2012
2013 fn save(&mut self, _: &Save, cx: &mut ViewContext<Self>) {
2014 self.conversation.update(cx, |conversation, cx| {
2015 conversation.save(None, self.fs.clone(), cx)
2016 });
2017 }
2018
2019 fn cycle_model(&mut self, cx: &mut ViewContext<Self>) {
2020 self.conversation.update(cx, |conversation, cx| {
2021 let new_model = match conversation.model.as_str() {
2022 "gpt-4-0613" => "gpt-3.5-turbo-0613",
2023 _ => "gpt-4-0613",
2024 };
2025 conversation.set_model(new_model.into(), cx);
2026 });
2027 }
2028
2029 fn title(&self, cx: &AppContext) -> String {
2030 self.conversation
2031 .read(cx)
2032 .summary
2033 .as_ref()
2034 .map(|summary| summary.text.clone())
2035 .unwrap_or_else(|| "New Conversation".into())
2036 }
2037
2038 fn render_current_model(
2039 &self,
2040 style: &AssistantStyle,
2041 cx: &mut ViewContext<Self>,
2042 ) -> impl Element<Self> {
2043 enum Model {}
2044
2045 MouseEventHandler::<Model, _>::new(0, cx, |state, cx| {
2046 let style = style.model.style_for(state);
2047 Label::new(self.conversation.read(cx).model.clone(), style.text.clone())
2048 .contained()
2049 .with_style(style.container)
2050 })
2051 .with_cursor_style(CursorStyle::PointingHand)
2052 .on_click(MouseButton::Left, |_, this, cx| this.cycle_model(cx))
2053 }
2054
2055 fn render_remaining_tokens(
2056 &self,
2057 style: &AssistantStyle,
2058 cx: &mut ViewContext<Self>,
2059 ) -> Option<impl Element<Self>> {
2060 let remaining_tokens = self.conversation.read(cx).remaining_tokens()?;
2061 let remaining_tokens_style = if remaining_tokens <= 0 {
2062 &style.no_remaining_tokens
2063 } else {
2064 &style.remaining_tokens
2065 };
2066 Some(
2067 Label::new(
2068 remaining_tokens.to_string(),
2069 remaining_tokens_style.text.clone(),
2070 )
2071 .contained()
2072 .with_style(remaining_tokens_style.container),
2073 )
2074 }
2075}
2076
2077impl Entity for ConversationEditor {
2078 type Event = ConversationEditorEvent;
2079}
2080
2081impl View for ConversationEditor {
2082 fn ui_name() -> &'static str {
2083 "ConversationEditor"
2084 }
2085
2086 fn render(&mut self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
2087 let theme = &theme::current(cx).assistant;
2088 Stack::new()
2089 .with_child(
2090 ChildView::new(&self.editor, cx)
2091 .contained()
2092 .with_style(theme.container),
2093 )
2094 .with_child(
2095 Flex::row()
2096 .with_child(self.render_current_model(theme, cx))
2097 .with_children(self.render_remaining_tokens(theme, cx))
2098 .aligned()
2099 .top()
2100 .right(),
2101 )
2102 .into_any()
2103 }
2104
2105 fn focus_in(&mut self, _: gpui::AnyViewHandle, cx: &mut ViewContext<Self>) {
2106 if cx.is_self_focused() {
2107 cx.focus(&self.editor);
2108 }
2109 }
2110}
2111
2112#[derive(Clone, Debug)]
2113struct MessageAnchor {
2114 id: MessageId,
2115 start: language::Anchor,
2116}
2117
2118#[derive(Clone, Debug)]
2119pub struct Message {
2120 offset_range: Range<usize>,
2121 index_range: Range<usize>,
2122 id: MessageId,
2123 anchor: language::Anchor,
2124 role: Role,
2125 sent_at: DateTime<Local>,
2126 status: MessageStatus,
2127}
2128
2129impl Message {
2130 fn to_open_ai_message(&self, buffer: &Buffer) -> RequestMessage {
2131 let mut content = format!("[Message {}]\n", self.id.0).to_string();
2132 content.extend(buffer.text_for_range(self.offset_range.clone()));
2133 RequestMessage {
2134 role: self.role,
2135 content: content.trim_end().into(),
2136 }
2137 }
2138}
2139
2140async fn stream_completion(
2141 api_key: String,
2142 executor: Arc<Background>,
2143 mut request: OpenAIRequest,
2144) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
2145 request.stream = true;
2146
2147 let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
2148
2149 let json_data = serde_json::to_string(&request)?;
2150 let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
2151 .header("Content-Type", "application/json")
2152 .header("Authorization", format!("Bearer {}", api_key))
2153 .body(json_data)?
2154 .send_async()
2155 .await?;
2156
2157 let status = response.status();
2158 if status == StatusCode::OK {
2159 executor
2160 .spawn(async move {
2161 let mut lines = BufReader::new(response.body_mut()).lines();
2162
2163 fn parse_line(
2164 line: Result<String, io::Error>,
2165 ) -> Result<Option<OpenAIResponseStreamEvent>> {
2166 if let Some(data) = line?.strip_prefix("data: ") {
2167 let event = serde_json::from_str(&data)?;
2168 Ok(Some(event))
2169 } else {
2170 Ok(None)
2171 }
2172 }
2173
2174 while let Some(line) = lines.next().await {
2175 if let Some(event) = parse_line(line).transpose() {
2176 let done = event.as_ref().map_or(false, |event| {
2177 event
2178 .choices
2179 .last()
2180 .map_or(false, |choice| choice.finish_reason.is_some())
2181 });
2182 if tx.unbounded_send(event).is_err() {
2183 break;
2184 }
2185
2186 if done {
2187 break;
2188 }
2189 }
2190 }
2191
2192 anyhow::Ok(())
2193 })
2194 .detach();
2195
2196 Ok(rx)
2197 } else {
2198 let mut body = String::new();
2199 response.body_mut().read_to_string(&mut body).await?;
2200
2201 #[derive(Deserialize)]
2202 struct OpenAIResponse {
2203 error: OpenAIError,
2204 }
2205
2206 #[derive(Deserialize)]
2207 struct OpenAIError {
2208 message: String,
2209 }
2210
2211 match serde_json::from_str::<OpenAIResponse>(&body) {
2212 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
2213 "Failed to connect to OpenAI API: {}",
2214 response.error.message,
2215 )),
2216
2217 _ => Err(anyhow!(
2218 "Failed to connect to OpenAI API: {} {}",
2219 response.status(),
2220 body,
2221 )),
2222 }
2223 }
2224}
2225
2226#[cfg(test)]
2227mod tests {
2228 use super::*;
2229 use crate::MessageId;
2230 use gpui::AppContext;
2231
2232 #[gpui::test]
2233 fn test_inserting_and_removing_messages(cx: &mut AppContext) {
2234 let registry = Arc::new(LanguageRegistry::test());
2235 let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
2236 let buffer = conversation.read(cx).buffer.clone();
2237
2238 let message_1 = conversation.read(cx).message_anchors[0].clone();
2239 assert_eq!(
2240 messages(&conversation, cx),
2241 vec![(message_1.id, Role::User, 0..0)]
2242 );
2243
2244 let message_2 = conversation.update(cx, |conversation, cx| {
2245 conversation
2246 .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
2247 .unwrap()
2248 });
2249 assert_eq!(
2250 messages(&conversation, cx),
2251 vec![
2252 (message_1.id, Role::User, 0..1),
2253 (message_2.id, Role::Assistant, 1..1)
2254 ]
2255 );
2256
2257 buffer.update(cx, |buffer, cx| {
2258 buffer.edit([(0..0, "1"), (1..1, "2")], None, cx)
2259 });
2260 assert_eq!(
2261 messages(&conversation, cx),
2262 vec![
2263 (message_1.id, Role::User, 0..2),
2264 (message_2.id, Role::Assistant, 2..3)
2265 ]
2266 );
2267
2268 let message_3 = conversation.update(cx, |conversation, cx| {
2269 conversation
2270 .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
2271 .unwrap()
2272 });
2273 assert_eq!(
2274 messages(&conversation, cx),
2275 vec![
2276 (message_1.id, Role::User, 0..2),
2277 (message_2.id, Role::Assistant, 2..4),
2278 (message_3.id, Role::User, 4..4)
2279 ]
2280 );
2281
2282 let message_4 = conversation.update(cx, |conversation, cx| {
2283 conversation
2284 .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
2285 .unwrap()
2286 });
2287 assert_eq!(
2288 messages(&conversation, cx),
2289 vec![
2290 (message_1.id, Role::User, 0..2),
2291 (message_2.id, Role::Assistant, 2..4),
2292 (message_4.id, Role::User, 4..5),
2293 (message_3.id, Role::User, 5..5),
2294 ]
2295 );
2296
2297 buffer.update(cx, |buffer, cx| {
2298 buffer.edit([(4..4, "C"), (5..5, "D")], None, cx)
2299 });
2300 assert_eq!(
2301 messages(&conversation, cx),
2302 vec![
2303 (message_1.id, Role::User, 0..2),
2304 (message_2.id, Role::Assistant, 2..4),
2305 (message_4.id, Role::User, 4..6),
2306 (message_3.id, Role::User, 6..7),
2307 ]
2308 );
2309
2310 // Deleting across message boundaries merges the messages.
2311 buffer.update(cx, |buffer, cx| buffer.edit([(1..4, "")], None, cx));
2312 assert_eq!(
2313 messages(&conversation, cx),
2314 vec![
2315 (message_1.id, Role::User, 0..3),
2316 (message_3.id, Role::User, 3..4),
2317 ]
2318 );
2319
2320 // Undoing the deletion should also undo the merge.
2321 buffer.update(cx, |buffer, cx| buffer.undo(cx));
2322 assert_eq!(
2323 messages(&conversation, cx),
2324 vec![
2325 (message_1.id, Role::User, 0..2),
2326 (message_2.id, Role::Assistant, 2..4),
2327 (message_4.id, Role::User, 4..6),
2328 (message_3.id, Role::User, 6..7),
2329 ]
2330 );
2331
2332 // Redoing the deletion should also redo the merge.
2333 buffer.update(cx, |buffer, cx| buffer.redo(cx));
2334 assert_eq!(
2335 messages(&conversation, cx),
2336 vec![
2337 (message_1.id, Role::User, 0..3),
2338 (message_3.id, Role::User, 3..4),
2339 ]
2340 );
2341
2342 // Ensure we can still insert after a merged message.
2343 let message_5 = conversation.update(cx, |conversation, cx| {
2344 conversation
2345 .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
2346 .unwrap()
2347 });
2348 assert_eq!(
2349 messages(&conversation, cx),
2350 vec![
2351 (message_1.id, Role::User, 0..3),
2352 (message_5.id, Role::System, 3..4),
2353 (message_3.id, Role::User, 4..5)
2354 ]
2355 );
2356 }
2357
2358 #[gpui::test]
2359 fn test_message_splitting(cx: &mut AppContext) {
2360 let registry = Arc::new(LanguageRegistry::test());
2361 let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
2362 let buffer = conversation.read(cx).buffer.clone();
2363
2364 let message_1 = conversation.read(cx).message_anchors[0].clone();
2365 assert_eq!(
2366 messages(&conversation, cx),
2367 vec![(message_1.id, Role::User, 0..0)]
2368 );
2369
2370 buffer.update(cx, |buffer, cx| {
2371 buffer.edit([(0..0, "aaa\nbbb\nccc\nddd\n")], None, cx)
2372 });
2373
2374 let (_, message_2) =
2375 conversation.update(cx, |conversation, cx| conversation.split_message(3..3, cx));
2376 let message_2 = message_2.unwrap();
2377
2378 // We recycle newlines in the middle of a split message
2379 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\nddd\n");
2380 assert_eq!(
2381 messages(&conversation, cx),
2382 vec![
2383 (message_1.id, Role::User, 0..4),
2384 (message_2.id, Role::User, 4..16),
2385 ]
2386 );
2387
2388 let (_, message_3) =
2389 conversation.update(cx, |conversation, cx| conversation.split_message(3..3, cx));
2390 let message_3 = message_3.unwrap();
2391
2392 // We don't recycle newlines at the end of a split message
2393 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
2394 assert_eq!(
2395 messages(&conversation, cx),
2396 vec![
2397 (message_1.id, Role::User, 0..4),
2398 (message_3.id, Role::User, 4..5),
2399 (message_2.id, Role::User, 5..17),
2400 ]
2401 );
2402
2403 let (_, message_4) =
2404 conversation.update(cx, |conversation, cx| conversation.split_message(9..9, cx));
2405 let message_4 = message_4.unwrap();
2406 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
2407 assert_eq!(
2408 messages(&conversation, cx),
2409 vec![
2410 (message_1.id, Role::User, 0..4),
2411 (message_3.id, Role::User, 4..5),
2412 (message_2.id, Role::User, 5..9),
2413 (message_4.id, Role::User, 9..17),
2414 ]
2415 );
2416
2417 let (_, message_5) =
2418 conversation.update(cx, |conversation, cx| conversation.split_message(9..9, cx));
2419 let message_5 = message_5.unwrap();
2420 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\nddd\n");
2421 assert_eq!(
2422 messages(&conversation, cx),
2423 vec![
2424 (message_1.id, Role::User, 0..4),
2425 (message_3.id, Role::User, 4..5),
2426 (message_2.id, Role::User, 5..9),
2427 (message_4.id, Role::User, 9..10),
2428 (message_5.id, Role::User, 10..18),
2429 ]
2430 );
2431
2432 let (message_6, message_7) = conversation.update(cx, |conversation, cx| {
2433 conversation.split_message(14..16, cx)
2434 });
2435 let message_6 = message_6.unwrap();
2436 let message_7 = message_7.unwrap();
2437 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\ndd\nd\n");
2438 assert_eq!(
2439 messages(&conversation, cx),
2440 vec![
2441 (message_1.id, Role::User, 0..4),
2442 (message_3.id, Role::User, 4..5),
2443 (message_2.id, Role::User, 5..9),
2444 (message_4.id, Role::User, 9..10),
2445 (message_5.id, Role::User, 10..14),
2446 (message_6.id, Role::User, 14..17),
2447 (message_7.id, Role::User, 17..19),
2448 ]
2449 );
2450 }
2451
2452 #[gpui::test]
2453 fn test_messages_for_offsets(cx: &mut AppContext) {
2454 let registry = Arc::new(LanguageRegistry::test());
2455 let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
2456 let buffer = conversation.read(cx).buffer.clone();
2457
2458 let message_1 = conversation.read(cx).message_anchors[0].clone();
2459 assert_eq!(
2460 messages(&conversation, cx),
2461 vec![(message_1.id, Role::User, 0..0)]
2462 );
2463
2464 buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
2465 let message_2 = conversation
2466 .update(cx, |conversation, cx| {
2467 conversation.insert_message_after(message_1.id, Role::User, MessageStatus::Done, cx)
2468 })
2469 .unwrap();
2470 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx));
2471
2472 let message_3 = conversation
2473 .update(cx, |conversation, cx| {
2474 conversation.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
2475 })
2476 .unwrap();
2477 buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx));
2478
2479 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc");
2480 assert_eq!(
2481 messages(&conversation, cx),
2482 vec![
2483 (message_1.id, Role::User, 0..4),
2484 (message_2.id, Role::User, 4..8),
2485 (message_3.id, Role::User, 8..11)
2486 ]
2487 );
2488
2489 assert_eq!(
2490 message_ids_for_offsets(&conversation, &[0, 4, 9], cx),
2491 [message_1.id, message_2.id, message_3.id]
2492 );
2493 assert_eq!(
2494 message_ids_for_offsets(&conversation, &[0, 1, 11], cx),
2495 [message_1.id, message_3.id]
2496 );
2497
2498 let message_4 = conversation
2499 .update(cx, |conversation, cx| {
2500 conversation.insert_message_after(message_3.id, Role::User, MessageStatus::Done, cx)
2501 })
2502 .unwrap();
2503 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\n");
2504 assert_eq!(
2505 messages(&conversation, cx),
2506 vec![
2507 (message_1.id, Role::User, 0..4),
2508 (message_2.id, Role::User, 4..8),
2509 (message_3.id, Role::User, 8..12),
2510 (message_4.id, Role::User, 12..12)
2511 ]
2512 );
2513 assert_eq!(
2514 message_ids_for_offsets(&conversation, &[0, 4, 8, 12], cx),
2515 [message_1.id, message_2.id, message_3.id, message_4.id]
2516 );
2517
2518 fn message_ids_for_offsets(
2519 conversation: &ModelHandle<Conversation>,
2520 offsets: &[usize],
2521 cx: &AppContext,
2522 ) -> Vec<MessageId> {
2523 conversation
2524 .read(cx)
2525 .messages_for_offsets(offsets.iter().copied(), cx)
2526 .into_iter()
2527 .map(|message| message.id)
2528 .collect()
2529 }
2530 }
2531
2532 #[gpui::test]
2533 fn test_serialization(cx: &mut AppContext) {
2534 let registry = Arc::new(LanguageRegistry::test());
2535 let conversation =
2536 cx.add_model(|cx| Conversation::new(Default::default(), registry.clone(), cx));
2537 let buffer = conversation.read(cx).buffer.clone();
2538 let message_0 = conversation.read(cx).message_anchors[0].id;
2539 let message_1 = conversation.update(cx, |conversation, cx| {
2540 conversation
2541 .insert_message_after(message_0, Role::Assistant, MessageStatus::Done, cx)
2542 .unwrap()
2543 });
2544 let message_2 = conversation.update(cx, |conversation, cx| {
2545 conversation
2546 .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
2547 .unwrap()
2548 });
2549 buffer.update(cx, |buffer, cx| {
2550 buffer.edit([(0..0, "a"), (1..1, "b\nc")], None, cx);
2551 buffer.finalize_last_transaction();
2552 });
2553 let _message_3 = conversation.update(cx, |conversation, cx| {
2554 conversation
2555 .insert_message_after(message_2.id, Role::System, MessageStatus::Done, cx)
2556 .unwrap()
2557 });
2558 buffer.update(cx, |buffer, cx| buffer.undo(cx));
2559 assert_eq!(buffer.read(cx).text(), "a\nb\nc\n");
2560 assert_eq!(
2561 messages(&conversation, cx),
2562 [
2563 (message_0, Role::User, 0..2),
2564 (message_1.id, Role::Assistant, 2..6),
2565 (message_2.id, Role::System, 6..6),
2566 ]
2567 );
2568
2569 let deserialized_conversation = cx.add_model(|cx| {
2570 Conversation::deserialize(
2571 conversation.read(cx).serialize(cx),
2572 Default::default(),
2573 Default::default(),
2574 registry.clone(),
2575 cx,
2576 )
2577 });
2578 let deserialized_buffer = deserialized_conversation.read(cx).buffer.clone();
2579 assert_eq!(deserialized_buffer.read(cx).text(), "a\nb\nc\n");
2580 assert_eq!(
2581 messages(&deserialized_conversation, cx),
2582 [
2583 (message_0, Role::User, 0..2),
2584 (message_1.id, Role::Assistant, 2..6),
2585 (message_2.id, Role::System, 6..6),
2586 ]
2587 );
2588 }
2589
2590 fn messages(
2591 conversation: &ModelHandle<Conversation>,
2592 cx: &AppContext,
2593 ) -> Vec<(MessageId, Role, Range<usize>)> {
2594 conversation
2595 .read(cx)
2596 .messages(cx)
2597 .map(|message| (message.id, message.role, message.offset_range))
2598 .collect()
2599 }
2600}