1use crate::{
2 assistant_settings::{AssistantDockPosition, AssistantSettings},
3 MessageId, MessageMetadata, MessageStatus, OpenAIRequest, OpenAIResponseStreamEvent,
4 RequestMessage, Role, SavedConversation, SavedConversationMetadata, SavedMessage,
5};
6use anyhow::{anyhow, Result};
7use chrono::{DateTime, Local};
8use collections::{HashMap, HashSet};
9use editor::{
10 display_map::{BlockDisposition, BlockId, BlockProperties, BlockStyle, ToDisplayPoint},
11 scroll::autoscroll::{Autoscroll, AutoscrollStrategy},
12 Anchor, Editor, ToOffset,
13};
14use fs::Fs;
15use futures::{io::BufReader, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
16use gpui::{
17 actions,
18 elements::*,
19 executor::Background,
20 geometry::vector::{vec2f, Vector2F},
21 platform::{CursorStyle, MouseButton},
22 Action, AppContext, AsyncAppContext, ClipboardItem, Entity, ModelContext, ModelHandle,
23 Subscription, Task, View, ViewContext, ViewHandle, WeakViewHandle, WindowContext,
24};
25use isahc::{http::StatusCode, Request, RequestExt};
26use language::{language_settings::SoftWrap, Buffer, LanguageRegistry, ToOffset as _};
27use search::BufferSearchBar;
28use serde::Deserialize;
29use settings::SettingsStore;
30use std::{
31 cell::RefCell,
32 cmp, env,
33 fmt::Write,
34 io, iter,
35 ops::Range,
36 path::{Path, PathBuf},
37 rc::Rc,
38 sync::Arc,
39 time::Duration,
40};
41use theme::AssistantStyle;
42use util::{paths::CONVERSATIONS_DIR, post_inc, ResultExt, TryFutureExt};
43use workspace::{
44 dock::{DockPosition, Panel},
45 searchable::Direction,
46 Save, ToggleZoom, Toolbar, Workspace,
47};
48
49const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
50
51actions!(
52 assistant,
53 [
54 NewConversation,
55 Assist,
56 Split,
57 CycleMessageRole,
58 QuoteSelection,
59 ToggleFocus,
60 ResetKey,
61 ]
62);
63
64pub fn init(cx: &mut AppContext) {
65 settings::register::<AssistantSettings>(cx);
66 cx.add_action(
67 |this: &mut AssistantPanel,
68 _: &workspace::NewFile,
69 cx: &mut ViewContext<AssistantPanel>| {
70 this.new_conversation(cx);
71 },
72 );
73 cx.add_action(ConversationEditor::assist);
74 cx.capture_action(ConversationEditor::cancel_last_assist);
75 cx.capture_action(ConversationEditor::save);
76 cx.add_action(ConversationEditor::quote_selection);
77 cx.capture_action(ConversationEditor::copy);
78 cx.add_action(ConversationEditor::split);
79 cx.capture_action(ConversationEditor::cycle_message_role);
80 cx.add_action(AssistantPanel::save_api_key);
81 cx.add_action(AssistantPanel::reset_api_key);
82 cx.add_action(AssistantPanel::toggle_zoom);
83 cx.add_action(AssistantPanel::deploy);
84 cx.add_action(AssistantPanel::select_next_match);
85 cx.add_action(AssistantPanel::select_prev_match);
86 cx.add_action(AssistantPanel::handle_editor_cancel);
87 cx.add_action(
88 |workspace: &mut Workspace, _: &ToggleFocus, cx: &mut ViewContext<Workspace>| {
89 workspace.toggle_panel_focus::<AssistantPanel>(cx);
90 },
91 );
92}
93
94#[derive(Debug)]
95pub enum AssistantPanelEvent {
96 ZoomIn,
97 ZoomOut,
98 Focus,
99 Close,
100 DockPositionChanged,
101}
102
103pub struct AssistantPanel {
104 workspace: WeakViewHandle<Workspace>,
105 width: Option<f32>,
106 height: Option<f32>,
107 active_editor_index: Option<usize>,
108 prev_active_editor_index: Option<usize>,
109 editors: Vec<ViewHandle<ConversationEditor>>,
110 saved_conversations: Vec<SavedConversationMetadata>,
111 saved_conversations_list_state: UniformListState,
112 zoomed: bool,
113 has_focus: bool,
114 toolbar: ViewHandle<Toolbar>,
115 api_key: Rc<RefCell<Option<String>>>,
116 api_key_editor: Option<ViewHandle<Editor>>,
117 has_read_credentials: bool,
118 languages: Arc<LanguageRegistry>,
119 fs: Arc<dyn Fs>,
120 subscriptions: Vec<Subscription>,
121 _watch_saved_conversations: Task<Result<()>>,
122}
123
124impl AssistantPanel {
125 pub fn load(
126 workspace: WeakViewHandle<Workspace>,
127 cx: AsyncAppContext,
128 ) -> Task<Result<ViewHandle<Self>>> {
129 cx.spawn(|mut cx| async move {
130 let fs = workspace.read_with(&cx, |workspace, _| workspace.app_state().fs.clone())?;
131 let saved_conversations = SavedConversationMetadata::list(fs.clone())
132 .await
133 .log_err()
134 .unwrap_or_default();
135
136 // TODO: deserialize state.
137 let workspace_handle = workspace.clone();
138 workspace.update(&mut cx, |workspace, cx| {
139 cx.add_view::<Self, _>(|cx| {
140 const CONVERSATION_WATCH_DURATION: Duration = Duration::from_millis(100);
141 let _watch_saved_conversations = cx.spawn(move |this, mut cx| async move {
142 let mut events = fs
143 .watch(&CONVERSATIONS_DIR, CONVERSATION_WATCH_DURATION)
144 .await;
145 while events.next().await.is_some() {
146 let saved_conversations = SavedConversationMetadata::list(fs.clone())
147 .await
148 .log_err()
149 .unwrap_or_default();
150 this.update(&mut cx, |this, cx| {
151 this.saved_conversations = saved_conversations;
152 cx.notify();
153 })
154 .ok();
155 }
156
157 anyhow::Ok(())
158 });
159
160 let toolbar = cx.add_view(|cx| {
161 let mut toolbar = Toolbar::new(None);
162 toolbar.set_can_navigate(false, cx);
163 toolbar.add_item(cx.add_view(|cx| BufferSearchBar::new(cx)), cx);
164 toolbar
165 });
166 let mut this = Self {
167 workspace: workspace_handle,
168 active_editor_index: Default::default(),
169 prev_active_editor_index: Default::default(),
170 editors: Default::default(),
171 saved_conversations,
172 saved_conversations_list_state: Default::default(),
173 zoomed: false,
174 has_focus: false,
175 toolbar,
176 api_key: Rc::new(RefCell::new(None)),
177 api_key_editor: None,
178 has_read_credentials: false,
179 languages: workspace.app_state().languages.clone(),
180 fs: workspace.app_state().fs.clone(),
181 width: None,
182 height: None,
183 subscriptions: Default::default(),
184 _watch_saved_conversations,
185 };
186
187 let mut old_dock_position = this.position(cx);
188 this.subscriptions =
189 vec![cx.observe_global::<SettingsStore, _>(move |this, cx| {
190 let new_dock_position = this.position(cx);
191 if new_dock_position != old_dock_position {
192 old_dock_position = new_dock_position;
193 cx.emit(AssistantPanelEvent::DockPositionChanged);
194 }
195 })];
196
197 this
198 })
199 })
200 })
201 }
202
203 fn new_conversation(&mut self, cx: &mut ViewContext<Self>) -> ViewHandle<ConversationEditor> {
204 let editor = cx.add_view(|cx| {
205 ConversationEditor::new(
206 self.api_key.clone(),
207 self.languages.clone(),
208 self.fs.clone(),
209 cx,
210 )
211 });
212 self.add_conversation(editor.clone(), cx);
213 editor
214 }
215
216 fn add_conversation(
217 &mut self,
218 editor: ViewHandle<ConversationEditor>,
219 cx: &mut ViewContext<Self>,
220 ) {
221 self.subscriptions
222 .push(cx.subscribe(&editor, Self::handle_conversation_editor_event));
223
224 let conversation = editor.read(cx).conversation.clone();
225 self.subscriptions
226 .push(cx.observe(&conversation, |_, _, cx| cx.notify()));
227
228 let index = self.editors.len();
229 self.editors.push(editor);
230 self.set_active_editor_index(Some(index), cx);
231 }
232
233 fn set_active_editor_index(&mut self, index: Option<usize>, cx: &mut ViewContext<Self>) {
234 self.prev_active_editor_index = self.active_editor_index;
235 self.active_editor_index = index;
236 if let Some(editor) = self.active_editor() {
237 let editor = editor.read(cx).editor.clone();
238 self.toolbar.update(cx, |toolbar, cx| {
239 toolbar.set_active_item(Some(&editor), cx);
240 });
241 if self.has_focus(cx) {
242 cx.focus(&editor);
243 }
244 } else {
245 self.toolbar.update(cx, |toolbar, cx| {
246 toolbar.set_active_item(None, cx);
247 });
248 }
249
250 cx.notify();
251 }
252
253 fn handle_conversation_editor_event(
254 &mut self,
255 _: ViewHandle<ConversationEditor>,
256 event: &ConversationEditorEvent,
257 cx: &mut ViewContext<Self>,
258 ) {
259 match event {
260 ConversationEditorEvent::TabContentChanged => cx.notify(),
261 }
262 }
263
264 fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
265 if let Some(api_key) = self
266 .api_key_editor
267 .as_ref()
268 .map(|editor| editor.read(cx).text(cx))
269 {
270 if !api_key.is_empty() {
271 cx.platform()
272 .write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
273 .log_err();
274 *self.api_key.borrow_mut() = Some(api_key);
275 self.api_key_editor.take();
276 cx.focus_self();
277 cx.notify();
278 }
279 } else {
280 cx.propagate_action();
281 }
282 }
283
284 fn reset_api_key(&mut self, _: &ResetKey, cx: &mut ViewContext<Self>) {
285 cx.platform().delete_credentials(OPENAI_API_URL).log_err();
286 self.api_key.take();
287 self.api_key_editor = Some(build_api_key_editor(cx));
288 cx.focus_self();
289 cx.notify();
290 }
291
292 fn toggle_zoom(&mut self, _: &workspace::ToggleZoom, cx: &mut ViewContext<Self>) {
293 if self.zoomed {
294 cx.emit(AssistantPanelEvent::ZoomOut)
295 } else {
296 cx.emit(AssistantPanelEvent::ZoomIn)
297 }
298 }
299
300 fn deploy(&mut self, action: &search::buffer_search::Deploy, cx: &mut ViewContext<Self>) {
301 if let Some(search_bar) = self.toolbar.read(cx).item_of_type::<BufferSearchBar>() {
302 if search_bar.update(cx, |search_bar, cx| search_bar.show(action.focus, true, cx)) {
303 return;
304 }
305 }
306 cx.propagate_action();
307 }
308
309 fn handle_editor_cancel(&mut self, _: &editor::Cancel, cx: &mut ViewContext<Self>) {
310 if let Some(search_bar) = self.toolbar.read(cx).item_of_type::<BufferSearchBar>() {
311 if !search_bar.read(cx).is_dismissed() {
312 search_bar.update(cx, |search_bar, cx| {
313 search_bar.dismiss(&Default::default(), cx)
314 });
315 return;
316 }
317 }
318 cx.propagate_action();
319 }
320
321 fn select_next_match(&mut self, _: &search::SelectNextMatch, cx: &mut ViewContext<Self>) {
322 if let Some(search_bar) = self.toolbar.read(cx).item_of_type::<BufferSearchBar>() {
323 search_bar.update(cx, |bar, cx| bar.select_match(Direction::Next, cx));
324 }
325 }
326
327 fn select_prev_match(&mut self, _: &search::SelectPrevMatch, cx: &mut ViewContext<Self>) {
328 if let Some(search_bar) = self.toolbar.read(cx).item_of_type::<BufferSearchBar>() {
329 search_bar.update(cx, |bar, cx| bar.select_match(Direction::Prev, cx));
330 }
331 }
332
333 fn active_editor(&self) -> Option<&ViewHandle<ConversationEditor>> {
334 self.editors.get(self.active_editor_index?)
335 }
336
337 fn render_hamburger_button(cx: &mut ViewContext<Self>) -> impl Element<Self> {
338 enum History {}
339 let theme = theme::current(cx);
340 let tooltip_style = theme::current(cx).tooltip.clone();
341 MouseEventHandler::<History, _>::new(0, cx, |state, _| {
342 let style = theme.assistant.hamburger_button.style_for(state);
343 Svg::for_style(style.icon.clone())
344 .contained()
345 .with_style(style.container)
346 })
347 .with_cursor_style(CursorStyle::PointingHand)
348 .on_click(MouseButton::Left, |_, this: &mut Self, cx| {
349 if this.active_editor().is_some() {
350 this.set_active_editor_index(None, cx);
351 } else {
352 this.set_active_editor_index(this.prev_active_editor_index, cx);
353 }
354 })
355 .with_tooltip::<History>(1, "History".into(), None, tooltip_style, cx)
356 }
357
358 fn render_editor_tools(&self, cx: &mut ViewContext<Self>) -> Vec<AnyElement<Self>> {
359 if self.active_editor().is_some() {
360 vec![
361 Self::render_split_button(cx).into_any(),
362 Self::render_quote_button(cx).into_any(),
363 Self::render_assist_button(cx).into_any(),
364 ]
365 } else {
366 Default::default()
367 }
368 }
369
370 fn render_split_button(cx: &mut ViewContext<Self>) -> impl Element<Self> {
371 let theme = theme::current(cx);
372 let tooltip_style = theme::current(cx).tooltip.clone();
373 MouseEventHandler::<Split, _>::new(0, cx, |state, _| {
374 let style = theme.assistant.split_button.style_for(state);
375 Svg::for_style(style.icon.clone())
376 .contained()
377 .with_style(style.container)
378 })
379 .with_cursor_style(CursorStyle::PointingHand)
380 .on_click(MouseButton::Left, |_, this: &mut Self, cx| {
381 if let Some(active_editor) = this.active_editor() {
382 active_editor.update(cx, |editor, cx| editor.split(&Default::default(), cx));
383 }
384 })
385 .with_tooltip::<Split>(
386 1,
387 "Split Message".into(),
388 Some(Box::new(Split)),
389 tooltip_style,
390 cx,
391 )
392 }
393
394 fn render_assist_button(cx: &mut ViewContext<Self>) -> impl Element<Self> {
395 let theme = theme::current(cx);
396 let tooltip_style = theme::current(cx).tooltip.clone();
397 MouseEventHandler::<Assist, _>::new(0, cx, |state, _| {
398 let style = theme.assistant.assist_button.style_for(state);
399 Svg::for_style(style.icon.clone())
400 .contained()
401 .with_style(style.container)
402 })
403 .with_cursor_style(CursorStyle::PointingHand)
404 .on_click(MouseButton::Left, |_, this: &mut Self, cx| {
405 if let Some(active_editor) = this.active_editor() {
406 active_editor.update(cx, |editor, cx| editor.assist(&Default::default(), cx));
407 }
408 })
409 .with_tooltip::<Assist>(
410 1,
411 "Assist".into(),
412 Some(Box::new(Assist)),
413 tooltip_style,
414 cx,
415 )
416 }
417
418 fn render_quote_button(cx: &mut ViewContext<Self>) -> impl Element<Self> {
419 let theme = theme::current(cx);
420 let tooltip_style = theme::current(cx).tooltip.clone();
421 MouseEventHandler::<QuoteSelection, _>::new(0, cx, |state, _| {
422 let style = theme.assistant.quote_button.style_for(state);
423 Svg::for_style(style.icon.clone())
424 .contained()
425 .with_style(style.container)
426 })
427 .with_cursor_style(CursorStyle::PointingHand)
428 .on_click(MouseButton::Left, |_, this: &mut Self, cx| {
429 if let Some(workspace) = this.workspace.upgrade(cx) {
430 cx.window_context().defer(move |cx| {
431 workspace.update(cx, |workspace, cx| {
432 ConversationEditor::quote_selection(workspace, &Default::default(), cx)
433 });
434 });
435 }
436 })
437 .with_tooltip::<QuoteSelection>(
438 1,
439 "Quote Selection".into(),
440 Some(Box::new(QuoteSelection)),
441 tooltip_style,
442 cx,
443 )
444 }
445
446 fn render_plus_button(cx: &mut ViewContext<Self>) -> impl Element<Self> {
447 let theme = theme::current(cx);
448 let tooltip_style = theme::current(cx).tooltip.clone();
449 MouseEventHandler::<NewConversation, _>::new(0, cx, |state, _| {
450 let style = theme.assistant.plus_button.style_for(state);
451 Svg::for_style(style.icon.clone())
452 .contained()
453 .with_style(style.container)
454 })
455 .with_cursor_style(CursorStyle::PointingHand)
456 .on_click(MouseButton::Left, |_, this: &mut Self, cx| {
457 this.new_conversation(cx);
458 })
459 .with_tooltip::<NewConversation>(
460 1,
461 "New Conversation".into(),
462 Some(Box::new(NewConversation)),
463 tooltip_style,
464 cx,
465 )
466 }
467
468 fn render_zoom_button(&self, cx: &mut ViewContext<Self>) -> impl Element<Self> {
469 enum ToggleZoomButton {}
470
471 let theme = theme::current(cx);
472 let tooltip_style = theme::current(cx).tooltip.clone();
473 let style = if self.zoomed {
474 &theme.assistant.zoom_out_button
475 } else {
476 &theme.assistant.zoom_in_button
477 };
478
479 MouseEventHandler::<ToggleZoomButton, _>::new(0, cx, |state, _| {
480 let style = style.style_for(state);
481 Svg::for_style(style.icon.clone())
482 .contained()
483 .with_style(style.container)
484 })
485 .with_cursor_style(CursorStyle::PointingHand)
486 .on_click(MouseButton::Left, |_, this, cx| {
487 this.toggle_zoom(&ToggleZoom, cx);
488 })
489 .with_tooltip::<ToggleZoom>(
490 0,
491 if self.zoomed {
492 "Zoom Out".into()
493 } else {
494 "Zoom In".into()
495 },
496 Some(Box::new(ToggleZoom)),
497 tooltip_style,
498 cx,
499 )
500 }
501
502 fn render_saved_conversation(
503 &mut self,
504 index: usize,
505 cx: &mut ViewContext<Self>,
506 ) -> impl Element<Self> {
507 let conversation = &self.saved_conversations[index];
508 let path = conversation.path.clone();
509 MouseEventHandler::<SavedConversationMetadata, _>::new(index, cx, move |state, cx| {
510 let style = &theme::current(cx).assistant.saved_conversation;
511 Flex::row()
512 .with_child(
513 Label::new(
514 conversation.mtime.format("%F %I:%M%p").to_string(),
515 style.saved_at.text.clone(),
516 )
517 .aligned()
518 .contained()
519 .with_style(style.saved_at.container),
520 )
521 .with_child(
522 Label::new(conversation.title.clone(), style.title.text.clone())
523 .aligned()
524 .contained()
525 .with_style(style.title.container),
526 )
527 .contained()
528 .with_style(*style.container.style_for(state))
529 })
530 .with_cursor_style(CursorStyle::PointingHand)
531 .on_click(MouseButton::Left, move |_, this, cx| {
532 this.open_conversation(path.clone(), cx)
533 .detach_and_log_err(cx)
534 })
535 }
536
537 fn open_conversation(&mut self, path: PathBuf, cx: &mut ViewContext<Self>) -> Task<Result<()>> {
538 if let Some(ix) = self.editor_index_for_path(&path, cx) {
539 self.set_active_editor_index(Some(ix), cx);
540 return Task::ready(Ok(()));
541 }
542
543 let fs = self.fs.clone();
544 let api_key = self.api_key.clone();
545 let languages = self.languages.clone();
546 cx.spawn(|this, mut cx| async move {
547 let saved_conversation = fs.load(&path).await?;
548 let saved_conversation = serde_json::from_str(&saved_conversation)?;
549 let conversation = cx.add_model(|cx| {
550 Conversation::deserialize(saved_conversation, path.clone(), api_key, languages, cx)
551 });
552 this.update(&mut cx, |this, cx| {
553 // If, by the time we've loaded the conversation, the user has already opened
554 // the same conversation, we don't want to open it again.
555 if let Some(ix) = this.editor_index_for_path(&path, cx) {
556 this.set_active_editor_index(Some(ix), cx);
557 } else {
558 let editor = cx
559 .add_view(|cx| ConversationEditor::for_conversation(conversation, fs, cx));
560 this.add_conversation(editor, cx);
561 }
562 })?;
563 Ok(())
564 })
565 }
566
567 fn editor_index_for_path(&self, path: &Path, cx: &AppContext) -> Option<usize> {
568 self.editors
569 .iter()
570 .position(|editor| editor.read(cx).conversation.read(cx).path.as_deref() == Some(path))
571 }
572}
573
574fn build_api_key_editor(cx: &mut ViewContext<AssistantPanel>) -> ViewHandle<Editor> {
575 cx.add_view(|cx| {
576 let mut editor = Editor::single_line(
577 Some(Arc::new(|theme| theme.assistant.api_key_editor.clone())),
578 cx,
579 );
580 editor.set_placeholder_text("sk-000000000000000000000000000000000000000000000000", cx);
581 editor
582 })
583}
584
585impl Entity for AssistantPanel {
586 type Event = AssistantPanelEvent;
587}
588
589impl View for AssistantPanel {
590 fn ui_name() -> &'static str {
591 "AssistantPanel"
592 }
593
594 fn render(&mut self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
595 let theme = &theme::current(cx);
596 let style = &theme.assistant;
597 if let Some(api_key_editor) = self.api_key_editor.as_ref() {
598 Flex::column()
599 .with_child(
600 Text::new(
601 "Paste your OpenAI API key and press Enter to use the assistant",
602 style.api_key_prompt.text.clone(),
603 )
604 .aligned(),
605 )
606 .with_child(
607 ChildView::new(api_key_editor, cx)
608 .contained()
609 .with_style(style.api_key_editor.container)
610 .aligned(),
611 )
612 .contained()
613 .with_style(style.api_key_prompt.container)
614 .aligned()
615 .into_any()
616 } else {
617 let title = self.active_editor().map(|editor| {
618 Label::new(editor.read(cx).title(cx), style.title.text.clone())
619 .contained()
620 .with_style(style.title.container)
621 .aligned()
622 .left()
623 .flex(1., false)
624 });
625 let mut header = Flex::row()
626 .with_child(Self::render_hamburger_button(cx).aligned())
627 .with_children(title);
628 if self.has_focus {
629 header.add_children(
630 self.render_editor_tools(cx)
631 .into_iter()
632 .map(|tool| tool.aligned().flex_float()),
633 );
634 header.add_child(Self::render_plus_button(cx).aligned().flex_float());
635 header.add_child(self.render_zoom_button(cx).aligned());
636 }
637
638 Flex::column()
639 .with_child(
640 header
641 .contained()
642 .with_style(theme.workspace.tab_bar.container)
643 .expanded()
644 .constrained()
645 .with_height(theme.workspace.tab_bar.height),
646 )
647 .with_children(if self.toolbar.read(cx).hidden() {
648 None
649 } else {
650 Some(ChildView::new(&self.toolbar, cx).expanded())
651 })
652 .with_child(if let Some(editor) = self.active_editor() {
653 ChildView::new(editor, cx).flex(1., true).into_any()
654 } else {
655 UniformList::new(
656 self.saved_conversations_list_state.clone(),
657 self.saved_conversations.len(),
658 cx,
659 |this, range, items, cx| {
660 for ix in range {
661 items.push(this.render_saved_conversation(ix, cx).into_any());
662 }
663 },
664 )
665 .flex(1., true)
666 .into_any()
667 })
668 .into_any()
669 }
670 }
671
672 fn focus_in(&mut self, _: gpui::AnyViewHandle, cx: &mut ViewContext<Self>) {
673 self.has_focus = true;
674 self.toolbar
675 .update(cx, |toolbar, cx| toolbar.focus_changed(true, cx));
676 cx.notify();
677 if cx.is_self_focused() {
678 if let Some(editor) = self.active_editor() {
679 cx.focus(editor);
680 } else if let Some(api_key_editor) = self.api_key_editor.as_ref() {
681 cx.focus(api_key_editor);
682 }
683 }
684 }
685
686 fn focus_out(&mut self, _: gpui::AnyViewHandle, cx: &mut ViewContext<Self>) {
687 self.has_focus = false;
688 self.toolbar
689 .update(cx, |toolbar, cx| toolbar.focus_changed(false, cx));
690 cx.notify();
691 }
692}
693
694impl Panel for AssistantPanel {
695 fn position(&self, cx: &WindowContext) -> DockPosition {
696 match settings::get::<AssistantSettings>(cx).dock {
697 AssistantDockPosition::Left => DockPosition::Left,
698 AssistantDockPosition::Bottom => DockPosition::Bottom,
699 AssistantDockPosition::Right => DockPosition::Right,
700 }
701 }
702
703 fn position_is_valid(&self, _: DockPosition) -> bool {
704 true
705 }
706
707 fn set_position(&mut self, position: DockPosition, cx: &mut ViewContext<Self>) {
708 settings::update_settings_file::<AssistantSettings>(self.fs.clone(), cx, move |settings| {
709 let dock = match position {
710 DockPosition::Left => AssistantDockPosition::Left,
711 DockPosition::Bottom => AssistantDockPosition::Bottom,
712 DockPosition::Right => AssistantDockPosition::Right,
713 };
714 settings.dock = Some(dock);
715 });
716 }
717
718 fn size(&self, cx: &WindowContext) -> f32 {
719 let settings = settings::get::<AssistantSettings>(cx);
720 match self.position(cx) {
721 DockPosition::Left | DockPosition::Right => {
722 self.width.unwrap_or_else(|| settings.default_width)
723 }
724 DockPosition::Bottom => self.height.unwrap_or_else(|| settings.default_height),
725 }
726 }
727
728 fn set_size(&mut self, size: f32, cx: &mut ViewContext<Self>) {
729 match self.position(cx) {
730 DockPosition::Left | DockPosition::Right => self.width = Some(size),
731 DockPosition::Bottom => self.height = Some(size),
732 }
733 cx.notify();
734 }
735
736 fn should_zoom_in_on_event(event: &AssistantPanelEvent) -> bool {
737 matches!(event, AssistantPanelEvent::ZoomIn)
738 }
739
740 fn should_zoom_out_on_event(event: &AssistantPanelEvent) -> bool {
741 matches!(event, AssistantPanelEvent::ZoomOut)
742 }
743
744 fn is_zoomed(&self, _: &WindowContext) -> bool {
745 self.zoomed
746 }
747
748 fn set_zoomed(&mut self, zoomed: bool, cx: &mut ViewContext<Self>) {
749 self.zoomed = zoomed;
750 cx.notify();
751 }
752
753 fn set_active(&mut self, active: bool, cx: &mut ViewContext<Self>) {
754 if active {
755 if self.api_key.borrow().is_none() && !self.has_read_credentials {
756 self.has_read_credentials = true;
757 let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
758 Some(api_key)
759 } else if let Some((_, api_key)) = cx
760 .platform()
761 .read_credentials(OPENAI_API_URL)
762 .log_err()
763 .flatten()
764 {
765 String::from_utf8(api_key).log_err()
766 } else {
767 None
768 };
769 if let Some(api_key) = api_key {
770 *self.api_key.borrow_mut() = Some(api_key);
771 } else if self.api_key_editor.is_none() {
772 self.api_key_editor = Some(build_api_key_editor(cx));
773 cx.notify();
774 }
775 }
776
777 if self.editors.is_empty() {
778 self.new_conversation(cx);
779 }
780 }
781 }
782
783 fn icon_path(&self) -> &'static str {
784 "icons/robot_14.svg"
785 }
786
787 fn icon_tooltip(&self) -> (String, Option<Box<dyn Action>>) {
788 ("Assistant Panel".into(), Some(Box::new(ToggleFocus)))
789 }
790
791 fn should_change_position_on_event(event: &Self::Event) -> bool {
792 matches!(event, AssistantPanelEvent::DockPositionChanged)
793 }
794
795 fn should_activate_on_event(_: &Self::Event) -> bool {
796 false
797 }
798
799 fn should_close_on_event(event: &AssistantPanelEvent) -> bool {
800 matches!(event, AssistantPanelEvent::Close)
801 }
802
803 fn has_focus(&self, _: &WindowContext) -> bool {
804 self.has_focus
805 }
806
807 fn is_focus_event(event: &Self::Event) -> bool {
808 matches!(event, AssistantPanelEvent::Focus)
809 }
810}
811
812enum ConversationEvent {
813 MessagesEdited,
814 SummaryChanged,
815 StreamedCompletion,
816}
817
818#[derive(Default)]
819struct Summary {
820 text: String,
821 done: bool,
822}
823
824struct Conversation {
825 buffer: ModelHandle<Buffer>,
826 message_anchors: Vec<MessageAnchor>,
827 messages_metadata: HashMap<MessageId, MessageMetadata>,
828 next_message_id: MessageId,
829 summary: Option<Summary>,
830 pending_summary: Task<Option<()>>,
831 completion_count: usize,
832 pending_completions: Vec<PendingCompletion>,
833 model: String,
834 token_count: Option<usize>,
835 max_token_count: usize,
836 pending_token_count: Task<Option<()>>,
837 api_key: Rc<RefCell<Option<String>>>,
838 pending_save: Task<Result<()>>,
839 path: Option<PathBuf>,
840 _subscriptions: Vec<Subscription>,
841}
842
843impl Entity for Conversation {
844 type Event = ConversationEvent;
845}
846
847impl Conversation {
848 fn new(
849 api_key: Rc<RefCell<Option<String>>>,
850 language_registry: Arc<LanguageRegistry>,
851 cx: &mut ModelContext<Self>,
852 ) -> Self {
853 let model = "gpt-3.5-turbo-0613";
854 let markdown = language_registry.language_for_name("Markdown");
855 let buffer = cx.add_model(|cx| {
856 let mut buffer = Buffer::new(0, "", cx);
857 buffer.set_language_registry(language_registry);
858 cx.spawn_weak(|buffer, mut cx| async move {
859 let markdown = markdown.await?;
860 let buffer = buffer
861 .upgrade(&cx)
862 .ok_or_else(|| anyhow!("buffer was dropped"))?;
863 buffer.update(&mut cx, |buffer, cx| {
864 buffer.set_language(Some(markdown), cx)
865 });
866 anyhow::Ok(())
867 })
868 .detach_and_log_err(cx);
869 buffer
870 });
871
872 let mut this = Self {
873 message_anchors: Default::default(),
874 messages_metadata: Default::default(),
875 next_message_id: Default::default(),
876 summary: None,
877 pending_summary: Task::ready(None),
878 completion_count: Default::default(),
879 pending_completions: Default::default(),
880 token_count: None,
881 max_token_count: tiktoken_rs::model::get_context_size(model),
882 pending_token_count: Task::ready(None),
883 model: model.into(),
884 _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
885 pending_save: Task::ready(Ok(())),
886 path: None,
887 api_key,
888 buffer,
889 };
890 let message = MessageAnchor {
891 id: MessageId(post_inc(&mut this.next_message_id.0)),
892 start: language::Anchor::MIN,
893 };
894 this.message_anchors.push(message.clone());
895 this.messages_metadata.insert(
896 message.id,
897 MessageMetadata {
898 role: Role::User,
899 sent_at: Local::now(),
900 status: MessageStatus::Done,
901 },
902 );
903
904 this.count_remaining_tokens(cx);
905 this
906 }
907
908 fn serialize(&self, cx: &AppContext) -> SavedConversation {
909 SavedConversation {
910 zed: "conversation".into(),
911 version: SavedConversation::VERSION.into(),
912 text: self.buffer.read(cx).text(),
913 message_metadata: self.messages_metadata.clone(),
914 messages: self
915 .messages(cx)
916 .map(|message| SavedMessage {
917 id: message.id,
918 start: message.offset_range.start,
919 })
920 .collect(),
921 summary: self
922 .summary
923 .as_ref()
924 .map(|summary| summary.text.clone())
925 .unwrap_or_default(),
926 model: self.model.clone(),
927 }
928 }
929
930 fn deserialize(
931 saved_conversation: SavedConversation,
932 path: PathBuf,
933 api_key: Rc<RefCell<Option<String>>>,
934 language_registry: Arc<LanguageRegistry>,
935 cx: &mut ModelContext<Self>,
936 ) -> Self {
937 let model = saved_conversation.model;
938 let markdown = language_registry.language_for_name("Markdown");
939 let mut message_anchors = Vec::new();
940 let mut next_message_id = MessageId(0);
941 let buffer = cx.add_model(|cx| {
942 let mut buffer = Buffer::new(0, saved_conversation.text, cx);
943 for message in saved_conversation.messages {
944 message_anchors.push(MessageAnchor {
945 id: message.id,
946 start: buffer.anchor_before(message.start),
947 });
948 next_message_id = cmp::max(next_message_id, MessageId(message.id.0 + 1));
949 }
950 buffer.set_language_registry(language_registry);
951 cx.spawn_weak(|buffer, mut cx| async move {
952 let markdown = markdown.await?;
953 let buffer = buffer
954 .upgrade(&cx)
955 .ok_or_else(|| anyhow!("buffer was dropped"))?;
956 buffer.update(&mut cx, |buffer, cx| {
957 buffer.set_language(Some(markdown), cx)
958 });
959 anyhow::Ok(())
960 })
961 .detach_and_log_err(cx);
962 buffer
963 });
964
965 let mut this = Self {
966 message_anchors,
967 messages_metadata: saved_conversation.message_metadata,
968 next_message_id,
969 summary: Some(Summary {
970 text: saved_conversation.summary,
971 done: true,
972 }),
973 pending_summary: Task::ready(None),
974 completion_count: Default::default(),
975 pending_completions: Default::default(),
976 token_count: None,
977 max_token_count: tiktoken_rs::model::get_context_size(&model),
978 pending_token_count: Task::ready(None),
979 model,
980 _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
981 pending_save: Task::ready(Ok(())),
982 path: Some(path),
983 api_key,
984 buffer,
985 };
986 this.count_remaining_tokens(cx);
987 this
988 }
989
990 fn handle_buffer_event(
991 &mut self,
992 _: ModelHandle<Buffer>,
993 event: &language::Event,
994 cx: &mut ModelContext<Self>,
995 ) {
996 match event {
997 language::Event::Edited => {
998 self.count_remaining_tokens(cx);
999 cx.emit(ConversationEvent::MessagesEdited);
1000 }
1001 _ => {}
1002 }
1003 }
1004
1005 fn count_remaining_tokens(&mut self, cx: &mut ModelContext<Self>) {
1006 let messages = self
1007 .messages(cx)
1008 .into_iter()
1009 .filter_map(|message| {
1010 Some(tiktoken_rs::ChatCompletionRequestMessage {
1011 role: match message.role {
1012 Role::User => "user".into(),
1013 Role::Assistant => "assistant".into(),
1014 Role::System => "system".into(),
1015 },
1016 content: self
1017 .buffer
1018 .read(cx)
1019 .text_for_range(message.offset_range)
1020 .collect(),
1021 name: None,
1022 })
1023 })
1024 .collect::<Vec<_>>();
1025 let model = self.model.clone();
1026 self.pending_token_count = cx.spawn_weak(|this, mut cx| {
1027 async move {
1028 cx.background().timer(Duration::from_millis(200)).await;
1029 let token_count = cx
1030 .background()
1031 .spawn(async move { tiktoken_rs::num_tokens_from_messages(&model, &messages) })
1032 .await?;
1033
1034 this.upgrade(&cx)
1035 .ok_or_else(|| anyhow!("conversation was dropped"))?
1036 .update(&mut cx, |this, cx| {
1037 this.max_token_count = tiktoken_rs::model::get_context_size(&this.model);
1038 this.token_count = Some(token_count);
1039 cx.notify()
1040 });
1041 anyhow::Ok(())
1042 }
1043 .log_err()
1044 });
1045 }
1046
1047 fn remaining_tokens(&self) -> Option<isize> {
1048 Some(self.max_token_count as isize - self.token_count? as isize)
1049 }
1050
1051 fn set_model(&mut self, model: String, cx: &mut ModelContext<Self>) {
1052 self.model = model;
1053 self.count_remaining_tokens(cx);
1054 cx.notify();
1055 }
1056
1057 fn assist(
1058 &mut self,
1059 selected_messages: HashSet<MessageId>,
1060 cx: &mut ModelContext<Self>,
1061 ) -> Vec<MessageAnchor> {
1062 let mut user_messages = Vec::new();
1063 let mut tasks = Vec::new();
1064
1065 let last_message_id = self.message_anchors.iter().rev().find_map(|message| {
1066 message
1067 .start
1068 .is_valid(self.buffer.read(cx))
1069 .then_some(message.id)
1070 });
1071
1072 for selected_message_id in selected_messages {
1073 let selected_message_role =
1074 if let Some(metadata) = self.messages_metadata.get(&selected_message_id) {
1075 metadata.role
1076 } else {
1077 continue;
1078 };
1079
1080 if selected_message_role == Role::Assistant {
1081 if let Some(user_message) = self.insert_message_after(
1082 selected_message_id,
1083 Role::User,
1084 MessageStatus::Done,
1085 cx,
1086 ) {
1087 user_messages.push(user_message);
1088 } else {
1089 continue;
1090 }
1091 } else {
1092 let request = OpenAIRequest {
1093 model: self.model.clone(),
1094 messages: self
1095 .messages(cx)
1096 .filter(|message| matches!(message.status, MessageStatus::Done))
1097 .flat_map(|message| {
1098 let mut system_message = None;
1099 if message.id == selected_message_id {
1100 system_message = Some(RequestMessage {
1101 role: Role::System,
1102 content: concat!(
1103 "Treat the following messages as additional knowledge you have learned about, ",
1104 "but act as if they were not part of this conversation. That is, treat them ",
1105 "as if the user didn't see them and couldn't possibly inquire about them."
1106 ).into()
1107 });
1108 }
1109
1110 Some(message.to_open_ai_message(self.buffer.read(cx))).into_iter().chain(system_message)
1111 })
1112 .chain(Some(RequestMessage {
1113 role: Role::System,
1114 content: format!(
1115 "Direct your reply to message with id {}. Do not include a [Message X] header.",
1116 selected_message_id.0
1117 ),
1118 }))
1119 .collect(),
1120 stream: true,
1121 };
1122
1123 let Some(api_key) = self.api_key.borrow().clone() else { continue };
1124 let stream = stream_completion(api_key, cx.background().clone(), request);
1125 let assistant_message = self
1126 .insert_message_after(
1127 selected_message_id,
1128 Role::Assistant,
1129 MessageStatus::Pending,
1130 cx,
1131 )
1132 .unwrap();
1133
1134 // Queue up the user's next reply
1135 if Some(selected_message_id) == last_message_id {
1136 let user_message = self
1137 .insert_message_after(
1138 assistant_message.id,
1139 Role::User,
1140 MessageStatus::Done,
1141 cx,
1142 )
1143 .unwrap();
1144 user_messages.push(user_message);
1145 }
1146
1147 tasks.push(cx.spawn_weak({
1148 |this, mut cx| async move {
1149 let assistant_message_id = assistant_message.id;
1150 let stream_completion = async {
1151 let mut messages = stream.await?;
1152
1153 while let Some(message) = messages.next().await {
1154 let mut message = message?;
1155 if let Some(choice) = message.choices.pop() {
1156 this.upgrade(&cx)
1157 .ok_or_else(|| anyhow!("conversation was dropped"))?
1158 .update(&mut cx, |this, cx| {
1159 let text: Arc<str> = choice.delta.content?.into();
1160 let message_ix = this.message_anchors.iter().position(
1161 |message| message.id == assistant_message_id,
1162 )?;
1163 this.buffer.update(cx, |buffer, cx| {
1164 let offset = this.message_anchors[message_ix + 1..]
1165 .iter()
1166 .find(|message| message.start.is_valid(buffer))
1167 .map_or(buffer.len(), |message| {
1168 message
1169 .start
1170 .to_offset(buffer)
1171 .saturating_sub(1)
1172 });
1173 buffer.edit([(offset..offset, text)], None, cx);
1174 });
1175 cx.emit(ConversationEvent::StreamedCompletion);
1176
1177 Some(())
1178 });
1179 }
1180 smol::future::yield_now().await;
1181 }
1182
1183 this.upgrade(&cx)
1184 .ok_or_else(|| anyhow!("conversation was dropped"))?
1185 .update(&mut cx, |this, cx| {
1186 this.pending_completions.retain(|completion| {
1187 completion.id != this.completion_count
1188 });
1189 this.summarize(cx);
1190 });
1191
1192 anyhow::Ok(())
1193 };
1194
1195 let result = stream_completion.await;
1196 if let Some(this) = this.upgrade(&cx) {
1197 this.update(&mut cx, |this, cx| {
1198 if let Some(metadata) =
1199 this.messages_metadata.get_mut(&assistant_message.id)
1200 {
1201 match result {
1202 Ok(_) => {
1203 metadata.status = MessageStatus::Done;
1204 }
1205 Err(error) => {
1206 metadata.status = MessageStatus::Error(
1207 error.to_string().trim().into(),
1208 );
1209 }
1210 }
1211 cx.notify();
1212 }
1213 });
1214 }
1215 }
1216 }));
1217 }
1218 }
1219
1220 if !tasks.is_empty() {
1221 self.pending_completions.push(PendingCompletion {
1222 id: post_inc(&mut self.completion_count),
1223 _tasks: tasks,
1224 });
1225 }
1226
1227 user_messages
1228 }
1229
1230 fn cancel_last_assist(&mut self) -> bool {
1231 self.pending_completions.pop().is_some()
1232 }
1233
1234 fn cycle_message_roles(&mut self, ids: HashSet<MessageId>, cx: &mut ModelContext<Self>) {
1235 for id in ids {
1236 if let Some(metadata) = self.messages_metadata.get_mut(&id) {
1237 metadata.role.cycle();
1238 cx.emit(ConversationEvent::MessagesEdited);
1239 cx.notify();
1240 }
1241 }
1242 }
1243
1244 fn insert_message_after(
1245 &mut self,
1246 message_id: MessageId,
1247 role: Role,
1248 status: MessageStatus,
1249 cx: &mut ModelContext<Self>,
1250 ) -> Option<MessageAnchor> {
1251 if let Some(prev_message_ix) = self
1252 .message_anchors
1253 .iter()
1254 .position(|message| message.id == message_id)
1255 {
1256 // Find the next valid message after the one we were given.
1257 let mut next_message_ix = prev_message_ix + 1;
1258 while let Some(next_message) = self.message_anchors.get(next_message_ix) {
1259 if next_message.start.is_valid(self.buffer.read(cx)) {
1260 break;
1261 }
1262 next_message_ix += 1;
1263 }
1264
1265 let start = self.buffer.update(cx, |buffer, cx| {
1266 let offset = self
1267 .message_anchors
1268 .get(next_message_ix)
1269 .map_or(buffer.len(), |message| message.start.to_offset(buffer) - 1);
1270 buffer.edit([(offset..offset, "\n")], None, cx);
1271 buffer.anchor_before(offset + 1)
1272 });
1273 let message = MessageAnchor {
1274 id: MessageId(post_inc(&mut self.next_message_id.0)),
1275 start,
1276 };
1277 self.message_anchors
1278 .insert(next_message_ix, message.clone());
1279 self.messages_metadata.insert(
1280 message.id,
1281 MessageMetadata {
1282 role,
1283 sent_at: Local::now(),
1284 status,
1285 },
1286 );
1287 cx.emit(ConversationEvent::MessagesEdited);
1288 Some(message)
1289 } else {
1290 None
1291 }
1292 }
1293
1294 fn split_message(
1295 &mut self,
1296 range: Range<usize>,
1297 cx: &mut ModelContext<Self>,
1298 ) -> (Option<MessageAnchor>, Option<MessageAnchor>) {
1299 let start_message = self.message_for_offset(range.start, cx);
1300 let end_message = self.message_for_offset(range.end, cx);
1301 if let Some((start_message, end_message)) = start_message.zip(end_message) {
1302 // Prevent splitting when range spans multiple messages.
1303 if start_message.id != end_message.id {
1304 return (None, None);
1305 }
1306
1307 let message = start_message;
1308 let role = message.role;
1309 let mut edited_buffer = false;
1310
1311 let mut suffix_start = None;
1312 if range.start > message.offset_range.start && range.end < message.offset_range.end - 1
1313 {
1314 if self.buffer.read(cx).chars_at(range.end).next() == Some('\n') {
1315 suffix_start = Some(range.end + 1);
1316 } else if self.buffer.read(cx).reversed_chars_at(range.end).next() == Some('\n') {
1317 suffix_start = Some(range.end);
1318 }
1319 }
1320
1321 let suffix = if let Some(suffix_start) = suffix_start {
1322 MessageAnchor {
1323 id: MessageId(post_inc(&mut self.next_message_id.0)),
1324 start: self.buffer.read(cx).anchor_before(suffix_start),
1325 }
1326 } else {
1327 self.buffer.update(cx, |buffer, cx| {
1328 buffer.edit([(range.end..range.end, "\n")], None, cx);
1329 });
1330 edited_buffer = true;
1331 MessageAnchor {
1332 id: MessageId(post_inc(&mut self.next_message_id.0)),
1333 start: self.buffer.read(cx).anchor_before(range.end + 1),
1334 }
1335 };
1336
1337 self.message_anchors
1338 .insert(message.index_range.end + 1, suffix.clone());
1339 self.messages_metadata.insert(
1340 suffix.id,
1341 MessageMetadata {
1342 role,
1343 sent_at: Local::now(),
1344 status: MessageStatus::Done,
1345 },
1346 );
1347
1348 let new_messages =
1349 if range.start == range.end || range.start == message.offset_range.start {
1350 (None, Some(suffix))
1351 } else {
1352 let mut prefix_end = None;
1353 if range.start > message.offset_range.start
1354 && range.end < message.offset_range.end - 1
1355 {
1356 if self.buffer.read(cx).chars_at(range.start).next() == Some('\n') {
1357 prefix_end = Some(range.start + 1);
1358 } else if self.buffer.read(cx).reversed_chars_at(range.start).next()
1359 == Some('\n')
1360 {
1361 prefix_end = Some(range.start);
1362 }
1363 }
1364
1365 let selection = if let Some(prefix_end) = prefix_end {
1366 cx.emit(ConversationEvent::MessagesEdited);
1367 MessageAnchor {
1368 id: MessageId(post_inc(&mut self.next_message_id.0)),
1369 start: self.buffer.read(cx).anchor_before(prefix_end),
1370 }
1371 } else {
1372 self.buffer.update(cx, |buffer, cx| {
1373 buffer.edit([(range.start..range.start, "\n")], None, cx)
1374 });
1375 edited_buffer = true;
1376 MessageAnchor {
1377 id: MessageId(post_inc(&mut self.next_message_id.0)),
1378 start: self.buffer.read(cx).anchor_before(range.end + 1),
1379 }
1380 };
1381
1382 self.message_anchors
1383 .insert(message.index_range.end + 1, selection.clone());
1384 self.messages_metadata.insert(
1385 selection.id,
1386 MessageMetadata {
1387 role,
1388 sent_at: Local::now(),
1389 status: MessageStatus::Done,
1390 },
1391 );
1392 (Some(selection), Some(suffix))
1393 };
1394
1395 if !edited_buffer {
1396 cx.emit(ConversationEvent::MessagesEdited);
1397 }
1398 new_messages
1399 } else {
1400 (None, None)
1401 }
1402 }
1403
1404 fn summarize(&mut self, cx: &mut ModelContext<Self>) {
1405 if self.message_anchors.len() >= 2 && self.summary.is_none() {
1406 let api_key = self.api_key.borrow().clone();
1407 if let Some(api_key) = api_key {
1408 let messages = self
1409 .messages(cx)
1410 .take(2)
1411 .map(|message| message.to_open_ai_message(self.buffer.read(cx)))
1412 .chain(Some(RequestMessage {
1413 role: Role::User,
1414 content:
1415 "Summarize the conversation into a short title without punctuation"
1416 .into(),
1417 }));
1418 let request = OpenAIRequest {
1419 model: self.model.clone(),
1420 messages: messages.collect(),
1421 stream: true,
1422 };
1423
1424 let stream = stream_completion(api_key, cx.background().clone(), request);
1425 self.pending_summary = cx.spawn(|this, mut cx| {
1426 async move {
1427 let mut messages = stream.await?;
1428
1429 while let Some(message) = messages.next().await {
1430 let mut message = message?;
1431 if let Some(choice) = message.choices.pop() {
1432 let text = choice.delta.content.unwrap_or_default();
1433 this.update(&mut cx, |this, cx| {
1434 this.summary
1435 .get_or_insert(Default::default())
1436 .text
1437 .push_str(&text);
1438 cx.emit(ConversationEvent::SummaryChanged);
1439 });
1440 }
1441 }
1442
1443 this.update(&mut cx, |this, cx| {
1444 if let Some(summary) = this.summary.as_mut() {
1445 summary.done = true;
1446 cx.emit(ConversationEvent::SummaryChanged);
1447 }
1448 });
1449
1450 anyhow::Ok(())
1451 }
1452 .log_err()
1453 });
1454 }
1455 }
1456 }
1457
1458 fn message_for_offset(&self, offset: usize, cx: &AppContext) -> Option<Message> {
1459 self.messages_for_offsets([offset], cx).pop()
1460 }
1461
1462 fn messages_for_offsets(
1463 &self,
1464 offsets: impl IntoIterator<Item = usize>,
1465 cx: &AppContext,
1466 ) -> Vec<Message> {
1467 let mut result = Vec::new();
1468
1469 let mut messages = self.messages(cx).peekable();
1470 let mut offsets = offsets.into_iter().peekable();
1471 let mut current_message = messages.next();
1472 while let Some(offset) = offsets.next() {
1473 // Locate the message that contains the offset.
1474 while current_message.as_ref().map_or(false, |message| {
1475 !message.offset_range.contains(&offset) && messages.peek().is_some()
1476 }) {
1477 current_message = messages.next();
1478 }
1479 let Some(message) = current_message.as_ref() else { break };
1480
1481 // Skip offsets that are in the same message.
1482 while offsets.peek().map_or(false, |offset| {
1483 message.offset_range.contains(offset) || messages.peek().is_none()
1484 }) {
1485 offsets.next();
1486 }
1487
1488 result.push(message.clone());
1489 }
1490 result
1491 }
1492
1493 fn messages<'a>(&'a self, cx: &'a AppContext) -> impl 'a + Iterator<Item = Message> {
1494 let buffer = self.buffer.read(cx);
1495 let mut message_anchors = self.message_anchors.iter().enumerate().peekable();
1496 iter::from_fn(move || {
1497 while let Some((start_ix, message_anchor)) = message_anchors.next() {
1498 let metadata = self.messages_metadata.get(&message_anchor.id)?;
1499 let message_start = message_anchor.start.to_offset(buffer);
1500 let mut message_end = None;
1501 let mut end_ix = start_ix;
1502 while let Some((_, next_message)) = message_anchors.peek() {
1503 if next_message.start.is_valid(buffer) {
1504 message_end = Some(next_message.start);
1505 break;
1506 } else {
1507 end_ix += 1;
1508 message_anchors.next();
1509 }
1510 }
1511 let message_end = message_end
1512 .unwrap_or(language::Anchor::MAX)
1513 .to_offset(buffer);
1514 return Some(Message {
1515 index_range: start_ix..end_ix,
1516 offset_range: message_start..message_end,
1517 id: message_anchor.id,
1518 anchor: message_anchor.start,
1519 role: metadata.role,
1520 sent_at: metadata.sent_at,
1521 status: metadata.status.clone(),
1522 });
1523 }
1524 None
1525 })
1526 }
1527
1528 fn save(
1529 &mut self,
1530 debounce: Option<Duration>,
1531 fs: Arc<dyn Fs>,
1532 cx: &mut ModelContext<Conversation>,
1533 ) {
1534 self.pending_save = cx.spawn(|this, mut cx| async move {
1535 if let Some(debounce) = debounce {
1536 cx.background().timer(debounce).await;
1537 }
1538
1539 let (old_path, summary) = this.read_with(&cx, |this, _| {
1540 let path = this.path.clone();
1541 let summary = if let Some(summary) = this.summary.as_ref() {
1542 if summary.done {
1543 Some(summary.text.clone())
1544 } else {
1545 None
1546 }
1547 } else {
1548 None
1549 };
1550 (path, summary)
1551 });
1552
1553 if let Some(summary) = summary {
1554 let conversation = this.read_with(&cx, |this, cx| this.serialize(cx));
1555 let path = if let Some(old_path) = old_path {
1556 old_path
1557 } else {
1558 let mut discriminant = 1;
1559 let mut new_path;
1560 loop {
1561 new_path = CONVERSATIONS_DIR.join(&format!(
1562 "{} - {}.zed.json",
1563 summary.trim(),
1564 discriminant
1565 ));
1566 if fs.is_file(&new_path).await {
1567 discriminant += 1;
1568 } else {
1569 break;
1570 }
1571 }
1572 new_path
1573 };
1574
1575 fs.create_dir(CONVERSATIONS_DIR.as_ref()).await?;
1576 fs.atomic_write(path.clone(), serde_json::to_string(&conversation).unwrap())
1577 .await?;
1578 this.update(&mut cx, |this, _| this.path = Some(path));
1579 }
1580
1581 Ok(())
1582 });
1583 }
1584}
1585
1586struct PendingCompletion {
1587 id: usize,
1588 _tasks: Vec<Task<()>>,
1589}
1590
1591enum ConversationEditorEvent {
1592 TabContentChanged,
1593}
1594
1595#[derive(Copy, Clone, Debug, PartialEq)]
1596struct ScrollPosition {
1597 offset_before_cursor: Vector2F,
1598 cursor: Anchor,
1599}
1600
1601struct ConversationEditor {
1602 conversation: ModelHandle<Conversation>,
1603 fs: Arc<dyn Fs>,
1604 editor: ViewHandle<Editor>,
1605 blocks: HashSet<BlockId>,
1606 scroll_position: Option<ScrollPosition>,
1607 _subscriptions: Vec<Subscription>,
1608}
1609
1610impl ConversationEditor {
1611 fn new(
1612 api_key: Rc<RefCell<Option<String>>>,
1613 language_registry: Arc<LanguageRegistry>,
1614 fs: Arc<dyn Fs>,
1615 cx: &mut ViewContext<Self>,
1616 ) -> Self {
1617 let conversation = cx.add_model(|cx| Conversation::new(api_key, language_registry, cx));
1618 Self::for_conversation(conversation, fs, cx)
1619 }
1620
1621 fn for_conversation(
1622 conversation: ModelHandle<Conversation>,
1623 fs: Arc<dyn Fs>,
1624 cx: &mut ViewContext<Self>,
1625 ) -> Self {
1626 let editor = cx.add_view(|cx| {
1627 let mut editor = Editor::for_buffer(conversation.read(cx).buffer.clone(), None, cx);
1628 editor.set_soft_wrap_mode(SoftWrap::EditorWidth, cx);
1629 editor.set_show_gutter(false, cx);
1630 editor
1631 });
1632
1633 let _subscriptions = vec![
1634 cx.observe(&conversation, |_, _, cx| cx.notify()),
1635 cx.subscribe(&conversation, Self::handle_conversation_event),
1636 cx.subscribe(&editor, Self::handle_editor_event),
1637 ];
1638
1639 let mut this = Self {
1640 conversation,
1641 editor,
1642 blocks: Default::default(),
1643 scroll_position: None,
1644 fs,
1645 _subscriptions,
1646 };
1647 this.update_message_headers(cx);
1648 this
1649 }
1650
1651 fn assist(&mut self, _: &Assist, cx: &mut ViewContext<Self>) {
1652 let cursors = self.cursors(cx);
1653
1654 let user_messages = self.conversation.update(cx, |conversation, cx| {
1655 let selected_messages = conversation
1656 .messages_for_offsets(cursors, cx)
1657 .into_iter()
1658 .map(|message| message.id)
1659 .collect();
1660 conversation.assist(selected_messages, cx)
1661 });
1662 let new_selections = user_messages
1663 .iter()
1664 .map(|message| {
1665 let cursor = message
1666 .start
1667 .to_offset(self.conversation.read(cx).buffer.read(cx));
1668 cursor..cursor
1669 })
1670 .collect::<Vec<_>>();
1671 if !new_selections.is_empty() {
1672 self.editor.update(cx, |editor, cx| {
1673 editor.change_selections(
1674 Some(Autoscroll::Strategy(AutoscrollStrategy::Fit)),
1675 cx,
1676 |selections| selections.select_ranges(new_selections),
1677 );
1678 });
1679 // Avoid scrolling to the new cursor position so the assistant's output is stable.
1680 cx.defer(|this, _| this.scroll_position = None);
1681 }
1682 }
1683
1684 fn cancel_last_assist(&mut self, _: &editor::Cancel, cx: &mut ViewContext<Self>) {
1685 if !self
1686 .conversation
1687 .update(cx, |conversation, _| conversation.cancel_last_assist())
1688 {
1689 cx.propagate_action();
1690 }
1691 }
1692
1693 fn cycle_message_role(&mut self, _: &CycleMessageRole, cx: &mut ViewContext<Self>) {
1694 let cursors = self.cursors(cx);
1695 self.conversation.update(cx, |conversation, cx| {
1696 let messages = conversation
1697 .messages_for_offsets(cursors, cx)
1698 .into_iter()
1699 .map(|message| message.id)
1700 .collect();
1701 conversation.cycle_message_roles(messages, cx)
1702 });
1703 }
1704
1705 fn cursors(&self, cx: &AppContext) -> Vec<usize> {
1706 let selections = self.editor.read(cx).selections.all::<usize>(cx);
1707 selections
1708 .into_iter()
1709 .map(|selection| selection.head())
1710 .collect()
1711 }
1712
1713 fn handle_conversation_event(
1714 &mut self,
1715 _: ModelHandle<Conversation>,
1716 event: &ConversationEvent,
1717 cx: &mut ViewContext<Self>,
1718 ) {
1719 match event {
1720 ConversationEvent::MessagesEdited => {
1721 self.update_message_headers(cx);
1722 self.conversation.update(cx, |conversation, cx| {
1723 conversation.save(Some(Duration::from_millis(500)), self.fs.clone(), cx);
1724 });
1725 }
1726 ConversationEvent::SummaryChanged => {
1727 cx.emit(ConversationEditorEvent::TabContentChanged);
1728 self.conversation.update(cx, |conversation, cx| {
1729 conversation.save(None, self.fs.clone(), cx);
1730 });
1731 }
1732 ConversationEvent::StreamedCompletion => {
1733 self.editor.update(cx, |editor, cx| {
1734 if let Some(scroll_position) = self.scroll_position {
1735 let snapshot = editor.snapshot(cx);
1736 let cursor_point = scroll_position.cursor.to_display_point(&snapshot);
1737 let scroll_top =
1738 cursor_point.row() as f32 - scroll_position.offset_before_cursor.y();
1739 editor.set_scroll_position(
1740 vec2f(scroll_position.offset_before_cursor.x(), scroll_top),
1741 cx,
1742 );
1743 }
1744 });
1745 }
1746 }
1747 }
1748
1749 fn handle_editor_event(
1750 &mut self,
1751 _: ViewHandle<Editor>,
1752 event: &editor::Event,
1753 cx: &mut ViewContext<Self>,
1754 ) {
1755 match event {
1756 editor::Event::ScrollPositionChanged { autoscroll, .. } => {
1757 let cursor_scroll_position = self.cursor_scroll_position(cx);
1758 if *autoscroll {
1759 self.scroll_position = cursor_scroll_position;
1760 } else if self.scroll_position != cursor_scroll_position {
1761 self.scroll_position = None;
1762 }
1763 }
1764 editor::Event::SelectionsChanged { .. } => {
1765 self.scroll_position = self.cursor_scroll_position(cx);
1766 }
1767 _ => {}
1768 }
1769 }
1770
1771 fn cursor_scroll_position(&self, cx: &mut ViewContext<Self>) -> Option<ScrollPosition> {
1772 self.editor.update(cx, |editor, cx| {
1773 let snapshot = editor.snapshot(cx);
1774 let cursor = editor.selections.newest_anchor().head();
1775 let cursor_row = cursor.to_display_point(&snapshot.display_snapshot).row() as f32;
1776 let scroll_position = editor
1777 .scroll_manager
1778 .anchor()
1779 .scroll_position(&snapshot.display_snapshot);
1780
1781 let scroll_bottom = scroll_position.y() + editor.visible_line_count().unwrap_or(0.);
1782 if (scroll_position.y()..scroll_bottom).contains(&cursor_row) {
1783 Some(ScrollPosition {
1784 cursor,
1785 offset_before_cursor: vec2f(
1786 scroll_position.x(),
1787 cursor_row - scroll_position.y(),
1788 ),
1789 })
1790 } else {
1791 None
1792 }
1793 })
1794 }
1795
1796 fn update_message_headers(&mut self, cx: &mut ViewContext<Self>) {
1797 self.editor.update(cx, |editor, cx| {
1798 let buffer = editor.buffer().read(cx).snapshot(cx);
1799 let excerpt_id = *buffer.as_singleton().unwrap().0;
1800 let old_blocks = std::mem::take(&mut self.blocks);
1801 let new_blocks = self
1802 .conversation
1803 .read(cx)
1804 .messages(cx)
1805 .map(|message| BlockProperties {
1806 position: buffer.anchor_in_excerpt(excerpt_id, message.anchor),
1807 height: 2,
1808 style: BlockStyle::Sticky,
1809 render: Arc::new({
1810 let conversation = self.conversation.clone();
1811 // let metadata = message.metadata.clone();
1812 // let message = message.clone();
1813 move |cx| {
1814 enum Sender {}
1815 enum ErrorTooltip {}
1816
1817 let theme = theme::current(cx);
1818 let style = &theme.assistant;
1819 let message_id = message.id;
1820 let sender = MouseEventHandler::<Sender, _>::new(
1821 message_id.0,
1822 cx,
1823 |state, _| match message.role {
1824 Role::User => {
1825 let style = style.user_sender.style_for(state);
1826 Label::new("You", style.text.clone())
1827 .contained()
1828 .with_style(style.container)
1829 }
1830 Role::Assistant => {
1831 let style = style.assistant_sender.style_for(state);
1832 Label::new("Assistant", style.text.clone())
1833 .contained()
1834 .with_style(style.container)
1835 }
1836 Role::System => {
1837 let style = style.system_sender.style_for(state);
1838 Label::new("System", style.text.clone())
1839 .contained()
1840 .with_style(style.container)
1841 }
1842 },
1843 )
1844 .with_cursor_style(CursorStyle::PointingHand)
1845 .on_down(MouseButton::Left, {
1846 let conversation = conversation.clone();
1847 move |_, _, cx| {
1848 conversation.update(cx, |conversation, cx| {
1849 conversation.cycle_message_roles(
1850 HashSet::from_iter(Some(message_id)),
1851 cx,
1852 )
1853 })
1854 }
1855 });
1856
1857 Flex::row()
1858 .with_child(sender.aligned())
1859 .with_child(
1860 Label::new(
1861 message.sent_at.format("%I:%M%P").to_string(),
1862 style.sent_at.text.clone(),
1863 )
1864 .contained()
1865 .with_style(style.sent_at.container)
1866 .aligned(),
1867 )
1868 .with_children(
1869 if let MessageStatus::Error(error) = &message.status {
1870 Some(
1871 Svg::new("icons/circle_x_mark_12.svg")
1872 .with_color(style.error_icon.color)
1873 .constrained()
1874 .with_width(style.error_icon.width)
1875 .contained()
1876 .with_style(style.error_icon.container)
1877 .with_tooltip::<ErrorTooltip>(
1878 message_id.0,
1879 error.to_string(),
1880 None,
1881 theme.tooltip.clone(),
1882 cx,
1883 )
1884 .aligned(),
1885 )
1886 } else {
1887 None
1888 },
1889 )
1890 .aligned()
1891 .left()
1892 .contained()
1893 .with_style(style.message_header)
1894 .into_any()
1895 }
1896 }),
1897 disposition: BlockDisposition::Above,
1898 })
1899 .collect::<Vec<_>>();
1900
1901 editor.remove_blocks(old_blocks, None, cx);
1902 let ids = editor.insert_blocks(new_blocks, None, cx);
1903 self.blocks = HashSet::from_iter(ids);
1904 });
1905 }
1906
1907 fn quote_selection(
1908 workspace: &mut Workspace,
1909 _: &QuoteSelection,
1910 cx: &mut ViewContext<Workspace>,
1911 ) {
1912 let Some(panel) = workspace.panel::<AssistantPanel>(cx) else {
1913 return;
1914 };
1915 let Some(editor) = workspace.active_item(cx).and_then(|item| item.act_as::<Editor>(cx)) else {
1916 return;
1917 };
1918
1919 let text = editor.read_with(cx, |editor, cx| {
1920 let range = editor.selections.newest::<usize>(cx).range();
1921 let buffer = editor.buffer().read(cx).snapshot(cx);
1922 let start_language = buffer.language_at(range.start);
1923 let end_language = buffer.language_at(range.end);
1924 let language_name = if start_language == end_language {
1925 start_language.map(|language| language.name())
1926 } else {
1927 None
1928 };
1929 let language_name = language_name.as_deref().unwrap_or("").to_lowercase();
1930
1931 let selected_text = buffer.text_for_range(range).collect::<String>();
1932 if selected_text.is_empty() {
1933 None
1934 } else {
1935 Some(if language_name == "markdown" {
1936 selected_text
1937 .lines()
1938 .map(|line| format!("> {}", line))
1939 .collect::<Vec<_>>()
1940 .join("\n")
1941 } else {
1942 format!("```{language_name}\n{selected_text}\n```")
1943 })
1944 }
1945 });
1946
1947 // Activate the panel
1948 if !panel.read(cx).has_focus(cx) {
1949 workspace.toggle_panel_focus::<AssistantPanel>(cx);
1950 }
1951
1952 if let Some(text) = text {
1953 panel.update(cx, |panel, cx| {
1954 let conversation = panel
1955 .active_editor()
1956 .cloned()
1957 .unwrap_or_else(|| panel.new_conversation(cx));
1958 conversation.update(cx, |conversation, cx| {
1959 conversation
1960 .editor
1961 .update(cx, |editor, cx| editor.insert(&text, cx))
1962 });
1963 });
1964 }
1965 }
1966
1967 fn copy(&mut self, _: &editor::Copy, cx: &mut ViewContext<Self>) {
1968 let editor = self.editor.read(cx);
1969 let conversation = self.conversation.read(cx);
1970 if editor.selections.count() == 1 {
1971 let selection = editor.selections.newest::<usize>(cx);
1972 let mut copied_text = String::new();
1973 let mut spanned_messages = 0;
1974 for message in conversation.messages(cx) {
1975 if message.offset_range.start >= selection.range().end {
1976 break;
1977 } else if message.offset_range.end >= selection.range().start {
1978 let range = cmp::max(message.offset_range.start, selection.range().start)
1979 ..cmp::min(message.offset_range.end, selection.range().end);
1980 if !range.is_empty() {
1981 spanned_messages += 1;
1982 write!(&mut copied_text, "## {}\n\n", message.role).unwrap();
1983 for chunk in conversation.buffer.read(cx).text_for_range(range) {
1984 copied_text.push_str(&chunk);
1985 }
1986 copied_text.push('\n');
1987 }
1988 }
1989 }
1990
1991 if spanned_messages > 1 {
1992 cx.platform()
1993 .write_to_clipboard(ClipboardItem::new(copied_text));
1994 return;
1995 }
1996 }
1997
1998 cx.propagate_action();
1999 }
2000
2001 fn split(&mut self, _: &Split, cx: &mut ViewContext<Self>) {
2002 self.conversation.update(cx, |conversation, cx| {
2003 let selections = self.editor.read(cx).selections.disjoint_anchors();
2004 for selection in selections.into_iter() {
2005 let buffer = self.editor.read(cx).buffer().read(cx).snapshot(cx);
2006 let range = selection
2007 .map(|endpoint| endpoint.to_offset(&buffer))
2008 .range();
2009 conversation.split_message(range, cx);
2010 }
2011 });
2012 }
2013
2014 fn save(&mut self, _: &Save, cx: &mut ViewContext<Self>) {
2015 self.conversation.update(cx, |conversation, cx| {
2016 conversation.save(None, self.fs.clone(), cx)
2017 });
2018 }
2019
2020 fn cycle_model(&mut self, cx: &mut ViewContext<Self>) {
2021 self.conversation.update(cx, |conversation, cx| {
2022 let new_model = match conversation.model.as_str() {
2023 "gpt-4-0613" => "gpt-3.5-turbo-0613",
2024 _ => "gpt-4-0613",
2025 };
2026 conversation.set_model(new_model.into(), cx);
2027 });
2028 }
2029
2030 fn title(&self, cx: &AppContext) -> String {
2031 self.conversation
2032 .read(cx)
2033 .summary
2034 .as_ref()
2035 .map(|summary| summary.text.clone())
2036 .unwrap_or_else(|| "New Conversation".into())
2037 }
2038
2039 fn render_current_model(
2040 &self,
2041 style: &AssistantStyle,
2042 cx: &mut ViewContext<Self>,
2043 ) -> impl Element<Self> {
2044 enum Model {}
2045
2046 MouseEventHandler::<Model, _>::new(0, cx, |state, cx| {
2047 let style = style.model.style_for(state);
2048 Label::new(self.conversation.read(cx).model.clone(), style.text.clone())
2049 .contained()
2050 .with_style(style.container)
2051 })
2052 .with_cursor_style(CursorStyle::PointingHand)
2053 .on_click(MouseButton::Left, |_, this, cx| this.cycle_model(cx))
2054 }
2055
2056 fn render_remaining_tokens(
2057 &self,
2058 style: &AssistantStyle,
2059 cx: &mut ViewContext<Self>,
2060 ) -> Option<impl Element<Self>> {
2061 let remaining_tokens = self.conversation.read(cx).remaining_tokens()?;
2062 let remaining_tokens_style = if remaining_tokens <= 0 {
2063 &style.no_remaining_tokens
2064 } else {
2065 &style.remaining_tokens
2066 };
2067 Some(
2068 Label::new(
2069 remaining_tokens.to_string(),
2070 remaining_tokens_style.text.clone(),
2071 )
2072 .contained()
2073 .with_style(remaining_tokens_style.container),
2074 )
2075 }
2076}
2077
2078impl Entity for ConversationEditor {
2079 type Event = ConversationEditorEvent;
2080}
2081
2082impl View for ConversationEditor {
2083 fn ui_name() -> &'static str {
2084 "ConversationEditor"
2085 }
2086
2087 fn render(&mut self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
2088 let theme = &theme::current(cx).assistant;
2089 Stack::new()
2090 .with_child(
2091 ChildView::new(&self.editor, cx)
2092 .contained()
2093 .with_style(theme.container),
2094 )
2095 .with_child(
2096 Flex::row()
2097 .with_child(self.render_current_model(theme, cx))
2098 .with_children(self.render_remaining_tokens(theme, cx))
2099 .aligned()
2100 .top()
2101 .right(),
2102 )
2103 .into_any()
2104 }
2105
2106 fn focus_in(&mut self, _: gpui::AnyViewHandle, cx: &mut ViewContext<Self>) {
2107 if cx.is_self_focused() {
2108 cx.focus(&self.editor);
2109 }
2110 }
2111}
2112
2113#[derive(Clone, Debug)]
2114struct MessageAnchor {
2115 id: MessageId,
2116 start: language::Anchor,
2117}
2118
2119#[derive(Clone, Debug)]
2120pub struct Message {
2121 offset_range: Range<usize>,
2122 index_range: Range<usize>,
2123 id: MessageId,
2124 anchor: language::Anchor,
2125 role: Role,
2126 sent_at: DateTime<Local>,
2127 status: MessageStatus,
2128}
2129
2130impl Message {
2131 fn to_open_ai_message(&self, buffer: &Buffer) -> RequestMessage {
2132 let mut content = format!("[Message {}]\n", self.id.0).to_string();
2133 content.extend(buffer.text_for_range(self.offset_range.clone()));
2134 RequestMessage {
2135 role: self.role,
2136 content: content.trim_end().into(),
2137 }
2138 }
2139}
2140
2141async fn stream_completion(
2142 api_key: String,
2143 executor: Arc<Background>,
2144 mut request: OpenAIRequest,
2145) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
2146 request.stream = true;
2147
2148 let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
2149
2150 let json_data = serde_json::to_string(&request)?;
2151 let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
2152 .header("Content-Type", "application/json")
2153 .header("Authorization", format!("Bearer {}", api_key))
2154 .body(json_data)?
2155 .send_async()
2156 .await?;
2157
2158 let status = response.status();
2159 if status == StatusCode::OK {
2160 executor
2161 .spawn(async move {
2162 let mut lines = BufReader::new(response.body_mut()).lines();
2163
2164 fn parse_line(
2165 line: Result<String, io::Error>,
2166 ) -> Result<Option<OpenAIResponseStreamEvent>> {
2167 if let Some(data) = line?.strip_prefix("data: ") {
2168 let event = serde_json::from_str(&data)?;
2169 Ok(Some(event))
2170 } else {
2171 Ok(None)
2172 }
2173 }
2174
2175 while let Some(line) = lines.next().await {
2176 if let Some(event) = parse_line(line).transpose() {
2177 let done = event.as_ref().map_or(false, |event| {
2178 event
2179 .choices
2180 .last()
2181 .map_or(false, |choice| choice.finish_reason.is_some())
2182 });
2183 if tx.unbounded_send(event).is_err() {
2184 break;
2185 }
2186
2187 if done {
2188 break;
2189 }
2190 }
2191 }
2192
2193 anyhow::Ok(())
2194 })
2195 .detach();
2196
2197 Ok(rx)
2198 } else {
2199 let mut body = String::new();
2200 response.body_mut().read_to_string(&mut body).await?;
2201
2202 #[derive(Deserialize)]
2203 struct OpenAIResponse {
2204 error: OpenAIError,
2205 }
2206
2207 #[derive(Deserialize)]
2208 struct OpenAIError {
2209 message: String,
2210 }
2211
2212 match serde_json::from_str::<OpenAIResponse>(&body) {
2213 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
2214 "Failed to connect to OpenAI API: {}",
2215 response.error.message,
2216 )),
2217
2218 _ => Err(anyhow!(
2219 "Failed to connect to OpenAI API: {} {}",
2220 response.status(),
2221 body,
2222 )),
2223 }
2224 }
2225}
2226
2227#[cfg(test)]
2228mod tests {
2229 use super::*;
2230 use crate::MessageId;
2231 use gpui::AppContext;
2232
2233 #[gpui::test]
2234 fn test_inserting_and_removing_messages(cx: &mut AppContext) {
2235 let registry = Arc::new(LanguageRegistry::test());
2236 let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
2237 let buffer = conversation.read(cx).buffer.clone();
2238
2239 let message_1 = conversation.read(cx).message_anchors[0].clone();
2240 assert_eq!(
2241 messages(&conversation, cx),
2242 vec![(message_1.id, Role::User, 0..0)]
2243 );
2244
2245 let message_2 = conversation.update(cx, |conversation, cx| {
2246 conversation
2247 .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
2248 .unwrap()
2249 });
2250 assert_eq!(
2251 messages(&conversation, cx),
2252 vec![
2253 (message_1.id, Role::User, 0..1),
2254 (message_2.id, Role::Assistant, 1..1)
2255 ]
2256 );
2257
2258 buffer.update(cx, |buffer, cx| {
2259 buffer.edit([(0..0, "1"), (1..1, "2")], None, cx)
2260 });
2261 assert_eq!(
2262 messages(&conversation, cx),
2263 vec![
2264 (message_1.id, Role::User, 0..2),
2265 (message_2.id, Role::Assistant, 2..3)
2266 ]
2267 );
2268
2269 let message_3 = conversation.update(cx, |conversation, cx| {
2270 conversation
2271 .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
2272 .unwrap()
2273 });
2274 assert_eq!(
2275 messages(&conversation, cx),
2276 vec![
2277 (message_1.id, Role::User, 0..2),
2278 (message_2.id, Role::Assistant, 2..4),
2279 (message_3.id, Role::User, 4..4)
2280 ]
2281 );
2282
2283 let message_4 = conversation.update(cx, |conversation, cx| {
2284 conversation
2285 .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
2286 .unwrap()
2287 });
2288 assert_eq!(
2289 messages(&conversation, cx),
2290 vec![
2291 (message_1.id, Role::User, 0..2),
2292 (message_2.id, Role::Assistant, 2..4),
2293 (message_4.id, Role::User, 4..5),
2294 (message_3.id, Role::User, 5..5),
2295 ]
2296 );
2297
2298 buffer.update(cx, |buffer, cx| {
2299 buffer.edit([(4..4, "C"), (5..5, "D")], None, cx)
2300 });
2301 assert_eq!(
2302 messages(&conversation, cx),
2303 vec![
2304 (message_1.id, Role::User, 0..2),
2305 (message_2.id, Role::Assistant, 2..4),
2306 (message_4.id, Role::User, 4..6),
2307 (message_3.id, Role::User, 6..7),
2308 ]
2309 );
2310
2311 // Deleting across message boundaries merges the messages.
2312 buffer.update(cx, |buffer, cx| buffer.edit([(1..4, "")], None, cx));
2313 assert_eq!(
2314 messages(&conversation, cx),
2315 vec![
2316 (message_1.id, Role::User, 0..3),
2317 (message_3.id, Role::User, 3..4),
2318 ]
2319 );
2320
2321 // Undoing the deletion should also undo the merge.
2322 buffer.update(cx, |buffer, cx| buffer.undo(cx));
2323 assert_eq!(
2324 messages(&conversation, cx),
2325 vec![
2326 (message_1.id, Role::User, 0..2),
2327 (message_2.id, Role::Assistant, 2..4),
2328 (message_4.id, Role::User, 4..6),
2329 (message_3.id, Role::User, 6..7),
2330 ]
2331 );
2332
2333 // Redoing the deletion should also redo the merge.
2334 buffer.update(cx, |buffer, cx| buffer.redo(cx));
2335 assert_eq!(
2336 messages(&conversation, cx),
2337 vec![
2338 (message_1.id, Role::User, 0..3),
2339 (message_3.id, Role::User, 3..4),
2340 ]
2341 );
2342
2343 // Ensure we can still insert after a merged message.
2344 let message_5 = conversation.update(cx, |conversation, cx| {
2345 conversation
2346 .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
2347 .unwrap()
2348 });
2349 assert_eq!(
2350 messages(&conversation, cx),
2351 vec![
2352 (message_1.id, Role::User, 0..3),
2353 (message_5.id, Role::System, 3..4),
2354 (message_3.id, Role::User, 4..5)
2355 ]
2356 );
2357 }
2358
2359 #[gpui::test]
2360 fn test_message_splitting(cx: &mut AppContext) {
2361 let registry = Arc::new(LanguageRegistry::test());
2362 let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
2363 let buffer = conversation.read(cx).buffer.clone();
2364
2365 let message_1 = conversation.read(cx).message_anchors[0].clone();
2366 assert_eq!(
2367 messages(&conversation, cx),
2368 vec![(message_1.id, Role::User, 0..0)]
2369 );
2370
2371 buffer.update(cx, |buffer, cx| {
2372 buffer.edit([(0..0, "aaa\nbbb\nccc\nddd\n")], None, cx)
2373 });
2374
2375 let (_, message_2) =
2376 conversation.update(cx, |conversation, cx| conversation.split_message(3..3, cx));
2377 let message_2 = message_2.unwrap();
2378
2379 // We recycle newlines in the middle of a split message
2380 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\nddd\n");
2381 assert_eq!(
2382 messages(&conversation, cx),
2383 vec![
2384 (message_1.id, Role::User, 0..4),
2385 (message_2.id, Role::User, 4..16),
2386 ]
2387 );
2388
2389 let (_, message_3) =
2390 conversation.update(cx, |conversation, cx| conversation.split_message(3..3, cx));
2391 let message_3 = message_3.unwrap();
2392
2393 // We don't recycle newlines at the end of a split message
2394 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
2395 assert_eq!(
2396 messages(&conversation, cx),
2397 vec![
2398 (message_1.id, Role::User, 0..4),
2399 (message_3.id, Role::User, 4..5),
2400 (message_2.id, Role::User, 5..17),
2401 ]
2402 );
2403
2404 let (_, message_4) =
2405 conversation.update(cx, |conversation, cx| conversation.split_message(9..9, cx));
2406 let message_4 = message_4.unwrap();
2407 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
2408 assert_eq!(
2409 messages(&conversation, cx),
2410 vec![
2411 (message_1.id, Role::User, 0..4),
2412 (message_3.id, Role::User, 4..5),
2413 (message_2.id, Role::User, 5..9),
2414 (message_4.id, Role::User, 9..17),
2415 ]
2416 );
2417
2418 let (_, message_5) =
2419 conversation.update(cx, |conversation, cx| conversation.split_message(9..9, cx));
2420 let message_5 = message_5.unwrap();
2421 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\nddd\n");
2422 assert_eq!(
2423 messages(&conversation, cx),
2424 vec![
2425 (message_1.id, Role::User, 0..4),
2426 (message_3.id, Role::User, 4..5),
2427 (message_2.id, Role::User, 5..9),
2428 (message_4.id, Role::User, 9..10),
2429 (message_5.id, Role::User, 10..18),
2430 ]
2431 );
2432
2433 let (message_6, message_7) = conversation.update(cx, |conversation, cx| {
2434 conversation.split_message(14..16, cx)
2435 });
2436 let message_6 = message_6.unwrap();
2437 let message_7 = message_7.unwrap();
2438 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\ndd\nd\n");
2439 assert_eq!(
2440 messages(&conversation, cx),
2441 vec![
2442 (message_1.id, Role::User, 0..4),
2443 (message_3.id, Role::User, 4..5),
2444 (message_2.id, Role::User, 5..9),
2445 (message_4.id, Role::User, 9..10),
2446 (message_5.id, Role::User, 10..14),
2447 (message_6.id, Role::User, 14..17),
2448 (message_7.id, Role::User, 17..19),
2449 ]
2450 );
2451 }
2452
2453 #[gpui::test]
2454 fn test_messages_for_offsets(cx: &mut AppContext) {
2455 let registry = Arc::new(LanguageRegistry::test());
2456 let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
2457 let buffer = conversation.read(cx).buffer.clone();
2458
2459 let message_1 = conversation.read(cx).message_anchors[0].clone();
2460 assert_eq!(
2461 messages(&conversation, cx),
2462 vec![(message_1.id, Role::User, 0..0)]
2463 );
2464
2465 buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
2466 let message_2 = conversation
2467 .update(cx, |conversation, cx| {
2468 conversation.insert_message_after(message_1.id, Role::User, MessageStatus::Done, cx)
2469 })
2470 .unwrap();
2471 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx));
2472
2473 let message_3 = conversation
2474 .update(cx, |conversation, cx| {
2475 conversation.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
2476 })
2477 .unwrap();
2478 buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx));
2479
2480 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc");
2481 assert_eq!(
2482 messages(&conversation, cx),
2483 vec![
2484 (message_1.id, Role::User, 0..4),
2485 (message_2.id, Role::User, 4..8),
2486 (message_3.id, Role::User, 8..11)
2487 ]
2488 );
2489
2490 assert_eq!(
2491 message_ids_for_offsets(&conversation, &[0, 4, 9], cx),
2492 [message_1.id, message_2.id, message_3.id]
2493 );
2494 assert_eq!(
2495 message_ids_for_offsets(&conversation, &[0, 1, 11], cx),
2496 [message_1.id, message_3.id]
2497 );
2498
2499 let message_4 = conversation
2500 .update(cx, |conversation, cx| {
2501 conversation.insert_message_after(message_3.id, Role::User, MessageStatus::Done, cx)
2502 })
2503 .unwrap();
2504 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\n");
2505 assert_eq!(
2506 messages(&conversation, cx),
2507 vec![
2508 (message_1.id, Role::User, 0..4),
2509 (message_2.id, Role::User, 4..8),
2510 (message_3.id, Role::User, 8..12),
2511 (message_4.id, Role::User, 12..12)
2512 ]
2513 );
2514 assert_eq!(
2515 message_ids_for_offsets(&conversation, &[0, 4, 8, 12], cx),
2516 [message_1.id, message_2.id, message_3.id, message_4.id]
2517 );
2518
2519 fn message_ids_for_offsets(
2520 conversation: &ModelHandle<Conversation>,
2521 offsets: &[usize],
2522 cx: &AppContext,
2523 ) -> Vec<MessageId> {
2524 conversation
2525 .read(cx)
2526 .messages_for_offsets(offsets.iter().copied(), cx)
2527 .into_iter()
2528 .map(|message| message.id)
2529 .collect()
2530 }
2531 }
2532
2533 #[gpui::test]
2534 fn test_serialization(cx: &mut AppContext) {
2535 let registry = Arc::new(LanguageRegistry::test());
2536 let conversation =
2537 cx.add_model(|cx| Conversation::new(Default::default(), registry.clone(), cx));
2538 let buffer = conversation.read(cx).buffer.clone();
2539 let message_0 = conversation.read(cx).message_anchors[0].id;
2540 let message_1 = conversation.update(cx, |conversation, cx| {
2541 conversation
2542 .insert_message_after(message_0, Role::Assistant, MessageStatus::Done, cx)
2543 .unwrap()
2544 });
2545 let message_2 = conversation.update(cx, |conversation, cx| {
2546 conversation
2547 .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
2548 .unwrap()
2549 });
2550 buffer.update(cx, |buffer, cx| {
2551 buffer.edit([(0..0, "a"), (1..1, "b\nc")], None, cx);
2552 buffer.finalize_last_transaction();
2553 });
2554 let _message_3 = conversation.update(cx, |conversation, cx| {
2555 conversation
2556 .insert_message_after(message_2.id, Role::System, MessageStatus::Done, cx)
2557 .unwrap()
2558 });
2559 buffer.update(cx, |buffer, cx| buffer.undo(cx));
2560 assert_eq!(buffer.read(cx).text(), "a\nb\nc\n");
2561 assert_eq!(
2562 messages(&conversation, cx),
2563 [
2564 (message_0, Role::User, 0..2),
2565 (message_1.id, Role::Assistant, 2..6),
2566 (message_2.id, Role::System, 6..6),
2567 ]
2568 );
2569
2570 let deserialized_conversation = cx.add_model(|cx| {
2571 Conversation::deserialize(
2572 conversation.read(cx).serialize(cx),
2573 Default::default(),
2574 Default::default(),
2575 registry.clone(),
2576 cx,
2577 )
2578 });
2579 let deserialized_buffer = deserialized_conversation.read(cx).buffer.clone();
2580 assert_eq!(deserialized_buffer.read(cx).text(), "a\nb\nc\n");
2581 assert_eq!(
2582 messages(&deserialized_conversation, cx),
2583 [
2584 (message_0, Role::User, 0..2),
2585 (message_1.id, Role::Assistant, 2..6),
2586 (message_2.id, Role::System, 6..6),
2587 ]
2588 );
2589 }
2590
2591 fn messages(
2592 conversation: &ModelHandle<Conversation>,
2593 cx: &AppContext,
2594 ) -> Vec<(MessageId, Role, Range<usize>)> {
2595 conversation
2596 .read(cx)
2597 .messages(cx)
2598 .map(|message| (message.id, message.role, message.offset_range))
2599 .collect()
2600 }
2601}