1use crate::{
2 assistant_settings::{AssistantDockPosition, AssistantSettings},
3 MessageId, MessageMetadata, MessageStatus, OpenAIRequest, OpenAIResponseStreamEvent,
4 RequestMessage, Role, SavedConversation, SavedConversationMetadata, SavedMessage,
5};
6use anyhow::{anyhow, Result};
7use chrono::{DateTime, Local};
8use collections::{HashMap, HashSet};
9use editor::{
10 display_map::{BlockDisposition, BlockId, BlockProperties, BlockStyle, ToDisplayPoint},
11 scroll::autoscroll::{Autoscroll, AutoscrollStrategy},
12 Anchor, Editor, ToOffset,
13};
14use fs::Fs;
15use futures::{io::BufReader, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
16use gpui::{
17 actions,
18 elements::*,
19 executor::Background,
20 geometry::vector::{vec2f, Vector2F},
21 platform::{CursorStyle, MouseButton},
22 Action, AppContext, AsyncAppContext, ClipboardItem, Entity, ModelContext, ModelHandle,
23 Subscription, Task, View, ViewContext, ViewHandle, WeakViewHandle, WindowContext,
24};
25use isahc::{http::StatusCode, Request, RequestExt};
26use language::{language_settings::SoftWrap, Buffer, LanguageRegistry, ToOffset as _};
27use search::BufferSearchBar;
28use serde::Deserialize;
29use settings::SettingsStore;
30use std::{
31 cell::RefCell,
32 cmp, env,
33 fmt::Write,
34 io, iter,
35 ops::Range,
36 path::{Path, PathBuf},
37 rc::Rc,
38 sync::Arc,
39 time::Duration,
40};
41use theme::AssistantStyle;
42use util::{paths::CONVERSATIONS_DIR, post_inc, ResultExt, TryFutureExt};
43use workspace::{
44 dock::{DockPosition, Panel},
45 searchable::Direction,
46 Save, ToggleZoom, Toolbar, Workspace,
47};
48
49const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
50
51actions!(
52 assistant,
53 [
54 NewConversation,
55 Assist,
56 Split,
57 CycleMessageRole,
58 QuoteSelection,
59 ToggleFocus,
60 ResetKey,
61 ]
62);
63
64pub fn init(cx: &mut AppContext) {
65 settings::register::<AssistantSettings>(cx);
66 cx.add_action(
67 |this: &mut AssistantPanel,
68 _: &workspace::NewFile,
69 cx: &mut ViewContext<AssistantPanel>| {
70 this.new_conversation(cx);
71 },
72 );
73 cx.add_action(ConversationEditor::assist);
74 cx.capture_action(ConversationEditor::cancel_last_assist);
75 cx.capture_action(ConversationEditor::save);
76 cx.add_action(ConversationEditor::quote_selection);
77 cx.capture_action(ConversationEditor::copy);
78 cx.add_action(ConversationEditor::split);
79 cx.capture_action(ConversationEditor::cycle_message_role);
80 cx.add_action(AssistantPanel::save_api_key);
81 cx.add_action(AssistantPanel::reset_api_key);
82 cx.add_action(AssistantPanel::toggle_zoom);
83 cx.add_action(AssistantPanel::deploy);
84 cx.add_action(AssistantPanel::select_next_match);
85 cx.add_action(AssistantPanel::select_prev_match);
86 cx.add_action(AssistantPanel::handle_editor_cancel);
87 cx.add_action(
88 |workspace: &mut Workspace, _: &ToggleFocus, cx: &mut ViewContext<Workspace>| {
89 workspace.toggle_panel_focus::<AssistantPanel>(cx);
90 },
91 );
92}
93
94#[derive(Debug)]
95pub enum AssistantPanelEvent {
96 ZoomIn,
97 ZoomOut,
98 Focus,
99 Close,
100 DockPositionChanged,
101}
102
103pub struct AssistantPanel {
104 workspace: WeakViewHandle<Workspace>,
105 width: Option<f32>,
106 height: Option<f32>,
107 active_editor_index: Option<usize>,
108 prev_active_editor_index: Option<usize>,
109 editors: Vec<ViewHandle<ConversationEditor>>,
110 saved_conversations: Vec<SavedConversationMetadata>,
111 saved_conversations_list_state: UniformListState,
112 zoomed: bool,
113 has_focus: bool,
114 toolbar: ViewHandle<Toolbar>,
115 api_key: Rc<RefCell<Option<String>>>,
116 api_key_editor: Option<ViewHandle<Editor>>,
117 has_read_credentials: bool,
118 languages: Arc<LanguageRegistry>,
119 fs: Arc<dyn Fs>,
120 subscriptions: Vec<Subscription>,
121 _watch_saved_conversations: Task<Result<()>>,
122}
123
124impl AssistantPanel {
125 pub fn load(
126 workspace: WeakViewHandle<Workspace>,
127 cx: AsyncAppContext,
128 ) -> Task<Result<ViewHandle<Self>>> {
129 cx.spawn(|mut cx| async move {
130 let fs = workspace.read_with(&cx, |workspace, _| workspace.app_state().fs.clone())?;
131 let saved_conversations = SavedConversationMetadata::list(fs.clone())
132 .await
133 .log_err()
134 .unwrap_or_default();
135
136 // TODO: deserialize state.
137 let workspace_handle = workspace.clone();
138 workspace.update(&mut cx, |workspace, cx| {
139 cx.add_view::<Self, _>(|cx| {
140 const CONVERSATION_WATCH_DURATION: Duration = Duration::from_millis(100);
141 let _watch_saved_conversations = cx.spawn(move |this, mut cx| async move {
142 let mut events = fs
143 .watch(&CONVERSATIONS_DIR, CONVERSATION_WATCH_DURATION)
144 .await;
145 while events.next().await.is_some() {
146 let saved_conversations = SavedConversationMetadata::list(fs.clone())
147 .await
148 .log_err()
149 .unwrap_or_default();
150 this.update(&mut cx, |this, cx| {
151 this.saved_conversations = saved_conversations;
152 cx.notify();
153 })
154 .ok();
155 }
156
157 anyhow::Ok(())
158 });
159
160 let toolbar = cx.add_view(|cx| {
161 let mut toolbar = Toolbar::new(None);
162 toolbar.set_can_navigate(false, cx);
163 toolbar.add_item(cx.add_view(|cx| BufferSearchBar::new(cx)), cx);
164 toolbar
165 });
166 let mut this = Self {
167 workspace: workspace_handle,
168 active_editor_index: Default::default(),
169 prev_active_editor_index: Default::default(),
170 editors: Default::default(),
171 saved_conversations,
172 saved_conversations_list_state: Default::default(),
173 zoomed: false,
174 has_focus: false,
175 toolbar,
176 api_key: Rc::new(RefCell::new(None)),
177 api_key_editor: None,
178 has_read_credentials: false,
179 languages: workspace.app_state().languages.clone(),
180 fs: workspace.app_state().fs.clone(),
181 width: None,
182 height: None,
183 subscriptions: Default::default(),
184 _watch_saved_conversations,
185 };
186
187 let mut old_dock_position = this.position(cx);
188 this.subscriptions =
189 vec![cx.observe_global::<SettingsStore, _>(move |this, cx| {
190 let new_dock_position = this.position(cx);
191 if new_dock_position != old_dock_position {
192 old_dock_position = new_dock_position;
193 cx.emit(AssistantPanelEvent::DockPositionChanged);
194 }
195 })];
196
197 this
198 })
199 })
200 })
201 }
202
203 fn new_conversation(&mut self, cx: &mut ViewContext<Self>) -> ViewHandle<ConversationEditor> {
204 let editor = cx.add_view(|cx| {
205 ConversationEditor::new(
206 self.api_key.clone(),
207 self.languages.clone(),
208 self.fs.clone(),
209 cx,
210 )
211 });
212 self.add_conversation(editor.clone(), cx);
213 editor
214 }
215
216 fn add_conversation(
217 &mut self,
218 editor: ViewHandle<ConversationEditor>,
219 cx: &mut ViewContext<Self>,
220 ) {
221 self.subscriptions
222 .push(cx.subscribe(&editor, Self::handle_conversation_editor_event));
223
224 let conversation = editor.read(cx).conversation.clone();
225 self.subscriptions
226 .push(cx.observe(&conversation, |_, _, cx| cx.notify()));
227
228 let index = self.editors.len();
229 self.editors.push(editor);
230 self.set_active_editor_index(Some(index), cx);
231 }
232
233 fn set_active_editor_index(&mut self, index: Option<usize>, cx: &mut ViewContext<Self>) {
234 self.prev_active_editor_index = self.active_editor_index;
235 self.active_editor_index = index;
236 if let Some(editor) = self.active_editor() {
237 let editor = editor.read(cx).editor.clone();
238 self.toolbar.update(cx, |toolbar, cx| {
239 toolbar.set_active_item(Some(&editor), cx);
240 });
241 if self.has_focus(cx) {
242 cx.focus(&editor);
243 }
244 } else {
245 self.toolbar.update(cx, |toolbar, cx| {
246 toolbar.set_active_item(None, cx);
247 });
248 }
249
250 cx.notify();
251 }
252
253 fn handle_conversation_editor_event(
254 &mut self,
255 _: ViewHandle<ConversationEditor>,
256 event: &ConversationEditorEvent,
257 cx: &mut ViewContext<Self>,
258 ) {
259 match event {
260 ConversationEditorEvent::TabContentChanged => cx.notify(),
261 }
262 }
263
264 fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
265 if let Some(api_key) = self
266 .api_key_editor
267 .as_ref()
268 .map(|editor| editor.read(cx).text(cx))
269 {
270 if !api_key.is_empty() {
271 cx.platform()
272 .write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
273 .log_err();
274 *self.api_key.borrow_mut() = Some(api_key);
275 self.api_key_editor.take();
276 cx.focus_self();
277 cx.notify();
278 }
279 } else {
280 cx.propagate_action();
281 }
282 }
283
284 fn reset_api_key(&mut self, _: &ResetKey, cx: &mut ViewContext<Self>) {
285 cx.platform().delete_credentials(OPENAI_API_URL).log_err();
286 self.api_key.take();
287 self.api_key_editor = Some(build_api_key_editor(cx));
288 cx.focus_self();
289 cx.notify();
290 }
291
292 fn toggle_zoom(&mut self, _: &workspace::ToggleZoom, cx: &mut ViewContext<Self>) {
293 if self.zoomed {
294 cx.emit(AssistantPanelEvent::ZoomOut)
295 } else {
296 cx.emit(AssistantPanelEvent::ZoomIn)
297 }
298 }
299
300 fn deploy(&mut self, action: &search::buffer_search::Deploy, cx: &mut ViewContext<Self>) {
301 let mut propagate_action = true;
302 if let Some(search_bar) = self.toolbar.read(cx).item_of_type::<BufferSearchBar>() {
303 search_bar.update(cx, |search_bar, cx| {
304 if search_bar.show(cx) {
305 search_bar.search_suggested(cx);
306 if action.focus {
307 search_bar.select_query(cx);
308 cx.focus_self();
309 }
310 propagate_action = false
311 }
312 });
313 }
314 if propagate_action {
315 cx.propagate_action();
316 }
317 }
318
319 fn handle_editor_cancel(&mut self, _: &editor::Cancel, cx: &mut ViewContext<Self>) {
320 if let Some(search_bar) = self.toolbar.read(cx).item_of_type::<BufferSearchBar>() {
321 if !search_bar.read(cx).is_dismissed() {
322 search_bar.update(cx, |search_bar, cx| {
323 search_bar.dismiss(&Default::default(), cx)
324 });
325 return;
326 }
327 }
328 cx.propagate_action();
329 }
330
331 fn select_next_match(&mut self, _: &search::SelectNextMatch, cx: &mut ViewContext<Self>) {
332 if let Some(search_bar) = self.toolbar.read(cx).item_of_type::<BufferSearchBar>() {
333 search_bar.update(cx, |bar, cx| bar.select_match(Direction::Next, 1, cx));
334 }
335 }
336
337 fn select_prev_match(&mut self, _: &search::SelectPrevMatch, cx: &mut ViewContext<Self>) {
338 if let Some(search_bar) = self.toolbar.read(cx).item_of_type::<BufferSearchBar>() {
339 search_bar.update(cx, |bar, cx| bar.select_match(Direction::Prev, 1, cx));
340 }
341 }
342
343 fn active_editor(&self) -> Option<&ViewHandle<ConversationEditor>> {
344 self.editors.get(self.active_editor_index?)
345 }
346
347 fn render_hamburger_button(cx: &mut ViewContext<Self>) -> impl Element<Self> {
348 enum History {}
349 let theme = theme::current(cx);
350 let tooltip_style = theme::current(cx).tooltip.clone();
351 MouseEventHandler::<History, _>::new(0, cx, |state, _| {
352 let style = theme.assistant.hamburger_button.style_for(state);
353 Svg::for_style(style.icon.clone())
354 .contained()
355 .with_style(style.container)
356 })
357 .with_cursor_style(CursorStyle::PointingHand)
358 .on_click(MouseButton::Left, |_, this: &mut Self, cx| {
359 if this.active_editor().is_some() {
360 this.set_active_editor_index(None, cx);
361 } else {
362 this.set_active_editor_index(this.prev_active_editor_index, cx);
363 }
364 })
365 .with_tooltip::<History>(1, "History", None, tooltip_style, cx)
366 }
367
368 fn render_editor_tools(&self, cx: &mut ViewContext<Self>) -> Vec<AnyElement<Self>> {
369 if self.active_editor().is_some() {
370 vec![
371 Self::render_split_button(cx).into_any(),
372 Self::render_quote_button(cx).into_any(),
373 Self::render_assist_button(cx).into_any(),
374 ]
375 } else {
376 Default::default()
377 }
378 }
379
380 fn render_split_button(cx: &mut ViewContext<Self>) -> impl Element<Self> {
381 let theme = theme::current(cx);
382 let tooltip_style = theme::current(cx).tooltip.clone();
383 MouseEventHandler::<Split, _>::new(0, cx, |state, _| {
384 let style = theme.assistant.split_button.style_for(state);
385 Svg::for_style(style.icon.clone())
386 .contained()
387 .with_style(style.container)
388 })
389 .with_cursor_style(CursorStyle::PointingHand)
390 .on_click(MouseButton::Left, |_, this: &mut Self, cx| {
391 if let Some(active_editor) = this.active_editor() {
392 active_editor.update(cx, |editor, cx| editor.split(&Default::default(), cx));
393 }
394 })
395 .with_tooltip::<Split>(
396 1,
397 "Split Message",
398 Some(Box::new(Split)),
399 tooltip_style,
400 cx,
401 )
402 }
403
404 fn render_assist_button(cx: &mut ViewContext<Self>) -> impl Element<Self> {
405 let theme = theme::current(cx);
406 let tooltip_style = theme::current(cx).tooltip.clone();
407 MouseEventHandler::<Assist, _>::new(0, cx, |state, _| {
408 let style = theme.assistant.assist_button.style_for(state);
409 Svg::for_style(style.icon.clone())
410 .contained()
411 .with_style(style.container)
412 })
413 .with_cursor_style(CursorStyle::PointingHand)
414 .on_click(MouseButton::Left, |_, this: &mut Self, cx| {
415 if let Some(active_editor) = this.active_editor() {
416 active_editor.update(cx, |editor, cx| editor.assist(&Default::default(), cx));
417 }
418 })
419 .with_tooltip::<Assist>(1, "Assist", Some(Box::new(Assist)), tooltip_style, cx)
420 }
421
422 fn render_quote_button(cx: &mut ViewContext<Self>) -> impl Element<Self> {
423 let theme = theme::current(cx);
424 let tooltip_style = theme::current(cx).tooltip.clone();
425 MouseEventHandler::<QuoteSelection, _>::new(0, cx, |state, _| {
426 let style = theme.assistant.quote_button.style_for(state);
427 Svg::for_style(style.icon.clone())
428 .contained()
429 .with_style(style.container)
430 })
431 .with_cursor_style(CursorStyle::PointingHand)
432 .on_click(MouseButton::Left, |_, this: &mut Self, cx| {
433 if let Some(workspace) = this.workspace.upgrade(cx) {
434 cx.window_context().defer(move |cx| {
435 workspace.update(cx, |workspace, cx| {
436 ConversationEditor::quote_selection(workspace, &Default::default(), cx)
437 });
438 });
439 }
440 })
441 .with_tooltip::<QuoteSelection>(
442 1,
443 "Quote Selection",
444 Some(Box::new(QuoteSelection)),
445 tooltip_style,
446 cx,
447 )
448 }
449
450 fn render_plus_button(cx: &mut ViewContext<Self>) -> impl Element<Self> {
451 let theme = theme::current(cx);
452 let tooltip_style = theme::current(cx).tooltip.clone();
453 MouseEventHandler::<NewConversation, _>::new(0, cx, |state, _| {
454 let style = theme.assistant.plus_button.style_for(state);
455 Svg::for_style(style.icon.clone())
456 .contained()
457 .with_style(style.container)
458 })
459 .with_cursor_style(CursorStyle::PointingHand)
460 .on_click(MouseButton::Left, |_, this: &mut Self, cx| {
461 this.new_conversation(cx);
462 })
463 .with_tooltip::<NewConversation>(
464 1,
465 "New Conversation",
466 Some(Box::new(NewConversation)),
467 tooltip_style,
468 cx,
469 )
470 }
471
472 fn render_zoom_button(&self, cx: &mut ViewContext<Self>) -> impl Element<Self> {
473 enum ToggleZoomButton {}
474
475 let theme = theme::current(cx);
476 let tooltip_style = theme::current(cx).tooltip.clone();
477 let style = if self.zoomed {
478 &theme.assistant.zoom_out_button
479 } else {
480 &theme.assistant.zoom_in_button
481 };
482
483 MouseEventHandler::<ToggleZoomButton, _>::new(0, cx, |state, _| {
484 let style = style.style_for(state);
485 Svg::for_style(style.icon.clone())
486 .contained()
487 .with_style(style.container)
488 })
489 .with_cursor_style(CursorStyle::PointingHand)
490 .on_click(MouseButton::Left, |_, this, cx| {
491 this.toggle_zoom(&ToggleZoom, cx);
492 })
493 .with_tooltip::<ToggleZoom>(
494 0,
495 if self.zoomed { "Zoom Out" } else { "Zoom In" },
496 Some(Box::new(ToggleZoom)),
497 tooltip_style,
498 cx,
499 )
500 }
501
502 fn render_saved_conversation(
503 &mut self,
504 index: usize,
505 cx: &mut ViewContext<Self>,
506 ) -> impl Element<Self> {
507 let conversation = &self.saved_conversations[index];
508 let path = conversation.path.clone();
509 MouseEventHandler::<SavedConversationMetadata, _>::new(index, cx, move |state, cx| {
510 let style = &theme::current(cx).assistant.saved_conversation;
511 Flex::row()
512 .with_child(
513 Label::new(
514 conversation.mtime.format("%F %I:%M%p").to_string(),
515 style.saved_at.text.clone(),
516 )
517 .aligned()
518 .contained()
519 .with_style(style.saved_at.container),
520 )
521 .with_child(
522 Label::new(conversation.title.clone(), style.title.text.clone())
523 .aligned()
524 .contained()
525 .with_style(style.title.container),
526 )
527 .contained()
528 .with_style(*style.container.style_for(state))
529 })
530 .with_cursor_style(CursorStyle::PointingHand)
531 .on_click(MouseButton::Left, move |_, this, cx| {
532 this.open_conversation(path.clone(), cx)
533 .detach_and_log_err(cx)
534 })
535 }
536
537 fn open_conversation(&mut self, path: PathBuf, cx: &mut ViewContext<Self>) -> Task<Result<()>> {
538 if let Some(ix) = self.editor_index_for_path(&path, cx) {
539 self.set_active_editor_index(Some(ix), cx);
540 return Task::ready(Ok(()));
541 }
542
543 let fs = self.fs.clone();
544 let api_key = self.api_key.clone();
545 let languages = self.languages.clone();
546 cx.spawn(|this, mut cx| async move {
547 let saved_conversation = fs.load(&path).await?;
548 let saved_conversation = serde_json::from_str(&saved_conversation)?;
549 let conversation = cx.add_model(|cx| {
550 Conversation::deserialize(saved_conversation, path.clone(), api_key, languages, cx)
551 });
552 this.update(&mut cx, |this, cx| {
553 // If, by the time we've loaded the conversation, the user has already opened
554 // the same conversation, we don't want to open it again.
555 if let Some(ix) = this.editor_index_for_path(&path, cx) {
556 this.set_active_editor_index(Some(ix), cx);
557 } else {
558 let editor = cx
559 .add_view(|cx| ConversationEditor::for_conversation(conversation, fs, cx));
560 this.add_conversation(editor, cx);
561 }
562 })?;
563 Ok(())
564 })
565 }
566
567 fn editor_index_for_path(&self, path: &Path, cx: &AppContext) -> Option<usize> {
568 self.editors
569 .iter()
570 .position(|editor| editor.read(cx).conversation.read(cx).path.as_deref() == Some(path))
571 }
572}
573
574fn build_api_key_editor(cx: &mut ViewContext<AssistantPanel>) -> ViewHandle<Editor> {
575 cx.add_view(|cx| {
576 let mut editor = Editor::single_line(
577 Some(Arc::new(|theme| theme.assistant.api_key_editor.clone())),
578 cx,
579 );
580 editor.set_placeholder_text("sk-000000000000000000000000000000000000000000000000", cx);
581 editor
582 })
583}
584
585impl Entity for AssistantPanel {
586 type Event = AssistantPanelEvent;
587}
588
589impl View for AssistantPanel {
590 fn ui_name() -> &'static str {
591 "AssistantPanel"
592 }
593
594 fn render(&mut self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
595 let theme = &theme::current(cx);
596 let style = &theme.assistant;
597 if let Some(api_key_editor) = self.api_key_editor.as_ref() {
598 Flex::column()
599 .with_child(
600 Text::new(
601 "Paste your OpenAI API key and press Enter to use the assistant",
602 style.api_key_prompt.text.clone(),
603 )
604 .aligned(),
605 )
606 .with_child(
607 ChildView::new(api_key_editor, cx)
608 .contained()
609 .with_style(style.api_key_editor.container)
610 .aligned(),
611 )
612 .contained()
613 .with_style(style.api_key_prompt.container)
614 .aligned()
615 .into_any()
616 } else {
617 let title = self.active_editor().map(|editor| {
618 Label::new(editor.read(cx).title(cx), style.title.text.clone())
619 .contained()
620 .with_style(style.title.container)
621 .aligned()
622 .left()
623 .flex(1., false)
624 });
625 let mut header = Flex::row()
626 .with_child(Self::render_hamburger_button(cx).aligned())
627 .with_children(title);
628 if self.has_focus {
629 header.add_children(
630 self.render_editor_tools(cx)
631 .into_iter()
632 .map(|tool| tool.aligned().flex_float()),
633 );
634 header.add_child(Self::render_plus_button(cx).aligned().flex_float());
635 header.add_child(self.render_zoom_button(cx).aligned());
636 }
637
638 Flex::column()
639 .with_child(
640 header
641 .contained()
642 .with_style(theme.workspace.tab_bar.container)
643 .expanded()
644 .constrained()
645 .with_height(theme.workspace.tab_bar.height),
646 )
647 .with_children(if self.toolbar.read(cx).hidden() {
648 None
649 } else {
650 Some(ChildView::new(&self.toolbar, cx).expanded())
651 })
652 .with_child(if let Some(editor) = self.active_editor() {
653 ChildView::new(editor, cx).flex(1., true).into_any()
654 } else {
655 UniformList::new(
656 self.saved_conversations_list_state.clone(),
657 self.saved_conversations.len(),
658 cx,
659 |this, range, items, cx| {
660 for ix in range {
661 items.push(this.render_saved_conversation(ix, cx).into_any());
662 }
663 },
664 )
665 .flex(1., true)
666 .into_any()
667 })
668 .into_any()
669 }
670 }
671
672 fn focus_in(&mut self, _: gpui::AnyViewHandle, cx: &mut ViewContext<Self>) {
673 self.has_focus = true;
674 self.toolbar
675 .update(cx, |toolbar, cx| toolbar.focus_changed(true, cx));
676 cx.notify();
677 if cx.is_self_focused() {
678 if let Some(editor) = self.active_editor() {
679 cx.focus(editor);
680 } else if let Some(api_key_editor) = self.api_key_editor.as_ref() {
681 cx.focus(api_key_editor);
682 }
683 }
684 }
685
686 fn focus_out(&mut self, _: gpui::AnyViewHandle, cx: &mut ViewContext<Self>) {
687 self.has_focus = false;
688 self.toolbar
689 .update(cx, |toolbar, cx| toolbar.focus_changed(false, cx));
690 cx.notify();
691 }
692}
693
694impl Panel for AssistantPanel {
695 fn position(&self, cx: &WindowContext) -> DockPosition {
696 match settings::get::<AssistantSettings>(cx).dock {
697 AssistantDockPosition::Left => DockPosition::Left,
698 AssistantDockPosition::Bottom => DockPosition::Bottom,
699 AssistantDockPosition::Right => DockPosition::Right,
700 }
701 }
702
703 fn position_is_valid(&self, _: DockPosition) -> bool {
704 true
705 }
706
707 fn set_position(&mut self, position: DockPosition, cx: &mut ViewContext<Self>) {
708 settings::update_settings_file::<AssistantSettings>(self.fs.clone(), cx, move |settings| {
709 let dock = match position {
710 DockPosition::Left => AssistantDockPosition::Left,
711 DockPosition::Bottom => AssistantDockPosition::Bottom,
712 DockPosition::Right => AssistantDockPosition::Right,
713 };
714 settings.dock = Some(dock);
715 });
716 }
717
718 fn size(&self, cx: &WindowContext) -> f32 {
719 let settings = settings::get::<AssistantSettings>(cx);
720 match self.position(cx) {
721 DockPosition::Left | DockPosition::Right => {
722 self.width.unwrap_or_else(|| settings.default_width)
723 }
724 DockPosition::Bottom => self.height.unwrap_or_else(|| settings.default_height),
725 }
726 }
727
728 fn set_size(&mut self, size: f32, cx: &mut ViewContext<Self>) {
729 match self.position(cx) {
730 DockPosition::Left | DockPosition::Right => self.width = Some(size),
731 DockPosition::Bottom => self.height = Some(size),
732 }
733 cx.notify();
734 }
735
736 fn should_zoom_in_on_event(event: &AssistantPanelEvent) -> bool {
737 matches!(event, AssistantPanelEvent::ZoomIn)
738 }
739
740 fn should_zoom_out_on_event(event: &AssistantPanelEvent) -> bool {
741 matches!(event, AssistantPanelEvent::ZoomOut)
742 }
743
744 fn is_zoomed(&self, _: &WindowContext) -> bool {
745 self.zoomed
746 }
747
748 fn set_zoomed(&mut self, zoomed: bool, cx: &mut ViewContext<Self>) {
749 self.zoomed = zoomed;
750 cx.notify();
751 }
752
753 fn set_active(&mut self, active: bool, cx: &mut ViewContext<Self>) {
754 if active {
755 if self.api_key.borrow().is_none() && !self.has_read_credentials {
756 self.has_read_credentials = true;
757 let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
758 Some(api_key)
759 } else if let Some((_, api_key)) = cx
760 .platform()
761 .read_credentials(OPENAI_API_URL)
762 .log_err()
763 .flatten()
764 {
765 String::from_utf8(api_key).log_err()
766 } else {
767 None
768 };
769 if let Some(api_key) = api_key {
770 *self.api_key.borrow_mut() = Some(api_key);
771 } else if self.api_key_editor.is_none() {
772 self.api_key_editor = Some(build_api_key_editor(cx));
773 cx.notify();
774 }
775 }
776
777 if self.editors.is_empty() {
778 self.new_conversation(cx);
779 }
780 }
781 }
782
783 fn icon_path(&self) -> &'static str {
784 "icons/robot_14.svg"
785 }
786
787 fn icon_tooltip(&self) -> (String, Option<Box<dyn Action>>) {
788 ("Assistant Panel".into(), Some(Box::new(ToggleFocus)))
789 }
790
791 fn should_change_position_on_event(event: &Self::Event) -> bool {
792 matches!(event, AssistantPanelEvent::DockPositionChanged)
793 }
794
795 fn should_activate_on_event(_: &Self::Event) -> bool {
796 false
797 }
798
799 fn should_close_on_event(event: &AssistantPanelEvent) -> bool {
800 matches!(event, AssistantPanelEvent::Close)
801 }
802
803 fn has_focus(&self, _: &WindowContext) -> bool {
804 self.has_focus
805 }
806
807 fn is_focus_event(event: &Self::Event) -> bool {
808 matches!(event, AssistantPanelEvent::Focus)
809 }
810}
811
812enum ConversationEvent {
813 MessagesEdited,
814 SummaryChanged,
815 StreamedCompletion,
816}
817
818#[derive(Default)]
819struct Summary {
820 text: String,
821 done: bool,
822}
823
824struct Conversation {
825 buffer: ModelHandle<Buffer>,
826 message_anchors: Vec<MessageAnchor>,
827 messages_metadata: HashMap<MessageId, MessageMetadata>,
828 next_message_id: MessageId,
829 summary: Option<Summary>,
830 pending_summary: Task<Option<()>>,
831 completion_count: usize,
832 pending_completions: Vec<PendingCompletion>,
833 model: String,
834 token_count: Option<usize>,
835 max_token_count: usize,
836 pending_token_count: Task<Option<()>>,
837 api_key: Rc<RefCell<Option<String>>>,
838 pending_save: Task<Result<()>>,
839 path: Option<PathBuf>,
840 _subscriptions: Vec<Subscription>,
841}
842
843impl Entity for Conversation {
844 type Event = ConversationEvent;
845}
846
847impl Conversation {
848 fn new(
849 api_key: Rc<RefCell<Option<String>>>,
850 language_registry: Arc<LanguageRegistry>,
851 cx: &mut ModelContext<Self>,
852 ) -> Self {
853 let model = "gpt-3.5-turbo-0613";
854 let markdown = language_registry.language_for_name("Markdown");
855 let buffer = cx.add_model(|cx| {
856 let mut buffer = Buffer::new(0, "", cx);
857 buffer.set_language_registry(language_registry);
858 cx.spawn_weak(|buffer, mut cx| async move {
859 let markdown = markdown.await?;
860 let buffer = buffer
861 .upgrade(&cx)
862 .ok_or_else(|| anyhow!("buffer was dropped"))?;
863 buffer.update(&mut cx, |buffer, cx| {
864 buffer.set_language(Some(markdown), cx)
865 });
866 anyhow::Ok(())
867 })
868 .detach_and_log_err(cx);
869 buffer
870 });
871
872 let mut this = Self {
873 message_anchors: Default::default(),
874 messages_metadata: Default::default(),
875 next_message_id: Default::default(),
876 summary: None,
877 pending_summary: Task::ready(None),
878 completion_count: Default::default(),
879 pending_completions: Default::default(),
880 token_count: None,
881 max_token_count: tiktoken_rs::model::get_context_size(model),
882 pending_token_count: Task::ready(None),
883 model: model.into(),
884 _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
885 pending_save: Task::ready(Ok(())),
886 path: None,
887 api_key,
888 buffer,
889 };
890 let message = MessageAnchor {
891 id: MessageId(post_inc(&mut this.next_message_id.0)),
892 start: language::Anchor::MIN,
893 };
894 this.message_anchors.push(message.clone());
895 this.messages_metadata.insert(
896 message.id,
897 MessageMetadata {
898 role: Role::User,
899 sent_at: Local::now(),
900 status: MessageStatus::Done,
901 },
902 );
903
904 this.count_remaining_tokens(cx);
905 this
906 }
907
908 fn serialize(&self, cx: &AppContext) -> SavedConversation {
909 SavedConversation {
910 zed: "conversation".into(),
911 version: SavedConversation::VERSION.into(),
912 text: self.buffer.read(cx).text(),
913 message_metadata: self.messages_metadata.clone(),
914 messages: self
915 .messages(cx)
916 .map(|message| SavedMessage {
917 id: message.id,
918 start: message.offset_range.start,
919 })
920 .collect(),
921 summary: self
922 .summary
923 .as_ref()
924 .map(|summary| summary.text.clone())
925 .unwrap_or_default(),
926 model: self.model.clone(),
927 }
928 }
929
930 fn deserialize(
931 saved_conversation: SavedConversation,
932 path: PathBuf,
933 api_key: Rc<RefCell<Option<String>>>,
934 language_registry: Arc<LanguageRegistry>,
935 cx: &mut ModelContext<Self>,
936 ) -> Self {
937 let model = saved_conversation.model;
938 let markdown = language_registry.language_for_name("Markdown");
939 let mut message_anchors = Vec::new();
940 let mut next_message_id = MessageId(0);
941 let buffer = cx.add_model(|cx| {
942 let mut buffer = Buffer::new(0, saved_conversation.text, cx);
943 for message in saved_conversation.messages {
944 message_anchors.push(MessageAnchor {
945 id: message.id,
946 start: buffer.anchor_before(message.start),
947 });
948 next_message_id = cmp::max(next_message_id, MessageId(message.id.0 + 1));
949 }
950 buffer.set_language_registry(language_registry);
951 cx.spawn_weak(|buffer, mut cx| async move {
952 let markdown = markdown.await?;
953 let buffer = buffer
954 .upgrade(&cx)
955 .ok_or_else(|| anyhow!("buffer was dropped"))?;
956 buffer.update(&mut cx, |buffer, cx| {
957 buffer.set_language(Some(markdown), cx)
958 });
959 anyhow::Ok(())
960 })
961 .detach_and_log_err(cx);
962 buffer
963 });
964
965 let mut this = Self {
966 message_anchors,
967 messages_metadata: saved_conversation.message_metadata,
968 next_message_id,
969 summary: Some(Summary {
970 text: saved_conversation.summary,
971 done: true,
972 }),
973 pending_summary: Task::ready(None),
974 completion_count: Default::default(),
975 pending_completions: Default::default(),
976 token_count: None,
977 max_token_count: tiktoken_rs::model::get_context_size(&model),
978 pending_token_count: Task::ready(None),
979 model,
980 _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
981 pending_save: Task::ready(Ok(())),
982 path: Some(path),
983 api_key,
984 buffer,
985 };
986 this.count_remaining_tokens(cx);
987 this
988 }
989
990 fn handle_buffer_event(
991 &mut self,
992 _: ModelHandle<Buffer>,
993 event: &language::Event,
994 cx: &mut ModelContext<Self>,
995 ) {
996 match event {
997 language::Event::Edited => {
998 self.count_remaining_tokens(cx);
999 cx.emit(ConversationEvent::MessagesEdited);
1000 }
1001 _ => {}
1002 }
1003 }
1004
1005 fn count_remaining_tokens(&mut self, cx: &mut ModelContext<Self>) {
1006 let messages = self
1007 .messages(cx)
1008 .into_iter()
1009 .filter_map(|message| {
1010 Some(tiktoken_rs::ChatCompletionRequestMessage {
1011 role: match message.role {
1012 Role::User => "user".into(),
1013 Role::Assistant => "assistant".into(),
1014 Role::System => "system".into(),
1015 },
1016 content: self
1017 .buffer
1018 .read(cx)
1019 .text_for_range(message.offset_range)
1020 .collect(),
1021 name: None,
1022 })
1023 })
1024 .collect::<Vec<_>>();
1025 let model = self.model.clone();
1026 self.pending_token_count = cx.spawn_weak(|this, mut cx| {
1027 async move {
1028 cx.background().timer(Duration::from_millis(200)).await;
1029 let token_count = cx
1030 .background()
1031 .spawn(async move { tiktoken_rs::num_tokens_from_messages(&model, &messages) })
1032 .await?;
1033
1034 this.upgrade(&cx)
1035 .ok_or_else(|| anyhow!("conversation was dropped"))?
1036 .update(&mut cx, |this, cx| {
1037 this.max_token_count = tiktoken_rs::model::get_context_size(&this.model);
1038 this.token_count = Some(token_count);
1039 cx.notify()
1040 });
1041 anyhow::Ok(())
1042 }
1043 .log_err()
1044 });
1045 }
1046
1047 fn remaining_tokens(&self) -> Option<isize> {
1048 Some(self.max_token_count as isize - self.token_count? as isize)
1049 }
1050
1051 fn set_model(&mut self, model: String, cx: &mut ModelContext<Self>) {
1052 self.model = model;
1053 self.count_remaining_tokens(cx);
1054 cx.notify();
1055 }
1056
1057 fn assist(
1058 &mut self,
1059 selected_messages: HashSet<MessageId>,
1060 cx: &mut ModelContext<Self>,
1061 ) -> Vec<MessageAnchor> {
1062 let mut user_messages = Vec::new();
1063 let mut tasks = Vec::new();
1064
1065 let last_message_id = self.message_anchors.iter().rev().find_map(|message| {
1066 message
1067 .start
1068 .is_valid(self.buffer.read(cx))
1069 .then_some(message.id)
1070 });
1071
1072 for selected_message_id in selected_messages {
1073 let selected_message_role =
1074 if let Some(metadata) = self.messages_metadata.get(&selected_message_id) {
1075 metadata.role
1076 } else {
1077 continue;
1078 };
1079
1080 if selected_message_role == Role::Assistant {
1081 if let Some(user_message) = self.insert_message_after(
1082 selected_message_id,
1083 Role::User,
1084 MessageStatus::Done,
1085 cx,
1086 ) {
1087 user_messages.push(user_message);
1088 } else {
1089 continue;
1090 }
1091 } else {
1092 let request = OpenAIRequest {
1093 model: self.model.clone(),
1094 messages: self
1095 .messages(cx)
1096 .filter(|message| matches!(message.status, MessageStatus::Done))
1097 .flat_map(|message| {
1098 let mut system_message = None;
1099 if message.id == selected_message_id {
1100 system_message = Some(RequestMessage {
1101 role: Role::System,
1102 content: concat!(
1103 "Treat the following messages as additional knowledge you have learned about, ",
1104 "but act as if they were not part of this conversation. That is, treat them ",
1105 "as if the user didn't see them and couldn't possibly inquire about them."
1106 ).into()
1107 });
1108 }
1109
1110 Some(message.to_open_ai_message(self.buffer.read(cx))).into_iter().chain(system_message)
1111 })
1112 .chain(Some(RequestMessage {
1113 role: Role::System,
1114 content: format!(
1115 "Direct your reply to message with id {}. Do not include a [Message X] header.",
1116 selected_message_id.0
1117 ),
1118 }))
1119 .collect(),
1120 stream: true,
1121 };
1122
1123 let Some(api_key) = self.api_key.borrow().clone() else { continue };
1124 let stream = stream_completion(api_key, cx.background().clone(), request);
1125 let assistant_message = self
1126 .insert_message_after(
1127 selected_message_id,
1128 Role::Assistant,
1129 MessageStatus::Pending,
1130 cx,
1131 )
1132 .unwrap();
1133
1134 // Queue up the user's next reply
1135 if Some(selected_message_id) == last_message_id {
1136 let user_message = self
1137 .insert_message_after(
1138 assistant_message.id,
1139 Role::User,
1140 MessageStatus::Done,
1141 cx,
1142 )
1143 .unwrap();
1144 user_messages.push(user_message);
1145 }
1146
1147 tasks.push(cx.spawn_weak({
1148 |this, mut cx| async move {
1149 let assistant_message_id = assistant_message.id;
1150 let stream_completion = async {
1151 let mut messages = stream.await?;
1152
1153 while let Some(message) = messages.next().await {
1154 let mut message = message?;
1155 if let Some(choice) = message.choices.pop() {
1156 this.upgrade(&cx)
1157 .ok_or_else(|| anyhow!("conversation was dropped"))?
1158 .update(&mut cx, |this, cx| {
1159 let text: Arc<str> = choice.delta.content?.into();
1160 let message_ix = this.message_anchors.iter().position(
1161 |message| message.id == assistant_message_id,
1162 )?;
1163 this.buffer.update(cx, |buffer, cx| {
1164 let offset = this.message_anchors[message_ix + 1..]
1165 .iter()
1166 .find(|message| message.start.is_valid(buffer))
1167 .map_or(buffer.len(), |message| {
1168 message
1169 .start
1170 .to_offset(buffer)
1171 .saturating_sub(1)
1172 });
1173 buffer.edit([(offset..offset, text)], None, cx);
1174 });
1175 cx.emit(ConversationEvent::StreamedCompletion);
1176
1177 Some(())
1178 });
1179 }
1180 smol::future::yield_now().await;
1181 }
1182
1183 this.upgrade(&cx)
1184 .ok_or_else(|| anyhow!("conversation was dropped"))?
1185 .update(&mut cx, |this, cx| {
1186 this.pending_completions.retain(|completion| {
1187 completion.id != this.completion_count
1188 });
1189 this.summarize(cx);
1190 });
1191
1192 anyhow::Ok(())
1193 };
1194
1195 let result = stream_completion.await;
1196 if let Some(this) = this.upgrade(&cx) {
1197 this.update(&mut cx, |this, cx| {
1198 if let Some(metadata) =
1199 this.messages_metadata.get_mut(&assistant_message.id)
1200 {
1201 match result {
1202 Ok(_) => {
1203 metadata.status = MessageStatus::Done;
1204 }
1205 Err(error) => {
1206 metadata.status = MessageStatus::Error(
1207 error.to_string().trim().into(),
1208 );
1209 }
1210 }
1211 cx.notify();
1212 }
1213 });
1214 }
1215 }
1216 }));
1217 }
1218 }
1219
1220 if !tasks.is_empty() {
1221 self.pending_completions.push(PendingCompletion {
1222 id: post_inc(&mut self.completion_count),
1223 _tasks: tasks,
1224 });
1225 }
1226
1227 user_messages
1228 }
1229
1230 fn cancel_last_assist(&mut self) -> bool {
1231 self.pending_completions.pop().is_some()
1232 }
1233
1234 fn cycle_message_roles(&mut self, ids: HashSet<MessageId>, cx: &mut ModelContext<Self>) {
1235 for id in ids {
1236 if let Some(metadata) = self.messages_metadata.get_mut(&id) {
1237 metadata.role.cycle();
1238 cx.emit(ConversationEvent::MessagesEdited);
1239 cx.notify();
1240 }
1241 }
1242 }
1243
1244 fn insert_message_after(
1245 &mut self,
1246 message_id: MessageId,
1247 role: Role,
1248 status: MessageStatus,
1249 cx: &mut ModelContext<Self>,
1250 ) -> Option<MessageAnchor> {
1251 if let Some(prev_message_ix) = self
1252 .message_anchors
1253 .iter()
1254 .position(|message| message.id == message_id)
1255 {
1256 // Find the next valid message after the one we were given.
1257 let mut next_message_ix = prev_message_ix + 1;
1258 while let Some(next_message) = self.message_anchors.get(next_message_ix) {
1259 if next_message.start.is_valid(self.buffer.read(cx)) {
1260 break;
1261 }
1262 next_message_ix += 1;
1263 }
1264
1265 let start = self.buffer.update(cx, |buffer, cx| {
1266 let offset = self
1267 .message_anchors
1268 .get(next_message_ix)
1269 .map_or(buffer.len(), |message| message.start.to_offset(buffer) - 1);
1270 buffer.edit([(offset..offset, "\n")], None, cx);
1271 buffer.anchor_before(offset + 1)
1272 });
1273 let message = MessageAnchor {
1274 id: MessageId(post_inc(&mut self.next_message_id.0)),
1275 start,
1276 };
1277 self.message_anchors
1278 .insert(next_message_ix, message.clone());
1279 self.messages_metadata.insert(
1280 message.id,
1281 MessageMetadata {
1282 role,
1283 sent_at: Local::now(),
1284 status,
1285 },
1286 );
1287 cx.emit(ConversationEvent::MessagesEdited);
1288 Some(message)
1289 } else {
1290 None
1291 }
1292 }
1293
1294 fn split_message(
1295 &mut self,
1296 range: Range<usize>,
1297 cx: &mut ModelContext<Self>,
1298 ) -> (Option<MessageAnchor>, Option<MessageAnchor>) {
1299 let start_message = self.message_for_offset(range.start, cx);
1300 let end_message = self.message_for_offset(range.end, cx);
1301 if let Some((start_message, end_message)) = start_message.zip(end_message) {
1302 // Prevent splitting when range spans multiple messages.
1303 if start_message.id != end_message.id {
1304 return (None, None);
1305 }
1306
1307 let message = start_message;
1308 let role = message.role;
1309 let mut edited_buffer = false;
1310
1311 let mut suffix_start = None;
1312 if range.start > message.offset_range.start && range.end < message.offset_range.end - 1
1313 {
1314 if self.buffer.read(cx).chars_at(range.end).next() == Some('\n') {
1315 suffix_start = Some(range.end + 1);
1316 } else if self.buffer.read(cx).reversed_chars_at(range.end).next() == Some('\n') {
1317 suffix_start = Some(range.end);
1318 }
1319 }
1320
1321 let suffix = if let Some(suffix_start) = suffix_start {
1322 MessageAnchor {
1323 id: MessageId(post_inc(&mut self.next_message_id.0)),
1324 start: self.buffer.read(cx).anchor_before(suffix_start),
1325 }
1326 } else {
1327 self.buffer.update(cx, |buffer, cx| {
1328 buffer.edit([(range.end..range.end, "\n")], None, cx);
1329 });
1330 edited_buffer = true;
1331 MessageAnchor {
1332 id: MessageId(post_inc(&mut self.next_message_id.0)),
1333 start: self.buffer.read(cx).anchor_before(range.end + 1),
1334 }
1335 };
1336
1337 self.message_anchors
1338 .insert(message.index_range.end + 1, suffix.clone());
1339 self.messages_metadata.insert(
1340 suffix.id,
1341 MessageMetadata {
1342 role,
1343 sent_at: Local::now(),
1344 status: MessageStatus::Done,
1345 },
1346 );
1347
1348 let new_messages =
1349 if range.start == range.end || range.start == message.offset_range.start {
1350 (None, Some(suffix))
1351 } else {
1352 let mut prefix_end = None;
1353 if range.start > message.offset_range.start
1354 && range.end < message.offset_range.end - 1
1355 {
1356 if self.buffer.read(cx).chars_at(range.start).next() == Some('\n') {
1357 prefix_end = Some(range.start + 1);
1358 } else if self.buffer.read(cx).reversed_chars_at(range.start).next()
1359 == Some('\n')
1360 {
1361 prefix_end = Some(range.start);
1362 }
1363 }
1364
1365 let selection = if let Some(prefix_end) = prefix_end {
1366 cx.emit(ConversationEvent::MessagesEdited);
1367 MessageAnchor {
1368 id: MessageId(post_inc(&mut self.next_message_id.0)),
1369 start: self.buffer.read(cx).anchor_before(prefix_end),
1370 }
1371 } else {
1372 self.buffer.update(cx, |buffer, cx| {
1373 buffer.edit([(range.start..range.start, "\n")], None, cx)
1374 });
1375 edited_buffer = true;
1376 MessageAnchor {
1377 id: MessageId(post_inc(&mut self.next_message_id.0)),
1378 start: self.buffer.read(cx).anchor_before(range.end + 1),
1379 }
1380 };
1381
1382 self.message_anchors
1383 .insert(message.index_range.end + 1, selection.clone());
1384 self.messages_metadata.insert(
1385 selection.id,
1386 MessageMetadata {
1387 role,
1388 sent_at: Local::now(),
1389 status: MessageStatus::Done,
1390 },
1391 );
1392 (Some(selection), Some(suffix))
1393 };
1394
1395 if !edited_buffer {
1396 cx.emit(ConversationEvent::MessagesEdited);
1397 }
1398 new_messages
1399 } else {
1400 (None, None)
1401 }
1402 }
1403
1404 fn summarize(&mut self, cx: &mut ModelContext<Self>) {
1405 if self.message_anchors.len() >= 2 && self.summary.is_none() {
1406 let api_key = self.api_key.borrow().clone();
1407 if let Some(api_key) = api_key {
1408 let messages = self
1409 .messages(cx)
1410 .take(2)
1411 .map(|message| message.to_open_ai_message(self.buffer.read(cx)))
1412 .chain(Some(RequestMessage {
1413 role: Role::User,
1414 content:
1415 "Summarize the conversation into a short title without punctuation"
1416 .into(),
1417 }));
1418 let request = OpenAIRequest {
1419 model: self.model.clone(),
1420 messages: messages.collect(),
1421 stream: true,
1422 };
1423
1424 let stream = stream_completion(api_key, cx.background().clone(), request);
1425 self.pending_summary = cx.spawn(|this, mut cx| {
1426 async move {
1427 let mut messages = stream.await?;
1428
1429 while let Some(message) = messages.next().await {
1430 let mut message = message?;
1431 if let Some(choice) = message.choices.pop() {
1432 let text = choice.delta.content.unwrap_or_default();
1433 this.update(&mut cx, |this, cx| {
1434 this.summary
1435 .get_or_insert(Default::default())
1436 .text
1437 .push_str(&text);
1438 cx.emit(ConversationEvent::SummaryChanged);
1439 });
1440 }
1441 }
1442
1443 this.update(&mut cx, |this, cx| {
1444 if let Some(summary) = this.summary.as_mut() {
1445 summary.done = true;
1446 cx.emit(ConversationEvent::SummaryChanged);
1447 }
1448 });
1449
1450 anyhow::Ok(())
1451 }
1452 .log_err()
1453 });
1454 }
1455 }
1456 }
1457
1458 fn message_for_offset(&self, offset: usize, cx: &AppContext) -> Option<Message> {
1459 self.messages_for_offsets([offset], cx).pop()
1460 }
1461
1462 fn messages_for_offsets(
1463 &self,
1464 offsets: impl IntoIterator<Item = usize>,
1465 cx: &AppContext,
1466 ) -> Vec<Message> {
1467 let mut result = Vec::new();
1468
1469 let mut messages = self.messages(cx).peekable();
1470 let mut offsets = offsets.into_iter().peekable();
1471 let mut current_message = messages.next();
1472 while let Some(offset) = offsets.next() {
1473 // Locate the message that contains the offset.
1474 while current_message.as_ref().map_or(false, |message| {
1475 !message.offset_range.contains(&offset) && messages.peek().is_some()
1476 }) {
1477 current_message = messages.next();
1478 }
1479 let Some(message) = current_message.as_ref() else { break };
1480
1481 // Skip offsets that are in the same message.
1482 while offsets.peek().map_or(false, |offset| {
1483 message.offset_range.contains(offset) || messages.peek().is_none()
1484 }) {
1485 offsets.next();
1486 }
1487
1488 result.push(message.clone());
1489 }
1490 result
1491 }
1492
1493 fn messages<'a>(&'a self, cx: &'a AppContext) -> impl 'a + Iterator<Item = Message> {
1494 let buffer = self.buffer.read(cx);
1495 let mut message_anchors = self.message_anchors.iter().enumerate().peekable();
1496 iter::from_fn(move || {
1497 while let Some((start_ix, message_anchor)) = message_anchors.next() {
1498 let metadata = self.messages_metadata.get(&message_anchor.id)?;
1499 let message_start = message_anchor.start.to_offset(buffer);
1500 let mut message_end = None;
1501 let mut end_ix = start_ix;
1502 while let Some((_, next_message)) = message_anchors.peek() {
1503 if next_message.start.is_valid(buffer) {
1504 message_end = Some(next_message.start);
1505 break;
1506 } else {
1507 end_ix += 1;
1508 message_anchors.next();
1509 }
1510 }
1511 let message_end = message_end
1512 .unwrap_or(language::Anchor::MAX)
1513 .to_offset(buffer);
1514 return Some(Message {
1515 index_range: start_ix..end_ix,
1516 offset_range: message_start..message_end,
1517 id: message_anchor.id,
1518 anchor: message_anchor.start,
1519 role: metadata.role,
1520 sent_at: metadata.sent_at,
1521 status: metadata.status.clone(),
1522 });
1523 }
1524 None
1525 })
1526 }
1527
1528 fn save(
1529 &mut self,
1530 debounce: Option<Duration>,
1531 fs: Arc<dyn Fs>,
1532 cx: &mut ModelContext<Conversation>,
1533 ) {
1534 self.pending_save = cx.spawn(|this, mut cx| async move {
1535 if let Some(debounce) = debounce {
1536 cx.background().timer(debounce).await;
1537 }
1538
1539 let (old_path, summary) = this.read_with(&cx, |this, _| {
1540 let path = this.path.clone();
1541 let summary = if let Some(summary) = this.summary.as_ref() {
1542 if summary.done {
1543 Some(summary.text.clone())
1544 } else {
1545 None
1546 }
1547 } else {
1548 None
1549 };
1550 (path, summary)
1551 });
1552
1553 if let Some(summary) = summary {
1554 let conversation = this.read_with(&cx, |this, cx| this.serialize(cx));
1555 let path = if let Some(old_path) = old_path {
1556 old_path
1557 } else {
1558 let mut discriminant = 1;
1559 let mut new_path;
1560 loop {
1561 new_path = CONVERSATIONS_DIR.join(&format!(
1562 "{} - {}.zed.json",
1563 summary.trim(),
1564 discriminant
1565 ));
1566 if fs.is_file(&new_path).await {
1567 discriminant += 1;
1568 } else {
1569 break;
1570 }
1571 }
1572 new_path
1573 };
1574
1575 fs.create_dir(CONVERSATIONS_DIR.as_ref()).await?;
1576 fs.atomic_write(path.clone(), serde_json::to_string(&conversation).unwrap())
1577 .await?;
1578 this.update(&mut cx, |this, _| this.path = Some(path));
1579 }
1580
1581 Ok(())
1582 });
1583 }
1584}
1585
1586struct PendingCompletion {
1587 id: usize,
1588 _tasks: Vec<Task<()>>,
1589}
1590
1591enum ConversationEditorEvent {
1592 TabContentChanged,
1593}
1594
1595#[derive(Copy, Clone, Debug, PartialEq)]
1596struct ScrollPosition {
1597 offset_before_cursor: Vector2F,
1598 cursor: Anchor,
1599}
1600
1601struct ConversationEditor {
1602 conversation: ModelHandle<Conversation>,
1603 fs: Arc<dyn Fs>,
1604 editor: ViewHandle<Editor>,
1605 blocks: HashSet<BlockId>,
1606 scroll_position: Option<ScrollPosition>,
1607 _subscriptions: Vec<Subscription>,
1608}
1609
1610impl ConversationEditor {
1611 fn new(
1612 api_key: Rc<RefCell<Option<String>>>,
1613 language_registry: Arc<LanguageRegistry>,
1614 fs: Arc<dyn Fs>,
1615 cx: &mut ViewContext<Self>,
1616 ) -> Self {
1617 let conversation = cx.add_model(|cx| Conversation::new(api_key, language_registry, cx));
1618 Self::for_conversation(conversation, fs, cx)
1619 }
1620
1621 fn for_conversation(
1622 conversation: ModelHandle<Conversation>,
1623 fs: Arc<dyn Fs>,
1624 cx: &mut ViewContext<Self>,
1625 ) -> Self {
1626 let editor = cx.add_view(|cx| {
1627 let mut editor = Editor::for_buffer(conversation.read(cx).buffer.clone(), None, cx);
1628 editor.set_soft_wrap_mode(SoftWrap::EditorWidth, cx);
1629 editor.set_show_gutter(false, cx);
1630 editor.set_show_wrap_guides(false, cx);
1631 editor
1632 });
1633
1634 let _subscriptions = vec![
1635 cx.observe(&conversation, |_, _, cx| cx.notify()),
1636 cx.subscribe(&conversation, Self::handle_conversation_event),
1637 cx.subscribe(&editor, Self::handle_editor_event),
1638 ];
1639
1640 let mut this = Self {
1641 conversation,
1642 editor,
1643 blocks: Default::default(),
1644 scroll_position: None,
1645 fs,
1646 _subscriptions,
1647 };
1648 this.update_message_headers(cx);
1649 this
1650 }
1651
1652 fn assist(&mut self, _: &Assist, cx: &mut ViewContext<Self>) {
1653 let cursors = self.cursors(cx);
1654
1655 let user_messages = self.conversation.update(cx, |conversation, cx| {
1656 let selected_messages = conversation
1657 .messages_for_offsets(cursors, cx)
1658 .into_iter()
1659 .map(|message| message.id)
1660 .collect();
1661 conversation.assist(selected_messages, cx)
1662 });
1663 let new_selections = user_messages
1664 .iter()
1665 .map(|message| {
1666 let cursor = message
1667 .start
1668 .to_offset(self.conversation.read(cx).buffer.read(cx));
1669 cursor..cursor
1670 })
1671 .collect::<Vec<_>>();
1672 if !new_selections.is_empty() {
1673 self.editor.update(cx, |editor, cx| {
1674 editor.change_selections(
1675 Some(Autoscroll::Strategy(AutoscrollStrategy::Fit)),
1676 cx,
1677 |selections| selections.select_ranges(new_selections),
1678 );
1679 });
1680 // Avoid scrolling to the new cursor position so the assistant's output is stable.
1681 cx.defer(|this, _| this.scroll_position = None);
1682 }
1683 }
1684
1685 fn cancel_last_assist(&mut self, _: &editor::Cancel, cx: &mut ViewContext<Self>) {
1686 if !self
1687 .conversation
1688 .update(cx, |conversation, _| conversation.cancel_last_assist())
1689 {
1690 cx.propagate_action();
1691 }
1692 }
1693
1694 fn cycle_message_role(&mut self, _: &CycleMessageRole, cx: &mut ViewContext<Self>) {
1695 let cursors = self.cursors(cx);
1696 self.conversation.update(cx, |conversation, cx| {
1697 let messages = conversation
1698 .messages_for_offsets(cursors, cx)
1699 .into_iter()
1700 .map(|message| message.id)
1701 .collect();
1702 conversation.cycle_message_roles(messages, cx)
1703 });
1704 }
1705
1706 fn cursors(&self, cx: &AppContext) -> Vec<usize> {
1707 let selections = self.editor.read(cx).selections.all::<usize>(cx);
1708 selections
1709 .into_iter()
1710 .map(|selection| selection.head())
1711 .collect()
1712 }
1713
1714 fn handle_conversation_event(
1715 &mut self,
1716 _: ModelHandle<Conversation>,
1717 event: &ConversationEvent,
1718 cx: &mut ViewContext<Self>,
1719 ) {
1720 match event {
1721 ConversationEvent::MessagesEdited => {
1722 self.update_message_headers(cx);
1723 self.conversation.update(cx, |conversation, cx| {
1724 conversation.save(Some(Duration::from_millis(500)), self.fs.clone(), cx);
1725 });
1726 }
1727 ConversationEvent::SummaryChanged => {
1728 cx.emit(ConversationEditorEvent::TabContentChanged);
1729 self.conversation.update(cx, |conversation, cx| {
1730 conversation.save(None, self.fs.clone(), cx);
1731 });
1732 }
1733 ConversationEvent::StreamedCompletion => {
1734 self.editor.update(cx, |editor, cx| {
1735 if let Some(scroll_position) = self.scroll_position {
1736 let snapshot = editor.snapshot(cx);
1737 let cursor_point = scroll_position.cursor.to_display_point(&snapshot);
1738 let scroll_top =
1739 cursor_point.row() as f32 - scroll_position.offset_before_cursor.y();
1740 editor.set_scroll_position(
1741 vec2f(scroll_position.offset_before_cursor.x(), scroll_top),
1742 cx,
1743 );
1744 }
1745 });
1746 }
1747 }
1748 }
1749
1750 fn handle_editor_event(
1751 &mut self,
1752 _: ViewHandle<Editor>,
1753 event: &editor::Event,
1754 cx: &mut ViewContext<Self>,
1755 ) {
1756 match event {
1757 editor::Event::ScrollPositionChanged { autoscroll, .. } => {
1758 let cursor_scroll_position = self.cursor_scroll_position(cx);
1759 if *autoscroll {
1760 self.scroll_position = cursor_scroll_position;
1761 } else if self.scroll_position != cursor_scroll_position {
1762 self.scroll_position = None;
1763 }
1764 }
1765 editor::Event::SelectionsChanged { .. } => {
1766 self.scroll_position = self.cursor_scroll_position(cx);
1767 }
1768 _ => {}
1769 }
1770 }
1771
1772 fn cursor_scroll_position(&self, cx: &mut ViewContext<Self>) -> Option<ScrollPosition> {
1773 self.editor.update(cx, |editor, cx| {
1774 let snapshot = editor.snapshot(cx);
1775 let cursor = editor.selections.newest_anchor().head();
1776 let cursor_row = cursor.to_display_point(&snapshot.display_snapshot).row() as f32;
1777 let scroll_position = editor
1778 .scroll_manager
1779 .anchor()
1780 .scroll_position(&snapshot.display_snapshot);
1781
1782 let scroll_bottom = scroll_position.y() + editor.visible_line_count().unwrap_or(0.);
1783 if (scroll_position.y()..scroll_bottom).contains(&cursor_row) {
1784 Some(ScrollPosition {
1785 cursor,
1786 offset_before_cursor: vec2f(
1787 scroll_position.x(),
1788 cursor_row - scroll_position.y(),
1789 ),
1790 })
1791 } else {
1792 None
1793 }
1794 })
1795 }
1796
1797 fn update_message_headers(&mut self, cx: &mut ViewContext<Self>) {
1798 self.editor.update(cx, |editor, cx| {
1799 let buffer = editor.buffer().read(cx).snapshot(cx);
1800 let excerpt_id = *buffer.as_singleton().unwrap().0;
1801 let old_blocks = std::mem::take(&mut self.blocks);
1802 let new_blocks = self
1803 .conversation
1804 .read(cx)
1805 .messages(cx)
1806 .map(|message| BlockProperties {
1807 position: buffer.anchor_in_excerpt(excerpt_id, message.anchor),
1808 height: 2,
1809 style: BlockStyle::Sticky,
1810 render: Arc::new({
1811 let conversation = self.conversation.clone();
1812 // let metadata = message.metadata.clone();
1813 // let message = message.clone();
1814 move |cx| {
1815 enum Sender {}
1816 enum ErrorTooltip {}
1817
1818 let theme = theme::current(cx);
1819 let style = &theme.assistant;
1820 let message_id = message.id;
1821 let sender = MouseEventHandler::<Sender, _>::new(
1822 message_id.0,
1823 cx,
1824 |state, _| match message.role {
1825 Role::User => {
1826 let style = style.user_sender.style_for(state);
1827 Label::new("You", style.text.clone())
1828 .contained()
1829 .with_style(style.container)
1830 }
1831 Role::Assistant => {
1832 let style = style.assistant_sender.style_for(state);
1833 Label::new("Assistant", style.text.clone())
1834 .contained()
1835 .with_style(style.container)
1836 }
1837 Role::System => {
1838 let style = style.system_sender.style_for(state);
1839 Label::new("System", style.text.clone())
1840 .contained()
1841 .with_style(style.container)
1842 }
1843 },
1844 )
1845 .with_cursor_style(CursorStyle::PointingHand)
1846 .on_down(MouseButton::Left, {
1847 let conversation = conversation.clone();
1848 move |_, _, cx| {
1849 conversation.update(cx, |conversation, cx| {
1850 conversation.cycle_message_roles(
1851 HashSet::from_iter(Some(message_id)),
1852 cx,
1853 )
1854 })
1855 }
1856 });
1857
1858 Flex::row()
1859 .with_child(sender.aligned())
1860 .with_child(
1861 Label::new(
1862 message.sent_at.format("%I:%M%P").to_string(),
1863 style.sent_at.text.clone(),
1864 )
1865 .contained()
1866 .with_style(style.sent_at.container)
1867 .aligned(),
1868 )
1869 .with_children(
1870 if let MessageStatus::Error(error) = &message.status {
1871 Some(
1872 Svg::new("icons/circle_x_mark_12.svg")
1873 .with_color(style.error_icon.color)
1874 .constrained()
1875 .with_width(style.error_icon.width)
1876 .contained()
1877 .with_style(style.error_icon.container)
1878 .with_tooltip::<ErrorTooltip>(
1879 message_id.0,
1880 error.to_string(),
1881 None,
1882 theme.tooltip.clone(),
1883 cx,
1884 )
1885 .aligned(),
1886 )
1887 } else {
1888 None
1889 },
1890 )
1891 .aligned()
1892 .left()
1893 .contained()
1894 .with_style(style.message_header)
1895 .into_any()
1896 }
1897 }),
1898 disposition: BlockDisposition::Above,
1899 })
1900 .collect::<Vec<_>>();
1901
1902 editor.remove_blocks(old_blocks, None, cx);
1903 let ids = editor.insert_blocks(new_blocks, None, cx);
1904 self.blocks = HashSet::from_iter(ids);
1905 });
1906 }
1907
1908 fn quote_selection(
1909 workspace: &mut Workspace,
1910 _: &QuoteSelection,
1911 cx: &mut ViewContext<Workspace>,
1912 ) {
1913 let Some(panel) = workspace.panel::<AssistantPanel>(cx) else {
1914 return;
1915 };
1916 let Some(editor) = workspace.active_item(cx).and_then(|item| item.act_as::<Editor>(cx)) else {
1917 return;
1918 };
1919
1920 let text = editor.read_with(cx, |editor, cx| {
1921 let range = editor.selections.newest::<usize>(cx).range();
1922 let buffer = editor.buffer().read(cx).snapshot(cx);
1923 let start_language = buffer.language_at(range.start);
1924 let end_language = buffer.language_at(range.end);
1925 let language_name = if start_language == end_language {
1926 start_language.map(|language| language.name())
1927 } else {
1928 None
1929 };
1930 let language_name = language_name.as_deref().unwrap_or("").to_lowercase();
1931
1932 let selected_text = buffer.text_for_range(range).collect::<String>();
1933 if selected_text.is_empty() {
1934 None
1935 } else {
1936 Some(if language_name == "markdown" {
1937 selected_text
1938 .lines()
1939 .map(|line| format!("> {}", line))
1940 .collect::<Vec<_>>()
1941 .join("\n")
1942 } else {
1943 format!("```{language_name}\n{selected_text}\n```")
1944 })
1945 }
1946 });
1947
1948 // Activate the panel
1949 if !panel.read(cx).has_focus(cx) {
1950 workspace.toggle_panel_focus::<AssistantPanel>(cx);
1951 }
1952
1953 if let Some(text) = text {
1954 panel.update(cx, |panel, cx| {
1955 let conversation = panel
1956 .active_editor()
1957 .cloned()
1958 .unwrap_or_else(|| panel.new_conversation(cx));
1959 conversation.update(cx, |conversation, cx| {
1960 conversation
1961 .editor
1962 .update(cx, |editor, cx| editor.insert(&text, cx))
1963 });
1964 });
1965 }
1966 }
1967
1968 fn copy(&mut self, _: &editor::Copy, cx: &mut ViewContext<Self>) {
1969 let editor = self.editor.read(cx);
1970 let conversation = self.conversation.read(cx);
1971 if editor.selections.count() == 1 {
1972 let selection = editor.selections.newest::<usize>(cx);
1973 let mut copied_text = String::new();
1974 let mut spanned_messages = 0;
1975 for message in conversation.messages(cx) {
1976 if message.offset_range.start >= selection.range().end {
1977 break;
1978 } else if message.offset_range.end >= selection.range().start {
1979 let range = cmp::max(message.offset_range.start, selection.range().start)
1980 ..cmp::min(message.offset_range.end, selection.range().end);
1981 if !range.is_empty() {
1982 spanned_messages += 1;
1983 write!(&mut copied_text, "## {}\n\n", message.role).unwrap();
1984 for chunk in conversation.buffer.read(cx).text_for_range(range) {
1985 copied_text.push_str(&chunk);
1986 }
1987 copied_text.push('\n');
1988 }
1989 }
1990 }
1991
1992 if spanned_messages > 1 {
1993 cx.platform()
1994 .write_to_clipboard(ClipboardItem::new(copied_text));
1995 return;
1996 }
1997 }
1998
1999 cx.propagate_action();
2000 }
2001
2002 fn split(&mut self, _: &Split, cx: &mut ViewContext<Self>) {
2003 self.conversation.update(cx, |conversation, cx| {
2004 let selections = self.editor.read(cx).selections.disjoint_anchors();
2005 for selection in selections.into_iter() {
2006 let buffer = self.editor.read(cx).buffer().read(cx).snapshot(cx);
2007 let range = selection
2008 .map(|endpoint| endpoint.to_offset(&buffer))
2009 .range();
2010 conversation.split_message(range, cx);
2011 }
2012 });
2013 }
2014
2015 fn save(&mut self, _: &Save, cx: &mut ViewContext<Self>) {
2016 self.conversation.update(cx, |conversation, cx| {
2017 conversation.save(None, self.fs.clone(), cx)
2018 });
2019 }
2020
2021 fn cycle_model(&mut self, cx: &mut ViewContext<Self>) {
2022 self.conversation.update(cx, |conversation, cx| {
2023 let new_model = match conversation.model.as_str() {
2024 "gpt-4-0613" => "gpt-3.5-turbo-0613",
2025 _ => "gpt-4-0613",
2026 };
2027 conversation.set_model(new_model.into(), cx);
2028 });
2029 }
2030
2031 fn title(&self, cx: &AppContext) -> String {
2032 self.conversation
2033 .read(cx)
2034 .summary
2035 .as_ref()
2036 .map(|summary| summary.text.clone())
2037 .unwrap_or_else(|| "New Conversation".into())
2038 }
2039
2040 fn render_current_model(
2041 &self,
2042 style: &AssistantStyle,
2043 cx: &mut ViewContext<Self>,
2044 ) -> impl Element<Self> {
2045 enum Model {}
2046
2047 MouseEventHandler::<Model, _>::new(0, cx, |state, cx| {
2048 let style = style.model.style_for(state);
2049 Label::new(self.conversation.read(cx).model.clone(), style.text.clone())
2050 .contained()
2051 .with_style(style.container)
2052 })
2053 .with_cursor_style(CursorStyle::PointingHand)
2054 .on_click(MouseButton::Left, |_, this, cx| this.cycle_model(cx))
2055 }
2056
2057 fn render_remaining_tokens(
2058 &self,
2059 style: &AssistantStyle,
2060 cx: &mut ViewContext<Self>,
2061 ) -> Option<impl Element<Self>> {
2062 let remaining_tokens = self.conversation.read(cx).remaining_tokens()?;
2063 let remaining_tokens_style = if remaining_tokens <= 0 {
2064 &style.no_remaining_tokens
2065 } else if remaining_tokens <= 500 {
2066 &style.low_remaining_tokens
2067 } else {
2068 &style.remaining_tokens
2069 };
2070 Some(
2071 Label::new(
2072 remaining_tokens.to_string(),
2073 remaining_tokens_style.text.clone(),
2074 )
2075 .contained()
2076 .with_style(remaining_tokens_style.container),
2077 )
2078 }
2079}
2080
2081impl Entity for ConversationEditor {
2082 type Event = ConversationEditorEvent;
2083}
2084
2085impl View for ConversationEditor {
2086 fn ui_name() -> &'static str {
2087 "ConversationEditor"
2088 }
2089
2090 fn render(&mut self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
2091 let theme = &theme::current(cx).assistant;
2092 Stack::new()
2093 .with_child(
2094 ChildView::new(&self.editor, cx)
2095 .contained()
2096 .with_style(theme.container),
2097 )
2098 .with_child(
2099 Flex::row()
2100 .with_child(self.render_current_model(theme, cx))
2101 .with_children(self.render_remaining_tokens(theme, cx))
2102 .aligned()
2103 .top()
2104 .right(),
2105 )
2106 .into_any()
2107 }
2108
2109 fn focus_in(&mut self, _: gpui::AnyViewHandle, cx: &mut ViewContext<Self>) {
2110 if cx.is_self_focused() {
2111 cx.focus(&self.editor);
2112 }
2113 }
2114}
2115
2116#[derive(Clone, Debug)]
2117struct MessageAnchor {
2118 id: MessageId,
2119 start: language::Anchor,
2120}
2121
2122#[derive(Clone, Debug)]
2123pub struct Message {
2124 offset_range: Range<usize>,
2125 index_range: Range<usize>,
2126 id: MessageId,
2127 anchor: language::Anchor,
2128 role: Role,
2129 sent_at: DateTime<Local>,
2130 status: MessageStatus,
2131}
2132
2133impl Message {
2134 fn to_open_ai_message(&self, buffer: &Buffer) -> RequestMessage {
2135 let mut content = format!("[Message {}]\n", self.id.0).to_string();
2136 content.extend(buffer.text_for_range(self.offset_range.clone()));
2137 RequestMessage {
2138 role: self.role,
2139 content: content.trim_end().into(),
2140 }
2141 }
2142}
2143
2144async fn stream_completion(
2145 api_key: String,
2146 executor: Arc<Background>,
2147 mut request: OpenAIRequest,
2148) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
2149 request.stream = true;
2150
2151 let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
2152
2153 let json_data = serde_json::to_string(&request)?;
2154 let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
2155 .header("Content-Type", "application/json")
2156 .header("Authorization", format!("Bearer {}", api_key))
2157 .body(json_data)?
2158 .send_async()
2159 .await?;
2160
2161 let status = response.status();
2162 if status == StatusCode::OK {
2163 executor
2164 .spawn(async move {
2165 let mut lines = BufReader::new(response.body_mut()).lines();
2166
2167 fn parse_line(
2168 line: Result<String, io::Error>,
2169 ) -> Result<Option<OpenAIResponseStreamEvent>> {
2170 if let Some(data) = line?.strip_prefix("data: ") {
2171 let event = serde_json::from_str(&data)?;
2172 Ok(Some(event))
2173 } else {
2174 Ok(None)
2175 }
2176 }
2177
2178 while let Some(line) = lines.next().await {
2179 if let Some(event) = parse_line(line).transpose() {
2180 let done = event.as_ref().map_or(false, |event| {
2181 event
2182 .choices
2183 .last()
2184 .map_or(false, |choice| choice.finish_reason.is_some())
2185 });
2186 if tx.unbounded_send(event).is_err() {
2187 break;
2188 }
2189
2190 if done {
2191 break;
2192 }
2193 }
2194 }
2195
2196 anyhow::Ok(())
2197 })
2198 .detach();
2199
2200 Ok(rx)
2201 } else {
2202 let mut body = String::new();
2203 response.body_mut().read_to_string(&mut body).await?;
2204
2205 #[derive(Deserialize)]
2206 struct OpenAIResponse {
2207 error: OpenAIError,
2208 }
2209
2210 #[derive(Deserialize)]
2211 struct OpenAIError {
2212 message: String,
2213 }
2214
2215 match serde_json::from_str::<OpenAIResponse>(&body) {
2216 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
2217 "Failed to connect to OpenAI API: {}",
2218 response.error.message,
2219 )),
2220
2221 _ => Err(anyhow!(
2222 "Failed to connect to OpenAI API: {} {}",
2223 response.status(),
2224 body,
2225 )),
2226 }
2227 }
2228}
2229
2230#[cfg(test)]
2231mod tests {
2232 use super::*;
2233 use crate::MessageId;
2234 use gpui::AppContext;
2235
2236 #[gpui::test]
2237 fn test_inserting_and_removing_messages(cx: &mut AppContext) {
2238 let registry = Arc::new(LanguageRegistry::test());
2239 let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
2240 let buffer = conversation.read(cx).buffer.clone();
2241
2242 let message_1 = conversation.read(cx).message_anchors[0].clone();
2243 assert_eq!(
2244 messages(&conversation, cx),
2245 vec![(message_1.id, Role::User, 0..0)]
2246 );
2247
2248 let message_2 = conversation.update(cx, |conversation, cx| {
2249 conversation
2250 .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
2251 .unwrap()
2252 });
2253 assert_eq!(
2254 messages(&conversation, cx),
2255 vec![
2256 (message_1.id, Role::User, 0..1),
2257 (message_2.id, Role::Assistant, 1..1)
2258 ]
2259 );
2260
2261 buffer.update(cx, |buffer, cx| {
2262 buffer.edit([(0..0, "1"), (1..1, "2")], None, cx)
2263 });
2264 assert_eq!(
2265 messages(&conversation, cx),
2266 vec![
2267 (message_1.id, Role::User, 0..2),
2268 (message_2.id, Role::Assistant, 2..3)
2269 ]
2270 );
2271
2272 let message_3 = conversation.update(cx, |conversation, cx| {
2273 conversation
2274 .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
2275 .unwrap()
2276 });
2277 assert_eq!(
2278 messages(&conversation, cx),
2279 vec![
2280 (message_1.id, Role::User, 0..2),
2281 (message_2.id, Role::Assistant, 2..4),
2282 (message_3.id, Role::User, 4..4)
2283 ]
2284 );
2285
2286 let message_4 = conversation.update(cx, |conversation, cx| {
2287 conversation
2288 .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
2289 .unwrap()
2290 });
2291 assert_eq!(
2292 messages(&conversation, cx),
2293 vec![
2294 (message_1.id, Role::User, 0..2),
2295 (message_2.id, Role::Assistant, 2..4),
2296 (message_4.id, Role::User, 4..5),
2297 (message_3.id, Role::User, 5..5),
2298 ]
2299 );
2300
2301 buffer.update(cx, |buffer, cx| {
2302 buffer.edit([(4..4, "C"), (5..5, "D")], None, cx)
2303 });
2304 assert_eq!(
2305 messages(&conversation, cx),
2306 vec![
2307 (message_1.id, Role::User, 0..2),
2308 (message_2.id, Role::Assistant, 2..4),
2309 (message_4.id, Role::User, 4..6),
2310 (message_3.id, Role::User, 6..7),
2311 ]
2312 );
2313
2314 // Deleting across message boundaries merges the messages.
2315 buffer.update(cx, |buffer, cx| buffer.edit([(1..4, "")], None, cx));
2316 assert_eq!(
2317 messages(&conversation, cx),
2318 vec![
2319 (message_1.id, Role::User, 0..3),
2320 (message_3.id, Role::User, 3..4),
2321 ]
2322 );
2323
2324 // Undoing the deletion should also undo the merge.
2325 buffer.update(cx, |buffer, cx| buffer.undo(cx));
2326 assert_eq!(
2327 messages(&conversation, cx),
2328 vec![
2329 (message_1.id, Role::User, 0..2),
2330 (message_2.id, Role::Assistant, 2..4),
2331 (message_4.id, Role::User, 4..6),
2332 (message_3.id, Role::User, 6..7),
2333 ]
2334 );
2335
2336 // Redoing the deletion should also redo the merge.
2337 buffer.update(cx, |buffer, cx| buffer.redo(cx));
2338 assert_eq!(
2339 messages(&conversation, cx),
2340 vec![
2341 (message_1.id, Role::User, 0..3),
2342 (message_3.id, Role::User, 3..4),
2343 ]
2344 );
2345
2346 // Ensure we can still insert after a merged message.
2347 let message_5 = conversation.update(cx, |conversation, cx| {
2348 conversation
2349 .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
2350 .unwrap()
2351 });
2352 assert_eq!(
2353 messages(&conversation, cx),
2354 vec![
2355 (message_1.id, Role::User, 0..3),
2356 (message_5.id, Role::System, 3..4),
2357 (message_3.id, Role::User, 4..5)
2358 ]
2359 );
2360 }
2361
2362 #[gpui::test]
2363 fn test_message_splitting(cx: &mut AppContext) {
2364 let registry = Arc::new(LanguageRegistry::test());
2365 let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
2366 let buffer = conversation.read(cx).buffer.clone();
2367
2368 let message_1 = conversation.read(cx).message_anchors[0].clone();
2369 assert_eq!(
2370 messages(&conversation, cx),
2371 vec![(message_1.id, Role::User, 0..0)]
2372 );
2373
2374 buffer.update(cx, |buffer, cx| {
2375 buffer.edit([(0..0, "aaa\nbbb\nccc\nddd\n")], None, cx)
2376 });
2377
2378 let (_, message_2) =
2379 conversation.update(cx, |conversation, cx| conversation.split_message(3..3, cx));
2380 let message_2 = message_2.unwrap();
2381
2382 // We recycle newlines in the middle of a split message
2383 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\nddd\n");
2384 assert_eq!(
2385 messages(&conversation, cx),
2386 vec![
2387 (message_1.id, Role::User, 0..4),
2388 (message_2.id, Role::User, 4..16),
2389 ]
2390 );
2391
2392 let (_, message_3) =
2393 conversation.update(cx, |conversation, cx| conversation.split_message(3..3, cx));
2394 let message_3 = message_3.unwrap();
2395
2396 // We don't recycle newlines at the end of a split message
2397 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
2398 assert_eq!(
2399 messages(&conversation, cx),
2400 vec![
2401 (message_1.id, Role::User, 0..4),
2402 (message_3.id, Role::User, 4..5),
2403 (message_2.id, Role::User, 5..17),
2404 ]
2405 );
2406
2407 let (_, message_4) =
2408 conversation.update(cx, |conversation, cx| conversation.split_message(9..9, cx));
2409 let message_4 = message_4.unwrap();
2410 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
2411 assert_eq!(
2412 messages(&conversation, cx),
2413 vec![
2414 (message_1.id, Role::User, 0..4),
2415 (message_3.id, Role::User, 4..5),
2416 (message_2.id, Role::User, 5..9),
2417 (message_4.id, Role::User, 9..17),
2418 ]
2419 );
2420
2421 let (_, message_5) =
2422 conversation.update(cx, |conversation, cx| conversation.split_message(9..9, cx));
2423 let message_5 = message_5.unwrap();
2424 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\nddd\n");
2425 assert_eq!(
2426 messages(&conversation, cx),
2427 vec![
2428 (message_1.id, Role::User, 0..4),
2429 (message_3.id, Role::User, 4..5),
2430 (message_2.id, Role::User, 5..9),
2431 (message_4.id, Role::User, 9..10),
2432 (message_5.id, Role::User, 10..18),
2433 ]
2434 );
2435
2436 let (message_6, message_7) = conversation.update(cx, |conversation, cx| {
2437 conversation.split_message(14..16, cx)
2438 });
2439 let message_6 = message_6.unwrap();
2440 let message_7 = message_7.unwrap();
2441 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\ndd\nd\n");
2442 assert_eq!(
2443 messages(&conversation, cx),
2444 vec![
2445 (message_1.id, Role::User, 0..4),
2446 (message_3.id, Role::User, 4..5),
2447 (message_2.id, Role::User, 5..9),
2448 (message_4.id, Role::User, 9..10),
2449 (message_5.id, Role::User, 10..14),
2450 (message_6.id, Role::User, 14..17),
2451 (message_7.id, Role::User, 17..19),
2452 ]
2453 );
2454 }
2455
2456 #[gpui::test]
2457 fn test_messages_for_offsets(cx: &mut AppContext) {
2458 let registry = Arc::new(LanguageRegistry::test());
2459 let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
2460 let buffer = conversation.read(cx).buffer.clone();
2461
2462 let message_1 = conversation.read(cx).message_anchors[0].clone();
2463 assert_eq!(
2464 messages(&conversation, cx),
2465 vec![(message_1.id, Role::User, 0..0)]
2466 );
2467
2468 buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
2469 let message_2 = conversation
2470 .update(cx, |conversation, cx| {
2471 conversation.insert_message_after(message_1.id, Role::User, MessageStatus::Done, cx)
2472 })
2473 .unwrap();
2474 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx));
2475
2476 let message_3 = conversation
2477 .update(cx, |conversation, cx| {
2478 conversation.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
2479 })
2480 .unwrap();
2481 buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx));
2482
2483 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc");
2484 assert_eq!(
2485 messages(&conversation, cx),
2486 vec![
2487 (message_1.id, Role::User, 0..4),
2488 (message_2.id, Role::User, 4..8),
2489 (message_3.id, Role::User, 8..11)
2490 ]
2491 );
2492
2493 assert_eq!(
2494 message_ids_for_offsets(&conversation, &[0, 4, 9], cx),
2495 [message_1.id, message_2.id, message_3.id]
2496 );
2497 assert_eq!(
2498 message_ids_for_offsets(&conversation, &[0, 1, 11], cx),
2499 [message_1.id, message_3.id]
2500 );
2501
2502 let message_4 = conversation
2503 .update(cx, |conversation, cx| {
2504 conversation.insert_message_after(message_3.id, Role::User, MessageStatus::Done, cx)
2505 })
2506 .unwrap();
2507 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\n");
2508 assert_eq!(
2509 messages(&conversation, cx),
2510 vec![
2511 (message_1.id, Role::User, 0..4),
2512 (message_2.id, Role::User, 4..8),
2513 (message_3.id, Role::User, 8..12),
2514 (message_4.id, Role::User, 12..12)
2515 ]
2516 );
2517 assert_eq!(
2518 message_ids_for_offsets(&conversation, &[0, 4, 8, 12], cx),
2519 [message_1.id, message_2.id, message_3.id, message_4.id]
2520 );
2521
2522 fn message_ids_for_offsets(
2523 conversation: &ModelHandle<Conversation>,
2524 offsets: &[usize],
2525 cx: &AppContext,
2526 ) -> Vec<MessageId> {
2527 conversation
2528 .read(cx)
2529 .messages_for_offsets(offsets.iter().copied(), cx)
2530 .into_iter()
2531 .map(|message| message.id)
2532 .collect()
2533 }
2534 }
2535
2536 #[gpui::test]
2537 fn test_serialization(cx: &mut AppContext) {
2538 let registry = Arc::new(LanguageRegistry::test());
2539 let conversation =
2540 cx.add_model(|cx| Conversation::new(Default::default(), registry.clone(), cx));
2541 let buffer = conversation.read(cx).buffer.clone();
2542 let message_0 = conversation.read(cx).message_anchors[0].id;
2543 let message_1 = conversation.update(cx, |conversation, cx| {
2544 conversation
2545 .insert_message_after(message_0, Role::Assistant, MessageStatus::Done, cx)
2546 .unwrap()
2547 });
2548 let message_2 = conversation.update(cx, |conversation, cx| {
2549 conversation
2550 .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
2551 .unwrap()
2552 });
2553 buffer.update(cx, |buffer, cx| {
2554 buffer.edit([(0..0, "a"), (1..1, "b\nc")], None, cx);
2555 buffer.finalize_last_transaction();
2556 });
2557 let _message_3 = conversation.update(cx, |conversation, cx| {
2558 conversation
2559 .insert_message_after(message_2.id, Role::System, MessageStatus::Done, cx)
2560 .unwrap()
2561 });
2562 buffer.update(cx, |buffer, cx| buffer.undo(cx));
2563 assert_eq!(buffer.read(cx).text(), "a\nb\nc\n");
2564 assert_eq!(
2565 messages(&conversation, cx),
2566 [
2567 (message_0, Role::User, 0..2),
2568 (message_1.id, Role::Assistant, 2..6),
2569 (message_2.id, Role::System, 6..6),
2570 ]
2571 );
2572
2573 let deserialized_conversation = cx.add_model(|cx| {
2574 Conversation::deserialize(
2575 conversation.read(cx).serialize(cx),
2576 Default::default(),
2577 Default::default(),
2578 registry.clone(),
2579 cx,
2580 )
2581 });
2582 let deserialized_buffer = deserialized_conversation.read(cx).buffer.clone();
2583 assert_eq!(deserialized_buffer.read(cx).text(), "a\nb\nc\n");
2584 assert_eq!(
2585 messages(&deserialized_conversation, cx),
2586 [
2587 (message_0, Role::User, 0..2),
2588 (message_1.id, Role::Assistant, 2..6),
2589 (message_2.id, Role::System, 6..6),
2590 ]
2591 );
2592 }
2593
2594 fn messages(
2595 conversation: &ModelHandle<Conversation>,
2596 cx: &AppContext,
2597 ) -> Vec<(MessageId, Role, Range<usize>)> {
2598 conversation
2599 .read(cx)
2600 .messages(cx)
2601 .map(|message| (message.id, message.role, message.offset_range))
2602 .collect()
2603 }
2604}